Predicting pileups: Using ML to predict Chicago crash types

A comparative model analysis of the factors that lead to severe traffic collisions in Chicago, IL.

Published

December 18, 2021

Wouldn’t it be great to know the chances of being injured or having your car totaled in a car accident before ever being involved in the accident? If you live in a less densely populated area, this question might not be very interesting. But if you live in a big city, like Chicago, you might be a little more concerned about this. By identifying factors that predict these bad accidents, we might be able to develop low cost interventions or redesign the environment to reduce the frequency of these types of accidents, which can translate to lives and money saved. So in this post, we’ll use machine learning to predict which traffic crashes in Chicago, IL, result in injuries and/or the vehicle being towed based on situational features (e.g., posted speed limit, lighting conditions, road surface condition, etc.) that are likely known before the crash occurred.

Traffic crash data were imported from the Chicago Data Portal using the RSocrata package. The City of Chicago also has data sets related to the vehicles and people involved in these crashes, but to keep things simple for now, let’s focus on a few situational factors related to the crash, all of which can be found in the main data set.

These data show information about each traffic crash on city streets within the City of Chicago and under the jurisdiction of the Chicago Police Department. Many of the variables (e.g. street conditions, weather conditions, etc.) are recorded by the reporting officer and are based on the best available information at the time, but according to the Chicago Data Portal, these data may disagree with other posted information. As such, data are subject to change based on new information.

Import data

Code
set.seed(20211218)
library(RSocrata)
library(tidyverse)
library(lubridate)
library(janitor)
library(tidymodels)
library(vip)
library(skimr, include.only = "skim")

# data pulled at time of post; new cases likely added to data portal since then
crashes <- read.socrata(
  "https://data.cityofchicago.org/resource/85ca-t3if.csv", # url of data set
  app_token = Sys.getenv("rsocrata_token")                 # my personal creds
) %>%
  clean_names() %>%
  select(
    crash_type, crash_date, posted_speed_limit, traffic_control_device,
    device_condition, weather_condition, lighting_condition, first_crash_type,
    trafficway_type, alignment, roadway_surface_cond, road_defect, prim_contributory_cause
  ) %>%
  mutate(
    crash_date = ymd_hms(crash_date),
    date = as_date(crash_date)
  ) %>%
  select(!crash_date)

The goal here is to build a model that predicts whether a crash will result in someone being injured and/or the vehicle being towed using the crash_type variable in the crashes data set.

From the variables listed in the Traffic Crash data set, let’s select some that are likely to have been known before the crash occurred. For example, lighting_condition is a variable that records the light condition at the time of the crash. This situational factor is likely to be known (e.g., observed by the driver, reported by others prior to the crash, etc.) before an accident occurs. At the very least, this information has a higher chance of being known prior to a crash compared to something like injuries_total. injuries_total is a variable that records the total number of people who sustained an injury as a result of the accident. Since this is a consequence of a crash, this information can only known after the crash occurs.

With this logic, let’s focus on the following variables from the initial Traffic Crash data set:

Column name Description Type
crash_date Date and time of crash as entered by the reporting officer Date/Time
posted_speed_limit Posted speed limit, as determined by reporting officer Numeric
traffic_control_device Traffic control device present at crash location, as determined by reporting officer Factor
device_condition Condition of traffic control device, as determined by reporting officer Factor
weather_condition Weather condition at time of crash, as determined by reporting officer Factor
lighting_condition Light condition at time of crash, as determined by reporting officer Factor
first_crash_type Type of first collision in crash Factor
trafficway_type Trafficway type, as determined by reporting officer Factor
alignment Street alignment at crash location, as determined by reporting officer Factor
roadway_surface_cond Road surface condition, as determined by reporting officer Factor
road_defect Road defects, as determined by reporting officer Factor
prim_contributory_cause The factor which was most significant in causing the crash, as determined by officer judgment Factor

Data exploration & cleaning

Now that we have some idea of what we’ll be looking at in our model, let’s get some impressions of the data and see if anything needs to be cleaned up before going further.

Code
glimpse(crashes)

The first thing to note is that there are 571426 observations with 13 columns in the crashes data set. And other than some potential capitalization issues with the strings, there doesn’t seem to be any obvious issues with the way the data are formatted. We’ll come back to this, but for now, let’s check for missing data.

Code
colSums(is.na(crashes))

Great, no missing data! This will simplify data preparation later on. Next, we should examine the frequency counts of the variables we’ll use to predict the outcome class.

Code
var_freq <- function(df) {
  map(names(df), ~ count(df, .data[[.x]]))[1:12]
}

var_freq(crashes)

Looks like there are some issues that need to be addressed here! First, the responses that are coded for some variables don’t make sense. For example, posted_speed_limit has 6967 recorded observations for a posted speed limit of 0 mph. Clearly, this is not a legitimate posted speed limit within the City of Chicago (as much as it may feel like it on the Dan Ryan at 5pm). There are some other odd speed limits recorded for this variable as well.

Another issue these frequency counts reveal is that many of the levels within a variable could be grouped together. For example, prim_contributory_cause makes a distinction between disregarding road markings, stop signs, traffic signals, yield signs, and other traffic signs. Instead, these levels could be grouped into a single level called “disregarding signs/markings”.

So, let’s address each of these problems by cleaning up the levels for each variable. And while we’re at it, let’s clean up the format of all of the strings so there is a consistent style (i.e., lower case).

Code
crashes <- crashes %>%
  mutate(
    across(where(is.character), ~ str_to_lower(.)),
    traffic_control_device = case_when(
      traffic_control_device == "railroad crossing gate" |
       traffic_control_device == "other railroad crossing" |
        traffic_control_device == "rr crossing sign" ~ "rr crossing",
      traffic_control_device == "bicycle crossing sign" |
        traffic_control_device == "pedestrian crossing sign" |
        traffic_control_device == "school zone" ~ "pedestrian signs",
      traffic_control_device == "flashing control signal" ~ "stop sign/flasher",
      traffic_control_device == "no passing" ~ "other warning sign",
      TRUE ~ traffic_control_device
    ),
    device_condition = case_when(
      device_condition == "missing" ~ "no controls",
      device_condition == "worn reflective material" |
        device_condition == "not functioning" ~ "functioning improperly",
      TRUE ~ device_condition
    ),
    weather_condition = case_when(
      weather_condition == "blowing sand, soil, dirt" |
        weather_condition == "severe cross wind gate" |
        weather_condition == "blowing snow" ~ "blowing debris",
      weather_condition == "freezing rain/drizzle" |
        weather_condition == "sleet/hail" ~ "sleet/hail/freezing rain",
      TRUE ~ weather_condition
    ),
    prim_contributory_cause = case_when(
      prim_contributory_cause == "disregarding other traffic signs" |
        prim_contributory_cause == "disregarding road markings" |
        prim_contributory_cause == "disregarding stop sign" |
        prim_contributory_cause == "disregarding traffic signals" |
        prim_contributory_cause == "disregarding yield sign" |
        prim_contributory_cause == "passing stopped school bus" ~ "disregarding signs/markings",
      prim_contributory_cause == "distraction - from inside vehicle" |
        prim_contributory_cause == "distraction - from outside vehicle" |
        prim_contributory_cause == "distraction - other electronic device (navigation device, dvd player, etc.)" |
        prim_contributory_cause == "cell phone use other than texting" |
        prim_contributory_cause == "texting" ~ "distraction",
      prim_contributory_cause == "had been drinking (use when arrest is not made)" |
        prim_contributory_cause == "under the influence of alcohol/drugs (use when arrest is effected)" ~ "under the influence",
      prim_contributory_cause == "obstructed crosswalks" |
        prim_contributory_cause == "vision obscured (signs, tree limbs, buildings, etc.)" ~ "obstructions",
      prim_contributory_cause == "bicycle advancing legally on red light" |
        prim_contributory_cause == "motorcycle advancing legally on red light" ~ "bike/motorcycle advancing legally on red light",
      prim_contributory_cause == "animal" |
        prim_contributory_cause == "evasive action due to animal, object, nonmotorist" ~ "evasive action",
      prim_contributory_cause == "exceeding authorized speed limit" |
        prim_contributory_cause == "exceeding safe speed for conditions" ~ "speeding",
      TRUE ~ prim_contributory_cause
    ),
    across(where(is_character), ~ as_factor(.))
  )

Great! Now, the data are clean, and we can start thinking about how to set up the model. Let’s take a look at the current data summary.

Code
skim(crashes)

And, let’s take a closer look at the dependent variable, crash_type.

Code
crashes %>%
  count(crash_type) %>%
  mutate(prop = n / sum(n))

It looks like there might be a bit of imbalance in the data since the class proportions are skewed towards non-injury/drive-away (74.3%) crash types. Imbalance can sometimes lead to problems in an analysis, especially in severe cases of imbalance. Fortunately, there are a few approaches that try to mitigate this issue (e.g., themis). For now, let’s just analyze the data as is.

At this point, it would be helpful to know which variables in the crashes data set are associated with the different levels of the dependent variable, crash_type. One quick way of doing this for the numeric predictors is by using box plots.

Code
crashes %>%
  ggplot(aes(x = crash_type, y = posted_speed_limit, fill = crash_type)) +
  geom_boxplot()

It looks like any differences between the crash_type levels are quite small for posted_speed_limit. So, maybe this variable won’t be so helpful in predicting injuries/towed crash types after all.

Next, let’s check the the relationship between the categorical variables and crash_type using simple counts. We’ll also filter for counts that are at least 1% of the total proportion of observations to get a better idea of the larger data patterns.

Code
print_counts <- function(.y_var) {
  y_var <- sym(.y_var)

  crashes %>%
    count(crash_type, {{y_var}}) %>%
    group_by(crash_type) %>%
    mutate(percent = round_half_up(n / sum(n) * 100, 2))
}

y_var <- crashes %>%
  select(where(is.factor), -crash_type) %>%
  variable.names()

map(y_var, print_counts) %>%
  map(., ~ filter(., percent > 1))

It looks like any differences between the two crash_type classes are small for alignment and weather_condition as well. However, because of the small differences across multiple levels of weather_condition, it’s tough to see if really there is a relationship there or not. Another way we can look for differences between two categorical variables is by plotting a heatmap of the frequency counts.

Code
crashes %>%
  ggplot(aes(crash_type, weather_condition)) +
  geom_bin2d()

Although there are some differences, these appear to be pretty small. So, perhaps weather isn’t important for this model either.

Data preparation

Next, we’ll do a bit of preprocessing before training the models. This is where we’ll handle feature selection, data splitting, feature engineering, feature scaling, and creating the validation set (i.e., resampling).

The first thing we’ll do here is drop the variables that did not seem to have much of a relationship with crash_type during data exploration.

Code
crashes <- select(crashes, -c(posted_speed_limit, weather_condition, alignment))

Next, let’s split the single data set into two: a training set and a testing set. A training data set is a data set of examples used during the learning process and is used to fit the models. A test data set is a data set that is independent of the training data set and is used to evaluate the performance of the final model. If a model fit to the training data set also fits the test data set well, we can be confident minimal overfitting has taken place. On the other hand, if the model seems to fit the training set better than the test set, we might have a case of overfitting.

For a data splitting strategy, let’s set aside 25% of the data for the test set. Since the outcome variable (crash_type) is somewhat imbalanced, we’ll also use a stratified random sample.

Code
crash_split <- initial_split(crashes, strata = crash_type)
crash_train <- training(crash_split)
crash_test <- testing(crash_split)

Next, let’s create a base recipe for all models. Note the sequence of steps does matter here: + receipe(): + Any variable on the left-hand side of the tilde (~) is considered the model outcome (here, crash_type). The predictors of the model outcome appear on the right-hand side of the tilde. Here, we use the dot (.) to indicate all the other variables will be used as predictors. + A recipe is also associated with the data set used to create the model. This will usually be the training set, so crash_train here. + step_date(): Creates predictors for the year, month, and day of the week. Here, we’re selecting only the day of the week and month since there are limited observations for earlier years (e.g., 2013, 2014) in the data. + step_rm(): Removes variables; here we use it to remove the original date variable since we no longer want it in the model. + step_normalize(): Centers and scales numeric variables. + step_dummy(): Converts characters or factors (i.e., nominal variables) into one or more numeric binary model terms for the levels of the original data. + step_zv(): Removes indicator variables that only contain a single unique value (e.g. all zeros).

Code
crash_recipe <- recipe(crash_type ~ ., data = crash_train) %>%
  step_date(date, features = c("dow", "month")) %>%
  step_rm(date) %>%
  step_normalize(all_numeric_predictors(), -all_outcomes()) %>%
  step_dummy(all_nominal_predictors(), -all_outcomes()) %>%
  step_zv(all_predictors(), -all_outcomes())

Recall that we already partitioned our data set into a training set and test set. This lets us judge whether a given model will generalize well to new data. However, using only two partitions may be insufficient when doing many rounds of hyperparameter tuning. So, it’s usually a good idea to create a validation set as well. We’ll use k-fold cross validation to build a set of 5 validation folds with the function vfold_cv, and we’ll also use stratified sampling to maintain the outcome class proportions.

k-fold cross validation randomly allocates the 571184 observations in the training set to 5 groups of roughly equal size, called “folds”. For the first iteration of resampling, the first fold is held out for the purpose of measuring performance. The other 80% of the data are used to fit the model. This model, trained on the analysis set, is applied to the assessment set to generate predictions. Then, performance statistics are computed based on those predictions.

In this case, 5-fold cross validation iteratively moves through the folds and leaves a different 20% out each time for model assessment. At the end of this process, there are 5 sets of performance statistics that were created on 5 data sets that were not used in the modeling process. While 5 models were created, these are not used further; we do not keep the models themselves trained on these folds because their only purpose is calculating performance metrics. The final resampling estimates for the model are the averages of the performance statistics replicates.

Code
crashes_vfold <- vfold_cv(crash_train, v = 5, strata = crash_type)

We will come back to the validation set after we specified the models.

Model 1: Logistic regression

All available models are listed at https://www.tidymodels.org/find/parsnip/. Since the outcome variable (crash_type) is categorical, a logistic regression model is a good place to start. Let’s use a model that can perform feature selection during training. The glmnet R package fits a generalized linear model via penalized maximum likelihood. This method of estimating the logistic regression slope parameters uses a penalty on the process so that less relevant predictors are driven towards a value of zero. One of the glmnet penalization methods, called the lasso method, can set the predictor slopes to zero if a large enough penalty is used.

To specify a penalized logistic regression model that uses a feature selection penalty, let’s use the parsnip package with the glmnet engine.

Code
lr_mod <- logistic_reg(penalty = tune(), mixture = 1) %>%
  set_engine("glmnet") %>%
  set_mode("classification")

We’ll set the penalty argument to tune() as a placeholder for now. This is a model hyperparameter that we will tune to find the best value for making predictions with our data. Setting mixture to a value of 1 means the glmnet model will potentially remove irrelevant predictors and choose a simpler model (i.e., via least absolute shrinkage and selection operator).

Create the workflow

Now, let’s bundle the model and recipe into a single workflow() object to make management of the R objects easier.

Code
lr_workflow <- workflow() %>%
  add_model(lr_mod) %>%
  add_recipe(crash_recipe)

Train and tune the model

Before we fit this model, we need to set up a grid of penalty values to tune. Since there is only one hyperparameter to tune here, we can set the grid up manually using a one-column tibble with 30 candidate values.

Code
lr_reg_grid <- tibble(penalty = 10^seq(-4, -1, length.out = 30))

Now we can use the validation set (crashes_vfold) to estimate the performance of our models by fitting the models on each of the folds and storing the results.

Let’s use tune_grid() to train these penalized logistic regression models. This will fit our model to each resample and evaluate on the heldout set from each resample. We’ll also save the validation set predictions (using control_grid()) so that diagnostic information can be available after the model fit. The area under the ROC curve, precision, recall, and F1-Score metrics will be used to quantify how well the model performs across a continuum of event thresholds.

Code
lr_res <- lr_workflow %>%
  tune_grid(
    crashes_vfold,
    grid = lr_reg_grid,
    control = control_grid(save_pred = TRUE),
    metrics = metric_set(roc_auc, precision, recall, f_meas)
  )

Evaluate the model

Let’s take a look at the performance for every single fold.

Code
lr_res %>%
  collect_metrics()

This isn’t very helpful on it’s own. Let’s visualize the validation set metrics by plotting the area under the ROC curve against the range of penalty values.

Code
lr_res %>%
  collect_metrics() %>%
  filter(`.metric` == "roc_auc") %>%
  ggplot(aes(x = penalty, y = mean)) +
  geom_point() +
  geom_line() +
  ylab("Area under the ROC Curve") +
  scale_x_log10(labels = scales::label_number())

This plot suggests model performance is generally better at the smaller penalty values, meaning the majority of the predictors are important to the model. There’s also a steep drop in the area under the ROC curve towards the highest penalty values. This happens because a large enough penalty will remove all predictors from the model. And when there are no predictors in the model, predictive accuracy takes a nose dive.

Our model performance seems to plateau at the smaller penalty values, so judging performance by the roc_auc metric alone could lead to multiple options for the “best” value for this hyperparameter.

Code
lr_res %>%
  show_best("roc_auc", n = 15) %>%
  arrange(penalty)

However, we may want to choose a penalty value further along the x-axis, closer to where we start to see the decline in model performance. For example, candidate model 12 with a penalty value of 0.00137 has basically the same performance as the numerically best model (model 1). However, model 12 might eliminate more predictors than model 1, and generally speaking, fewer irrelevant predictors is better. So if model performance is about the same, we should choose a model with a higher penalty value.

But keep in mind, we also collected other performance metrics. So, let’s take a look at those:

Code
perf_metrics <- c("roc_auc", "precision", "recall", "f_meas")

get_metrics <- function(x) {
  lr_res %>%
    show_best(x, n = 15) %>%
    arrange(penalty)
}

map(perf_metrics, get_metrics)

Let’s select model 15 in this case:

Code
lr_best <- lr_res %>%
  select_best(metric = "f_meas")

Now we can use the predictions to create a confusion matrix with conf_mat().

Code
lr_res %>%
  collect_predictions(parameters = lr_best) %>%
  conf_mat(crash_type, .pred_class)

The confusion matrix can also be visualized in different formats using autoplot(). I personally like the heatmap type, but there are others that can be used as well.

Code
lr_res %>%
  collect_predictions(parameters = lr_best)  %>%
  conf_mat(crash_type, .pred_class) %>%
  autoplot(type = "heatmap")

Let’s visualize the validation set ROC curve:

Code
lr_auc <- lr_res %>%
  collect_predictions(parameters = lr_best) %>%
  roc_curve(crash_type, `.pred_injury and / or tow due to crash`) %>%
  mutate(model = "Logistic Regression")

autoplot(lr_auc)

We can also make a ROC cure for the 5 folds. Since the category we are predicting is the injury/tow level in the crash_type factor, we provide roc_curve() with the relevant class probability .pred_injury and / or tow due to crash:

Code
lr_res %>%
  collect_predictions(parameters = lr_best) %>%
  group_by(id) %>%
  roc_curve(crash_type, `.pred_injury and / or tow due to crash`) %>%
  autoplot()

Finally, we can also look at the predicted probability distributions for our two classes:

Code
lr_res %>%
  collect_predictions(parameters = lr_best) %>%
  ggplot() +
  geom_density(
    aes(x = `.pred_injury and / or tow due to crash`,
        fill = crash_type),
    alpha = 0.5
  )

The level of performance generated by this logistic regression model isn’t great, but it’s better than an educated guess. Based on the frequency of crashes that result in injuries or vehicles being towed in the entire data set, we would expect about 24.6% of crashes to have these outcomes. However, based on the features we’ve selected here, our model correctly predicted these crash types about 33% of the time. So, we’ve improved our predictions, but only by about 8%. Perhaps the linear nature of the prediction equation could be limiting our model’s performance. As a next step, we might consider using a non-linear model, like a tree-based ensemble method.

Model 2: Random forest

An effective, low-maintenance, non-linear modeling approach is a random forest, which tends to be more flexible than logistic regression. A random forest is an ensemble model that often consists of thousands of decision trees. Each individual tree sees a slightly different version of the training set and learns a sequence of splitting rules to predict new data. Random forests require very little preprocessing and can handle many types of predictors (e.g., skewed, continuous, categorical, etc.). Although the default hyperparameters for random forests tend to give reasonable results, we’ll tune two hyperparameters that could improve performance. This should also help since we’ll be limiting the number of trees used to 20 to speed up the time it takes to fit the model.

Code
rf_mod <- rand_forest(mtry = tune(), min_n = tune(), trees = 20) %>%
  set_engine("ranger", importance = "impurity") %>%
  set_mode("classification")

For the hyperparameters in this model, we use tune() as a placeholder for the mtry and min_n argument values. The mtry hyperparameter sets the number of predictor variables that each node in the decision tree sees and learns about. The min_n hyperparameter sets the minimum n to split at any node. We also added importance = "impurity" when setting the engine. This will provide variable importance scores for this model, which gives some insight into which predictors drive model performance.

Create the workflow

Next, let’s bundle the recipe and model.

Code
rf_workflow <- workflow() %>%
  add_model(rf_mod) %>%
  add_recipe(crash_recipe)

Train and tune the model

Since we have more than one hyperparameter to tune in this model, let’s use a space-filling design with 25 candidate models.

Code
rf_res <- rf_workflow %>%
  tune_grid(
    crashes_vfold,
    grid = 25,
    control = control_grid(save_pred = TRUE),
    metrics = metric_set(roc_auc, precision, recall, f_meas)
  )

Evaluate the model

Out of the 25 candidates, here are the top 5 random forest models based on their F1-Scores:

Code
rf_res %>%
  show_best(metric = "f_meas")

Let’s select the best model according to the F1-Score. Our final tuning parameter values are:

Code
rf_best <- rf_res %>%
  select_best(metric = "f_meas")

rf_best

To calculate the data needed to plot the ROC curve, we use collect_predictions(). This is only possible after tuning with control_grid(save_pred = TRUE). Now, we can use the predictions to create a confusion matrix with conf_mat().

Code
rf_res %>%
  collect_predictions() %>%
  conf_mat(crash_type, .pred_class)

To filter the predictions for only our best random forest model, we can use the parameters argument and pass it our tibble with the best hyperparameter values from tuning, which we called rf_best.

Code
rf_auc <- rf_res %>%
  collect_predictions(parameters = rf_best) %>%
  roc_curve(crash_type, `.pred_injury and / or tow due to crash`) %>%
  mutate(model = "Random Forest")

autoplot(rf_auc)

Compare models

Now, it’s time to compare the models. The first thing we’ll do is extract the performance metrics from each of the models and combine them into a single data frame.

Code
lr_metrics <- lr_res %>%
  collect_metrics() %>%
  mutate(model = "Logistic Regression")

rf_metrics <- rf_res %>%
  collect_metrics() %>%
  mutate(model = "Random Forest")

compare_mod <- bind_rows(lr_metrics, rf_metrics)

Fist, let’s take a look at the average F1-Score for each model:

Code
compare_mod %>%
  filter(.metric == "f_meas") %>%
  group_by(model) %>%
  summarize(avg_f_meas = mean(mean)) %>%
  mutate(model = fct_reorder(model, avg_f_meas)) %>%
  ggplot(aes(model, avg_f_meas, fill = model)) +
  geom_col() +
  coord_flip() +
  scale_fill_brewer(palette = "Blues") +
  geom_text(
    size = 5,
    aes(label = round_half_up(avg_f_meas, 2), y = avg_f_meas - .8)
  )

Not much of a difference here. So, we may also want to check out the average ROC curve for each model:

Code
compare_mod %>%
  filter(.metric == "roc_auc") %>%
  group_by(model) %>%
  summarize(avg_roc = mean(mean)) %>%
  mutate(model = fct_reorder(model, avg_roc)) %>%
  ggplot(aes(model, avg_roc, fill = model)) +
  geom_col() +
  coord_flip() +
  scale_fill_brewer(palette = "Blues") +
  geom_text(
    size = 5,
    aes(label = round_half_up(avg_roc, 2), y = avg_roc - .7)
  )

Looks like our random forest model did a bit better here, but still pretty close. Let’s plot the validation set ROC curves for the top penalized logistic regression model and random forest model:

Code
bind_rows(rf_auc, lr_auc) %>%
  ggplot(aes(x = 1 - specificity, y = sensitivity, col = model)) +
  geom_path(lwd = 1.5, alpha = 0.8) +
  geom_abline(lty = 3) +
  coord_equal() +
  scale_color_viridis_d(option = "plasma", end = .6)

Overall, the model results are pretty similar, but the random forest model did seem to perform better than the logistic regression model. In this case, I highlighted the ROC AUC and F1-Score performance metrics, but the “best” performance metric will always depend on the question you are trying to answer with your model. For example, in some cases, you might be much more concerned about false negatives than you are false positives (e.g., when predicting severe storms). In other situations, you might only be concerned about each these to the extent they influence a model’s precision (e.g., when predicting profitable stocks).

To keep things simple, let’s stick with the ROC AUC metric in this case. AUC stands for area under the curve. What curve, you may ask? The ROC curve, specifically. The ROC curve plots the tradeoff between the true positive rate (sensitivity) and and false positive rate (1 - specificity). Ideally, we want to maximize the true positive rate and minimize the false positive rate.

Let’s find the maximum mean ROC AUC:

Code
compare_mod %>%
  filter(.metric == "roc_auc") %>%
  group_by(model) %>%
  summarize(avg_roc_auc = mean(mean)) %>%
  slice_max(avg_roc_auc)

Now, it’s time to fit the best model one last time to the full training set. Then, we can evaluate the resulting final model on the test set.

Last fit

Recall that our goal was to predict whether a traffic crash would result in an injury or a vehicle being towed based on a priori situational factors. Given the results, we determined the random forest model performed better than the penalized logistic regression model. We also know learned the best model hyperparameters from the rf_best object we created earlier. Now, we just need to fit the final model on all the rows of data not originally held out for testing (i.e., the training and validation sets) and evaluate the model performance one more time with the test set.

The tune package contains the function last_fit(), which fits a model to the whole training data and evaluates it on the test set. We just need to provide the workflow object of the best model and data split object (not the training data).

Code
last_rf_mod <- rand_forest(
    mtry = rf_best$mtry,
    min_n = rf_best$min_n,
    trees = 20
  ) %>%
  set_engine("ranger", importance = "impurity") %>%
  set_mode("classification")

last_rf_workflow <- rf_workflow %>%
  update_model(last_rf_mod)

last_rf_fit <- last_rf_workflow %>%
  last_fit(crash_split)

And these are the final performance metrics:

Code
last_rf_fit %>%
  collect_metrics()

Remember, if a model fit to the training data set also fits the test data set well, we can be reasonably confident that minimal overfitting has taken place.

To learn more about the model, we can look at the variable importance scores in the .workflow column. We pluck the first element from the column, and pull out the fit from the workflow object. Then, we can use the vip package to visualize the variable importance scores for the top features.

Code
last_rf_fit %>%
  pluck(".workflow", 1) %>%
  extract_fit_parsnip() %>%
  vip(num_features = 10)

By far, the most important factor in whether a crash results in injuries or the vehicle being towed is if the first collision in the crash involved a pedestrian or not.

Let’s take a quick look at the confusion matrix:

Code
last_rf_fit %>%
  collect_predictions() %>%
  conf_mat(crash_type, .pred_class) %>%
  autoplot(type = "heatmap")

And, let’s create the final ROC curve:

Code
last_rf_fit %>%
  collect_predictions() %>%
  roc_curve(crash_type, `.pred_injury and / or tow due to crash`) %>%
  autoplot()

The results from the validation set and test set performance statistics are very close, so we can be reasonably confident the random forest model with the selected features and hyperparameters would perform well when predicting new data.

Special thanks to Drew Triplett for his helpful comments on an earlier draft of this post!

Citation

BibTeX citation:
@online{2021,
  author = {},
  title = {Predicting Pileups: {Using} {ML} to Predict {Chicago} Crash
    Types},
  date = {2021-12-18},
  url = {https://www.jrwinget.com/blog/2021-12-18-predicting-pileups/},
  langid = {en}
}
For attribution, please cite this work as:
“Predicting Pileups: Using ML to Predict Chicago Crash Types.” 2021. December 18, 2021. https://www.jrwinget.com/blog/2021-12-18-predicting-pileups/.