Skip to content

Commit 9e49083

Browse files
Add cross-validation diagram to GridSearchCV notebook (#847)
1 parent 73e12fb commit 9e49083

File tree

1 file changed

+60
-45
lines changed

1 file changed

+60
-45
lines changed

python_scripts/parameter_tuning_grid_search.py

Lines changed: 60 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -116,77 +116,92 @@
116116
# %% [markdown]
117117
# ## Tuning using a grid-search
118118
#
119-
# In the previous exercise we used one `for` loop for each hyperparameter to
120-
# find the best combination over a fixed grid of values. `GridSearchCV` is a
121-
# scikit-learn class that implements a very similar logic with less repetitive
122-
# code.
119+
# In the previous exercise (M3.01) we used two nested `for` loops (one for each
120+
# hyperparameter) to test different combinations over a fixed grid of
121+
# hyperparameter values. In each iteration of the loop, we used
122+
# `cross_val_score` to compute the mean score (as averaged across
123+
# cross-validation splits), and compared those mean scores to select the best
124+
# combination. `GridSearchCV` is a scikit-learn class that implements a very
125+
# similar logic with less repetitive code. The suffix `CV` refers to the
126+
# cross-validation it runs internally (instead of the `cross_val_score` we
127+
# "hard" coded).
123128
#
124-
# Let's see how to use the `GridSearchCV` estimator for doing such search. Since
125-
# the grid-search is costly, we only explore the combination learning-rate and
126-
# the maximum number of nodes.
129+
# The `GridSearchCV` estimator takes a `param_grid` parameter which defines all
130+
# hyperparameters and their associated values. The grid-search is in charge of
131+
# creating all possible combinations and testing them.
132+
#
133+
# The number of combinations is equal to the product of the number of values to
134+
# explore for each parameter. Thus, adding new parameters with their associated
135+
# values to be explored rapidly becomes computationally expensive. Because of
136+
# that, here we only explore the combination learning-rate and the maximum
137+
# number of nodes for a total of 4 x 3 = 12 combinations.
127138

128-
# %%
129139
# %%time
130140
from sklearn.model_selection import GridSearchCV
131141

132142
param_grid = {
133-
"classifier__learning_rate": (0.01, 0.1, 1, 10),
134-
"classifier__max_leaf_nodes": (3, 10, 30),
135-
}
143+
"classifier__learning_rate": (0.01, 0.1, 1, 10), # 4 possible values
144+
"classifier__max_leaf_nodes": (3, 10, 30), # 3 possible values
145+
} # 12 unique combinations
136146
model_grid_search = GridSearchCV(model, param_grid=param_grid, n_jobs=2, cv=2)
137147
model_grid_search.fit(data_train, target_train)
138148

139149
# %% [markdown]
140-
# Finally, we check the accuracy of our model using the test set.
150+
# You can access the best combination of hyperparameters found by the grid
151+
# search using the `best_params_` attribute.
141152

142153
# %%
143-
accuracy = model_grid_search.score(data_test, target_test)
144-
print(
145-
f"The test accuracy score of the grid-searched pipeline is: {accuracy:.2f}"
146-
)
147-
148-
# %% [markdown]
149-
# ```{warning}
150-
# Be aware that the evaluation should normally be performed through
151-
# cross-validation by providing `model_grid_search` as a model to the
152-
# `cross_validate` function.
153-
#
154-
# Here, we used a single train-test split to evaluate `model_grid_search`. In
155-
# a future notebook will go into more detail about nested cross-validation, when
156-
# you use cross-validation both for hyperparameter tuning and model evaluation.
157-
# ```
154+
print(f"The best set of parameters is: {model_grid_search.best_params_}")
158155

159156
# %% [markdown]
160-
# The `GridSearchCV` estimator takes a `param_grid` parameter which defines all
161-
# hyperparameters and their associated values. The grid-search is in charge
162-
# of creating all possible combinations and test them.
163-
#
164-
# The number of combinations are equal to the product of the number of values to
165-
# explore for each parameter (e.g. in our example 4 x 3 combinations). Thus,
166-
# adding new parameters with their associated values to be explored become
167-
# rapidly computationally expensive.
168-
#
169-
# Once the grid-search is fitted, it can be used as any other predictor by
170-
# calling `predict` and `predict_proba`. Internally, it uses the model with the
157+
# Once the grid-search is fitted, it can be used as any other estimator, i.e. it
158+
# has `predict` and `score` methods. Internally, it uses the model with the
171159
# best parameters found during `fit`.
172160
#
173-
# Get predictions for the 5 first samples using the estimator with the best
174-
# parameters.
161+
# Let's get the predictions for the 5 first samples using the estimator with the
162+
# best parameters:
175163

176164
# %%
177165
model_grid_search.predict(data_test.iloc[0:5])
178166

179167
# %% [markdown]
180-
# You can know about these parameters by looking at the `best_params_`
181-
# attribute.
168+
# Finally, we check the accuracy of our model using the test set.
182169

183170
# %%
184-
print(f"The best set of parameters is: {model_grid_search.best_params_}")
171+
accuracy = model_grid_search.score(data_test, target_test)
172+
print(
173+
f"The test accuracy score of the grid-search pipeline is: {accuracy:.2f}"
174+
)
185175

186176
# %% [markdown]
187-
# The accuracy and the best parameters of the grid-searched pipeline are similar
177+
# The accuracy and the best parameters of the grid-search pipeline are similar
188178
# to the ones we found in the previous exercise, where we searched the best
189-
# parameters "by hand" through a double for loop.
179+
# parameters "by hand" through a double `for` loop.
180+
#
181+
# ## The need for a validation set
182+
#
183+
# In the previous section, the selection of the best hyperparameters was done
184+
# using the train set, coming from the initial train-test split. Then, we
185+
# evaluated the generalization performance of our tuned model on the left out
186+
# test set. This can be shown schematically as follows:
187+
#
188+
# ![Cross-validation tuning
189+
# diagram](../figures/cross_validation_train_test_diagram.png)
190+
#
191+
# ```{note}
192+
# This figure shows the particular case of **K-fold** cross-validation strategy
193+
# using `n_splits=5` to further split the train set coming from a train-test
194+
# split. For each cross-validation split, the procedure trains a model on all
195+
# the red samples, evaluates the score of a given set of hyperparameters on the
196+
# green samples. The best combination of hyperparameters `best_params` is selected
197+
# based on those intermediate scores.
198+
#
199+
# Then a final model is refitted using `best_params` on the concatenation of the
200+
# red and green samples and evaluated on the blue samples.
201+
#
202+
# The green samples are sometimes referred as the **validation set** to
203+
# differentiate them from the final test set in blue.
204+
# ```
190205
#
191206
# In addition, we can inspect all results which are stored in the attribute
192207
# `cv_results_` of the grid-search. We filter some specific columns from these

0 commit comments

Comments
 (0)