blob: 47d24b265f71cb29c8e19ab01db246b46f7c5023 [file] [log] [blame]
steniu01bee466b2017-06-21 16:45:41 +01001#!/usr/bin/env python
SiCong Li86b53332017-08-23 11:02:43 +01002"""Extracts trainable parameters from Caffe models and stores them in numpy arrays.
3Usage
4 python caffe_data_extractor -m path_to_caffe_model_file -n path_to_caffe_netlist
steniu01bee466b2017-06-21 16:45:41 +01005
SiCong Li86b53332017-08-23 11:02:43 +01006Saves each variable to a {variable_name}.npy binary file.
7
8Tested with Caffe 1.0 on Python 2.7
9"""
10import argparse
steniu01bee466b2017-06-21 16:45:41 +010011import caffe
SiCong Li86b53332017-08-23 11:02:43 +010012import os
steniu01bee466b2017-06-21 16:45:41 +010013import numpy as np
steniu01bee466b2017-06-21 16:45:41 +010014
15
16if __name__ == "__main__":
17 # Parse arguments
SiCong Li86b53332017-08-23 11:02:43 +010018 parser = argparse.ArgumentParser('Extract Caffe net parameters')
19 parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Path to Caffe model file')
20 parser.add_argument('-n', dest='netFile', type=str, required=True, help='Path to Caffe netlist')
steniu01bee466b2017-06-21 16:45:41 +010021 args = parser.parse_args()
22
23 # Create Caffe Net
24 net = caffe.Net(args.netFile, 1, weights=args.modelFile)
25
26 # Read and dump blobs
27 for name, blobs in net.params.iteritems():
SiCong Li86b53332017-08-23 11:02:43 +010028 print('Name: {0}, Blobs: {1}'.format(name, len(blobs)))
steniu01bee466b2017-06-21 16:45:41 +010029 for i in range(len(blobs)):
30 # Weights
31 if i == 0:
32 outname = name + "_w"
33 # Bias
34 elif i == 1:
35 outname = name + "_b"
36 else:
Pablo Tello32521432018-11-15 14:43:10 +000037 continue
steniu01bee466b2017-06-21 16:45:41 +010038
SiCong Li86b53332017-08-23 11:02:43 +010039 varname = outname
40 if os.path.sep in varname:
41 varname = varname.replace(os.path.sep, '_')
42 print("Renaming variable {0} to {1}".format(outname, varname))
43 print("Saving variable {0} with shape {1} ...".format(varname, blobs[i].data.shape))
steniu01bee466b2017-06-21 16:45:41 +010044 # Dump as binary
SiCong Li86b53332017-08-23 11:02:43 +010045 np.save(varname, blobs[i].data)