10. Tree Methods (continued)#
but first…
10.1. Hyper-parameter Search and Validation#
Tree methods have many hyper-parameters compared to the algorithms we’ve encountered thus far. Here are some from the documentation:
max_depth - how many layers may the tree have
min_samples_split - how many samples must be in a node to allow a split
min_samples_leaf - how many samples must be in a leaf
max_features - the number of features to consider in a split
max_leaf_nodes - stop growing the tree at a specified number of leaves
and others
10.1.1. Grid Search#
How do we explore this space? Suppose I want to try trees with these options:
max_depth = [4, 6, 8, 10, 12]
min_samples_split = [10, 20, 40]
How many models will I be testing?
GridSearch does just this in an automated way, testing every combination from the parameters you’d like to test.
max_depth |
min_samples_split |
Cartesian Product |
|---|---|---|
4 |
10 |
(4, 10) |
4 |
20 |
(4, 20) |
4 |
40 |
(4, 40) |
6 |
10 |
(6, 10) |
6 |
20 |
(6, 20) |
6 |
40 |
(6, 40) |
8 |
10 |
(8, 10) |
8 |
20 |
(8, 20) |
8 |
40 |
(8, 40) |
10 |
10 |
(10, 10) |
10 |
20 |
(10, 20) |
10 |
40 |
(10, 40) |
12 |
10 |
(12, 10) |
12 |
20 |
(12, 20) |
12 |
40 |
(12, 40) |
10.1.2. Cross-Validation#
Validation is used to select from a set of candidate models (e.g. different learning algorithms, variations on the same algorithm with different hyperparameters). In the simplest form of validation, we split off a portion of the training data and compare models based on their performance on this validation set. But more commonly, we use K-fold Cross-Validation:
here

Split the training data into K “folds”
Set the first fold aside as a validation set and train on the remaining data.
Validate using that first fold as a validation set.
Repeat the process (K times in total), each time using a different fold as the validation set.
Average the performance across all the training-validation iterations.
10.1.3. Grid Search + Cross-Validation#
Grid Search and Cross-Validation are used in tandem so commonly that sklearn packages them together in some very convenient functions.
(some we’ve seen before)
10.2. Example: Palmer Penguins#
Let’s try to predict the species of penguins based on measured attributes.
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn as sns
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder, LabelEncoder, StandardScaler
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.inspection import DecisionBoundaryDisplay
palmer = pd.read_csv('https://gist.githubusercontent.com/slopp/ce3b90b9168f2f921784de84fa445651/raw/4ecf3041f0ed4913e7c230758733948bc561f434/penguins.csv', index_col = 'rowid')
display(palmer.sample(10))
display(palmer.info())
display(palmer['species'].unique())
sns.pairplot(palmer, hue = 'species')
plt.show()
| species | island | bill_length_mm | bill_depth_mm | flipper_length_mm | body_mass_g | sex | year | |
|---|---|---|---|---|---|---|---|---|
| rowid | ||||||||
| 76 | Adelie | Torgersen | 42.8 | 18.5 | 195.0 | 4250.0 | male | 2008 |
| 239 | Gentoo | Biscoe | 43.4 | 14.4 | 218.0 | 4600.0 | female | 2009 |
| 276 | Gentoo | Biscoe | 49.9 | 16.1 | 213.0 | 5400.0 | male | 2009 |
| 324 | Chinstrap | Dream | 49.0 | 19.6 | 212.0 | 4300.0 | male | 2009 |
| 88 | Adelie | Dream | 36.9 | 18.6 | 189.0 | 3500.0 | female | 2008 |
| 115 | Adelie | Biscoe | 39.6 | 20.7 | 191.0 | 3900.0 | female | 2009 |
| 17 | Adelie | Torgersen | 38.7 | 19.0 | 195.0 | 3450.0 | female | 2007 |
| 98 | Adelie | Dream | 40.3 | 18.5 | 196.0 | 4350.0 | male | 2008 |
| 186 | Gentoo | Biscoe | 59.6 | 17.0 | 230.0 | 6050.0 | male | 2007 |
| 217 | Gentoo | Biscoe | 45.8 | 14.2 | 219.0 | 4700.0 | female | 2008 |
<class 'pandas.core.frame.DataFrame'>
Index: 344 entries, 1 to 344
Data columns (total 8 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 species 344 non-null object
1 island 344 non-null object
2 bill_length_mm 342 non-null float64
3 bill_depth_mm 342 non-null float64
4 flipper_length_mm 342 non-null float64
5 body_mass_g 342 non-null float64
6 sex 333 non-null object
7 year 344 non-null int64
dtypes: float64(4), int64(1), object(3)
memory usage: 24.2+ KB
None
array(['Adelie', 'Gentoo', 'Chinstrap'], dtype=object)
# Prepare the data
palmer.dropna(axis = 0, inplace=True)
palmer.reset_index(drop = True, inplace=True)
First, a quick visualization of how a decision tree slices up the feature space. For this visualization, we’ll select just two features.
Which two features would work well for predicting the species of the penguin?
features = ['bill_length_mm', 'flipper_length_mm']
target = ['species']
X = palmer[features]
y = palmer[target]
y_label = LabelEncoder().fit_transform(palmer[target]).ravel()
X_train, X_test, y_train, y_test = train_test_split(X, y_label, test_size=0.2)
fig, axs = plt.subplots(1,5, figsize = (20,4), sharey = True)
for ax, depth in zip(axs, [1, 2, 3, 5, 10]):
tree_clf = DecisionTreeClassifier(max_depth=depth)
tree_clf.fit(X_train, y_train)
y_pred = tree_clf.predict(X_test)
cmap = ListedColormap(['red', 'green', 'blue']) # Training Data
DecisionBoundaryDisplay.from_estimator(
estimator = tree_clf,
X = X,
multiclass_colors=['r', 'g', 'b'],
alpha = 0.25,
ax = ax
)
## Training
ax.scatter(X_train.iloc[:,0], X_train.iloc[:,1],
c = y_train, cmap = cmap, marker = 'x')
# Testing Data
# ax.scatter(X_test.iloc[:,0], X_test.iloc[:,1],
# c = y_test, cmap = cmap, marker = '.')
ax.set_title(f'Tree depth of {depth}')
feature_names = tree_clf.feature_names_in_
label_names = ['Adelie', 'Chinstrap', 'Gentoo']
fig, ax = plt.subplots(1,1, figsize = (10, 8))
plot_tree(tree_clf,
filled = True, fontsize = 9,
feature_names = feature_names, class_names = label_names)
plt.show()
Now let’s find the best-fit tree for these data using GridsearchCV. Our modeling pipeline:
Select features and targets
Encode categorical data (tree methods don’t require scaling)
Specify candidate values for hyper-parameters
Fit GridSearchCV
Evaluate best-fit model
palmer.head()
| species | island | bill_length_mm | bill_depth_mm | flipper_length_mm | body_mass_g | sex | year | |
|---|---|---|---|---|---|---|---|---|
| 0 | Adelie | Torgersen | 39.1 | 18.7 | 181.0 | 3750.0 | male | 2007 |
| 1 | Adelie | Torgersen | 39.5 | 17.4 | 186.0 | 3800.0 | female | 2007 |
| 2 | Adelie | Torgersen | 40.3 | 18.0 | 195.0 | 3250.0 | female | 2007 |
| 3 | Adelie | Torgersen | 36.7 | 19.3 | 193.0 | 3450.0 | female | 2007 |
| 4 | Adelie | Torgersen | 39.3 | 20.6 | 190.0 | 3650.0 | male | 2007 |
y = palmer['species']
X = palmer.drop(columns = 'species')
oh_features = ['island', 'sex']
oh = OneHotEncoder(drop = 'if_binary')
coltrans = ColumnTransformer(
transformers = [
('oh', oh, oh_features)
],
remainder = 'passthrough',
verbose_feature_names_out = False
)
X_trans = coltrans.fit_transform(X)
feature_names = coltrans.get_feature_names_out()
X = pd.DataFrame(X_trans, columns = feature_names)
X
| island_Biscoe | island_Dream | island_Torgersen | sex_male | bill_length_mm | bill_depth_mm | flipper_length_mm | body_mass_g | year | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | 0.0 | 0.0 | 1.0 | 1.0 | 39.1 | 18.7 | 181.0 | 3750.0 | 2007.0 |
| 1 | 0.0 | 0.0 | 1.0 | 0.0 | 39.5 | 17.4 | 186.0 | 3800.0 | 2007.0 |
| 2 | 0.0 | 0.0 | 1.0 | 0.0 | 40.3 | 18.0 | 195.0 | 3250.0 | 2007.0 |
| 3 | 0.0 | 0.0 | 1.0 | 0.0 | 36.7 | 19.3 | 193.0 | 3450.0 | 2007.0 |
| 4 | 0.0 | 0.0 | 1.0 | 1.0 | 39.3 | 20.6 | 190.0 | 3650.0 | 2007.0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 328 | 0.0 | 1.0 | 0.0 | 1.0 | 55.8 | 19.8 | 207.0 | 4000.0 | 2009.0 |
| 329 | 0.0 | 1.0 | 0.0 | 0.0 | 43.5 | 18.1 | 202.0 | 3400.0 | 2009.0 |
| 330 | 0.0 | 1.0 | 0.0 | 1.0 | 49.6 | 18.2 | 193.0 | 3775.0 | 2009.0 |
| 331 | 0.0 | 1.0 | 0.0 | 1.0 | 50.8 | 19.0 | 210.0 | 4100.0 | 2009.0 |
| 332 | 0.0 | 1.0 | 0.0 | 0.0 | 50.2 | 18.7 | 198.0 | 3775.0 | 2009.0 |
333 rows × 9 columns
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 42)
GridSearchCV requires three parameters:
an not-yet-fit instance of the desired model type (e.g. DecisionTreeClassifier, LogisticRegressionClassifier, etc)
a dictionary in which the keys are the hyper-parameter names and the corresponding values are a list of values you’d like to try for that hyper-parameter
the number of folds in your cross validation
tree_params = {
'max_depth':[2, 4, 6],
'min_samples_split':[5, 10, 20],
'min_impurity_decrease':[0.01, 0.03, 0.1, 0.3]
}
grid = GridSearchCV(DecisionTreeClassifier(), tree_params, cv = 5)
grid.fit(X_train, y_train)
GridSearchCV(cv=5, estimator=DecisionTreeClassifier(),
param_grid={'max_depth': [2, 4, 6],
'min_impurity_decrease': [0.01, 0.03, 0.1, 0.3],
'min_samples_split': [5, 10, 20]})In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
| estimator | DecisionTreeClassifier() | |
| param_grid | {'max_depth': [2, 4, ...], 'min_impurity_decrease': [0.01, 0.03, ...], 'min_samples_split': [5, 10, ...]} | |
| scoring | None | |
| n_jobs | None | |
| refit | True | |
| cv | 5 | |
| verbose | 0 | |
| pre_dispatch | '2*n_jobs' | |
| error_score | nan | |
| return_train_score | False |
DecisionTreeClassifier(max_depth=4, min_impurity_decrease=0.01,
min_samples_split=10)Parameters
| criterion | 'gini' | |
| splitter | 'best' | |
| max_depth | 4 | |
| min_samples_split | 10 | |
| min_samples_leaf | 1 | |
| min_weight_fraction_leaf | 0.0 | |
| max_features | None | |
| random_state | None | |
| max_leaf_nodes | None | |
| min_impurity_decrease | 0.01 | |
| class_weight | None | |
| ccp_alpha | 0.0 | |
| monotonic_cst | None |
grid.__dict__
{'scoring': None,
'estimator': DecisionTreeClassifier(),
'n_jobs': None,
'refit': True,
'cv': 5,
'verbose': 0,
'pre_dispatch': '2*n_jobs',
'error_score': nan,
'return_train_score': False,
'param_grid': {'max_depth': [2, 4, 6],
'min_samples_split': [5, 10, 20],
'min_impurity_decrease': [0.01, 0.03, 0.1, 0.3]},
'multimetric_': False,
'best_index_': np.int64(13),
'best_score_': np.float64(0.9586303284416491),
'best_params_': {'max_depth': 4,
'min_impurity_decrease': 0.01,
'min_samples_split': 10},
'best_estimator_': DecisionTreeClassifier(max_depth=4, min_impurity_decrease=0.01,
min_samples_split=10),
'refit_time_': 0.0005388259887695312,
'feature_names_in_': array(['island_Biscoe', 'island_Dream', 'island_Torgersen', 'sex_male',
'bill_length_mm', 'bill_depth_mm', 'flipper_length_mm',
'body_mass_g', 'year'], dtype=object),
'scorer_': <class 'sklearn.tree._classes.DecisionTreeClassifier'>.score,
'cv_results_': {'mean_fit_time': array([0.00069599, 0.00058508, 0.00053301, 0.00061517, 0.0006206 ,
0.00061412, 0.00062509, 0.00062008, 0.00054741, 0.00062771,
0.00056887, 0.00058684, 0.00061417, 0.00055718, 0.00056195,
0.00054965, 0.00061855, 0.00054836, 0.00061703, 0.0005826 ,
0.00061593, 0.00059791, 0.00052962, 0.00051947, 0.0005558 ,
0.00054812, 0.00057416, 0.00055761, 0.00067716, 0.0006042 ,
0.00059004, 0.00061092, 0.00058112, 0.00057569, 0.00052319,
0.00051708]),
'std_fit_time': array([1.44370464e-04, 1.05063036e-04, 5.86063712e-06, 1.36101254e-04,
1.50566345e-04, 3.88033544e-05, 1.30833064e-04, 6.95205683e-05,
1.66087249e-05, 1.19374104e-04, 1.07258899e-04, 5.42444133e-05,
1.15201627e-04, 6.02590873e-06, 1.36678532e-05, 8.33524446e-06,
1.01982816e-04, 1.55977509e-06, 1.03448610e-04, 5.04408525e-05,
7.90650938e-05, 6.83714319e-05, 1.69345009e-05, 8.20437096e-06,
5.47304376e-06, 7.22844870e-06, 2.15946698e-05, 6.11654024e-06,
1.05355376e-04, 1.09496743e-04, 3.64312327e-05, 5.87742371e-05,
6.98128684e-05, 8.92340068e-05, 5.94155172e-06, 8.10679265e-06]),
'mean_score_time': array([0.00044875, 0.00036669, 0.00036221, 0.00039735, 0.00038676,
0.00040507, 0.00038714, 0.00048995, 0.00043421, 0.00037503,
0.00039773, 0.00038528, 0.00036783, 0.00035729, 0.00036221,
0.00035491, 0.00037999, 0.00035672, 0.00042777, 0.0003655 ,
0.00041585, 0.00042176, 0.00036054, 0.0003551 , 0.0003562 ,
0.00035596, 0.00037742, 0.00035753, 0.00040898, 0.00038366,
0.00037489, 0.00037956, 0.00037117, 0.00037074, 0.00035777,
0.00035443]),
'std_score_time': array([8.12994630e-05, 1.17431739e-05, 4.89867510e-06, 6.25318075e-05,
4.10759717e-05, 4.27361562e-05, 2.29731611e-05, 1.17726450e-04,
1.40134677e-04, 2.75193296e-05, 6.04776160e-05, 2.50935979e-05,
2.28893757e-05, 3.49297458e-06, 4.46651612e-06, 1.00701867e-06,
4.80760645e-05, 8.03580262e-07, 8.46589964e-05, 1.37267828e-05,
5.49895591e-05, 6.91524036e-05, 1.28192063e-05, 5.13259331e-06,
5.63394756e-06, 3.70585618e-06, 3.78165238e-05, 4.74590732e-06,
4.98771667e-05, 5.83211807e-05, 1.36125108e-05, 2.26720901e-05,
2.30057993e-05, 2.90091832e-05, 6.52169608e-06, 3.84615800e-06]),
'param_max_depth': masked_array(data=[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6],
mask=[False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False],
fill_value=999999),
'param_min_impurity_decrease': masked_array(data=[0.01, 0.01, 0.01, 0.03, 0.03, 0.03, 0.1, 0.1, 0.1, 0.3,
0.3, 0.3, 0.01, 0.01, 0.01, 0.03, 0.03, 0.03, 0.1, 0.1,
0.1, 0.3, 0.3, 0.3, 0.01, 0.01, 0.01, 0.03, 0.03, 0.03,
0.1, 0.1, 0.1, 0.3, 0.3, 0.3],
mask=[False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False],
fill_value=1e+20),
'param_min_samples_split': masked_array(data=[5, 10, 20, 5, 10, 20, 5, 10, 20, 5, 10, 20, 5, 10, 20,
5, 10, 20, 5, 10, 20, 5, 10, 20, 5, 10, 20, 5, 10, 20,
5, 10, 20, 5, 10, 20],
mask=[False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False],
fill_value=999999),
'params': [{'max_depth': 2,
'min_impurity_decrease': 0.01,
'min_samples_split': 5},
{'max_depth': 2, 'min_impurity_decrease': 0.01, 'min_samples_split': 10},
{'max_depth': 2, 'min_impurity_decrease': 0.01, 'min_samples_split': 20},
{'max_depth': 2, 'min_impurity_decrease': 0.03, 'min_samples_split': 5},
{'max_depth': 2, 'min_impurity_decrease': 0.03, 'min_samples_split': 10},
{'max_depth': 2, 'min_impurity_decrease': 0.03, 'min_samples_split': 20},
{'max_depth': 2, 'min_impurity_decrease': 0.1, 'min_samples_split': 5},
{'max_depth': 2, 'min_impurity_decrease': 0.1, 'min_samples_split': 10},
{'max_depth': 2, 'min_impurity_decrease': 0.1, 'min_samples_split': 20},
{'max_depth': 2, 'min_impurity_decrease': 0.3, 'min_samples_split': 5},
{'max_depth': 2, 'min_impurity_decrease': 0.3, 'min_samples_split': 10},
{'max_depth': 2, 'min_impurity_decrease': 0.3, 'min_samples_split': 20},
{'max_depth': 4, 'min_impurity_decrease': 0.01, 'min_samples_split': 5},
{'max_depth': 4, 'min_impurity_decrease': 0.01, 'min_samples_split': 10},
{'max_depth': 4, 'min_impurity_decrease': 0.01, 'min_samples_split': 20},
{'max_depth': 4, 'min_impurity_decrease': 0.03, 'min_samples_split': 5},
{'max_depth': 4, 'min_impurity_decrease': 0.03, 'min_samples_split': 10},
{'max_depth': 4, 'min_impurity_decrease': 0.03, 'min_samples_split': 20},
{'max_depth': 4, 'min_impurity_decrease': 0.1, 'min_samples_split': 5},
{'max_depth': 4, 'min_impurity_decrease': 0.1, 'min_samples_split': 10},
{'max_depth': 4, 'min_impurity_decrease': 0.1, 'min_samples_split': 20},
{'max_depth': 4, 'min_impurity_decrease': 0.3, 'min_samples_split': 5},
{'max_depth': 4, 'min_impurity_decrease': 0.3, 'min_samples_split': 10},
{'max_depth': 4, 'min_impurity_decrease': 0.3, 'min_samples_split': 20},
{'max_depth': 6, 'min_impurity_decrease': 0.01, 'min_samples_split': 5},
{'max_depth': 6, 'min_impurity_decrease': 0.01, 'min_samples_split': 10},
{'max_depth': 6, 'min_impurity_decrease': 0.01, 'min_samples_split': 20},
{'max_depth': 6, 'min_impurity_decrease': 0.03, 'min_samples_split': 5},
{'max_depth': 6, 'min_impurity_decrease': 0.03, 'min_samples_split': 10},
{'max_depth': 6, 'min_impurity_decrease': 0.03, 'min_samples_split': 20},
{'max_depth': 6, 'min_impurity_decrease': 0.1, 'min_samples_split': 5},
{'max_depth': 6, 'min_impurity_decrease': 0.1, 'min_samples_split': 10},
{'max_depth': 6, 'min_impurity_decrease': 0.1, 'min_samples_split': 20},
{'max_depth': 6, 'min_impurity_decrease': 0.3, 'min_samples_split': 5},
{'max_depth': 6, 'min_impurity_decrease': 0.3, 'min_samples_split': 10},
{'max_depth': 6, 'min_impurity_decrease': 0.3, 'min_samples_split': 20}],
'split0_test_score': array([0.96296296, 0.96296296, 0.96296296, 0.96296296, 0.96296296,
0.96296296, 0.94444444, 0.94444444, 0.94444444, 0.81481481,
0.81481481, 0.81481481, 0.96296296, 0.96296296, 0.96296296,
0.96296296, 0.96296296, 0.96296296, 0.94444444, 0.94444444,
0.94444444, 0.81481481, 0.81481481, 0.81481481, 0.96296296,
0.96296296, 0.96296296, 0.96296296, 0.96296296, 0.96296296,
0.94444444, 0.94444444, 0.94444444, 0.81481481, 0.81481481,
0.81481481]),
'split1_test_score': array([0.96226415, 0.96226415, 0.96226415, 0.96226415, 0.96226415,
0.96226415, 0.94339623, 0.94339623, 0.94339623, 0.81132075,
0.81132075, 0.81132075, 0.98113208, 0.98113208, 0.98113208,
0.96226415, 0.96226415, 0.96226415, 0.94339623, 0.94339623,
0.94339623, 0.81132075, 0.81132075, 0.81132075, 0.98113208,
0.98113208, 0.98113208, 0.96226415, 0.96226415, 0.96226415,
0.94339623, 0.94339623, 0.94339623, 0.81132075, 0.81132075,
0.81132075]),
'split2_test_score': array([0.94339623, 0.94339623, 0.94339623, 0.94339623, 0.94339623,
0.94339623, 0.94339623, 0.94339623, 0.94339623, 0.77358491,
0.77358491, 0.77358491, 0.94339623, 0.94339623, 0.94339623,
0.94339623, 0.94339623, 0.94339623, 0.94339623, 0.94339623,
0.94339623, 0.77358491, 0.77358491, 0.77358491, 0.94339623,
0.94339623, 0.94339623, 0.94339623, 0.94339623, 0.94339623,
0.94339623, 0.94339623, 0.94339623, 0.77358491, 0.77358491,
0.77358491]),
'split3_test_score': array([0.90566038, 0.90566038, 0.90566038, 0.90566038, 0.90566038,
0.90566038, 0.88679245, 0.88679245, 0.88679245, 0.79245283,
0.79245283, 0.79245283, 0.90566038, 0.90566038, 0.90566038,
0.90566038, 0.90566038, 0.90566038, 0.88679245, 0.88679245,
0.88679245, 0.79245283, 0.79245283, 0.79245283, 0.90566038,
0.90566038, 0.90566038, 0.90566038, 0.90566038, 0.90566038,
0.88679245, 0.88679245, 0.88679245, 0.79245283, 0.79245283,
0.79245283]),
'split4_test_score': array([1. , 1. , 1. , 1. , 1. ,
1. , 0.98113208, 0.98113208, 0.98113208, 0.81132075,
0.81132075, 0.81132075, 0.98113208, 1. , 1. ,
1. , 1. , 1. , 0.98113208, 0.98113208,
0.98113208, 0.81132075, 0.81132075, 0.81132075, 0.98113208,
1. , 1. , 1. , 1. , 1. ,
0.98113208, 0.98113208, 0.98113208, 0.81132075, 0.81132075,
0.81132075]),
'mean_test_score': array([0.95485674, 0.95485674, 0.95485674, 0.95485674, 0.95485674,
0.95485674, 0.93983229, 0.93983229, 0.93983229, 0.80069881,
0.80069881, 0.80069881, 0.95485674, 0.95863033, 0.95863033,
0.95485674, 0.95485674, 0.95485674, 0.93983229, 0.93983229,
0.93983229, 0.80069881, 0.80069881, 0.80069881, 0.95485674,
0.95863033, 0.95863033, 0.95485674, 0.95485674, 0.95485674,
0.93983229, 0.93983229, 0.93983229, 0.80069881, 0.80069881,
0.80069881]),
'std_test_score': array([0.03069241, 0.03069241, 0.03069241, 0.03069241, 0.03069241,
0.03069241, 0.03021778, 0.03021778, 0.03021778, 0.0156721 ,
0.0156721 , 0.0156721 , 0.02827763, 0.03247905, 0.03247905,
0.03069241, 0.03069241, 0.03069241, 0.03021778, 0.03021778,
0.03021778, 0.0156721 , 0.0156721 , 0.0156721 , 0.02827763,
0.03247905, 0.03247905, 0.03069241, 0.03069241, 0.03069241,
0.03021778, 0.03021778, 0.03021778, 0.0156721 , 0.0156721 ,
0.0156721 ]),
'rank_test_score': array([ 5, 5, 5, 5, 5, 5, 19, 19, 19, 28, 28, 28, 5, 1, 1, 5, 5,
5, 19, 19, 19, 28, 28, 28, 5, 1, 1, 5, 5, 5, 19, 19, 19, 28,
28, 28], dtype=int32)},
'n_splits_': 5}
The grid object stores the info for all fit models, and we may choose to compare them at some point, but for now, we just want to get the best of the models it tested.
We can extract the best model using .best_estimator_.
Let’s evaluate and visualize our model.
tree2 = grid.best_estimator_
y_pred = tree2.predict(X_test)
cfm = confusion_matrix(y_test, y_pred)
ConfusionMatrixDisplay(cfm, display_labels = label_names).plot()
plt.show()
Interpret the confusion matrix above (for the test data). Which species get mistaken for other species?
Next, let’s investigate which features were most important. Generally, we think the features used higher up on the tree are the most informative features, since decision trees are trained by a greedy algorithm. But more specifically, the most informative features are the ones that are responsible for the greatest decrease in impurity.
We can get the importance of each feature using .feature_importances_.
Below we plot the decision tree and then print a list of features ordered by importance.
Compare that list to the structure of the tree. Does it make sense? What does it mean if a feature has 0 importance?
fig, ax = plt.subplots(1,1, figsize = (15, 6))
plot_tree(tree2,
filled = True, fontsize = 14,
feature_names = feature_names, class_names = label_names,
)
plt.show()
# Get feature importances and sort in decending order
idx = np.argsort(tree2.feature_importances_)[::-1]
# print importances
importances = tree2.feature_importances_[idx]
features = feature_names[idx]
for feature, imp in zip(features, importances):
print(f'{feature:_<30}{imp:.3f}')
flipper_length_mm_____________0.575
bill_length_mm________________0.332
island_Biscoe_________________0.056
island_Dream__________________0.036
year__________________________0.000
body_mass_g___________________0.000
bill_depth_mm_________________0.000
sex_male______________________0.000
island_Torgersen______________0.000
10.3. Decision/Classification Tree Recap#
A decision tree classifies samples using a heirarchy of cascading conditional statements (yes-no questions about the sample).
In fitting a decision tree, at each node, a conditional statement is chosen so as to maximize the purity of the resultant split sub-samples.
Fitting a decision tree is deterministic. That is, given the same training data and hyper-parameters, you will arrive at the same tree.
There are challenges that come with real data that often cause problems for a decision tree:
High-dimensional data (many features)
Noisy data
Imbalanced class representation
We can improve on an individual decision tree by fitting numerous decision trees—an ensemble—and averaging/polling their results. There are two main categories of ensemble methods for decision trees, Random Forests and Boosted Trees, each with their own benefits and weaknesses.