| #!/usr/bin/env python |
| """Extracts trainable parameters from Tensorflow models and stores them in numpy arrays. |
| Usage |
| python tensorflow_data_extractor -m path_to_binary_checkpoint_file -n path_to_metagraph_file |
| |
| Saves each variable to a {variable_name}.npy binary file. |
| |
| Note that since Tensorflow version 0.11 the binary checkpoint file which contains the values for each parameter has the format of: |
| {model_name}.data-{step}-of-{max_step} |
| instead of: |
| {model_name}.ckpt |
| When dealing with binary files with version >= 0.11, only pass {model_name} to -m option; |
| when dealing with binary files with version < 0.11, pass the whole file name {model_name}.ckpt to -m option. |
| |
| Also note that this script relies on the parameters to be extracted being in the |
| 'trainable_variables' tensor collection. By default all variables are automatically added to this collection unless |
| specified otherwise by the user. Thus should a user alter this default behavior and/or want to extract parameters from other |
| collections, tf.GraphKeys.TRAINABLE_VARIABLES should be replaced accordingly. |
| |
| Tested with Tensorflow 1.2, 1.3 on Python 2.7.6 and Python 3.4.3. |
| """ |
| import argparse |
| import numpy as np |
| import os |
| import tensorflow as tf |
| |
| |
| if __name__ == "__main__": |
| # Parse arguments |
| parser = argparse.ArgumentParser('Extract Tensorflow net parameters') |
| parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Path to Tensorflow checkpoint binary\ |
| file. For Tensorflow version >= 0.11, only include model name; for Tensorflow version < 0.11, include\ |
| model name with ".ckpt" extension') |
| parser.add_argument('-n', dest='netFile', type=str, required=True, help='Path to Tensorflow MetaGraph file') |
| args = parser.parse_args() |
| |
| # Load Tensorflow Net |
| saver = tf.train.import_meta_graph(args.netFile) |
| with tf.Session() as sess: |
| # Restore session |
| saver.restore(sess, args.modelFile) |
| print('Model restored.') |
| # Save trainable variables to numpy arrays |
| for t in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES): |
| varname = t.name |
| if os.path.sep in t.name: |
| varname = varname.replace(os.path.sep, '_') |
| print("Renaming variable {0} to {1}".format(t.name, varname)) |
| print("Saving variable {0} with shape {1} ...".format(varname, t.shape)) |
| # Dump as binary |
| np.save(varname, sess.run(t)) |