Back
Close

Handwritten Digit Recognition Using scikit-learn

[CG]Maxime
778 views

Handwritten Digit Recognition Using scikit-learn

In this article, I'll show you how to use scikit-learn to do machine learning classification on the MNIST database of handwritten digits. We'll use and discuss the following methods:

  • K-Nearest Neighbors
  • Random Forest
  • Linear SVC

The MNIST dataset is a well-known dataset consisting of 28x28 grayscale images. For each image, we know the corresponding digits (from 0 to 9). It is available here: http://yann.lecun.com/exdb/mnist/index.html

All the examples are runnable in the browser directly. However, if you want to run it directly on your computer, you'll need to install some dependencies: pip3 install Pillow scikit-learn python-mnist.

Loading the Dataset

To load the dataset, we use the python-mnist package.

from mnist import MNIST
from PIL import Image, ImageDraw
# Load dataset
mndata = MNIST('./data/')
images, labels = mndata.load_training()
# Pick the fifth image from the dataset (it's a 9)
i = 4
image, label = images[i], labels[i]
# Print the image
output = Image.new("L", (28, 28))
output.putdata(image)
output.save("output.png")
# Print label
print(label)
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX

K-Nearest Neighbors

The k-nearest neighbors algorithm is one of the simplest algorithm for classification. Let's represent the training data as a set of points in the feature space (e.g. for objects with 2 features, they are points of the 2D plan). Classification is based on the K closest points of the training set to the object we wish to classify. The object is classified by a majority vote of its k-nearest neighbors.

The training phase of the K-Nearest Neighbors algorithm is fast, however the classification phase may be slow due to computation of K distances.

from mnist import MNIST
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
print("Loading dataset...")
mndata = MNIST("./data/")
images, labels = mndata.load_training()
clf = KNeighborsClassifier()
# Train on the first 10000 images:
train_x = images[:10000]
train_y = labels[:10000]
print("Train model")
clf.fit(train_x, train_y)
# Test on the next 100 images:
test_x = images[10000:10100]
expected = labels[10000:10100].tolist()
print("Compute predictions")
predicted = clf.predict(test_x)
print("Accuracy: ", accuracy_score(expected, predicted))
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX

Random Forest

Random forests are an ensemble learning method that can be used for classification. It works by using a multitude of decision trees and it selects the class that is the most often predicted by the trees.

A decision tree contains at each vertex a "question" and each descending edge is an "answer" to that question. The leaves of the tree are the possible outcomes. A decision tree can be built automatically from a training set.

Each tree of the forest is created using a random sample of the original training set, and by considering only a subset of the features (typically the square root of the number of features). The number of trees is controlled by cross-validation.

from mnist import MNIST
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
print("Loading dataset...")
mndata = MNIST("./data/")
images, labels = mndata.load_training()
clf = RandomForestClassifier(n_estimators=100)
# Train on the first 10000 images:
train_x = images[:10000]
train_y = labels[:10000]
print("Train model")
clf.fit(train_x, train_y)
# Test on the next 1000 images:
test_x = images[10000:11000]
expected = labels[10000:11000].tolist()
print("Compute predictions")
predicted = clf.predict(test_x)
print("Accuracy: ", accuracy_score(expected, predicted))
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX

Linear Support Vector Classification.

In a Support Vector Machine (SVM) model, the dataset is represented as points in space. The space is separated in clusters by several hyperplanes. Each hyperplan tries to maximize the margin between two classes (i.e. the distance to the closest points is maximized).

Scikit-learn provided multiple Support Vector Machine classifier implementations. SVC supports multiple kernel functions (used to split with non-linearly) but the training time complexity is quadradic with the number of samples. Multiclass classification is done with a one-vs-one scheme. On the other hand, LinearSVC only supports linear kernels but the training time is linear with the number of samples. The multiclass classification is done with a one-vs-others scheme.

We'll use LinearSVC here because it is fast enough. Feel free to experiment the standard SVC with a smaller training set.

from mnist import MNIST
from sklearn.svm import LinearSVC
from sklearn.metrics import accuracy_score
print("Loading dataset...")
mndata = MNIST("./data/")
images, labels = mndata.load_training()
clf = LinearSVC()
# Train on the first 10000 images:
train_x = images[:10000]
train_y = labels[:10000]
print("Train model")
clf.fit(train_x, train_y)
# Test on the next 1000 images:
test_x = images[10000:11000]
expected = labels[10000:11000].tolist()
print("Compute predictions")
predicted = clf.predict(test_x)
print("Accuracy: ", accuracy_score(expected, predicted))
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX

Conclusion

In this article, we saw 3 different approaches for Supervised Machine Learning classification. It's not possible to say which one is the best to classify this MNIST dataset because that depends on the many criteria and they can be fine-tuned to improve their performance (which I didn't here). The K-nearest neighbors algorithm is fast to train the data but is slow to compute the results. On the other hand, the Random Forest is faster to classify the data. The results obtained with LinearSVC were less good, but this could probably be improved by using better parameters.

Create your playground on Tech.io
This playground was created on Tech.io, our hands-on, knowledge-sharing platform for developers.
Go to tech.io
codingame x discord
Join the CodinGame community on Discord to chat about puzzle contributions, challenges, streams, blog articles - all that good stuff!
JOIN US ON DISCORD
Online Participants