The data this week comes from the survivorR
R package by way of Daniel Oehm.
castaways <- read_csv("castaways.csv") %>%
select(season, castaway, age, state, personality_type, order,
total_votes_received, immunity_idols_won) %>%
drop_na()
variable | class | description |
---|---|---|
season | integer | Season number |
castaway | character | Castaway’s first name |
age | double | Age |
state | character | Origin state |
personality_type | character | personality type |
order | integer | Order |
total_votes_received | double | Total votes received |
immunity_idols_won | double | Immunity idols won |
For the next few weeks, the goal of my Tidy Tuesday analyses will be to get used to working with tidymodels. This week, we’ll try to predict which week (order
) the individual was cast out using a few of the variables. Note that votes will be a good measure of prediction which would silly to use in a real model, but here the idea is just to practice using tidymodels.
Note after the fact: lots of learning today, and I’m starting to get the hang of the tidymodels framework. However, the dataset doesn’t have very many variables. I look forward to next week’s data in hopes that I’ll be able to model with more variables.
We’ll want to first split the data into test and training sets.
set.seed(47)
data_split <- initial_split(castaways, prop = 2/3)
cast_train <- training(data_split)
cast_test <- testing(data_split)
We’ll also specify the recipe()
we want to use. We’ll use the same recipe for both the linear model and the random forest. The power of the recipe()
function is all about feature selection. There are lots of ways to build / combine / edit features for improved prediction. We won’t do that here.
cast_rec <-
recipe(order ~ ., data = cast_train) %>%
update_role(season, castaway, state, new_role = "ID") #%>%
#step_dummy(all_nominal(), -all_outcomes())
summary(cast_rec)
## # A tibble: 8 x 4
## variable type role source
## <chr> <chr> <chr> <chr>
## 1 season numeric ID original
## 2 castaway nominal ID original
## 3 age numeric predictor original
## 4 state nominal ID original
## 5 personality_type nominal predictor original
## 6 total_votes_received numeric predictor original
## 7 immunity_idols_won numeric predictor original
## 8 order numeric outcome original
As specified in the tidymodels vignette (I’ll be using it heavily to work through today’s analysis), we set the engine to be a linear regression model.
lm_mod <-
linear_reg() %>%
set_engine("lm")
After we’ve specified the engine, we can then fit()
the model:
cast_lm <-
workflow() %>%
add_model(lm_mod) %>%
add_recipe(cast_rec) %>%
fit(data = cast_train)
cast_lm %>% tidy()
## # A tibble: 19 x 5
## term estimate std.error statistic p.value
## <chr> <dbl> <dbl> <dbl> <dbl>
## 1 (Intercept) 10.3 1.34 7.67 9.56e-14
## 2 age 0.0148 0.0223 0.663 5.08e- 1
## 3 personality_typeENFP -4.24 1.22 -3.49 5.26e- 4
## 4 personality_typeENTJ -2.52 1.43 -1.76 7.91e- 2
## 5 personality_typeENTP -3.79 1.30 -2.92 3.65e- 3
## 6 personality_typeESFJ -3.67 1.39 -2.63 8.74e- 3
## 7 personality_typeESFP -1.76 1.27 -1.39 1.66e- 1
## 8 personality_typeESTJ -4.51 1.26 -3.59 3.67e- 4
## 9 personality_typeESTP -3.07 1.25 -2.47 1.40e- 2
## 10 personality_typeINFJ -2.47 1.49 -1.66 9.82e- 2
## 11 personality_typeINFP -2.73 1.24 -2.20 2.85e- 2
## 12 personality_typeINTJ -2.85 1.54 -1.85 6.50e- 2
## 13 personality_typeINTP -2.52 1.31 -1.93 5.42e- 2
## 14 personality_typeISFJ -2.32 1.43 -1.62 1.05e- 1
## 15 personality_typeISFP -3.56 1.26 -2.83 4.90e- 3
## 16 personality_typeISTJ -4.82 1.26 -3.81 1.54e- 4
## 17 personality_typeISTP -3.23 1.34 -2.42 1.58e- 2
## 18 total_votes_received 0.133 0.0578 2.30 2.20e- 2
## 19 immunity_idols_won 2.81 0.223 12.6 9.92e-32
Let’s look at the coefficients:
tidy(cast_lm) %>%
dwplot() +
geom_vline(xintercept = 0)
The recipe is already set, so now, the only thing that needs to change is the engine which builds the model. Note, however, that with a random forest, it’ll be important to specify the tuning parameters (which don’t exist in a linear model).
rf_mod1 <-
rand_forest(mtry = 3,
trees = 500,
min_n = 5) %>%
set_engine("ranger", importance = "permutation") %>%
set_mode("regression")
After we’ve specified the engine, we can then fit()
the model:
cast_rf1 <-
workflow() %>%
add_model(rf_mod1) %>%
add_recipe(cast_rec) %>%
fit(data = cast_train)
cast_rf1
## ══ Workflow [trained] ══════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: rand_forest()
##
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 0 Recipe Steps
##
## ── Model ───────────────────────────────────────────────────────────────────────
## Ranger result
##
## Call:
## ranger::ranger(x = maybe_data_frame(x), y = y, mtry = min_cols(~3, x), num.trees = ~500, min.node.size = min_rows(~5, x), importance = ~"permutation", num.threads = 1, verbose = FALSE, seed = sample.int(10^5, 1))
##
## Type: Regression
## Number of trees: 500
## Sample size: 494
## Number of independent variables: 4
## Mtry: 3
## Target node size: 5
## Variable importance mode: permutation
## Splitrule: variance
## OOB prediction error (MSE): 22.14426
## R squared (OOB): 0.2499416
library(vip)
cast_rf1 %>%
pull_workflow_fit() %>%
vip()
mtry
and min_n
in the random forest.rf_mod2 <-
rand_forest(mtry = tune(),
trees = 500,
min_n = tune()) %>%
set_engine("ranger", importance = "permutation") %>%
set_mode("regression")
After we’ve specified the engine, we can then fit()
the model:
cast_rf2 <-
workflow() %>%
add_model(rf_mod2) %>%
add_recipe(cast_rec) %>%
fit(data = cast_train)
cast_rf2
## ══ Workflow [trained] ══════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: rand_forest()
##
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 0 Recipe Steps
##
## ── Model ───────────────────────────────────────────────────────────────────────
## Ranger result
##
## Call:
## ranger::ranger(x = maybe_data_frame(x), y = y, mtry = min_cols(~tune(), x), num.trees = ~500, min.node.size = min_rows(~tune(), x), importance = ~"permutation", num.threads = 1, verbose = FALSE, seed = sample.int(10^5, 1))
##
## Type: Regression
## Number of trees: 500
## Sample size: 494
## Number of independent variables: 4
## Mtry: 4
## Target node size: 494
## Variable importance mode: permutation
## Splitrule: variance
## OOB prediction error (MSE): 29.5882
## R squared (OOB): -0.002195362
Now we’ll CV the data so that we can run the random forest on different subsets of the training data.
set.seed(4747)
cast_folds <- vfold_cv(cast_train, v = 4)
cast_grid <- grid_regular(mtry(range = c(1,3)),
min_n(range = c(5,10)),
levels = 3)
cast_grid
## # A tibble: 9 x 2
## mtry min_n
## <int> <int>
## 1 1 5
## 2 2 5
## 3 3 5
## 4 1 7
## 5 2 7
## 6 3 7
## 7 1 10
## 8 2 10
## 9 3 10
Now we tune the model, this time using the grid of parameter values.
doParallel::registerDoParallel()
set.seed(470)
tune_result <- tune_grid(
cast_rf2,
resamples = cast_folds,
grid = cast_grid
)
tune_result
## # Tuning results
## # 4-fold cross-validation
## # A tibble: 4 x 4
## splits id .metrics .notes
## <list> <chr> <list> <list>
## 1 <split [370/124]> Fold1 <tibble [18 × 6]> <tibble [0 × 1]>
## 2 <split [370/124]> Fold2 <tibble [18 × 6]> <tibble [0 × 1]>
## 3 <split [371/123]> Fold3 <tibble [18 × 6]> <tibble [0 × 1]>
## 4 <split [371/123]> Fold4 <tibble [18 × 6]> <tibble [0 × 1]>
tune_result %>%
collect_metrics() %>%
filter(.metric == "rmse") %>%
mutate(min_n = factor(min_n)) %>%
ggplot(aes(x = mtry, y = mean, color = min_n)) +
geom_line(alpha = 0.5, size = 1.5) +
geom_point() +
labs(y = "RMSE")
First we look at the linear model used to predict the test data.
cast_lm_final <-
cast_lm %>%
last_fit(data_split)
cast_lm_final %>%
collect_metrics()
## # A tibble: 2 x 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 rmse standard 5.04 Preprocessor1_Model1
## 2 rsq standard 0.210 Preprocessor1_Model1
Similarly, we can look at the random forest fit. The random forest did slightly better on the test data.
cast_rf_final <-
cast_rf1 %>%
last_fit(data_split)
cast_rf_final %>%
collect_metrics()
## # A tibble: 2 x 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 rmse standard 5.00 Preprocessor1_Model1
## 2 rsq standard 0.233 Preprocessor1_Model1
praise()
## [1] "You are fabulous!"