Recursive Partioning

In this lab, we will work on applying recursive partitioning (tree) models; identify and assess when a model might be overfitting; and assess and compare different models.

We will start with simulated data.

# The rpart package will allow us to create tree models.
library(rpart)
set.seed(0)

The data we simulate is based on the piecewise model \[ y = \begin{cases} 4 & & x < 50\\ 2 & & x \le 50 \end{cases} \]

n = 100
x = seq(n)
true_model = function(x){
    # Piecewise constant
    out = rep(2, length.out = length(x))
    out[x < 50] = 4
    out
}
y = true_model(x) + rnorm(n)

d = data.frame(x = x, y = y)

Building Models

Obviously the true model isn’t linear, but let’s see what happens when we create a linear regression model with these data:

fit_lm = lm(y ~ x, data = d)

plot(x, y)
lines(x, true_model(x), lty = 2)
lines(x, predict(fit_lm), col = "blue")

Now let’s try with the recursive partitioning. We should check out the help file before we get started.

# Use the defaults:
fit_rpart = rpart(y ~ x, data = d)

plot(x, y)
lines(x, true_model(x), lty = 2)
lines(x, predict(fit_rpart), col = "red")

plot(fit_rpart)
text(fit_rpart)

Did we capture the truth? It’s not terrible, but it’s not good either. To work on this, rpart has “tuning parameters” that we can adjust in order to adjust the model to better fit the data.

fit_overfit = rpart(y ~ x, data = d, minsplit = 4, cp = 0.0001)

plot(x, y)
lines(x, true_model(x), lty = 2)
lines(x, predict(fit_rpart), col = "red")
lines(x, predict(fit_overfit), col = "blue")

This is doing a really good job fitting the data, but it’s a terrible fit to the underlying model (and would therefore be a poor fit if we tried to use it to predict \(y\) for new values of \(x\))! This is a case of overfitting.

We want our models to be parsimonious, which is sort of like “conservative”. Basically, we want the simplest model that reasonably explains the data; you can think of it as finding a good balance between simplicity and usefulness.

fit_parsimonious = rpart(y ~ x, data = d, minbucket = 30)

plot(x, y)
lines(x, true_model(x), lty = 2)
lines(x, predict(fit_rpart), col = "red")
lines(x, predict(fit_overfit), col = "blue")
lines(x, predict(fit_parsimonious), col = "green")

Evaluating Models

Next, we generate new \(y\) data from the true model (and add some random variation).

d2 = data.frame(x, truth = true_model(x))
d2$y = d2$truth + rnorm(n)

One way to examine how well the model worked for making new predictions is to examine the mean squared error for each model. For the original model,

error = predict(fit_rpart, d2) - d2$y
mean(error^2)
## [1] 1.044938

Does the best model minimize or maximize MSE?

Next, we will find the MSE for all of the models - the regression, the original rpart model, the overfitted model, and the idealized parsimonious model.

models = list(fit_lm, fit_rpart, fit_overfit, fit_parsimonious)
names(models) = c("linear", "default", "overfit", "parsimonious")

mse = function(fit, data = d2, out = data$y){
  error = predict(fit, data) - out
  mean(error^2)   
}

mse(fit_rpart)
## [1] 1.044938
# Apply to all the models:
lapply(models, mse)
## $linear
## [1] 1.129974
## 
## $default
## [1] 1.044938
## 
## $overfit
## [1] 1.407317
## 
## $parsimonious
## [1] 0.9540097

Unsurprisingly, the parsimonious model is the best!

In this case, we found new data and tested the performance of the models on that new data. But what happens if I can’t get new data? We can split up a dataset into “training data” and “testing data”. This means that we will use some of the data (maybe 80%) to build the model and set the remaining 20% aside to test model performance.

# External data:
body = read.csv("https://raw.githubusercontent.com/clarkfitzg/21fall_stat128/main/Howell_us.csv")
head(body)
##   age height_in weight_lb    sex
## 1  63     59.75 105.43729   male
## 2  63     55.00  80.43734 female
## 3  65     53.75  70.24986 female
## 4  41     61.75 116.93727   male
## 5  51     57.25  90.99982 female
## 6  35     64.50 138.87472   male
# Randomly pick data to use for testing (validation)
n = nrow(body)
val = sample(n, size = round(0.2 * n)) # select 20% of the data
body_test = body[val,] # test data
body_train = body[-val, ] # training data

Next we will fit the model with the training data, using age to predict height.

# Sort data so lines function works later
body_train = body_train[order(body_train$age), ]

# Let's model height as a function of age.
fit1 = rpart(height_in ~ age, data = body_train)

with(body_train, plot(age, height_in))
lines(body_train$age, predict(fit1), col = "blue", lwd = 2)

We can include additional input variables in the same way as in a lm or glm command:

fit2 = rpart(height_in ~ age + sex, data = body_train)

male = body_train$sex == "male"

with(body_train[male, ], plot(age, height_in, col = "red"))
lines(body_train$age[male], predict(fit2)[male], col = "red", lwd = 2)
with(body_train[!male, ], points(age, height_in, col = "blue"))
lines(body_train$age[!male], predict(fit2)[!male], col = "blue", lwd = 2)

The last two models are regression models with interaction terms. Note that trees cannot handle interaction terms.

# * does all possible first order and interaction terms
fit3 = lm(height_in ~ age*sex, data = body_train)

# : does exactly one interaction effect
fit4 = lm(height_in ~ age:sex, data = body_train)

Which model is best? Let’s examine how each performs using the test data that we held back from model creation.

models2 = list(fit1, fit2, fit3, fit4)
names(models2) = c("rpart age", "rpart age sex", "lm age*sex", "lm age:sex")

# Apply to all the models:
lapply(models2, mse, body_test, body_test$height_in)
## $`rpart age`
## [1] 11.75447
## 
## $`rpart age sex`
## [1] 8.428809
## 
## $`lm age*sex`
## [1] 59.9885
## 
## $`lm age:sex`
## [1] 59.42185

On Your Own

Use the Titanic dataset (a slightly simplified version of the other Titantic data we worked with) to build several models. We want to predict survival. Which model is best?

To do:

  • Create test and train datasets.
  • Create four models: two logistic regression models (see Lab 12) and two RPart models (mess around with the tuning parameters).
  • Use your training data to test the models and examine their mean squared error.
data(Titanic)

Cross Validation

An approach called cross validation generalizes the idea of splitting the data into a test and a training dataset:

  1. Divide the data into \(s\) groups, \(g_1\), \(g_2\), \(\dots\), \(g_s\), each of size \(s/n\).
  2. For each group separately, fit a model on the full data except for the data in group \(i\).
  3. Determine the error for each group, using \(g_i\) at the testing data for the \(i\)th model.
  4. The model error is the mean of the group errors.

This allows us to use more of our data for the training sample. For data sets that aren’t too large, we often let \(s=n\). This is called “leave-one-out” cross validation.

Lucky for us, this is built into the RPart package! Using cross-validation, rpart will examine a variety of trees using age and sex to predict height_in.

fit5 <- rpart(height_in ~ age + sex, data = body, 
              control = rpart.control(xval = nrow(body), minbucket = 2, cp = 0))
printcp(fit5) 
## 
## Regression tree:
## rpart(formula = height_in ~ age + sex, data = body, control = rpart.control(xval = nrow(body), 
##     minbucket = 2, cp = 0))
## 
## Variables actually used in tree construction:
## [1] age sex
## 
## Root node error: 64125/544 = 117.88
## 
## n= 544 
## 
##            CP nsplit rel error   xerror      xstd
## 1  7.7722e-01      0  1.000000 1.003687 0.0693006
## 2  7.3546e-02      1  0.222782 0.225099 0.0153530
## 3  4.3648e-02      2  0.149236 0.151115 0.0096373
## 4  2.4864e-02      3  0.105588 0.108602 0.0063549
## 5  1.1797e-02      4  0.080723 0.083725 0.0060437
## 6  1.0834e-02      5  0.068926 0.080456 0.0063328
## 7  2.8043e-03      6  0.058092 0.063315 0.0050906
## 8  1.8141e-03      7  0.055287 0.062338 0.0049869
## 9  1.1051e-03      8  0.053473 0.060418 0.0047727
## 10 1.0920e-03      9  0.052368 0.062983 0.0047775
## 11 9.2748e-04     10  0.051276 0.063577 0.0047207
## 12 8.7395e-04     11  0.050349 0.064454 0.0047415
## 13 7.3280e-04     12  0.049475 0.062832 0.0046879
## 14 6.9295e-04     13  0.048742 0.061265 0.0044770
## 15 5.1851e-04     14  0.048049 0.060592 0.0044667
## 16 4.8287e-04     15  0.047530 0.061750 0.0046693
## 17 4.4766e-04     16  0.047048 0.060989 0.0046260
## 18 3.4569e-04     17  0.046600 0.060715 0.0046423
## 19 3.3535e-04     19  0.045909 0.059946 0.0045127
## 20 2.8807e-04     21  0.045238 0.059518 0.0044270
## 21 2.6593e-04     22  0.044950 0.061288 0.0045550
## 22 2.6365e-04     23  0.044684 0.062011 0.0044994
## 23 2.3791e-04     26  0.043893 0.062136 0.0045163
## 24 2.3011e-04     27  0.043655 0.061602 0.0045158
## 25 2.1180e-04     28  0.043425 0.061930 0.0045091
## 26 2.0292e-04     32  0.042578 0.062281 0.0045272
## 27 2.0160e-04     33  0.042375 0.062110 0.0045407
## 28 1.5638e-04     34  0.042173 0.061218 0.0045345
## 29 1.4870e-04     35  0.042017 0.061657 0.0045515
## 30 1.4232e-04     37  0.041719 0.062186 0.0045653
## 31 1.2599e-04     38  0.041577 0.063198 0.0045779
## 32 1.1630e-04     44  0.040821 0.063602 0.0045950
## 33 1.1507e-04     46  0.040589 0.063828 0.0046170
## 34 1.0860e-04     51  0.039945 0.063560 0.0046194
## 35 1.0024e-04     53  0.039727 0.063911 0.0046288
## 36 9.2531e-05     55  0.039527 0.063532 0.0046141
## 37 9.1494e-05     56  0.039434 0.063608 0.0046228
## 38 8.7200e-05     57  0.039343 0.063609 0.0046225
## 39 7.8321e-05     58  0.039256 0.063432 0.0045895
## 40 6.8984e-05     59  0.039177 0.063393 0.0045909
## 41 6.2378e-05     63  0.038902 0.062620 0.0045545
## 42 5.8700e-05     64  0.038839 0.062565 0.0045596
## 43 5.1474e-05     66  0.038722 0.062489 0.0045574
## 44 4.7878e-05     67  0.038670 0.062799 0.0046450
## 45 3.6550e-05     68  0.038622 0.062623 0.0046454
## 46 3.3970e-05     69  0.038586 0.062946 0.0046422
## 47 3.0291e-05     70  0.038552 0.062781 0.0045819
## 48 2.8924e-05     71  0.038522 0.062817 0.0045806
## 49 2.8748e-05     72  0.038493 0.063042 0.0046023
## 50 2.5965e-05     74  0.038435 0.063378 0.0046609
## 51 2.5827e-05     75  0.038409 0.063422 0.0046599
## 52 2.3360e-05     76  0.038383 0.063472 0.0046588
## 53 2.1930e-05     77  0.038360 0.063552 0.0046579
## 54 2.1426e-05     78  0.038338 0.063731 0.0046602
## 55 2.1059e-05     79  0.038317 0.063719 0.0046595
## 56 2.0263e-05     80  0.038296 0.063731 0.0046592
## 57 1.9743e-05     81  0.038275 0.063821 0.0046582
## 58 1.6955e-05     83  0.038236 0.063758 0.0046559
## 59 1.6244e-05     85  0.038202 0.063647 0.0046548
## 60 1.6096e-05     86  0.038186 0.063676 0.0046541
## 61 1.5482e-05     87  0.038170 0.063702 0.0046550
## 62 1.4408e-05     88  0.038154 0.063651 0.0046557
## 63 1.4040e-05     89  0.038140 0.063643 0.0046559
## 64 1.3653e-05     90  0.038126 0.063647 0.0046575
## 65 1.2736e-05     91  0.038112 0.063711 0.0046582
## 66 1.2476e-05     92  0.038099 0.063783 0.0046583
## 67 1.1681e-05     93  0.038087 0.063826 0.0046580
## 68 1.1050e-05     94  0.038075 0.063947 0.0046630
## 69 9.8684e-06     95  0.038064 0.063988 0.0046634
## 70 7.3734e-06     96  0.038054 0.063990 0.0046578
## 71 7.2519e-06     97  0.038047 0.064086 0.0046576
## 72 6.3363e-06     98  0.038040 0.064088 0.0046579
## 73 5.3570e-06     99  0.038033 0.064172 0.0046570
## 74 4.3860e-06    101  0.038023 0.064264 0.0046614
## 75 3.7205e-06    102  0.038018 0.064293 0.0046613
## 76 3.3833e-06    103  0.038014 0.064272 0.0046601
## 77 2.8844e-06    104  0.038011 0.064286 0.0046599
## 78 2.7754e-06    105  0.038008 0.064285 0.0046598
## 79 2.6107e-06    106  0.038005 0.064292 0.0046597
## 80 6.4977e-07    107  0.038003 0.064298 0.0046619
## 81 5.4148e-07    108  0.038002 0.064315 0.0046607
## 82 3.2489e-07    109  0.038002 0.064335 0.0046609
## 83 2.6532e-07    110  0.038001 0.064328 0.0046610
## 84 1.6341e-07    111  0.038001 0.064328 0.0046609
## 85 1.6244e-07    112  0.038001 0.064326 0.0046609
## 86 1.6244e-07    113  0.038001 0.064322 0.0046610
## 87 7.0202e-08    114  0.038000 0.064326 0.0046610
## 88 3.7427e-08    115  0.038000 0.064320 0.0046611
## 89 2.0305e-08    116  0.038000 0.064316 0.0046609
## 90 1.1603e-08    117  0.038000 0.064316 0.0046609
## 91 0.0000e+00    118  0.038000 0.064316 0.0046609
# "prune" the tree so that splits (branches) must decrease overall lack of fit by 0.02
# in the original call, splits did not have to decrease lack of fit at all
fit6 <- prune(fit5, cp=0.02) 
plot(fit6); text(fit6) # plot the tree

printcp(fit6)
## 
## Regression tree:
## rpart(formula = height_in ~ age + sex, data = body, control = rpart.control(xval = nrow(body), 
##     minbucket = 2, cp = 0))
## 
## Variables actually used in tree construction:
## [1] age sex
## 
## Root node error: 64125/544 = 117.88
## 
## n= 544 
## 
##         CP nsplit rel error   xerror      xstd
## 1 0.777218      0  1.000000 1.003687 0.0693006
## 2 0.073546      1  0.222782 0.225099 0.0153530
## 3 0.043648      2  0.149236 0.151115 0.0096373
## 4 0.024864      3  0.105588 0.108602 0.0063549
## 5 0.020000      4  0.080723 0.083725 0.0060437

Let’s examine the MSE for all our models.

models3 = list(fit1, fit2, fit3, fit4, fit5, fit6)
names(models3) = c("rpart age", "rpart age sex", "lm age*sex", "lm age:sex", "cv.all", "cv.fin")

# Apply to all the models:
lapply(models3, mse, body_test, body_test$height_in)
## $`rpart age`
## [1] 11.75447
## 
## $`rpart age sex`
## [1] 8.428809
## 
## $`lm age*sex`
## [1] 59.9885
## 
## $`lm age:sex`
## [1] 59.42185
## 
## $cv.all
## [1] 4.336
## 
## $cv.fin
## [1] 8.526544