MNIST Dataset#

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.datasets import fetch_openml

Load the MNIST dataset#

The MNIST dataset is used here to visualization. MNIST is an old but popular handwritten dataset that is used for digit classification. The original dataset was published by Yann LeCun which contains 70,000 images of handwritten single-digit numbers between 0 and 9. The images are preprocessed such that the digits are centered, scaled to 28x28 (28*28 = 784) resolution, and a single-channel grayscale image. The original dataset flattens the 28x28 2D-image into a 1D-vector.

# download the full mnist dataset
x_mnist, y_mnist = fetch_openml('mnist_784', version=1, as_frame=False, return_X_y=True)
# print the sizes
print(f'{x_mnist.shape = }')
print(f'{y_mnist.shape = }')
x_mnist.shape = (70000, 784)
y_mnist.shape = (70000,)
  • x_mnist contains the flattened 1-D vector containing a single sample of data

  • y_mnist contains the ground-truth label for the associated data

# show the first image
plt.figure(figsize=(16,8))
plt.imshow(x_mnist[0].reshape(28, 28))
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x7f2da9d44c50>
../../_images/d81799c29524683202a8f83c63d52b7e1db5a44074caa433fa2c2bee8c097285.png
N = 8   # columns
M = 10  # rows

# array to hold results
image_samples = None
image_labels = np.array([])

# generate a random permutation to sample dataset
permuted_array = np.random.permutation(y_mnist.shape[0])

# loop along rows
for m in range(M):
    # loop along columns to generate a single row
    for n in range(N):
        # get a random permutation index
        permute_idx = permuted_array[m*M + n]

        # get image at the permuted index
        image = x_mnist[permute_idx].reshape(28, 28)
        label = y_mnist[permute_idx]

        # generate a single row
        image_samples_row = image if n == 0 else np.hstack((image_samples_row, image))
        # image_labels.append(label)
        image_labels = np.append(image_labels, label)

    # append rows
    image_samples = image_samples_row if image_samples is None else np.vstack((image_samples, image_samples_row))

# show grid of images
plt.figure(figsize=(16,8))
plt.imshow(image_samples)
plt.colorbar()

# show grid of ground truth data
display(pd.DataFrame(image_labels.reshape(M,N)))
0 1 2 3 4 5 6 7
0 0 8 4 9 9 4 3 5
1 7 7 9 7 5 4 8 7
2 7 0 0 1 4 6 5 0
3 4 2 8 6 2 6 9 5
4 6 1 9 9 6 3 7 2
5 4 8 3 4 4 4 1 4
6 4 1 1 9 1 1 2 5
7 4 0 6 1 4 3 6 4
8 0 4 5 0 4 7 7 4
9 0 3 8 5 4 2 4 7
../../_images/b19e036859d58fad40ad11e99e3412ddebd3a655e1a952934f2c08fcff906ccf.png

Viewing a Single Image#

x_mnist = x_mnist.astype(np.float64)    # convert sample data into floating point for higher dynamic range manipulation
y_mnist = y_mnist.astype(np.int8)       # convert labels from strings into int

###### print the label
print(f'{y_mnist[permute_idx] = }')
print(f'{x_mnist[permute_idx] = }')
y_mnist[permute_idx] = np.int8(7)
x_mnist[permute_idx] = array([  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,  63., 110., 233., 253., 253., 255., 128.,  31.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,  42.,  73.,  73., 176., 237., 253.,
       252., 252., 252., 253., 252., 195.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,  11., 140.,
       221., 253., 252., 252., 252., 253., 241., 215., 236., 253., 252.,
       195.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   1.,  73., 252., 252., 253., 252., 252., 252.,
       191.,  77.,  21., 206., 253., 252.,  71.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   1., 170., 252.,
       252., 252., 253., 158.,  41.,   0.,   0.,  42., 221., 252., 253.,
        98.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,  42., 252., 252., 252., 210., 119.,   5.,   0.,
         0.,   0.,  63., 242., 252., 222.,  45.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,  11., 140., 221., 252.,
       252., 226.,  92.,   0.,   0.,   0.,   0.,  37., 181., 252., 231.,
        41.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,  73., 252., 252., 252., 210.,  31.,   0.,   0.,   0.,
         0.,   0., 170., 252., 252., 108.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0., 135., 253., 253.,
       253., 119.,   0.,   0.,   0.,   0.,   0.,  63., 255., 253., 253.,
       108.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,  94., 247., 252., 252., 118.,   5.,   0.,   0.,   0.,
         0.,  11., 175., 253., 252.,  96.,  15.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0., 109., 252., 246.,
       132.,   0.,   0.,   0.,   0.,   0.,   0.,  58., 252., 253., 179.,
        20.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0., 109., 252., 132.,   0.,   0.,   0.,   0.,   0.,
         0.,  42., 221., 252., 191.,  15.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0., 218., 253., 253., 145.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0., 135., 247., 252., 210.,  20.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0., 105., 253., 252., 226.,  31.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
       105., 206., 253., 210.,  92.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,  32., 212., 253., 255., 119.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,  32.,
       207., 252., 252., 222.,  25.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,  84., 211., 252., 252., 148.,  41.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   1.,
       169., 252., 231.,  46.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.])

Intuitively, one can see that the majority of pixels are white, or 0. Very few pixels of the whole 28x28 grid are actually non-zero.

plt.figure(figsize=(16,8))

# randomly sample 10 images and plot them
for idx in range(10):
    permute_idx = permuted_array[idx]
    plt.plot(x_mnist[permute_idx], '*', label=y_mnist[permute_idx])
plt.legend()
plt.grid()
../../_images/f9b602a2b0fe04962da30a3cdd19b117807e9f0bbaf7a2b993f57f6a7baa4761.png

Binary Classification#

To begin with a simple application, we’ll look at a binary classification problem by reducing the data to either labels of 0 or 1.

# grab all images that are 0 or 1
samples_zeros = x_mnist[y_mnist == 0]
samples_ones = x_mnist[y_mnist == 1]

print(f'{samples_zeros.shape = }')
print(f'{samples_ones.shape = }')
samples_zeros.shape = (6903, 784)
samples_ones.shape = (7877, 784)

We can take a look a which pixels are zero throughout the whole dataset. In particular, these pixels, being non-zero do no provide any information that would contribute to a model’s prediction. As a measure of information, the variance of each pixel along each sample can be taken to determine if there is any substantive changes from sample to sample.

# append data
samples = np.vstack((samples_zeros, samples_ones))

# get the covariance of the samples
samples_cov = np.cov(samples.T)
# plt.imshow(samples_cov)
# plot the image map of covariances of all the data
plt.figure(figsize=(16,8))
plt.imshow(np.sum(samples_cov, axis=1).reshape(28,28))
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x7f2daab4f350>
../../_images/0a3948b7a5bc61bcbd192178d166a44a6529bf5496707d936882bb4582a56c9c.png