Alex Gilday | b34b9d4 | 2018-03-08 11:28:29 +0000 | [diff] [blame] | 1 | #!/usr/bin/env python |
| 2 | """Extracts mnist image data from the Caffe data files and stores them in numpy arrays |
| 3 | Usage |
| 4 | python caffe_mnist_image_extractor.py -d path_to_caffe_data_directory -o desired_output_path |
| 5 | |
| 6 | Saves the first 10 images extracted as input10.npy, the first 100 images as input100.npy, and the |
| 7 | corresponding labels to labels100.txt. |
| 8 | |
| 9 | Tested with Caffe 1.0 on Python 2.7 |
| 10 | """ |
| 11 | import argparse |
| 12 | import os |
| 13 | import struct |
| 14 | import numpy as np |
| 15 | from array import array |
| 16 | |
| 17 | |
| 18 | if __name__ == "__main__": |
| 19 | # Parse arguments |
| 20 | parser = argparse.ArgumentParser('Extract Caffe mnist image data') |
| 21 | parser.add_argument('-d', dest='dataDir', type=str, required=True, help='Path to Caffe data directory') |
| 22 | parser.add_argument('-o', dest='outDir', type=str, default='.', help='Output directory (default = current directory)') |
| 23 | args = parser.parse_args() |
| 24 | |
| 25 | images_filename = os.path.join(args.dataDir, 'mnist/t10k-images-idx3-ubyte') |
| 26 | labels_filename = os.path.join(args.dataDir, 'mnist/t10k-labels-idx1-ubyte') |
| 27 | |
| 28 | images_file = open(images_filename, 'rb') |
| 29 | labels_file = open(labels_filename, 'rb') |
| 30 | images_magic, images_size, rows, cols = struct.unpack('>IIII', images_file.read(16)) |
| 31 | labels_magic, labels_size = struct.unpack('>II', labels_file.read(8)) |
| 32 | images = array('B', images_file.read()) |
| 33 | labels = array('b', labels_file.read()) |
| 34 | |
| 35 | input10_path = os.path.join(args.outDir, 'input10.npy') |
| 36 | input100_path = os.path.join(args.outDir, 'input100.npy') |
| 37 | labels100_path = os.path.join(args.outDir, 'labels100.npy') |
| 38 | |
| 39 | outputs_10 = np.zeros(( 10, 28, 28, 1), dtype=np.float32) |
| 40 | outputs_100 = np.zeros((100, 28, 28, 1), dtype=np.float32) |
| 41 | labels_output = open(labels100_path, 'w') |
| 42 | for i in xrange(100): |
| 43 | image = np.array(images[i * rows * cols : (i + 1) * rows * cols]).reshape((rows, cols)) / 256.0 |
| 44 | outputs_100[i, :, :, 0] = image |
| 45 | |
| 46 | if i < 10: |
| 47 | outputs_10[i, :, :, 0] = image |
| 48 | |
| 49 | if i == 10: |
| 50 | np.save(input10_path, np.transpose(outputs_10, (0, 3, 1, 2))) |
| 51 | print "Wrote", input10_path |
| 52 | |
| 53 | labels_output.write(str(labels[i]) + '\n') |
| 54 | |
| 55 | labels_output.close() |
| 56 | print "Wrote", labels100_path |
| 57 | |
| 58 | np.save(input100_path, np.transpose(outputs_100, (0, 3, 1, 2))) |
| 59 | print "Wrote", input100_path |