blob: 6b35f63a00fe5b219526d318bc1dfc9b7a73e600 [file] [log] [blame]
Pavel Macenauer59e057f2020-04-15 14:17:26 +00001#!/usr/bin/env python3
Éanna Ó Catháin6c3dee42020-09-10 13:02:37 +01002# Copyright © 2020 NXP and Contributors. All rights reserved.
Pavel Macenauerd0fedae2020-04-15 14:52:57 +00003# SPDX-License-Identifier: MIT
4
Pavel Macenauerd0fedae2020-04-15 14:52:57 +00005import example_utils as eu
6import os
7
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +00008if __name__ == "__main__":
9 args = eu.parse_command_line()
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000010
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000011 # names of the files in the archive
12 labels_filename = 'labels_mobilenet_quant_v1_224.txt'
13 model_filename = 'mobilenet_v1_1.0_224_quant.tflite'
14 archive_filename = 'mobilenet_v1_1.0_224_quant_and_labels.zip'
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000015
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000016 archive_url = \
17 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_quant_and_labels.zip'
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000018
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000019 model_filename, labels_filename = eu.get_model_and_labels(args.model_dir, model_filename, labels_filename,
20 archive_filename, archive_url)
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000021
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000022 image_filenames = eu.get_images(args.data_dir)
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000023
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000024 # all 3 resources must exist to proceed further
25 assert os.path.exists(labels_filename)
26 assert os.path.exists(model_filename)
27 assert image_filenames
28 for im in image_filenames:
29 assert(os.path.exists(im))
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000030
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000031 # Create a network from the model file
32 net_id, graph_id, parser, runtime = eu.create_tflite_network(model_filename)
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000033
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000034 # Load input information from the model
35 # tflite has all the need information in the model unlike other formats
36 input_names = parser.GetSubgraphInputTensorNames(graph_id)
37 assert len(input_names) == 1 # there should be 1 input tensor in mobilenet
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000038
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000039 input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0])
40 input_width = input_binding_info[1].GetShape()[1]
41 input_height = input_binding_info[1].GetShape()[2]
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000042
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000043 # Load output information from the model and create output tensors
44 output_names = parser.GetSubgraphOutputTensorNames(graph_id)
45 assert len(output_names) == 1 # and only one output tensor
46 output_binding_info = parser.GetNetworkOutputBindingInfo(graph_id, output_names[0])
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000047
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000048 # Load labels file
49 labels = eu.load_labels(labels_filename)
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000050
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000051 # Load images and resize to expected size
52 images = eu.load_images(image_filenames, input_width, input_height)
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000053
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000054 eu.run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info)