MLECO-1253 Adding ASR sample application using the PyArmNN api
Change-Id: I450b23800ca316a5bfd4608c8559cf4f11271c21
Signed-off-by: Éanna Ó Catháin <eanna.ocathain@arm.com>
diff --git a/python/pyarmnn/examples/image_classification/example_utils.py b/python/pyarmnn/examples/image_classification/example_utils.py
index 090ce2f..f0ba91e 100644
--- a/python/pyarmnn/examples/image_classification/example_utils.py
+++ b/python/pyarmnn/examples/image_classification/example_utils.py
@@ -38,7 +38,8 @@
runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
# Process output
- out_tensor = ann.workload_tensors_to_ndarray(output_tensors)[0]
+ # output tensor has a shape (1, 1001)
+ out_tensor = ann.workload_tensors_to_ndarray(output_tensors)[0][0]
results = np.argsort(out_tensor)[::-1]
print_top_n(5, results, labels, out_tensor)
@@ -121,7 +122,7 @@
return net_id, parser, runtime
-def create_tflite_network(model_file: str, backends: list = ['CpuAcc', 'CpuRef']):
+def create_tflite_network(model_file: str, backends: list = ('CpuAcc', 'CpuRef')):
"""Creates a network from a tflite model file.
Args:
@@ -140,7 +141,7 @@
return net_id, graph_id, parser, runtime
-def create_onnx_network(model_file: str, backends: list = ['CpuAcc', 'CpuRef']):
+def create_onnx_network(model_file: str, backends: list = ('CpuAcc', 'CpuRef')):
"""Creates a network from an onnx model file.
Args:
@@ -181,7 +182,7 @@
def load_images(image_files: list, input_width: int, input_height: int, data_type=np.uint8,
- scale: float = 1., mean: list = [0., 0., 0.], stddev: list = [1., 1., 1.],
+ scale: float = 1., mean: list = (0., 0., 0.), stddev: list = (1., 1., 1.),
preprocess_fn=preprocess_default):
"""Loads images, resizes and performs any additional preprocessing to run inference.
@@ -218,7 +219,6 @@
with open(label_file, 'r') as f:
labels = [l.rstrip() for l in f]
return labels
- return None
def print_top_n(N: int, results: list, labels: list, prob: list):
@@ -299,10 +299,10 @@
download_url = [download_url]
for dl in download_url:
archive = download_file(dl)
- if dl.lower().endswith(".zip"):
- unzip_file(archive)
+ if dl.lower().endswith(".zip"):
+ unzip_file(archive)
except RuntimeError:
- print("Unable to download file ({}).".format(archive_url))
+ print("Unable to download file ({}).".format(download_url))
if not os.path.exists(labels) or not os.path.exists(model):
raise RuntimeError("Unable to provide model and labels.")
@@ -310,7 +310,7 @@
return model, labels
-def list_images(folder: str = None, formats: list = ['.jpg', '.jpeg']):
+def list_images(folder: str = None, formats: list = ('.jpg', '.jpeg')):
"""Lists files of a certain format in a folder.
Args:
@@ -338,7 +338,7 @@
"""Gets image.
Args:
- image (str): Image filename
+ image_dir (str): Image filename
image_url (str): Image url
Returns: