Posted on July 26, 2017
There are lots of CIFAR-10 loaders out there. This one…
Install:
pip install cifar10_web
Usage:
train_images, train_labels, test_images, test_labels = cifar10(path=None)
Options:
If you leave path
as None
, it defaults to /home/USER/data/cifar10/
or the Windows equivalent, which I believe is C:\Users\USER\data\cifar10\
.
If the CIFAR-10 tar file is missing from path
, it will be downloaded to path
, and you’ll be told that’s happening.
Labels are onehot row vectors each of length 10
Images are flattened row vectors each of length 3072
Speed:
path
."""Load from /home/USER/data/cifar10 or elsewhere; download if missing.""" import tarfile import os from urllib.request import urlretrieve import numpy as np def cifar10(path=None): r"""Return (train_images, train_labels, test_images, test_labels). Args: path (str): Directory containing CIFAR-10. Default is /home/USER/data/cifar10 or C:\Users\USER\data\cifar10. Create if nonexistant. Download CIFAR-10 if missing. Returns: Tuple of (train_images, train_labels, test_images, test_labels), each a matrix. Rows are examples. Columns of images are pixel values, with the order (red -> blue -> green). Columns of labels are a onehot encoding of the correct class. """ url = 'https://www.cs.toronto.edu/~kriz/' tar = 'cifar-10-binary.tar.gz' files = ['cifar-10-batches-bin/data_batch_1.bin', 'cifar-10-batches-bin/data_batch_2.bin', 'cifar-10-batches-bin/data_batch_3.bin', 'cifar-10-batches-bin/data_batch_4.bin', 'cifar-10-batches-bin/data_batch_5.bin', 'cifar-10-batches-bin/test_batch.bin'] if path is None: # Set path to /home/USER/data/mnist or C:\Users\USER\data\mnist path = os.path.join(os.path.expanduser('~'), 'data', 'cifar10') # Create path if it doesn't exist os.makedirs(path, exist_ok=True) # Download tarfile if missing if tar not in os.listdir(path): urlretrieve(''.join((url, tar)), os.path.join(path, tar)) print("Downloaded %s to %s" % (tar, path)) # Load data from tarfile with tarfile.open(os.path.join(path, tar)) as tar_object: # Each file contains 10,000 color images and 10,000 labels fsize = 10000 * (32 * 32 * 3) + 10000 # There are 6 files (5 train and 1 test) buffr = np.zeros(fsize * 6, dtype='uint8') # Get members of tar corresponding to data files # -- The tar contains README's and other extraneous stuff members = [file for file in tar_object if file.name in files] # Sort those members by name # -- Ensures we load train data in the proper order # -- Ensures that test data is the last file in the list members.sort(key=lambda member: member.name) # Extract data from members for i, member in enumerate(members): # Get member as a file object f = tar_object.extractfile(member) # Read bytes from that file object into buffr buffr[i * fsize:(i + 1) * fsize] = np.frombuffer(f.read(), 'B') # Parse data from buffer # -- Examples are in chunks of 3,073 bytes # -- First byte of each chunk is the label # -- Next 32 * 32 * 3 = 3,072 bytes are its corresponding image # Labels are the first byte of every chunk labels = buffr[::3073] # Pixels are everything remaining after we delete the labels pixels = np.delete(buffr, np.arange(0, buffr.size, 3073)) images = pixels.reshape(-1, 3072).astype('float32') / 255 # Split into train and test train_images, test_images = images[:50000], images[50000:] train_labels, test_labels = labels[:50000], labels[50000:] def _onehot(integer_labels): """Return matrix whose rows are onehot encodings of integers.""" n_rows = len(integer_labels) n_cols = integer_labels.max() + 1 onehot = np.zeros((n_rows, n_cols), dtype='uint8') onehot[np.arange(n_rows), integer_labels] = 1 return onehot return train_images, _onehot(train_labels), \ test_images, _onehot(test_labels)