"""Fashion-MNIST dataset. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import gzip import os from ..utils.data_utils import get_file import numpy as np def load_data(): """Loads the Fashion-MNIST dataset. # Returns Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. """ dirname = os.path.join('datasets', 'fashion-mnist') base = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/' files = ['train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz'] paths = [] for fname in files: paths.append(get_file(fname, origin=base + fname, cache_subdir=dirname)) with gzip.open(paths[0], 'rb') as lbpath: y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8) with gzip.open(paths[1], 'rb') as imgpath: x_train = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28) with gzip.open(paths[2], 'rb') as lbpath: y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8) with gzip.open(paths[3], 'rb') as imgpath: x_test = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28) return (x_train, y_train), (x_test, y_test)