steniu01 | bee466b | 2017-06-21 16:45:41 +0100 | [diff] [blame] | 1 | #!/usr/bin/env python |
SiCong Li | 86b5333 | 2017-08-23 11:02:43 +0100 | [diff] [blame] | 2 | """Extracts trainable parameters from Caffe models and stores them in numpy arrays. |
| 3 | Usage |
| 4 | python caffe_data_extractor -m path_to_caffe_model_file -n path_to_caffe_netlist |
steniu01 | bee466b | 2017-06-21 16:45:41 +0100 | [diff] [blame] | 5 | |
SiCong Li | 86b5333 | 2017-08-23 11:02:43 +0100 | [diff] [blame] | 6 | Saves each variable to a {variable_name}.npy binary file. |
| 7 | |
| 8 | Tested with Caffe 1.0 on Python 2.7 |
| 9 | """ |
| 10 | import argparse |
steniu01 | bee466b | 2017-06-21 16:45:41 +0100 | [diff] [blame] | 11 | import caffe |
SiCong Li | 86b5333 | 2017-08-23 11:02:43 +0100 | [diff] [blame] | 12 | import os |
steniu01 | bee466b | 2017-06-21 16:45:41 +0100 | [diff] [blame] | 13 | import numpy as np |
steniu01 | bee466b | 2017-06-21 16:45:41 +0100 | [diff] [blame] | 14 | |
| 15 | |
| 16 | if __name__ == "__main__": |
| 17 | # Parse arguments |
SiCong Li | 86b5333 | 2017-08-23 11:02:43 +0100 | [diff] [blame] | 18 | parser = argparse.ArgumentParser('Extract Caffe net parameters') |
| 19 | parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Path to Caffe model file') |
| 20 | parser.add_argument('-n', dest='netFile', type=str, required=True, help='Path to Caffe netlist') |
steniu01 | bee466b | 2017-06-21 16:45:41 +0100 | [diff] [blame] | 21 | args = parser.parse_args() |
| 22 | |
| 23 | # Create Caffe Net |
| 24 | net = caffe.Net(args.netFile, 1, weights=args.modelFile) |
| 25 | |
| 26 | # Read and dump blobs |
| 27 | for name, blobs in net.params.iteritems(): |
SiCong Li | 86b5333 | 2017-08-23 11:02:43 +0100 | [diff] [blame] | 28 | print('Name: {0}, Blobs: {1}'.format(name, len(blobs))) |
steniu01 | bee466b | 2017-06-21 16:45:41 +0100 | [diff] [blame] | 29 | for i in range(len(blobs)): |
| 30 | # Weights |
| 31 | if i == 0: |
| 32 | outname = name + "_w" |
| 33 | # Bias |
| 34 | elif i == 1: |
| 35 | outname = name + "_b" |
| 36 | else: |
| 37 | pass |
| 38 | |
SiCong Li | 86b5333 | 2017-08-23 11:02:43 +0100 | [diff] [blame] | 39 | varname = outname |
| 40 | if os.path.sep in varname: |
| 41 | varname = varname.replace(os.path.sep, '_') |
| 42 | print("Renaming variable {0} to {1}".format(outname, varname)) |
| 43 | print("Saving variable {0} with shape {1} ...".format(varname, blobs[i].data.shape)) |
steniu01 | bee466b | 2017-06-21 16:45:41 +0100 | [diff] [blame] | 44 | # Dump as binary |
SiCong Li | 86b5333 | 2017-08-23 11:02:43 +0100 | [diff] [blame] | 45 | np.save(varname, blobs[i].data) |