steniu01 | bee466b | 2017-06-21 16:45:41 +0100 | [diff] [blame^] | 1 | #!/usr/bin/env python |
| 2 | import argparse |
| 3 | |
| 4 | import caffe |
| 5 | import numpy as np |
| 6 | import scipy.io |
| 7 | |
| 8 | |
| 9 | if __name__ == "__main__": |
| 10 | # Parse arguments |
| 11 | parser = argparse.ArgumentParser('Extract CNN hyper-parameters') |
| 12 | parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Caffe model file') |
| 13 | parser.add_argument('-n', dest='netFile', type=str, required=True, help='Caffe netlist') |
| 14 | args = parser.parse_args() |
| 15 | |
| 16 | # Create Caffe Net |
| 17 | net = caffe.Net(args.netFile, 1, weights=args.modelFile) |
| 18 | |
| 19 | # Read and dump blobs |
| 20 | for name, blobs in net.params.iteritems(): |
| 21 | print 'Name: {0}, Blobs: {1}'.format(name, len(blobs)) |
| 22 | for i in range(len(blobs)): |
| 23 | # Weights |
| 24 | if i == 0: |
| 25 | outname = name + "_w" |
| 26 | # Bias |
| 27 | elif i == 1: |
| 28 | outname = name + "_b" |
| 29 | else: |
| 30 | pass |
| 31 | |
| 32 | print("%s : %s" % (outname, blobs[i].data.shape)) |
| 33 | # Dump as binary |
| 34 | blobs[i].data.tofile(outname + ".dat") |