Introduction

This example builds on the statistical learning chapter from Modern Data Science with R: http://mdsr-book.github.io/.

Data ingestation and processing

library(mdsr) 
census <- read.csv(
"http://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data", 
  header = FALSE)
names(census) <- c("age", "workclass", "fnlwgt", "education", 
  "education.num", "marital.status", "occupation", "relationship", 
  "race", "sex", "capital.gain", "capital.loss", "hours.per.week", 
  "native.country", "income")
glimpse(census)
## Observations: 32,561
## Variables: 15
## $ age            <int> 39, 50, 38, 53, 28, 37, 49, 52, 31, 42, 37, 30,...
## $ workclass      <fctr>  State-gov,  Self-emp-not-inc,  Private,  Priv...
## $ fnlwgt         <int> 77516, 83311, 215646, 234721, 338409, 284582, 1...
## $ education      <fctr>  Bachelors,  Bachelors,  HS-grad,  11th,  Bach...
## $ education.num  <int> 13, 13, 9, 7, 13, 14, 5, 9, 14, 13, 10, 13, 13,...
## $ marital.status <fctr>  Never-married,  Married-civ-spouse,  Divorced...
## $ occupation     <fctr>  Adm-clerical,  Exec-managerial,  Handlers-cle...
## $ relationship   <fctr>  Not-in-family,  Husband,  Not-in-family,  Hus...
## $ race           <fctr>  White,  White,  White,  Black,  Black,  White...
## $ sex            <fctr>  Male,  Male,  Male,  Male,  Female,  Female, ...
## $ capital.gain   <int> 2174, 0, 0, 0, 0, 0, 0, 0, 14084, 5178, 0, 0, 0...
## $ capital.loss   <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...
## $ hours.per.week <int> 40, 13, 40, 40, 40, 40, 16, 45, 50, 40, 80, 40,...
## $ native.country <fctr>  United-States,  United-States,  United-States...
## $ income         <fctr>  <=50K,  <=50K,  <=50K,  <=50K,  <=50K,  <=50K...
set.seed(364)
n <- nrow(census)
test_idx <- sample.int(n, size = round(0.2 * n))
train <- census[-test_idx, ]
nrow(train)
## [1] 26049
test <- census[test_idx, ]
nrow(test)
## [1] 6512
pi_bar <- tally(~ income, data = train, format = "percent")[2]
pi_bar
##     >50K 
## 24.25045
tally(~ income, data = train, format = "percent")
## income
##    <=50K     >50K 
## 75.74955 24.25045
library(rpart)
dtree <- rpart(income ~ capital.gain, data = train)
split_val <- as.data.frame(dtree$splits)$index
dtree_frame <- dtree$frame %>%
  select(var, n, dev, yval) %>%
  mutate(pct = n / nrow(train), right = (n - dev) / n)
rpart(income ~ capital.gain, data = train)
## n= 26049 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
## 1) root 26049 6317  <=50K (0.75749549 0.24250451)  
##   2) capital.gain< 5095.5 24784 5115  <=50K (0.79361685 0.20638315) *
##   3) capital.gain>=5095.5 1265   63  >50K (0.04980237 0.95019763) *
split <- 5095.5
train <- train %>% mutate(hi_cap_gains = capital.gain >= split)

ggplot(data = train, aes(x = capital.gain, y = income)) +
  geom_count(aes(color = hi_cap_gains),
    position = position_jitter(width = 0, height = 0.1), alpha = 0.5) +
  geom_vline(xintercept = split, color = "dodgerblue", lty = 2) +
  scale_x_log10(labels = scales::dollar)
## Warning: Transformation introduced infinite values in continuous x-axis
## Warning: Removed 23870 rows containing non-finite values (stat_sum).

form <- as.formula("income ~ age + workclass + education + marital.status + 
  occupation + relationship + race + sex + capital.gain + capital.loss +
  hours.per.week")
mod_tree <- rpart(form, data = train)
mod_tree
## n= 26049 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
##  1) root 26049 6317  <=50K (0.75749549 0.24250451)  
##    2) relationship= Not-in-family, Other-relative, Own-child, Unmarried 14196  947  <=50K (0.93329107 0.06670893)  
##      4) capital.gain< 7073.5 13946  706  <=50K (0.94937617 0.05062383) *
##      5) capital.gain>=7073.5 250    9  >50K (0.03600000 0.96400000) *
##    3) relationship= Husband, Wife 11853 5370  <=50K (0.54695014 0.45304986)  
##      6) education= 10th, 11th, 12th, 1st-4th, 5th-6th, 7th-8th, 9th, Assoc-acdm, Assoc-voc, HS-grad, Preschool, Some-college 8280 2769  <=50K (0.66557971 0.33442029)  
##       12) capital.gain< 5095.5 7857 2355  <=50K (0.70026728 0.29973272) *
##       13) capital.gain>=5095.5 423    9  >50K (0.02127660 0.97872340) *
##      7) education= Bachelors, Doctorate, Masters, Prof-school 3573  972  >50K (0.27204030 0.72795970) *
plot(mod_tree)
text(mod_tree, use.n = TRUE, all = TRUE, cex = 0.7)

library(partykit)
## Loading required package: grid
plot(as.party(mod_tree))