Create training and test sets
In this tutorial I show how to separate a dataset into training and test sets using scikit-learn in Python and the Iris dataset.
Get Started
The first thing we do is import the modules that we will use.
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt
Load the Iris Dataset
The Iris dataset is one of the toy datasets included with sklearn.datasets. We load the Iris dataset.
iris = load_iris()
To learn more about the Iris dataset we view the description. The full description is quite long, so we only display the first five hundred characters. (You can use the command print(iris[‘DESCR’]) to view the full description.)
print(iris['DESCR'][:500])
.. _iris_dataset:
Iris plants dataset
--------------------
**Data Set Characteristics:**
:Number of Instances: 150 (50 in each of three classes)
:Number of Attributes: 4 numeric, predictive attributes and the class
:Attribute Information:
- sepal length in cm
- sepal width in cm
- petal length in cm
- petal width in cm
- class:
- Iris-Setosa
- Iris-Versicolour
- Iris-Virginica
The description tells us that the dataset contains 3 classes and each class has 50 samples for a total of 150 samples. The dataset has four variables or features. We can see how the data is stored by viewing the dataset’s keys.
iris.keys()
dict_keys(['data', 'target', 'frame', 'target_names', 'DESCR', 'feature_names', 'filename'])
The feature values are stored under data. The names of the classes are stored under target_names. The class of each sample is stored under target, where 0 represents class setosa, 1 represents class versicolor, and 2 represents class virginica.
iris.target_names
array(['setosa', 'versicolor', 'virginica'], dtype='<U10')
iris.target
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
Create Training and Test Sets
We use the function train_test_split() from sklearn.model_selection to psuedo-randomly assign samples to training and testing sets. More specifically, train_test_split() splits the feature values stored in iris.data into two sets: data_train and data_test. The class labels for the samples in data_train are in target_train and the class labels for data_test are in target_test.
data_train, data_test, target_train, target_test = train_test_split(iris.data, iris.target)
Choose the size of the training and test sets
By default, train_test_split() puts approximately 75% of the samples in the training set and approximately 25% in the test set.
print('Number of samples of in the training set:',target_train.shape[0])
print('Number of samples of in the test set:',target_test.shape[0])
Number of samples of in the training set: 112
Number of samples of in the test set: 38
We can specify the size of the training set with the parameter train_size. For example, if we want to use 100 samples for training we set train_size equal to 100. The remaining 50 samples will be placed in the test set.
data_train100, data_test100, target_train100, target_test100 = train_test_split(iris.data, iris.target, train_size=100)
print('Number of samples of in the training set:',target_train100.shape[0])
print('Number of samples of in the test set:',target_test100.shape[0])
Number of samples of in the training set: 100
Number of samples of in the test set: 50
Instead of telling the function the number of samples we want to use for training, we can also tell it the percentage of samples we would like to use. To do this, we set train_size equal to a number between 0.0 and 1.0. For example, if we want our training set to be comprised of 50% of the samples, we set train_size equal to 0.5. The other 50% of the samples will be placed in the test set.
data_train50, data_test50, target_train50, target_test50 = train_test_split(iris.data, iris.target, train_size=0.5)
print('Number of samples of in the training set:',target_train50.shape[0])
print('Number of samples of in the test set:',target_test50.shape[0])
Number of samples of in the training set: 75
Number of samples of in the test set: 75
Stratify by class
We can view the number of samples of each class that were placed in the training and test sets.
train_class_counts = [(target_train==0).sum(), (target_train==1).sum(), (target_train==0).sum()]
test_class_counts = [(target_test==0).sum(), (target_test==1).sum(), (target_test==0).sum()]
print('Number of samples of each class in the training set:', train_class_counts)
print('Number of samples of each class in the test set:', test_class_counts)
Number of samples of each class in the training set: [39, 35, 39]
Number of samples of each class in the test set: [11, 15, 11]
We can also plot the number of samples from each class in the training and test sets.
fig, axs = plt.subplots(1, 2, sharey=True)
axs[0].bar(iris.target_names, train_class_counts)
axs[0].set_title('Training Set')
axs[1].bar(iris.target_names, test_class_counts)
axs[1].set_title('Test Set')
for ax in axs.flat:
ax.set(xlabel='class', ylabel='number of samples')
# Hide y ticks for right plots.
for ax in axs.flat:
ax.label_outer()
The iris dataset has 50 samples from each class so each class represents 1/3 of the dataset. If we use the default settings, the function train_test_split() does not try to form a training set from the Iris dataset where each class represents 1/3 of the training set. If we want the training and test sets to have the same class proportations as the original dataset, we can use the stratify parameter.
data_train_strat, data_test_strat, target_train_strat, target_test_strat = train_test_split(iris.data, iris.target, stratify=iris.target)
train_class_counts = [(target_train_strat==0).sum(), (target_train_strat==1).sum(),(target_train_strat==0).sum()]
test_class_counts = [(target_test_strat==0).sum(), (target_test_strat==1).sum(),(target_test_strat==0).sum()]
print('Number of samples of each class in the training set:', train_class_counts)
print('Number of samples of each class in the test set:', test_class_counts)
Number of samples of each class in the training set: [38, 37, 38]
Number of samples of each class in the test set: [12, 13, 12]
We can plot the number of samples from each class in our new training and test sets.
fig, axs = plt.subplots(1, 2, sharey=True)
axs[0].bar(iris.target_names, train_class_counts)
axs[0].set_title('Stratified Training Set')
axs[1].bar(iris.target_names, test_class_counts)
axs[1].set_title('Stratified Test Set')
for ax in axs.flat:
ax.set(xlabel='class', ylabel='number of samples')
# Hide y ticks for right plots.
for ax in axs.flat:
ax.label_outer()
Set the random state
Often, we want our results to be reproducible. The function train_test_split() uses on a psuedo-random number generator to assign samples to the training and test sets. If we want the function to form the exact same training and test sets every time, we can use the random_state parameter by setting it equal to any non-negative integer that we choose.
data_train, data_test, target_train, target_test = train_test_split(iris.data, iris.target, random_state=5)
Other parameters
For a full list of parameters that can be used with train_test_split() go to https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html#.
References
I used the following sources when writing this tutorial and found them to be quite helpful:
- https://scikit-learn.org/stable/datasets/index.html
- https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html#
- https://matplotlib.org/3.1.1/gallery/lines_bars_and_markers/categorical_variables.html
- https://matplotlib.org/3.1.0/gallery/subplots_axes_and_figures/subplots_demo.html
- https://moonbooks.org/Articles/How-to-create-a-table-of-contents-in-a-jupyter-notebook-/