blob: f266c16537bd296cba471051f07ff9e0987641e0 [file] [log] [blame]
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +00001# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
2# SPDX-License-Identifier: MIT
3
4import os
Raviv Shalev97ddc062021-12-07 15:18:09 +02005import pytest
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +00006import cv2
Raviv Shalev97ddc062021-12-07 15:18:09 +02007import numpy as np
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +00008
9from context import network_executor
Raviv Shalev97ddc062021-12-07 15:18:09 +020010from context import network_executor_tflite
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000011from context import cv_utils
12
Raviv Shalev97ddc062021-12-07 15:18:09 +020013@pytest.mark.parametrize("executor_name", ["armnn", "tflite"])
14def test_execute_network(test_data_folder, executor_name):
alexanderf42f5682021-07-16 11:30:56 +010015 model_path = os.path.join(test_data_folder, "ssd_mobilenet_v1.tflite")
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000016 backends = ["CpuAcc", "CpuRef"]
Raviv Shalev97ddc062021-12-07 15:18:09 +020017 if executor_name == "armnn":
18 executor = network_executor.ArmnnNetworkExecutor(model_path, backends)
19 elif executor_name == "tflite":
20 delegate_path = os.path.join(test_data_folder, "libarmnnDelegate.so")
21 executor = network_executor_tflite.TFLiteNetworkExecutor(model_path, backends, delegate_path)
22 else:
23 raise f"unsupported executor_name: {executor_name}"
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000024
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000025 img = cv2.imread(os.path.join(test_data_folder, "messi5.jpg"))
Raviv Shalev97ddc062021-12-07 15:18:09 +020026 resized_img = cv_utils.preprocess(img, executor.get_data_type(), executor.get_shape(), True)
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000027
Raviv Shalev97ddc062021-12-07 15:18:09 +020028 output_result = executor.run([resized_img])
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000029
30 # Ensure it detects a person
31 classes = output_result[1]
32 assert classes[0][0] == 0
Raviv Shalev97ddc062021-12-07 15:18:09 +020033
34 # Unit tests for network executor class functions - specifically for ssd_mobilenet_v1.tflite network
35 assert executor.get_data_type() == np.uint8
36 assert executor.get_shape() == (1, 300, 300, 3)