Note
Click here to download the full example code or run this example in your browser via Binder
2. Cross-validation: some gotchas¶
Cross-validation is the ubiquitous test of a machine learning model. Yet many things can go wrong.
2.1. The uncertainty of measured accuracy¶
The first thing to have in mind is that the results of a cross-validation are noisy estimate of the real prediction accuracy
Let us create a simple artificial data
from sklearn import datasets, discriminant_analysis
import numpy as np
np.random.seed(0)
data, target = datasets.make_blobs(centers=[(0, 0), (0, 1)])
classifier = discriminant_analysis.LinearDiscriminantAnalysis()
One cross-validation gives spread out measures
from sklearn.model_selection import cross_val_score
print(cross_val_score(classifier, data, target))
Out:
[ 0.64705882 0.67647059 0.84375 ]
What if we try different random shuffles of the data?
from sklearn import utils
for _ in range(10):
data, target = utils.shuffle(data, target)
print(cross_val_score(classifier, data, target))
Out:
[ 0.76470588 0.70588235 0.65625 ]
[ 0.70588235 0.67647059 0.75 ]
[ 0.73529412 0.64705882 0.71875 ]
[ 0.70588235 0.58823529 0.8125 ]
[ 0.67647059 0.73529412 0.71875 ]
[ 0.70588235 0.64705882 0.75 ]
[ 0.67647059 0.67647059 0.71875 ]
[ 0.70588235 0.61764706 0.8125 ]
[ 0.76470588 0.76470588 0.59375 ]
[ 0.76470588 0.61764706 0.625 ]
This should not be surprising: if the lassification rate is p, the observed distribution of correct classifications on a set of size follows a binomial distribution
from scipy import stats
n = len(data)
distrib = stats.binom(n=n, p=.7)
We can plot it:
from matplotlib import pyplot as plt
plt.figure(figsize=(6, 3))
plt.plot(np.linspace(0, 1, n), distrib.pmf(np.arange(0, n)))

It is wide, because there are not that many samples to mesure the error upon: iris is a small dataset
We can look at the interval in which 95% of the observed accuracy lies for different sample sizes
for n in [100, 1000, 10000, 100000]:
distrib = stats.binom(n, .7)
interval = (distrib.isf(.025) - distrib.isf(.975)) / n
print("Size: {0: 7} | interval: {1}%".format(n, 100 * interval))
Out:
Size: 100 | interval: 18.0%
Size: 1000 | interval: 5.7%
Size: 10000 | interval: 1.8%
Size: 100000 | interval: 0.568%
At 100 000 samples, 5% of the observed classification accuracy still fall more than .5% away of the true rate
Keep in mind that cross-val is a noisy measure
Importantly, the variance across folds is not a good measure of this error, as the different data folds are not independent. For instance, doing many random splits will can reduce the variance arbitrarily, but does not provide actually new data points
2.2. Confounding effects and non independence¶
2.3. Measuring baselines and chance¶
Because of class imbalances, or confounding effects, it is easy to get it wrong it terms of what constitutes chances. There are two approaches to measure peformances of baselines or chance:
DummyClassifier The dummy classifier:
sklearn.dummy.DummyClassifier
, with different strategies to
provide simple baselines
from sklearn.dummy import DummyClassifier
dummy = DummyClassifier(strategy="stratified")
print(cross_val_score(dummy, data, target))
Out:
[ 0.44117647 0.61764706 0.40625 ]
Chance level To measure actual chance, the most robust approach is
to use permutations:
sklearn.model_selection.permutation_test_score()
, which is used
as cross_val_score
from sklearn.model_selection import permutation_test_score
score, permuted_scores, p_value = permutation_test_score(classifier, data, target)
print("Classifier score: {0},\np value: {1}\nPermutation scores {2}"
.format(score, p_value, permuted_scores))
Out:
Classifier score: 0.669117647059,
p value: 0.00990099009901
Permutation scores [ 0.54963235 0.47120098 0.51041667 0.59926471 0.45036765 0.44852941
0.52941176 0.59865196 0.47855392 0.39031863 0.45955882 0.56066176
0.60110294 0.38112745 0.45159314 0.46017157 0.58026961 0.57904412
0.59191176 0.58026961 0.51041667 0.53921569 0.41176471 0.37806373
0.62193627 0.52022059 0.41789216 0.50980392 0.4497549 0.59987745
0.47855392 0.53921569 0.44056373 0.60784314 0.47120098 0.47916667
0.48958333 0.5814951 0.50857843 0.43014706 0.53002451 0.48039216
0.48835784 0.43872549 0.43872549 0.59987745 0.45894608 0.40931373
0.52022059 0.45955882 0.44914216 0.52022059 0.55147059 0.47120098
0.49080882 0.49203431 0.47794118 0.5379902 0.62990196 0.51041667
0.46997549 0.44056373 0.56127451 0.60968137 0.47120098 0.54963235
0.54718137 0.56066176 0.47977941 0.42953431 0.43872549 0.38051471
0.43872549 0.41115196 0.48897059 0.48039216 0.60968137 0.60906863
0.57169118 0.52757353 0.5122549 0.52022059 0.46017157 0.62009804
0.47058824 0.59987745 0.55085784 0.46813725 0.53737745 0.54105392
0.48897059 0.51102941 0.48039216 0.50122549 0.47058824 0.59926471
0.49080882 0.4185049 0.47058824 0.49019608]
Total running time of the script: ( 0 minutes 0.602 seconds)