blob: cb2c91cba7f560e511813193e6668619d6f94302 [file] [log] [blame]
Pavel Macenauer59e057f2020-04-15 14:17:26 +00001#!/usr/bin/env python3
Pavel Macenauerd0fedae2020-04-15 14:52:57 +00002# Copyright 2020 NXP
3# SPDX-License-Identifier: MIT
4
Pavel Macenauerd0fedae2020-04-15 14:52:57 +00005import numpy as np
6import pyarmnn as ann
7import example_utils as eu
8import os
9
Pavel Macenauer09daef82020-06-02 11:54:59 +000010args = eu.parse_command_line()
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000011
Pavel Macenauer09daef82020-06-02 11:54:59 +000012# names of the files in the archive
13labels_filename = 'labels_mobilenet_quant_v1_224.txt'
14model_filename = 'mobilenet_v1_1.0_224_quant.tflite'
15archive_filename = 'mobilenet_v1_1.0_224_quant_and_labels.zip'
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000016
Pavel Macenauer09daef82020-06-02 11:54:59 +000017archive_url = \
18 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_quant_and_labels.zip'
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000019
Pavel Macenauer09daef82020-06-02 11:54:59 +000020model_filename, labels_filename = eu.get_model_and_labels(args.model_dir, model_filename, labels_filename,
21 archive_filename, archive_url)
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000022
Pavel Macenauer09daef82020-06-02 11:54:59 +000023image_filenames = eu.get_images(args.data_dir)
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000024
Pavel Macenauer09daef82020-06-02 11:54:59 +000025# all 3 resources must exist to proceed further
26assert os.path.exists(labels_filename)
27assert os.path.exists(model_filename)
28assert image_filenames
29for im in image_filenames:
30 assert(os.path.exists(im))
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000031
Pavel Macenauer09daef82020-06-02 11:54:59 +000032# Create a network from the model file
33net_id, graph_id, parser, runtime = eu.create_tflite_network(model_filename)
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000034
Pavel Macenauer09daef82020-06-02 11:54:59 +000035# Load input information from the model
36# tflite has all the need information in the model unlike other formats
37input_names = parser.GetSubgraphInputTensorNames(graph_id)
38assert len(input_names) == 1 # there should be 1 input tensor in mobilenet
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000039
Pavel Macenauer09daef82020-06-02 11:54:59 +000040input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0])
41input_width = input_binding_info[1].GetShape()[1]
42input_height = input_binding_info[1].GetShape()[2]
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000043
Pavel Macenauer09daef82020-06-02 11:54:59 +000044# Load output information from the model and create output tensors
45output_names = parser.GetSubgraphOutputTensorNames(graph_id)
46assert len(output_names) == 1 # and only one output tensor
47output_binding_info = parser.GetNetworkOutputBindingInfo(graph_id, output_names[0])
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000048
Pavel Macenauer09daef82020-06-02 11:54:59 +000049# Load labels file
50labels = eu.load_labels(labels_filename)
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000051
Pavel Macenauer09daef82020-06-02 11:54:59 +000052# Load images and resize to expected size
53images = eu.load_images(image_filenames, input_width, input_height)
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000054
Pavel Macenauer09daef82020-06-02 11:54:59 +000055eu.run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info)