blob: 73889863f0cdf889d4f51f012dbf3433b570b6ab [file] [log] [blame]
Sadik Armagandc032fc2021-01-19 17:24:21 +00001//
Teresa Charlinad1b3d72023-03-14 12:10:28 +00002// Copyright © 2021, 2023 Arm Ltd and Contributors. All rights reserved.
Sadik Armagandc032fc2021-01-19 17:24:21 +00003// SPDX-License-Identifier: MIT
4//
5
6#include "ArgMinMaxTestHelper.hpp"
7
8#include <armnn_delegate.hpp>
9
10#include <flatbuffers/flatbuffers.h>
Sadik Armagandc032fc2021-01-19 17:24:21 +000011
12#include <doctest/doctest.h>
13
14namespace armnnDelegate
15{
16
Colm Donelaneff204a2023-11-28 15:46:09 +000017void ArgMaxFP32Test(int axisValue)
Sadik Armagandc032fc2021-01-19 17:24:21 +000018{
19 // Set input data
20 std::vector<int32_t> inputShape { 1, 3, 2, 4 };
21 std::vector<int32_t> outputShape { 1, 3, 4 };
22 std::vector<int32_t> axisShape { 1 };
23
24 std::vector<float> inputValues = { 1.0f, 2.0f, 3.0f, 4.0f,
25 5.0f, 6.0f, 7.0f, 8.0f,
26
27 10.0f, 20.0f, 30.0f, 40.0f,
28 50.0f, 60.0f, 70.0f, 80.0f,
29
30 100.0f, 200.0f, 300.0f, 400.0f,
31 500.0f, 600.0f, 700.0f, 800.0f };
32
33 std::vector<int32_t> expectedOutputValues = { 1, 1, 1, 1,
34 1, 1, 1, 1,
35 1, 1, 1, 1 };
36
37 ArgMinMaxTest<float, int32_t>(tflite::BuiltinOperator_ARG_MAX,
38 ::tflite::TensorType_FLOAT32,
Sadik Armagandc032fc2021-01-19 17:24:21 +000039 inputShape,
40 axisShape,
41 outputShape,
42 inputValues,
43 expectedOutputValues,
44 axisValue,
45 ::tflite::TensorType_INT32);
46}
47
Colm Donelaneff204a2023-11-28 15:46:09 +000048void ArgMinFP32Test(int axisValue)
Sadik Armagandc032fc2021-01-19 17:24:21 +000049{
50 // Set input data
51 std::vector<int32_t> inputShape { 1, 3, 2, 4 };
52 std::vector<int32_t> outputShape { 1, 3, 2 };
53 std::vector<int32_t> axisShape { 1 };
54
55 std::vector<float> inputValues = { 1.0f, 2.0f, 3.0f, 4.0f,
56 5.0f, 6.0f, 7.0f, 8.0f,
57
58 10.0f, 20.0f, 30.0f, 40.0f,
59 50.0f, 60.0f, 70.0f, 80.0f,
60
61 100.0f, 200.0f, 300.0f, 400.0f,
62 500.0f, 600.0f, 700.0f, 800.0f };
63
64 std::vector<int32_t> expectedOutputValues = { 0, 0,
65 0, 0,
66 0, 0 };
67
68 ArgMinMaxTest<float, int32_t>(tflite::BuiltinOperator_ARG_MIN,
69 ::tflite::TensorType_FLOAT32,
Sadik Armagandc032fc2021-01-19 17:24:21 +000070 inputShape,
71 axisShape,
72 outputShape,
73 inputValues,
74 expectedOutputValues,
75 axisValue,
76 ::tflite::TensorType_INT32);
77}
78
Colm Donelaneff204a2023-11-28 15:46:09 +000079void ArgMaxUint8Test(int axisValue)
Sadik Armagandc032fc2021-01-19 17:24:21 +000080{
81 // Set input data
82 std::vector<int32_t> inputShape { 1, 1, 1, 5 };
83 std::vector<int32_t> outputShape { 1, 1, 1 };
84 std::vector<int32_t> axisShape { 1 };
85
86 std::vector<uint8_t> inputValues = { 5, 2, 8, 10, 9 };
87
88 std::vector<int32_t> expectedOutputValues = { 3 };
89
90 ArgMinMaxTest<uint8_t, int32_t>(tflite::BuiltinOperator_ARG_MAX,
91 ::tflite::TensorType_UINT8,
Sadik Armagandc032fc2021-01-19 17:24:21 +000092 inputShape,
93 axisShape,
94 outputShape,
95 inputValues,
96 expectedOutputValues,
97 axisValue,
98 ::tflite::TensorType_INT32);
99}
100
Colm Donelaneff204a2023-11-28 15:46:09 +0000101TEST_SUITE("ArgMinMax_Tests")
Sadik Armagandc032fc2021-01-19 17:24:21 +0000102{
103
Colm Donelaneff204a2023-11-28 15:46:09 +0000104 TEST_CASE("ArgMaxFP32Test_Test")
105 {
106 ArgMaxFP32Test(2);
107 }
108
109 TEST_CASE("ArgMinFP32Test_Test")
110 {
111 ArgMinFP32Test(3);
112 }
113
114 TEST_CASE("ArgMaxUint8Test_Test")
115 {
116 ArgMaxUint8Test(-1);
117 }
Sadik Armagandc032fc2021-01-19 17:24:21 +0000118}
119
Sadik Armagandc032fc2021-01-19 17:24:21 +0000120} // namespace armnnDelegate