Posted on July 26, 2017
There are lots of MNIST loaders out there. This one…
Install:
pip install mnist_webUsage:
train_images, train_labels, test_images, test_labels = mnist(path=None)Options:
If you leave path as None, it defaults to /home/USER/data/mnist/ or the Windows equivalent, which I believe is C:\Users\USER\data\mnist\.
Any of the four MNIST files missing from path will be downloaded to path, and it will tell you that’s happening.
Labels are onehot row vectors each of length 10
Images are flattened row vectors each of length 784
Speed:
path.
"""Load from /home/USER/data/mnist or elsewhere; download if missing."""
import gzip
import os
from urllib.request import urlretrieve
import numpy as np
def mnist(path=None):
r"""Return (train_images, train_labels, test_images, test_labels).
Args:
path (str): Directory containing MNIST. Default is
/home/USER/data/mnist or C:\Users\USER\data\mnist.
Create if nonexistant. Download any missing files.
Returns:
Tuple of (train_images, train_labels, test_images, test_labels), each
a matrix. Rows are examples. Columns of images are pixel values.
Columns of labels are a onehot encoding of the correct class.
"""
url = 'http://yann.lecun.com/exdb/mnist/'
files = ['train-images-idx3-ubyte.gz',
'train-labels-idx1-ubyte.gz',
't10k-images-idx3-ubyte.gz',
't10k-labels-idx1-ubyte.gz']
if path is None:
# Set path to /home/USER/data/mnist or C:\Users\USER\data\mnist
path = os.path.join(os.path.expanduser('~'), 'data', 'mnist')
# Create path if it doesn't exist
os.makedirs(path, exist_ok=True)
# Download any missing files
for file in files:
if file not in os.listdir(path):
urlretrieve(url + file, os.path.join(path, file))
print("Downloaded %s to %s" % (file, path))
def _images(path):
"""Return images loaded locally."""
with gzip.open(path) as f:
# First 16 bytes are magic_number, n_imgs, n_rows, n_cols
pixels = np.frombuffer(f.read(), 'B', offset=16)
return pixels.reshape(-1, 784).astype('float32') / 255
def _labels(path):
"""Return labels loaded locally."""
with gzip.open(path) as f:
# First 8 bytes are magic_number, n_labels
integer_labels = np.frombuffer(f.read(), 'B', offset=8)
def _onehot(integer_labels):
"""Return matrix whose rows are onehot encodings of integers."""
n_rows = len(integer_labels)
n_cols = integer_labels.max() + 1
onehot = np.zeros((n_rows, n_cols), dtype='uint8')
onehot[np.arange(n_rows), integer_labels] = 1
return onehot
return _onehot(integer_labels)
train_images = _images(os.path.join(path, files[0]))
train_labels = _labels(os.path.join(path, files[1]))
test_images = _images(os.path.join(path, files[2]))
test_labels = _labels(os.path.join(path, files[3]))
return train_images, train_labels, test_images, test_labels