scikit-learn Model selection K-Fold Cross Validation


Example

K-fold cross-validation is a systematic process for repeating the train/test split procedure multiple times, in order to reduce the variance associated with a single trial of train/test split. You essentially split the entire dataset into K equal size "folds", and each fold is used once for testing the model and K-1 times for training the model.

Multiple folding techniques are available with the scikit library. Their usage is dependent on the input data characteristics. Some examples are

K-Fold

You essentially split the entire dataset into K equal size "folds", and each fold is used once for testing the model and K-1 times for training the model.

from sklearn.model_selection import KFold
X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
y = np.array([1, 2, 1, 2])
cv = KFold(n_splits=3, random_state=0)

for train_index, test_index in cv.split(X):
...    print("TRAIN:", train_index, "TEST:", test_index)

TRAIN: [2 3] TEST: [0 1]
TRAIN: [0 1 3] TEST: [2]
TRAIN: [0 1 2] TEST: [3]

StratifiedKFold is a variation of k-fold which returns stratified folds: each set contains approximately the same percentage of samples of each target class as the complete set

ShuffleSplit

Used to generate a user defined number of independent train / test dataset splits. Samples are first shuffled and then split into a pair of train and test sets.

from sklearn.model_selection import ShuffleSplit
X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
y = np.array([1, 2, 1, 2])
cv = ShuffleSplit(n_splits=3, test_size=.25, random_state=0)

for train_index, test_index in cv.split(X):
...    print("TRAIN:", train_index, "TEST:", test_index)

TRAIN: [3 1 0] TEST: [2]
TRAIN: [2 1 3] TEST: [0]
TRAIN: [0 2 1] TEST: [3]

StratifiedShuffleSplit is a variation of ShuffleSplit, which returns stratified splits, i.e which creates splits by preserving the same percentage for each target class as in the complete set.

Other folding techniques such as Leave One/p Out, and TimeSeriesSplit (a variation of K-fold) are available in the scikit model_selection library.