Kshitij Sisodia | f9efe0d | 2022-09-30 16:42:50 +0100 | [diff] [blame] | 1 | # |
| 2 | # SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com> |
| 3 | # SPDX-License-Identifier: Apache-2.0 |
| 4 | # |
| 5 | from argparse import ArgumentParser |
| 6 | import os |
| 7 | import logging |
| 8 | from pathlib import Path |
| 9 | from typing import List |
| 10 | |
| 11 | import ethosu_driver as driver |
| 12 | try: |
| 13 | import numpy as np |
| 14 | with_numpy = True |
| 15 | except ImportError: |
| 16 | with_numpy = False |
| 17 | |
| 18 | |
| 19 | def read_bin_file_to_buf(file_path: str) -> bytearray: |
| 20 | with open(file_path, 'rb') as f: |
| 21 | return bytearray(f.read()) |
| 22 | |
| 23 | |
| 24 | def read_npy_file_to_buf(file_path: str) -> bytearray: |
| 25 | ifm_arr = np.load(file_path).astype(dtype=np.int8, order='C') |
| 26 | return ifm_arr.flatten().data |
| 27 | |
| 28 | |
| 29 | def read_ifms(ifm_files: List[str], use_npy: bool = False): |
| 30 | read_file_to_buf = read_npy_file_to_buf if use_npy else read_bin_file_to_buf |
| 31 | for ifm_file in ifm_files: |
| 32 | yield read_file_to_buf(ifm_file) |
| 33 | |
| 34 | |
| 35 | def write_npy(dir: str, file_name: str, data: memoryview): |
| 36 | ar = np.frombuffer(data, dtype=np.int8) |
| 37 | file_path = os.path.join(dir, "{}.npy".format(file_name)) |
| 38 | if os.path.isfile(file_path): |
| 39 | os.remove(file_path) |
| 40 | np.save(file_path, ar) |
| 41 | logging.info("File saved to {}".format(file_path)) |
| 42 | |
| 43 | |
| 44 | def write_bin_file(dir: str, file_name: str, data: memoryview): |
| 45 | file_path = os.path.join(dir, "{}.bin".format(file_name)) |
| 46 | if os.path.isfile(file_path): |
| 47 | os.remove(file_path) |
| 48 | with open(file_path, "wb") as f: |
| 49 | f.write(data) |
| 50 | logging.info("File saved to {}".format(file_path)) |
| 51 | |
| 52 | |
| 53 | def write_ofm(buf: memoryview, ofm_index: int, model_path: str, output_dir: str, use_npy: bool = False): |
| 54 | write_buf_to_file = write_npy if use_npy else write_bin_file |
| 55 | model_file_name = Path(model_path).name |
| 56 | ofm_name = "{}_ofm_{}".format(model_file_name, ofm_index) |
| 57 | write_buf_to_file(output_dir, ofm_name, buf) |
| 58 | |
| 59 | |
| 60 | def main(): |
| 61 | format = "%(asctime)s %(levelname)s - %(message)s" |
| 62 | logging.basicConfig(format=format, level=logging.INFO) |
| 63 | |
| 64 | parser = ArgumentParser() |
| 65 | parser.add_argument("--device", help="Npu device name. Default: ethosu0", default="ethosu0") |
| 66 | parser.add_argument("--model", help="Tflite model file path", required=True) |
| 67 | parser.add_argument("--timeout", help="Inference timout in seconds, Default: infinite", default=-1, type=int) |
| 68 | parser.add_argument("--inputs", nargs='+', help="list of files containing input feature maps", required=True) |
| 69 | parser.add_argument("--output_dir", help="directory to store inference results, output feature maps. " |
| 70 | "Default: current directory", default=os.getcwd()) |
| 71 | parser.add_argument("--npy", help="Use npy input/output", default=0, type=int) |
| 72 | parser.add_argument("--profile_counters", help="Performance counters to profile", nargs=4, type=int, required=True) |
| 73 | args = parser.parse_args() |
| 74 | |
| 75 | use_numpy = with_numpy & bool(int(args.npy)) |
| 76 | if use_numpy: |
| 77 | logging.info("Running with numpy inputs/outputs") |
| 78 | else: |
| 79 | logging.info("Running with byte array inputs/outputs") |
| 80 | |
| 81 | # @TODO: Discuss if this is needed anymore. Remove this commented line, if not. |
| 82 | # driver.reset() |
| 83 | |
| 84 | ifms_data = read_ifms(args.inputs, use_numpy) |
| 85 | |
| 86 | runner = driver.InferenceRunner(args.device, args.model) |
| 87 | runner.set_enabled_counters(args.profile_counters) |
| 88 | ofm_buffers = runner.run(list(ifms_data), int(args.timeout)) |
| 89 | |
| 90 | for index, buffer_out in enumerate(ofm_buffers): |
| 91 | logging.info("Output buffer size: {}".format(buffer_out.size())) |
| 92 | write_ofm(buffer_out.data(), index, args.model, args.output_dir, use_numpy) |
| 93 | |
| 94 | inference_pmu_counters = runner.get_pmu_counters() |
| 95 | |
| 96 | # Profiling |
| 97 | total_cycles = runner.get_pmu_total_cycles() |
| 98 | for pmu, value in inference_pmu_counters: |
| 99 | logging.info("\tNPU %d counter: %d", pmu, value) |
| 100 | logging.info("\tNPU TOTAL cycles: %d", total_cycles) |