2. Cross-validation: some gotchas

Cross-validation is the ubiquitous test of a machine learning model. Yet many things can go wrong.

2.1. The uncertainty of measured accuracy

The first thing to have in mind is that the results of a cross-validation are noisy estimate of the real prediction accuracy

Let us create a simple artificial data

from sklearn import datasets, discriminant_analysis
import numpy as np
np.random.seed(0)
data, target = datasets.make_blobs(centers=[(0, 0), (0, 1)])
classifier = discriminant_analysis.LinearDiscriminantAnalysis()

One cross-validation gives spread out measures

from sklearn.model_selection import cross_val_score
print(cross_val_score(classifier, data, target))

Out:

[ 0.64705882  0.67647059  0.84375   ]

What if we try different random shuffles of the data?

from sklearn import utils
for _ in range(10):
data, target = utils.shuffle(data, target)
print(cross_val_score(classifier, data, target))

Out:

[ 0.76470588  0.70588235  0.65625   ]
[ 0.70588235 0.67647059 0.75 ]
[ 0.73529412 0.64705882 0.71875 ]
[ 0.70588235 0.58823529 0.8125 ]
[ 0.67647059 0.73529412 0.71875 ]
[ 0.70588235 0.64705882 0.75 ]
[ 0.67647059 0.67647059 0.71875 ]
[ 0.70588235 0.61764706 0.8125 ]
[ 0.76470588 0.76470588 0.59375 ]
[ 0.76470588 0.61764706 0.625 ]

This should not be surprising: if the lassification rate is p, the observed distribution of correct classifications on a set of size follows a binomial distribution

from scipy import stats
n = len(data)
distrib = stats.binom(n=n, p=.7)

We can plot it:

from matplotlib import pyplot as plt
plt.figure(figsize=(6, 3))
plt.plot(np.linspace(0, 1, n), distrib.pmf(np.arange(0, n)))
../_images/sphx_glr_cross_validation_001.png

It is wide, because there are not that many samples to mesure the error upon: iris is a small dataset

We can look at the interval in which 95% of the observed accuracy lies for different sample sizes

for n in [100, 1000, 10000, 100000]:
distrib = stats.binom(n, .7)
interval = (distrib.isf(.025) - distrib.isf(.975)) / n
print("Size: {0: 7} | interval: {1}%".format(n, 100 * interval))

Out:

Size:     100  | interval: 18.0%
Size: 1000 | interval: 5.7%
Size: 10000 | interval: 1.8%
Size: 100000 | interval: 0.568%

At 100 000 samples, 5% of the observed classification accuracy still fall more than .5% away of the true rate

Keep in mind that cross-val is a noisy measure

Importantly, the variance across folds is not a good measure of this error, as the different data folds are not independent. For instance, doing many random splits will can reduce the variance arbitrarily, but does not provide actually new data points

2.2. Confounding effects and non independence

2.3. Measuring baselines and chance

Because of class imbalances, or confounding effects, it is easy to get it wrong it terms of what constitutes chances. There are two approaches to measure peformances of baselines or chance:

DummyClassifier The dummy classifier: sklearn.dummy.DummyClassifier, with different strategies to provide simple baselines

from sklearn.dummy import DummyClassifier
dummy = DummyClassifier(strategy="stratified")
print(cross_val_score(dummy, data, target))

Out:

[ 0.44117647  0.61764706  0.40625   ]

Chance level To measure actual chance, the most robust approach is to use permutations: sklearn.model_selection.permutation_test_score(), which is used as cross_val_score

from sklearn.model_selection import permutation_test_score
score, permuted_scores, p_value = permutation_test_score(classifier, data, target)
print("Classifier score: {0},\np value: {1}\nPermutation scores {2}"
.format(score, p_value, permuted_scores))

Out:

Classifier score: 0.669117647059,
p value: 0.00990099009901
Permutation scores [ 0.54963235 0.47120098 0.51041667 0.59926471 0.45036765 0.44852941
0.52941176 0.59865196 0.47855392 0.39031863 0.45955882 0.56066176
0.60110294 0.38112745 0.45159314 0.46017157 0.58026961 0.57904412
0.59191176 0.58026961 0.51041667 0.53921569 0.41176471 0.37806373
0.62193627 0.52022059 0.41789216 0.50980392 0.4497549 0.59987745
0.47855392 0.53921569 0.44056373 0.60784314 0.47120098 0.47916667
0.48958333 0.5814951 0.50857843 0.43014706 0.53002451 0.48039216
0.48835784 0.43872549 0.43872549 0.59987745 0.45894608 0.40931373
0.52022059 0.45955882 0.44914216 0.52022059 0.55147059 0.47120098
0.49080882 0.49203431 0.47794118 0.5379902 0.62990196 0.51041667
0.46997549 0.44056373 0.56127451 0.60968137 0.47120098 0.54963235
0.54718137 0.56066176 0.47977941 0.42953431 0.43872549 0.38051471
0.43872549 0.41115196 0.48897059 0.48039216 0.60968137 0.60906863
0.57169118 0.52757353 0.5122549 0.52022059 0.46017157 0.62009804
0.47058824 0.59987745 0.55085784 0.46813725 0.53737745 0.54105392
0.48897059 0.51102941 0.48039216 0.50122549 0.47058824 0.59926471
0.49080882 0.4185049 0.47058824 0.49019608]

Total running time of the script: ( 0 minutes 0.602 seconds)

Gallery generated by Sphinx-Gallery