blob: 1dbf0e127edd197465cc1d9afa3f9b04318db83c [file] [log] [blame]
SiCong Li86b53332017-08-23 11:02:43 +01001#!/usr/bin/env python
2"""Extracts trainable parameters from Tensorflow models and stores them in numpy arrays.
3Usage
4 python tensorflow_data_extractor -m path_to_binary_checkpoint_file -n path_to_metagraph_file
5
6Saves each variable to a {variable_name}.npy binary file.
7
8Note that since Tensorflow version 0.11 the binary checkpoint file which contains the values for each parameter has the format of:
9 {model_name}.data-{step}-of-{max_step}
10instead of:
11 {model_name}.ckpt
12When dealing with binary files with version >= 0.11, only pass {model_name} to -m option;
13when dealing with binary files with version < 0.11, pass the whole file name {model_name}.ckpt to -m option.
14
15Also note that this script relies on the parameters to be extracted being in the
16'trainable_variables' tensor collection. By default all variables are automatically added to this collection unless
17specified otherwise by the user. Thus should a user alter this default behavior and/or want to extract parameters from other
18collections, tf.GraphKeys.TRAINABLE_VARIABLES should be replaced accordingly.
19
20Tested with Tensorflow 1.2, 1.3 on Python 2.7.6 and Python 3.4.3.
21"""
22import argparse
23import numpy as np
24import os
25import tensorflow as tf
26
27
28if __name__ == "__main__":
29 # Parse arguments
30 parser = argparse.ArgumentParser('Extract Tensorflow net parameters')
31 parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Path to Tensorflow checkpoint binary\
32 file. For Tensorflow version >= 0.11, only include model name; for Tensorflow version < 0.11, include\
33 model name with ".ckpt" extension')
34 parser.add_argument('-n', dest='netFile', type=str, required=True, help='Path to Tensorflow MetaGraph file')
35 args = parser.parse_args()
36
37 # Load Tensorflow Net
38 saver = tf.train.import_meta_graph(args.netFile)
39 with tf.Session() as sess:
40 # Restore session
41 saver.restore(sess, args.modelFile)
42 print('Model restored.')
43 # Save trainable variables to numpy arrays
44 for t in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
45 varname = t.name
46 if os.path.sep in t.name:
47 varname = varname.replace(os.path.sep, '_')
48 print("Renaming variable {0} to {1}".format(t.name, varname))
49 print("Saving variable {0} with shape {1} ...".format(varname, t.shape))
50 # Dump as binary
51 np.save(varname, sess.run(t))