2 “Next-Level” Uses of Decision Trees

How to reduce predictor count and segment variables

Shaw Talebi
9 min readMar 11, 2023

This is the 3rd article in a series on decision trees. Thus far, we’ve discussed how to build predictive models using decision trees and tree ensembles. While this is the core use of these techniques, we can use them for more than just making predictions. In this post, I walk through 2 additional uses of decision trees that can take your analytics work to the next level.

Image from KnowYourMeme. Mutation by author.

Next-level?

In the 1st article of this series, I defined a model as something that allows us to make predictions. For example, predicting tomorrow’s weather forecast or the probability of developing heart disease.

We saw one way to get a model was by using decision trees. While predictions are a common application of decision trees, that is not their only use.

This brings up “next-level” data science methods. By this, I mean methods that go beyond the obvious task of building a predictive model toward using the model (or models) to do something more interesting (which is what data science in the real world is all about).

While data science may seem as easy and simple as training a model with a few lines of code using sklearn, the real (and fun) work is figuring out how to use a model to solve a problem and provide value.

Here I share two such uses of decision trees. Although both examples use medical data, these techniques can be applied to virtually any context (e.g. credit risk, sales, marketing, etc.) All Python code (as always) is freely available at the GitHub repo.

Use 1: Reduce Predictor Count

In the last blog of this series, we used decision tree ensembles to do breast cancer prediction. While this was an instructive example, it may be challenging to realize the value of these models.

One issue is that the tree ensemble methods we used are black boxes. We know what goes in and out comes out, but aren’t too sure about what happens in between.

Sure, we could plot every decision tree in the ensemble, look at its splits, and compare relative model weights, but with 30 predictor variables and 100 constituent models, it might be difficult to pull out an intuitive understanding.

For these cases, we can call upon Occam’s Razor. This principle implies that simpler models are better. In the context of a tree ensemble, we can translate this to fewer predictor variables are better.

There are two main upsides to this. One, interpreting a model with 3 variables is easier than one with 30. And two, fewer predictor variables mean less possibility of failure (this comes from the engineering wisdom that the more moving parts a system has, the greater the opportunity for failure).

So, the question becomes: how do we pick which variables to keep?

One way is via our model’s feature importance ranking. The basic idea is we use the importance ranking to add variables to a model (one at a time) and compare model performances.

While this does not guarantee an optimal choice of predictor variables, it does provide us with a quick and easy heuristic for developing simpler models. Let’s see what this might look like for the breast cancer prediction problem from the previous blog.

We start with importing libraries.

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.datasets import load_breast_cancer # toy dataset (sorry)

from imblearn.over_sampling import SMOTE

from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score

Next, we import our data and define our predictor and target variables.

# import data
df = load_breast_cancer(as_frame = True)['frame']

# define predictor and target variable names
X_var_names = df.columns[:df.shape[1]-1]
y_var_name = df.columns[df.shape[1]-1]

# create predictor and target arrays
X = df[X_var_names]
y = df[y_var_name]

# transform y such that 1=>malignant and 0=>benign
y = -1*y + 1

The last line of the above code block is doing a simple transformation of the target variable to convert every 0 → 1 and 1 → 0. This will help us interpret predictors as driving breast cancer risk as opposed to “safety.”

Since we have an imbalanced dataset, i.e. we have more benign than malignant tumor cases, we can make our data balanced using SMOTE.

# oversample minority class using smote
X_resampled, y_resampled = SMOTE(random_state=0).fit_resample(X, y)

# create train and test datasets using balanced
X_train, X_test, y_train, y_test = train_test_split(X_resampled, y_resampled, test_size=0.2, random_state=0)

Now to the fun part. We start by training a Random Forest model.

clf = RandomForestClassifier(random_state=0)
clf = clf.fit(X_train, y_train)

We can evaluate its performance via the area under the curve (AUC) of the ROC plot. I won’t get into the details of AUC here. I will just say an AUC close to 1 means a good model, and 0 means a bad model.

rf_auc_val = roc_auc_score(y_test, clf.predict_proba(X_test)[:,1])
print(rf_auc_val)

# >>> 0.995266272189349

Now with our model, we can spit out a feature importance ranking.

feature_importances = pd.Series(clf.feature_importances_, index=X_var_names)
feature_importances_sorted = feature_importances.sort_values(ascending=False)
print(feature_importances_sorted)

The printout looks like this.

Feature importance printed in descending order. Image by author.

With the predictor importance, we can go down the list one feature at a time and add it to a Logistic Regression model. The result of this will be 30 Logistic Regression models with increasing predictor counts, i.e. the 1st model with only use worst perimeter, the 2nd will use worst perimeter and mean concave points, and the 3rd worst perimeter, mean concave points, and worst radius.

The code to do this looks like this.

# intialize lists
base_clf_list = []
train_auc_list = []
test_auc_list = []

i=0
for i in range(len(feature_importances_sorted)):

feature_names = list(feature_importances_sorted.index[:i+1])
base_clf = LogisticRegression(random_state=0, solver='newton-cg').fit(X_train[feature_names], y_train)

base_clf_list.append(base_clf)
train_auc_list.append(roc_auc_score(y_train, base_clf.predict_proba(X_train[feature_names])[:,1]))
test_auc_list.append(roc_auc_score(y_test, base_clf.predict_proba(X_test[feature_names])[:,1]))

Then we can plot how performance evolves with the addition of each variable for both the training and testing datasets.

Model performance plotted against the number of variables. Image by author.

We can see from the plot above that the Logistic Regression model with 5 variables outperforms the Random Forest model with 30 variables!

Beyond the performance gain, Logistic Regression is a linear mode. Therefore, we can interpret predictor impacts on the target by simply looking at each predictor variable's coefficient.

Visualizing the coefficient in a bar chart.

Coefficients for Logistic Regression model with 5 input variables. Image by author.

The way to interpret these coefficients is a unit increase in the associated predictor variable translates to an increase in breast cancer risk equal to the coefficient value. For example, an increase of worst perimeter by 1 unit means breast cancer risk will increase by about 0.31. Here, “breast cancer risk” means the log odds of the tumor being malignant.

Technical note: we re-sampled our training and testing data which leads to bias in the model’s intercept. Thus, we can’t (yet) translate the log odds of malignant into an (unbiased) probability of malignant. However, there is a simple fix for this, which is to adjust the intercept for the oversampling procedure.

Use 2: Predictor Segmentation

The example from the 1st blog of this series used a single decision tree to do sepsis survival prediction. There we saw that age drove the bulk of the decision tree splits, indicating it was a key risk factor for sepsis survival.

In cases like this, where one predictor variable has an outsized impact on the target variable, it can make sense to segment data records using the high-impact predictor.

For example, if age is the main driver of sepsis survival we can segment patients into different age buckets, then develop age-bucket-specific survival models. Two benefits of this are: 1) this is it gives lower impact predictors a chance to reveal their predictive power, and 2) different age groups may have systematic differences that require separate modeling.

While we could develop these age buckets manually (or arbitrarily), another option is to let the data tell us the optimal age groups. In other words, we can learn age bucket definitions from our data.

It turns out we can use a decision tree to do this in a straightforward way. We simply train a decision tree model only using the high-impact predictor (e.g. age) to estimate the target (e.g. survival flag).

Let’s use the sepsis survival problem as an example. We again start by importing some Python libraries.

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn import tree

Next, we import our data and group it by age. Additionally, we can use the sepsis survival flag to compute the percent_alive and percent_not_alive for each age value.

# keep only age and sepsis survival flag
df = df.iloc[:,[0,3]]

# group data by age and compute percent alive for each age value
df_byAge = (df.groupby(by=df.columns[0]).sum()/df.groupby(by=df.columns[0]).count())
df_byAge.columns = ['percent_alive']
df_byAge['percent_not_alive'] = 1 - df_byAge['percent_alive']

The resulting dataframe looks like this.

Pandas dataframe after the preparation step. Image by author.

Next, we can define separate data frames for our predictor and target variables. Here we will use age_years to predict percent_not_alive.

# define predictor and target variable names
X_var_name = df_byAge.columns[0]
y_var_name = df_byAge.columns[2]

# create predictor and target arrays
X = df_byAge[X_var_name]
y = df_byAge[y_var_name]

To get a sense of the relationship between age and sepsis risk, we can plot the two variables against each other.

Sepsis death rate by age. Image by author.

As expected, the risk of sepsis increases with age. It is interesting to note the uptick in risk around mid-life.

Next, we use a decision tree to segment this data. Again, the goal is to split the data into different age groups based on sepsis risk, which can be framed as a modeling problem.

First, we train a univariate decision tree to predict percent_not_alive.

num_bins = 4

# train model
clf = tree.DecisionTreeRegressor(random_state=42, max_leaf_nodes=num_bins)
clf = clf.fit(X.to_numpy().reshape(-1, 1), y)

Notice we constrained the growth of the decision tree using the max_leaf_nodes argument. What this does is determine the number of age buckets we derive i.e. 4 bins implies 4 age buckets.

Next, we can traverse the decision tree and extract each age split value.

# variables creation
num_nodes = clf.tree_.node_count
left_child = clf.tree_.children_left
right_child = clf.tree_.children_right
threshold = clf.tree_.threshold

# list to store the bin edges
bin_edge_list = [X[0],X[len(X)-1]]

# loop through all the nodes
for i in range(num_nodes):
# If the left and right child of a node is not the same(-1) we have an internal node
# which we will append to bin_node list
if left_child[i]!=right_child[i]:
bin_edge_list.append(np.round(threshold[i],1))

# sort the nodes in increasing order
bin_edge_list.sort()

# create dictionary to store epoch bin edges
bin_dict = {}

# put in each dictionary index 2 consecutive bin edges
for i in range(num_bins):
bin_dict[str(i+1)] = [bin_edge_list[i], bin_edge_list[i+1]]

While this code snippet is a little technical, we can simply plot the extracted bin edges from our plot from before.

Sepsis death rate by age with 4 decision tree-based age buckets. Image by author.

We can see the age buckets do a qualitatively good job of splitting age groups based on risk. These buckets can inform sepsis risk tiers, treatment strategies for different age groups, or age-bucket-specific risk modeling.

[1] Scikit-learn: Machine Learning in Python, Pedregosa et al., JMLR 12, pp. 2825–2830, 2011.

[2] Dua, D. and Graff, C. (2019). UCI Machine Learning Repository [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California, School of Information and Computer Science. (CC BY 4.0)

[3] Survival prediction of patients with sepsis from age, sex, and septic episode number alone by Chicco & Jurman

--

--