blob: 09ea0b86b0a684d7f717bee22ad209b9096df2e5 [file] [log] [blame]
steniu01bee466b2017-06-21 16:45:41 +01001#!/usr/bin/env python
2import argparse
3
4import caffe
5import numpy as np
6import scipy.io
7
8
9if __name__ == "__main__":
10 # Parse arguments
11 parser = argparse.ArgumentParser('Extract CNN hyper-parameters')
12 parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Caffe model file')
13 parser.add_argument('-n', dest='netFile', type=str, required=True, help='Caffe netlist')
14 args = parser.parse_args()
15
16 # Create Caffe Net
17 net = caffe.Net(args.netFile, 1, weights=args.modelFile)
18
19 # Read and dump blobs
20 for name, blobs in net.params.iteritems():
21 print 'Name: {0}, Blobs: {1}'.format(name, len(blobs))
22 for i in range(len(blobs)):
23 # Weights
24 if i == 0:
25 outname = name + "_w"
26 # Bias
27 elif i == 1:
28 outname = name + "_b"
29 else:
30 pass
31
32 print("%s : %s" % (outname, blobs[i].data.shape))
33 # Dump as binary
34 blobs[i].data.tofile(outname + ".dat")