blob: d4dadf80a433ad81072f60fb91e18405b187a048 [file] [log] [blame]
alexanderf42f5682021-07-16 11:30:56 +01001# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +00002# SPDX-License-Identifier: MIT
3
4"""Contains helper functions that can be used across the example apps."""
5
6import os
7import errno
8from pathlib import Path
9
10import numpy as np
alexanderf42f5682021-07-16 11:30:56 +010011import pyarmnn as ann
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000012
13
14def dict_labels(labels_file_path: str, include_rgb=False) -> dict:
15 """Creates a dictionary of labels from the input labels file.
16
17 Args:
18 labels_file: Path to file containing labels to map model outputs.
19 include_rgb: Adds randomly generated RGB values to the values of the
20 dictionary. Used for plotting bounding boxes of different colours.
21
22 Returns:
23 Dictionary with classification indices for keys and labels for values.
24
25 Raises:
26 FileNotFoundError:
27 Provided `labels_file_path` does not exist.
28 """
29 labels_file = Path(labels_file_path)
30 if not labels_file.is_file():
31 raise FileNotFoundError(
32 errno.ENOENT, os.strerror(errno.ENOENT), labels_file_path
33 )
34
35 labels = {}
36 with open(labels_file, "r") as f:
37 for idx, line in enumerate(f, 0):
38 if include_rgb:
39 labels[idx] = line.strip("\n"), tuple(np.random.random(size=3) * 255)
40 else:
41 labels[idx] = line.strip("\n")
42 return labels
alexanderf42f5682021-07-16 11:30:56 +010043
44
45def prepare_input_tensors(audio_data, input_binding_info, mfcc_preprocessor):
46 """
47 Takes a block of audio data, extracts the MFCC features, quantizes the array, and uses ArmNN to create the
48 input tensors.
49
50 Args:
51 audio_data: The audio data to process
52 mfcc_instance: the mfcc class instance
53 input_binding_info: the model input binding info
54 mfcc_preprocessor: the mfcc preprocessor instance
55 Returns:
56 input_tensors: the prepared input tensors, ready to be consumed by the ArmNN NetworkExecutor
57 """
58
59 data_type = input_binding_info[1].GetDataType()
60 input_tensor = mfcc_preprocessor.extract_features(audio_data)
61 if data_type != ann.DataType_Float32:
62 input_tensor = quantize_input(input_tensor, input_binding_info)
63 input_tensors = ann.make_input_tensors([input_binding_info], [input_tensor])
64 return input_tensors
65
66
67def quantize_input(data, input_binding_info):
68 """Quantize the float input to (u)int8 ready for inputting to model."""
69 if data.ndim != 2:
70 raise RuntimeError("Audio data must have 2 dimensions for quantization")
71
72 quant_scale = input_binding_info[1].GetQuantizationScale()
73 quant_offset = input_binding_info[1].GetQuantizationOffset()
74 data_type = input_binding_info[1].GetDataType()
75
76 if data_type == ann.DataType_QAsymmS8:
77 data_type = np.int8
78 elif data_type == ann.DataType_QAsymmU8:
79 data_type = np.uint8
80 else:
81 raise ValueError("Could not quantize data to required data type")
82
83 d_min = np.iinfo(data_type).min
84 d_max = np.iinfo(data_type).max
85
86 for row in range(data.shape[0]):
87 for col in range(data.shape[1]):
88 data[row, col] = (data[row, col] / quant_scale) + quant_offset
89 data[row, col] = np.clip(data[row, col], d_min, d_max)
90 data = data.astype(data_type)
91 return data
92
93
94def dequantize_output(data, output_binding_info):
95 """Dequantize the (u)int8 output to float"""
96
97 if output_binding_info[1].IsQuantized():
98 if data.ndim != 2:
99 raise RuntimeError("Data must have 2 dimensions for quantization")
100
101 quant_scale = output_binding_info[1].GetQuantizationScale()
102 quant_offset = output_binding_info[1].GetQuantizationOffset()
103
104 data = data.astype(float)
105 for row in range(data.shape[0]):
106 for col in range(data.shape[1]):
107 data[row, col] = (data[row, col] - quant_offset)*quant_scale
108 return data