Monday, April 29, 2013

CART: predict email spam

rm(list=ls())
library(ElemStatLearn)
data(spam)
# names(spam) <- c("make", "address", "all", "3d", "our",
#                  "over", "remove", "internet", "order", "mail",
#                  "receive", "will", "people", "report", "addresses",
#                  "free", "business", "email", "you", "credit",
#                  "your", "font", "000", "money", "hp",
#                  "hpl", "george", "650", "lab", "labs",
#                  "telnet", "857", "data", "415", "85",
#                  "technology", "1999", "parts", "pm",
#                  "direct", "cs", "meeting", "original", "project",
#                  "re", "edu", "table", "conference", ";:",
#                  "(:", "[:", "!:", "$:", "#:",
#                  "CRave", "CRlong", "CRtotal", "spam")
names(spam) <- c("make", "address", "all", "x3d", "our",
                 "over", "remove", "internet", "order", "mail",
                 "receive", "will", "people", "report", "addresses",
                 "free", "business", "email", "you", "credit",
                 "your", "font", "x000", "money", "hp",
                 "hpl", "george", "x650", "lab", "labs",
                 "telnet", "x857", "data", "x415", "x85",
                 "technology", "x1999", "parts", "pm",
                 "direct", "cs", "meeting", "original", "project",
                 "re", "edu", "table", "conference", "p1",
                 "p2", "p3", "p4", "p5", "p6",
                 "CRave", "CRlong", "CRtotal", "spam")

summary(spam)
spam.test_indx = read.delim("http://www-stat.stanford.edu/~tibs/ElemStatLearn/datasets/spam.traintest",
                            sep="\n", header=FALSE)
#########################################################################
#########################################################################
Y = as.data.frame(matrix(rep(0,4601),nrow=4601,ncol=1))
Y[spam$spam == "spam", ] = 1
names(Y) = "spam"
Y[,1]=factor(Y[,1])
data = cbind(spam[,-58],Y)
data.train = data[spam.test_indx == 0,]
data.test = data[spam.test_indx == 1,]
summary(data.train)
summary(data.test)
rm(Y, data, spam, spam.test_indx)
###############################################################################
###############################################################################
library(rpart)
library(pROC)
names(data.train)
spam.tree = rpart(spam~.
                  , data = data.train
                  , method = "class"
                  , xval = 5
                  , cp = 0.00001
                  , minsplit = 1
                  , parms=list(split='information')
                  , na.action = na.exclude
)
printcp(spam.tree)
plotcp(spam.tree, upper = "size")
spam.prune = prune.rpart(spam.tree, 0.0025)
print(spam.prune)
plot(spam.prune)
plot(spam.prune, compress=T, uniform=T, branch=0.4, margin=0.01)
text(spam.prune)
spam.prune = prune.rpart(spam.tree, 0.003)
plot(spam.prune, compress=T, uniform=T, branch=0.4, margin=0.01)
text(spam.prune)
summary(spam.prune)
plotcp(spam.prune)
y.hat = predict(spam.prune, data.test, type="prob")[,2]
head(y.hat)
roc(data.test$spam,y.hat,plot=T)
############################################################################
spam.tree = rpart(spam~.
                  , data = data.train
                  , method = "class"
                  , xval = 5
                  , cp = 0.00001
                  , minsplit = 1
                  , parms=list(split='gini')
                  , na.action = na.exclude
)
plotcp(spam.tree)
spam.prune = prune.rpart(spam.tree, 0.002)
plotcp(spam.prune)
plot(spam.prune, compress=T, uniform=T, branch=0.4, margin=0.01)
text(spam.prune)
summary(spam.prune)
print(spam.prune)
y.hat = predict(spam.prune, data.test, type="prob")[,2]
roc(data.test$spam,y.hat,plot=T)
#Remark: for cross-entropy, smaller trees get better ROC than Gini

GAM: predict email spam

rm(list=ls())
library(ElemStatLearn)
data(spam)
# names(spam) <- c("make", "address", "all", "3d", "our",
#                  "over", "remove", "internet", "order", "mail",
#                  "receive", "will", "people", "report", "addresses",
#                  "free", "business", "email", "you", "credit",
#                  "your", "font", "000", "money", "hp",
#                  "hpl", "george", "650", "lab", "labs",
#                  "telnet", "857", "data", "415", "85",
#                  "technology", "1999", "parts", "pm",
#                  "direct", "cs", "meeting", "original", "project",
#                  "re", "edu", "table", "conference", ";:",
#                  "(:", "[:", "!:", "$:", "#:",
#                  "CRave", "CRlong", "CRtotal", "spam")
names(spam) <- c("make", "address", "all", "x3d", "our",
                 "over", "remove", "internet", "order", "mail",
                 "receive", "will", "people", "report", "addresses",
                 "free", "business", "email", "you", "credit",
                 "your", "font", "x000", "money", "hp",
                 "hpl", "george", "x650", "lab", "labs",
                 "telnet", "x857", "data", "x415", "x85",
                 "technology", "x1999", "parts", "pm",
                 "direct", "cs", "meeting", "original", "project",
                 "re", "edu", "table", "conference", "p1",
                 "p2", "p3", "p4", "p5", "p6",
                 "CRave", "CRlong", "CRtotal", "spam")

summary(spam)
spam.test_indx = read.delim("http://www-stat.stanford.edu/~tibs/ElemStatLearn/datasets/spam.traintest",
                            sep="\n", header=FALSE)
#########################################################################
#########################################################################
Y = as.data.frame(matrix(rep(0,4601),nrow=4601,ncol=1))
Y[spam$spam == "spam", ] = 1
names(Y) = "spam"
Y[,1]=factor(Y[,1])
data = cbind(spam[,-58],Y)
data.train = data[spam.test_indx == 0,]
data.test = data[spam.test_indx == 1,]
summary(data.train)
summary(data.test)

###########################################################################
library(gam)
# var.name = names(data)
# for (i in c(1:(length(var.name)-1))){
#   cat("+s(", var.name[i], ", 4)\n", append=TRUE, sep = "", collapse="")
# }
spamgam = gam(spam~
                +s(make, 4)
              +s(address, 4)
              +s(all, 4)
              +s(x3d, 4)
              +s(our, 4)
              +s(over, 4)
              +s(remove, 4)
              +s(internet, 4)
              +s(order, 4)
              +s(mail, 4)
              +s(receive, 4)
              +s(will, 4)
              +s(people, 4)
              +s(report, 4)
              +s(addresses, 4)
              +s(free, 4)
              +s(business, 4)
              +s(email, 4)
              +s(you, 4)
              +s(credit, 4)
              +s(your, 4)
              +s(font, 4)
              +s(x000, 4)
              +s(money, 4)
              +s(hp, 4)
              +s(hpl, 4)
              +s(george, 4)
              +s(x650, 4)
              +s(lab, 4)
              +s(labs, 4)
              +s(telnet, 4)
              +s(x857, 4)
              +s(data, 4)
              +s(x415, 4)
              +s(x85, 4)
              +s(technology, 4)
              +s(x1999, 4)
              +s(parts, 4)
              +s(pm, 4)
              +s(direct, 4)
              +s(cs, 4)
              +s(meeting, 4)
              +s(original, 4)
              +s(project, 4)
              +s(re, 4)
              +s(edu, 4)
              +s(table, 4)
              +s(conference, 4)
              +s(p1, 4)
              +s(p2, 4)
              +s(p3, 4)
              +s(p4, 4)
              +s(p5, 4)
              +s(p6, 4)
              +s(CRave, 4)
              +s(CRlong, 4)
              +s(CRtotal, 4)
              ,data = data.train
              ,family = binomial(link = "logit")
)
plot(spamgam, residual=TRUE, se=TRUE,pch=".")
summary(spamgam)
library(pROC)
y.hat = predict(spamgam, data.test, type="response")
roc(data.test$spam,y.hat,plot=T)
############################################################
for (i in c(1:(length(var.name)-1))){
  cat("+s(log(", var.name[i], "+0.1))\n", append=TRUE, sep = "", collapse="")
}
spamgam = gam(spam~
                +s(log(make+0.1))
              +s(log(address+0.1))
              +s(log(all+0.1))
              +s(log(x3d+0.1))
              +s(log(our+0.1))
              +s(log(over+0.1))
              +s(log(remove+0.1))
              +s(log(internet+0.1))
              +s(log(order+0.1))
              +s(log(mail+0.1))
              +s(log(receive+0.1))
              +s(log(will+0.1))
              +s(log(people+0.1))
              +s(log(report+0.1))
              +s(log(addresses+0.1))
              +s(log(free+0.1))
              +s(log(business+0.1))
              +s(log(email+0.1))
              +s(log(you+0.1))
              +s(log(credit+0.1))
              +s(log(your+0.1))
              +s(log(font+0.1))
              +s(log(x000+0.1))
              +s(log(money+0.1))
              +s(log(hp+0.1))
              +s(log(hpl+0.1))
              +s(log(george+0.1))
              +s(log(x650+0.1))
              +s(log(lab+0.1))
              +s(log(labs+0.1))
              +s(log(telnet+0.1))
              +s(log(x857+0.1))
              +s(log(data+0.1))
              +s(log(x415+0.1))
              +s(log(x85+0.1))
              +s(log(technology+0.1))
              +s(log(x1999+0.1))
              +s(log(parts+0.1))
              +s(log(pm+0.1))
              +s(log(direct+0.1))
              +s(log(cs+0.1))
              +s(log(meeting+0.1))
              +s(log(original+0.1))
              +s(log(project+0.1))
              +s(log(re+0.1))
              +s(log(edu+0.1))
              +s(log(table+0.1))
              +s(log(conference+0.1))
              +s(log(p1+0.1))
              +s(log(p2+0.1))
              +s(log(p3+0.1))
              +s(log(p4+0.1))
              +s(log(p5+0.1))
              +s(log(p6+0.1))
              +s(log(CRave+0.1))
              +s(log(CRlong+0.1))
              +s(log(CRtotal+0.1))
              ,data = data.train
              ,family = binomial
)
plot(spamgam, residual=TRUE, se=TRUE,pch=".")
summary(spamgam)
y.hat = predict(spamgam, data.test, type="response")
# head(y.hat,3)
# head(data.test$spam,3)
library(pROC)
roc(data.test$spam,y.hat,plot=T)

# penalized GAM is availabe in mgcv, but extremely slow.