SiCong Li | 86b5333 | 2017-08-23 11:02:43 +0100 | [diff] [blame] | 1 | #!/usr/bin/env python |
| 2 | """Extracts trainable parameters from Tensorflow models and stores them in numpy arrays. |
| 3 | Usage |
| 4 | python tensorflow_data_extractor -m path_to_binary_checkpoint_file -n path_to_metagraph_file |
| 5 | |
| 6 | Saves each variable to a {variable_name}.npy binary file. |
| 7 | |
| 8 | Note that since Tensorflow version 0.11 the binary checkpoint file which contains the values for each parameter has the format of: |
| 9 | {model_name}.data-{step}-of-{max_step} |
| 10 | instead of: |
| 11 | {model_name}.ckpt |
| 12 | When dealing with binary files with version >= 0.11, only pass {model_name} to -m option; |
| 13 | when dealing with binary files with version < 0.11, pass the whole file name {model_name}.ckpt to -m option. |
| 14 | |
| 15 | Also note that this script relies on the parameters to be extracted being in the |
| 16 | 'trainable_variables' tensor collection. By default all variables are automatically added to this collection unless |
| 17 | specified otherwise by the user. Thus should a user alter this default behavior and/or want to extract parameters from other |
| 18 | collections, tf.GraphKeys.TRAINABLE_VARIABLES should be replaced accordingly. |
| 19 | |
| 20 | Tested with Tensorflow 1.2, 1.3 on Python 2.7.6 and Python 3.4.3. |
| 21 | """ |
| 22 | import argparse |
| 23 | import numpy as np |
| 24 | import os |
| 25 | import tensorflow as tf |
| 26 | |
| 27 | |
| 28 | if __name__ == "__main__": |
| 29 | # Parse arguments |
| 30 | parser = argparse.ArgumentParser('Extract Tensorflow net parameters') |
| 31 | parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Path to Tensorflow checkpoint binary\ |
| 32 | file. For Tensorflow version >= 0.11, only include model name; for Tensorflow version < 0.11, include\ |
| 33 | model name with ".ckpt" extension') |
| 34 | parser.add_argument('-n', dest='netFile', type=str, required=True, help='Path to Tensorflow MetaGraph file') |
| 35 | args = parser.parse_args() |
| 36 | |
| 37 | # Load Tensorflow Net |
| 38 | saver = tf.train.import_meta_graph(args.netFile) |
| 39 | with tf.Session() as sess: |
| 40 | # Restore session |
| 41 | saver.restore(sess, args.modelFile) |
| 42 | print('Model restored.') |
| 43 | # Save trainable variables to numpy arrays |
| 44 | for t in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES): |
| 45 | varname = t.name |
| 46 | if os.path.sep in t.name: |
| 47 | varname = varname.replace(os.path.sep, '_') |
| 48 | print("Renaming variable {0} to {1}".format(t.name, varname)) |
| 49 | print("Saving variable {0} with shape {1} ...".format(varname, t.shape)) |
| 50 | # Dump as binary |
| 51 | np.save(varname, sess.run(t)) |