blob: 8ae4e7e0d144f040461fef46727e87d922812391 [file] [log] [blame]
Sadik Armagandc032fc2021-01-19 17:24:21 +00001//
Colm Donelan7bcae3c2024-01-22 10:07:14 +00002// Copyright © 2021, 2023-2024 Arm Ltd and Contributors. All rights reserved.
Sadik Armagandc032fc2021-01-19 17:24:21 +00003// SPDX-License-Identifier: MIT
4//
5
6#include "ArgMinMaxTestHelper.hpp"
7
Sadik Armagandc032fc2021-01-19 17:24:21 +00008#include <doctest/doctest.h>
9
10namespace armnnDelegate
11{
12
Colm Donelaneff204a2023-11-28 15:46:09 +000013void ArgMaxFP32Test(int axisValue)
Sadik Armagandc032fc2021-01-19 17:24:21 +000014{
15 // Set input data
16 std::vector<int32_t> inputShape { 1, 3, 2, 4 };
17 std::vector<int32_t> outputShape { 1, 3, 4 };
18 std::vector<int32_t> axisShape { 1 };
19
20 std::vector<float> inputValues = { 1.0f, 2.0f, 3.0f, 4.0f,
21 5.0f, 6.0f, 7.0f, 8.0f,
22
23 10.0f, 20.0f, 30.0f, 40.0f,
24 50.0f, 60.0f, 70.0f, 80.0f,
25
26 100.0f, 200.0f, 300.0f, 400.0f,
27 500.0f, 600.0f, 700.0f, 800.0f };
28
29 std::vector<int32_t> expectedOutputValues = { 1, 1, 1, 1,
30 1, 1, 1, 1,
31 1, 1, 1, 1 };
32
33 ArgMinMaxTest<float, int32_t>(tflite::BuiltinOperator_ARG_MAX,
34 ::tflite::TensorType_FLOAT32,
Sadik Armagandc032fc2021-01-19 17:24:21 +000035 inputShape,
36 axisShape,
37 outputShape,
38 inputValues,
39 expectedOutputValues,
40 axisValue,
41 ::tflite::TensorType_INT32);
42}
43
Colm Donelaneff204a2023-11-28 15:46:09 +000044void ArgMinFP32Test(int axisValue)
Sadik Armagandc032fc2021-01-19 17:24:21 +000045{
46 // Set input data
47 std::vector<int32_t> inputShape { 1, 3, 2, 4 };
48 std::vector<int32_t> outputShape { 1, 3, 2 };
49 std::vector<int32_t> axisShape { 1 };
50
51 std::vector<float> inputValues = { 1.0f, 2.0f, 3.0f, 4.0f,
52 5.0f, 6.0f, 7.0f, 8.0f,
53
54 10.0f, 20.0f, 30.0f, 40.0f,
55 50.0f, 60.0f, 70.0f, 80.0f,
56
57 100.0f, 200.0f, 300.0f, 400.0f,
58 500.0f, 600.0f, 700.0f, 800.0f };
59
60 std::vector<int32_t> expectedOutputValues = { 0, 0,
61 0, 0,
62 0, 0 };
63
64 ArgMinMaxTest<float, int32_t>(tflite::BuiltinOperator_ARG_MIN,
65 ::tflite::TensorType_FLOAT32,
Sadik Armagandc032fc2021-01-19 17:24:21 +000066 inputShape,
67 axisShape,
68 outputShape,
69 inputValues,
70 expectedOutputValues,
71 axisValue,
72 ::tflite::TensorType_INT32);
73}
74
Colm Donelaneff204a2023-11-28 15:46:09 +000075void ArgMaxUint8Test(int axisValue)
Sadik Armagandc032fc2021-01-19 17:24:21 +000076{
77 // Set input data
78 std::vector<int32_t> inputShape { 1, 1, 1, 5 };
79 std::vector<int32_t> outputShape { 1, 1, 1 };
80 std::vector<int32_t> axisShape { 1 };
81
82 std::vector<uint8_t> inputValues = { 5, 2, 8, 10, 9 };
83
84 std::vector<int32_t> expectedOutputValues = { 3 };
85
86 ArgMinMaxTest<uint8_t, int32_t>(tflite::BuiltinOperator_ARG_MAX,
87 ::tflite::TensorType_UINT8,
Sadik Armagandc032fc2021-01-19 17:24:21 +000088 inputShape,
89 axisShape,
90 outputShape,
91 inputValues,
92 expectedOutputValues,
93 axisValue,
94 ::tflite::TensorType_INT32);
95}
96
Colm Donelaneff204a2023-11-28 15:46:09 +000097TEST_SUITE("ArgMinMax_Tests")
Sadik Armagandc032fc2021-01-19 17:24:21 +000098{
99
Colm Donelaneff204a2023-11-28 15:46:09 +0000100 TEST_CASE("ArgMaxFP32Test_Test")
101 {
102 ArgMaxFP32Test(2);
103 }
104
105 TEST_CASE("ArgMinFP32Test_Test")
106 {
107 ArgMinFP32Test(3);
108 }
109
110 TEST_CASE("ArgMaxUint8Test_Test")
111 {
112 ArgMaxUint8Test(-1);
113 }
Sadik Armagandc032fc2021-01-19 17:24:21 +0000114}
115
Sadik Armagandc032fc2021-01-19 17:24:21 +0000116} // namespace armnnDelegate