In this lab, we will work on applying recursive partitioning (tree) models; identify and assess when a model might be overfitting; and assess and compare different models.
We will start with simulated data.
# The rpart package will allow us to create tree models.
library(rpart)
set.seed(0)
The data we simulate is based on the piecewise model \[ y = \begin{cases} 4 & & x < 50\\ 2 & & x \le 50 \end{cases} \]
n = 100
x = seq(n)
true_model = function(x){
# Piecewise constant
out = rep(2, length.out = length(x))
out[x < 50] = 4
out
}
y = true_model(x) + rnorm(n)
d = data.frame(x = x, y = y)
Obviously the true model isn’t linear, but let’s see what happens when we create a linear regression model with these data:
fit_lm = lm(y ~ x, data = d)
plot(x, y)
lines(x, true_model(x), lty = 2)
lines(x, predict(fit_lm), col = "blue")
Now let’s try with the recursive partitioning. We should check out the help file before we get started.
# Use the defaults:
fit_rpart = rpart(y ~ x, data = d)
plot(x, y)
lines(x, true_model(x), lty = 2)
lines(x, predict(fit_rpart), col = "red")
plot(fit_rpart)
text(fit_rpart)
Did we capture the truth? It’s not terrible, but
it’s not good either. To work on this, rpart
has “tuning
parameters” that we can adjust in order to adjust the model to better
fit the data.
fit_overfit = rpart(y ~ x, data = d, minsplit = 4, cp = 0.0001)
plot(x, y)
lines(x, true_model(x), lty = 2)
lines(x, predict(fit_rpart), col = "red")
lines(x, predict(fit_overfit), col = "blue")
This is doing a really good job fitting the data, but it’s a terrible fit to the underlying model (and would therefore be a poor fit if we tried to use it to predict \(y\) for new values of \(x\))! This is a case of overfitting.
We want our models to be parsimonious, which is sort of like “conservative”. Basically, we want the simplest model that reasonably explains the data; you can think of it as finding a good balance between simplicity and usefulness.
fit_parsimonious = rpart(y ~ x, data = d, minbucket = 30)
plot(x, y)
lines(x, true_model(x), lty = 2)
lines(x, predict(fit_rpart), col = "red")
lines(x, predict(fit_overfit), col = "blue")
lines(x, predict(fit_parsimonious), col = "green")
Next, we generate new \(y\) data from the true model (and add some random variation).
d2 = data.frame(x, truth = true_model(x))
d2$y = d2$truth + rnorm(n)
One way to examine how well the model worked for making new predictions is to examine the mean squared error for each model. For the original model,
error = predict(fit_rpart, d2) - d2$y
mean(error^2)
## [1] 1.044938
Does the best model minimize or maximize MSE?
Next, we will find the MSE for all of the models - the regression, the original rpart model, the overfitted model, and the idealized parsimonious model.
models = list(fit_lm, fit_rpart, fit_overfit, fit_parsimonious)
names(models) = c("linear", "default", "overfit", "parsimonious")
mse = function(fit, data = d2, out = data$y){
error = predict(fit, data) - out
mean(error^2)
}
mse(fit_rpart)
## [1] 1.044938
# Apply to all the models:
lapply(models, mse)
## $linear
## [1] 1.129974
##
## $default
## [1] 1.044938
##
## $overfit
## [1] 1.407317
##
## $parsimonious
## [1] 0.9540097
Unsurprisingly, the parsimonious model is the best!
In this case, we found new data and tested the performance of the models on that new data. But what happens if I can’t get new data? We can split up a dataset into “training data” and “testing data”. This means that we will use some of the data (maybe 80%) to build the model and set the remaining 20% aside to test model performance.
# External data:
body = read.csv("https://raw.githubusercontent.com/clarkfitzg/21fall_stat128/main/Howell_us.csv")
head(body)
## age height_in weight_lb sex
## 1 63 59.75 105.43729 male
## 2 63 55.00 80.43734 female
## 3 65 53.75 70.24986 female
## 4 41 61.75 116.93727 male
## 5 51 57.25 90.99982 female
## 6 35 64.50 138.87472 male
# Randomly pick data to use for testing (validation)
n = nrow(body)
val = sample(n, size = round(0.2 * n)) # select 20% of the data
body_test = body[val,] # test data
body_train = body[-val, ] # training data
Next we will fit the model with the training data, using
age
to predict height
.
# Sort data so lines function works later
body_train = body_train[order(body_train$age), ]
# Let's model height as a function of age.
fit1 = rpart(height_in ~ age, data = body_train)
with(body_train, plot(age, height_in))
lines(body_train$age, predict(fit1), col = "blue", lwd = 2)
We can include additional input variables in the same way as in a
lm
or glm
command:
fit2 = rpart(height_in ~ age + sex, data = body_train)
male = body_train$sex == "male"
with(body_train[male, ], plot(age, height_in, col = "red"))
lines(body_train$age[male], predict(fit2)[male], col = "red", lwd = 2)
with(body_train[!male, ], points(age, height_in, col = "blue"))
lines(body_train$age[!male], predict(fit2)[!male], col = "blue", lwd = 2)
The last two models are regression models with interaction terms. Note that trees cannot handle interaction terms.
# * does all possible first order and interaction terms
fit3 = lm(height_in ~ age*sex, data = body_train)
# : does exactly one interaction effect
fit4 = lm(height_in ~ age:sex, data = body_train)
Which model is best? Let’s examine how each performs using the test data that we held back from model creation.
models2 = list(fit1, fit2, fit3, fit4)
names(models2) = c("rpart age", "rpart age sex", "lm age*sex", "lm age:sex")
# Apply to all the models:
lapply(models2, mse, body_test, body_test$height_in)
## $`rpart age`
## [1] 11.75447
##
## $`rpart age sex`
## [1] 8.428809
##
## $`lm age*sex`
## [1] 59.9885
##
## $`lm age:sex`
## [1] 59.42185
Use the Titanic
dataset (a slightly simplified version
of the other Titantic data we worked with) to build several models. We
want to predict survival. Which model is best?
To do:
test
and train
datasets.data(Titanic)
An approach called cross validation generalizes the idea of splitting the data into a test and a training dataset:
This allows us to use more of our data for the training sample. For data sets that aren’t too large, we often let \(s=n\). This is called “leave-one-out” cross validation.
Lucky for us, this is built into the RPart package! Using
cross-validation, rpart
will examine a variety of trees
using age
and sex
to predict
height_in
.
fit5 <- rpart(height_in ~ age + sex, data = body,
control = rpart.control(xval = nrow(body), minbucket = 2, cp = 0))
printcp(fit5)
##
## Regression tree:
## rpart(formula = height_in ~ age + sex, data = body, control = rpart.control(xval = nrow(body),
## minbucket = 2, cp = 0))
##
## Variables actually used in tree construction:
## [1] age sex
##
## Root node error: 64125/544 = 117.88
##
## n= 544
##
## CP nsplit rel error xerror xstd
## 1 7.7722e-01 0 1.000000 1.003687 0.0693006
## 2 7.3546e-02 1 0.222782 0.225099 0.0153530
## 3 4.3648e-02 2 0.149236 0.151115 0.0096373
## 4 2.4864e-02 3 0.105588 0.108602 0.0063549
## 5 1.1797e-02 4 0.080723 0.083725 0.0060437
## 6 1.0834e-02 5 0.068926 0.080456 0.0063328
## 7 2.8043e-03 6 0.058092 0.063315 0.0050906
## 8 1.8141e-03 7 0.055287 0.062338 0.0049869
## 9 1.1051e-03 8 0.053473 0.060418 0.0047727
## 10 1.0920e-03 9 0.052368 0.062983 0.0047775
## 11 9.2748e-04 10 0.051276 0.063577 0.0047207
## 12 8.7395e-04 11 0.050349 0.064454 0.0047415
## 13 7.3280e-04 12 0.049475 0.062832 0.0046879
## 14 6.9295e-04 13 0.048742 0.061265 0.0044770
## 15 5.1851e-04 14 0.048049 0.060592 0.0044667
## 16 4.8287e-04 15 0.047530 0.061750 0.0046693
## 17 4.4766e-04 16 0.047048 0.060989 0.0046260
## 18 3.4569e-04 17 0.046600 0.060715 0.0046423
## 19 3.3535e-04 19 0.045909 0.059946 0.0045127
## 20 2.8807e-04 21 0.045238 0.059518 0.0044270
## 21 2.6593e-04 22 0.044950 0.061288 0.0045550
## 22 2.6365e-04 23 0.044684 0.062011 0.0044994
## 23 2.3791e-04 26 0.043893 0.062136 0.0045163
## 24 2.3011e-04 27 0.043655 0.061602 0.0045158
## 25 2.1180e-04 28 0.043425 0.061930 0.0045091
## 26 2.0292e-04 32 0.042578 0.062281 0.0045272
## 27 2.0160e-04 33 0.042375 0.062110 0.0045407
## 28 1.5638e-04 34 0.042173 0.061218 0.0045345
## 29 1.4870e-04 35 0.042017 0.061657 0.0045515
## 30 1.4232e-04 37 0.041719 0.062186 0.0045653
## 31 1.2599e-04 38 0.041577 0.063198 0.0045779
## 32 1.1630e-04 44 0.040821 0.063602 0.0045950
## 33 1.1507e-04 46 0.040589 0.063828 0.0046170
## 34 1.0860e-04 51 0.039945 0.063560 0.0046194
## 35 1.0024e-04 53 0.039727 0.063911 0.0046288
## 36 9.2531e-05 55 0.039527 0.063532 0.0046141
## 37 9.1494e-05 56 0.039434 0.063608 0.0046228
## 38 8.7200e-05 57 0.039343 0.063609 0.0046225
## 39 7.8321e-05 58 0.039256 0.063432 0.0045895
## 40 6.8984e-05 59 0.039177 0.063393 0.0045909
## 41 6.2378e-05 63 0.038902 0.062620 0.0045545
## 42 5.8700e-05 64 0.038839 0.062565 0.0045596
## 43 5.1474e-05 66 0.038722 0.062489 0.0045574
## 44 4.7878e-05 67 0.038670 0.062799 0.0046450
## 45 3.6550e-05 68 0.038622 0.062623 0.0046454
## 46 3.3970e-05 69 0.038586 0.062946 0.0046422
## 47 3.0291e-05 70 0.038552 0.062781 0.0045819
## 48 2.8924e-05 71 0.038522 0.062817 0.0045806
## 49 2.8748e-05 72 0.038493 0.063042 0.0046023
## 50 2.5965e-05 74 0.038435 0.063378 0.0046609
## 51 2.5827e-05 75 0.038409 0.063422 0.0046599
## 52 2.3360e-05 76 0.038383 0.063472 0.0046588
## 53 2.1930e-05 77 0.038360 0.063552 0.0046579
## 54 2.1426e-05 78 0.038338 0.063731 0.0046602
## 55 2.1059e-05 79 0.038317 0.063719 0.0046595
## 56 2.0263e-05 80 0.038296 0.063731 0.0046592
## 57 1.9743e-05 81 0.038275 0.063821 0.0046582
## 58 1.6955e-05 83 0.038236 0.063758 0.0046559
## 59 1.6244e-05 85 0.038202 0.063647 0.0046548
## 60 1.6096e-05 86 0.038186 0.063676 0.0046541
## 61 1.5482e-05 87 0.038170 0.063702 0.0046550
## 62 1.4408e-05 88 0.038154 0.063651 0.0046557
## 63 1.4040e-05 89 0.038140 0.063643 0.0046559
## 64 1.3653e-05 90 0.038126 0.063647 0.0046575
## 65 1.2736e-05 91 0.038112 0.063711 0.0046582
## 66 1.2476e-05 92 0.038099 0.063783 0.0046583
## 67 1.1681e-05 93 0.038087 0.063826 0.0046580
## 68 1.1050e-05 94 0.038075 0.063947 0.0046630
## 69 9.8684e-06 95 0.038064 0.063988 0.0046634
## 70 7.3734e-06 96 0.038054 0.063990 0.0046578
## 71 7.2519e-06 97 0.038047 0.064086 0.0046576
## 72 6.3363e-06 98 0.038040 0.064088 0.0046579
## 73 5.3570e-06 99 0.038033 0.064172 0.0046570
## 74 4.3860e-06 101 0.038023 0.064264 0.0046614
## 75 3.7205e-06 102 0.038018 0.064293 0.0046613
## 76 3.3833e-06 103 0.038014 0.064272 0.0046601
## 77 2.8844e-06 104 0.038011 0.064286 0.0046599
## 78 2.7754e-06 105 0.038008 0.064285 0.0046598
## 79 2.6107e-06 106 0.038005 0.064292 0.0046597
## 80 6.4977e-07 107 0.038003 0.064298 0.0046619
## 81 5.4148e-07 108 0.038002 0.064315 0.0046607
## 82 3.2489e-07 109 0.038002 0.064335 0.0046609
## 83 2.6532e-07 110 0.038001 0.064328 0.0046610
## 84 1.6341e-07 111 0.038001 0.064328 0.0046609
## 85 1.6244e-07 112 0.038001 0.064326 0.0046609
## 86 1.6244e-07 113 0.038001 0.064322 0.0046610
## 87 7.0202e-08 114 0.038000 0.064326 0.0046610
## 88 3.7427e-08 115 0.038000 0.064320 0.0046611
## 89 2.0305e-08 116 0.038000 0.064316 0.0046609
## 90 1.1603e-08 117 0.038000 0.064316 0.0046609
## 91 0.0000e+00 118 0.038000 0.064316 0.0046609
# "prune" the tree so that splits (branches) must decrease overall lack of fit by 0.02
# in the original call, splits did not have to decrease lack of fit at all
fit6 <- prune(fit5, cp=0.02)
plot(fit6); text(fit6) # plot the tree
printcp(fit6)
##
## Regression tree:
## rpart(formula = height_in ~ age + sex, data = body, control = rpart.control(xval = nrow(body),
## minbucket = 2, cp = 0))
##
## Variables actually used in tree construction:
## [1] age sex
##
## Root node error: 64125/544 = 117.88
##
## n= 544
##
## CP nsplit rel error xerror xstd
## 1 0.777218 0 1.000000 1.003687 0.0693006
## 2 0.073546 1 0.222782 0.225099 0.0153530
## 3 0.043648 2 0.149236 0.151115 0.0096373
## 4 0.024864 3 0.105588 0.108602 0.0063549
## 5 0.020000 4 0.080723 0.083725 0.0060437
Let’s examine the MSE for all our models.
models3 = list(fit1, fit2, fit3, fit4, fit5, fit6)
names(models3) = c("rpart age", "rpart age sex", "lm age*sex", "lm age:sex", "cv.all", "cv.fin")
# Apply to all the models:
lapply(models3, mse, body_test, body_test$height_in)
## $`rpart age`
## [1] 11.75447
##
## $`rpart age sex`
## [1] 8.428809
##
## $`lm age*sex`
## [1] 59.9885
##
## $`lm age:sex`
## [1] 59.42185
##
## $cv.all
## [1] 4.336
##
## $cv.fin
## [1] 8.526544