Georgios Pinitas | 588ebc5 | 2018-12-21 13:39:07 +0000 | [diff] [blame] | 1 | #!/usr/bin/env python |
| 2 | """ Extract trainable parameters from a frozen model and stores them in numpy arrays. |
| 3 | Usage: |
| 4 | python tf_frozen_model_extractor -m path_to_frozem_model -d path_to_store_the_parameters |
| 5 | |
| 6 | Saves each variable to a {variable_name}.npy binary file. |
| 7 | |
| 8 | Note that the script permutes the trainable parameters to NCHW format. This is a pretty manual step thus it's not thoroughly tested. |
| 9 | """ |
| 10 | import argparse |
| 11 | import os |
| 12 | import numpy as np |
| 13 | import tensorflow as tf |
| 14 | from tensorflow.python.platform import gfile |
| 15 | |
| 16 | strings_to_remove=["read", "/:0"] |
| 17 | permutations = { 1 : [0], 2 : [1, 0], 3 : [2, 1, 0], 4 : [3, 2, 0, 1]} |
| 18 | |
| 19 | if __name__ == "__main__": |
| 20 | # Parse arguments |
| 21 | parser = argparse.ArgumentParser('Extract TensorFlow net parameters') |
| 22 | parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Path to TensorFlow frozen graph file (.pb)') |
| 23 | parser.add_argument('-d', dest='dumpPath', type=str, required=False, default='./', help='Path to store the resulting files.') |
| 24 | parser.add_argument('--nostore', dest='storeRes', action='store_false', help='Specify if files should not be stored. Used for debugging.') |
| 25 | parser.set_defaults(storeRes=True) |
| 26 | args = parser.parse_args() |
| 27 | |
| 28 | # Create directory if not present |
| 29 | if not os.path.exists(args.dumpPath): |
| 30 | os.makedirs(args.dumpPath) |
| 31 | |
| 32 | # Extract parameters |
| 33 | with tf.Graph().as_default() as graph: |
| 34 | with tf.Session() as sess: |
| 35 | print("Loading model.") |
| 36 | with gfile.FastGFile(args.modelFile, 'rb') as f: |
| 37 | graph_def = tf.GraphDef() |
| 38 | graph_def.ParseFromString(f.read()) |
| 39 | sess.graph.as_default() |
| 40 | |
| 41 | tf.import_graph_def(graph_def, input_map=None, return_elements=None, name="", op_dict=None, producer_op_list=None) |
| 42 | |
| 43 | for op in graph.get_operations(): |
| 44 | for op_val in op.values(): |
| 45 | varname = op_val.name |
| 46 | |
| 47 | # Skip non-const values |
| 48 | if "read" in varname: |
| 49 | t = op_val.eval() |
| 50 | tT = t.transpose(permutations[len(t.shape)]) |
| 51 | t = np.ascontiguousarray(tT) |
| 52 | |
| 53 | for s in strings_to_remove: |
| 54 | varname = varname.replace(s, "") |
| 55 | if os.path.sep in varname: |
| 56 | varname = varname.replace(os.path.sep, '_') |
| 57 | print("Renaming variable {0} to {1}".format(op_val.name, varname)) |
| 58 | |
| 59 | # Store files |
| 60 | if args.storeRes: |
| 61 | print("Saving variable {0} with shape {1} ...".format(varname, t.shape)) |
| 62 | np.save(os.path.join(args.dumpPath, varname), t) |