This example builds on the statistical learning chapter from Modern Data Science with R: http://mdsr-book.github.io/.
library(mdsr)
census <- read.csv(
"http://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data",
header = FALSE)
names(census) <- c("age", "workclass", "fnlwgt", "education",
"education.num", "marital.status", "occupation", "relationship",
"race", "sex", "capital.gain", "capital.loss", "hours.per.week",
"native.country", "income")
glimpse(census)
## Observations: 32,561
## Variables: 15
## $ age <int> 39, 50, 38, 53, 28, 37, 49, 52, 31, 42, 37, 30,...
## $ workclass <fctr> State-gov, Self-emp-not-inc, Private, Priv...
## $ fnlwgt <int> 77516, 83311, 215646, 234721, 338409, 284582, 1...
## $ education <fctr> Bachelors, Bachelors, HS-grad, 11th, Bach...
## $ education.num <int> 13, 13, 9, 7, 13, 14, 5, 9, 14, 13, 10, 13, 13,...
## $ marital.status <fctr> Never-married, Married-civ-spouse, Divorced...
## $ occupation <fctr> Adm-clerical, Exec-managerial, Handlers-cle...
## $ relationship <fctr> Not-in-family, Husband, Not-in-family, Hus...
## $ race <fctr> White, White, White, Black, Black, White...
## $ sex <fctr> Male, Male, Male, Male, Female, Female, ...
## $ capital.gain <int> 2174, 0, 0, 0, 0, 0, 0, 0, 14084, 5178, 0, 0, 0...
## $ capital.loss <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...
## $ hours.per.week <int> 40, 13, 40, 40, 40, 40, 16, 45, 50, 40, 80, 40,...
## $ native.country <fctr> United-States, United-States, United-States...
## $ income <fctr> <=50K, <=50K, <=50K, <=50K, <=50K, <=50K...
set.seed(364)
n <- nrow(census)
test_idx <- sample.int(n, size = round(0.2 * n))
train <- census[-test_idx, ]
nrow(train)
## [1] 26049
test <- census[test_idx, ]
nrow(test)
## [1] 6512
pi_bar <- tally(~ income, data = train, format = "percent")[2]
pi_bar
## >50K
## 24.25045
tally(~ income, data = train, format = "percent")
## income
## <=50K >50K
## 75.74955 24.25045
library(rpart)
dtree <- rpart(income ~ capital.gain, data = train)
split_val <- as.data.frame(dtree$splits)$index
dtree_frame <- dtree$frame %>%
select(var, n, dev, yval) %>%
mutate(pct = n / nrow(train), right = (n - dev) / n)
rpart(income ~ capital.gain, data = train)
## n= 26049
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 26049 6317 <=50K (0.75749549 0.24250451)
## 2) capital.gain< 5095.5 24784 5115 <=50K (0.79361685 0.20638315) *
## 3) capital.gain>=5095.5 1265 63 >50K (0.04980237 0.95019763) *
split <- 5095.5
train <- train %>% mutate(hi_cap_gains = capital.gain >= split)
ggplot(data = train, aes(x = capital.gain, y = income)) +
geom_count(aes(color = hi_cap_gains),
position = position_jitter(width = 0, height = 0.1), alpha = 0.5) +
geom_vline(xintercept = split, color = "dodgerblue", lty = 2) +
scale_x_log10(labels = scales::dollar)
## Warning: Transformation introduced infinite values in continuous x-axis
## Warning: Removed 23870 rows containing non-finite values (stat_sum).
form <- as.formula("income ~ age + workclass + education + marital.status +
occupation + relationship + race + sex + capital.gain + capital.loss +
hours.per.week")
mod_tree <- rpart(form, data = train)
mod_tree
## n= 26049
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 26049 6317 <=50K (0.75749549 0.24250451)
## 2) relationship= Not-in-family, Other-relative, Own-child, Unmarried 14196 947 <=50K (0.93329107 0.06670893)
## 4) capital.gain< 7073.5 13946 706 <=50K (0.94937617 0.05062383) *
## 5) capital.gain>=7073.5 250 9 >50K (0.03600000 0.96400000) *
## 3) relationship= Husband, Wife 11853 5370 <=50K (0.54695014 0.45304986)
## 6) education= 10th, 11th, 12th, 1st-4th, 5th-6th, 7th-8th, 9th, Assoc-acdm, Assoc-voc, HS-grad, Preschool, Some-college 8280 2769 <=50K (0.66557971 0.33442029)
## 12) capital.gain< 5095.5 7857 2355 <=50K (0.70026728 0.29973272) *
## 13) capital.gain>=5095.5 423 9 >50K (0.02127660 0.97872340) *
## 7) education= Bachelors, Doctorate, Masters, Prof-school 3573 972 >50K (0.27204030 0.72795970) *
plot(mod_tree)
text(mod_tree, use.n = TRUE, all = TRUE, cex = 0.7)
library(partykit)
## Loading required package: grid
plot(as.party(mod_tree))