blob: ad99b4828186eeadf85b4a1ee0c021e72b3b5bb5 [file] [log] [blame]
Matthew Sloyan28f177c2021-04-09 14:38:52 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <boost/test/unit_test.hpp>
7#include "ParserFlatbuffersFixture.hpp"
8#include "../TfLiteParser.hpp"
9
10#include <iostream>
11#include <string>
12
13BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
14
15struct ArgMinMaxFixture : public ParserFlatbuffersFixture
16{
17 explicit ArgMinMaxFixture(const std::string& operatorCode,
18 const std::string& inputShape,
19 const std::string& outputShape,
20 const std::string& axisData)
21 {
22 m_JsonString = R"(
23 {
24 "version": 3,
25 "operator_codes": [ { "builtin_code": )" + operatorCode + R"( } ],
26 "subgraphs": [ {
27 "tensors": [
28 {
29 "shape": )" + inputShape + R"(,
30 "type": "FLOAT32",
31 "buffer": 0,
32 "name": "inputTensor",
33 "quantization": {
34 "min": [ 0.0 ],
35 "max": [ 255.0 ],
36 "scale": [ 1.0 ],
37 "zero_point": [ 0 ],
38 }
39 },
40 {
41 "shape": )" + outputShape + R"( ,
42 "type": "INT32",
43 "buffer": 1,
44 "name": "outputTensor",
45 "quantization": {
46 "min": [ 0.0 ],
47 "max": [ 255.0 ],
48 "scale": [ 1.0 ],
49 "zero_point": [ 0 ],
50 }
51 },
52 {
53 "shape": [ 1 ],
54 "type": "INT32",
55 "buffer": 2,
56 "name": "axis",
57 "quantization": {
58 "min": [ 0.0 ],
59 "max": [ 255.0 ],
60 "scale": [ 1.0 ],
61 "zero_point": [ 0 ],
62 }
63 },
64 ],
65 "inputs": [ 0 ],
66 "outputs": [ 1 ],
67 "operators": [
68 {
69 "opcode_index": 0,
70 "inputs": [ 0 , 2 ],
71 "outputs": [ 1 ],
72 "custom_options_format": "FLEXBUFFERS"
73 }
74 ],
75 } ],
76 "buffers" : [
77 { },
78 { },
79 { "data": )" + axisData + R"(, },
80 ]
81 }
82 )";
83
84 SetupSingleInputSingleOutput("inputTensor", "outputTensor");
85 }
86};
87
88struct SimpleArgMaxFixture : public ArgMinMaxFixture
89{
90 SimpleArgMaxFixture() : ArgMinMaxFixture("ARG_MAX",
91 "[ 1, 1, 1, 5 ]",
92 "[ 1, 1, 1 ]",
93 "[ 3, 0, 0, 0 ]") {}
94};
95
96BOOST_FIXTURE_TEST_CASE(ParseSimpleArgMax, SimpleArgMaxFixture)
97{
98 RunTest<3, armnn::DataType::Float32, armnn::DataType::Signed32>(
99 0,
100 {{ "inputTensor", { 6.0f, 2.0f, 8.0f, 10.0f, 9.0f } } },
101 {{ "outputTensor", { 3l } } });
102}
103
104struct ArgMaxFixture : public ArgMinMaxFixture
105{
106 ArgMaxFixture() : ArgMinMaxFixture("ARG_MAX",
107 "[ 3, 2, 1, 4 ]",
108 "[ 2, 1, 4 ]",
109 "[ 0, 0, 0, 0 ]") {}
110};
111
112BOOST_FIXTURE_TEST_CASE(ParseArgMax, ArgMaxFixture)
113{
114 RunTest<3, armnn::DataType::Float32, armnn::DataType::Signed32>(
115 0,
116 {{ "inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f,
117 8.0f, 7.0f, 6.0f, 6.0f,
118 100.0f, 20.0f, 300.0f, 40.0f,
119 500.0f, 476.0f, 450.0f, 426.0f,
120 50.0f, 60.0f, 70.0f, 80.0f,
121 10.0f, 200.0f, 30.0f, 400.0f } } },
122 {{ "outputTensor", { 1, 2, 1, 2,
123 1, 1, 1, 1 } } });
124}
125
126struct SimpleArgMinFixture : public ArgMinMaxFixture
127{
128 SimpleArgMinFixture() : ArgMinMaxFixture("ARG_MIN",
129 "[ 1, 1, 1, 5 ]",
130 "[ 1, 1, 1 ]",
131 "[ 3, 0, 0, 0 ]") {}
132};
133
134BOOST_FIXTURE_TEST_CASE(ParseSimpleArgMin, SimpleArgMinFixture)
135{
136 RunTest<3, armnn::DataType::Float32, armnn::DataType::Signed32>(
137 0,
138 {{ "inputTensor", { 6.0f, 2.0f, 8.0f, 10.0f, 9.0f } } },
139 {{ "outputTensor", { 1l } } });
140}
141
142struct ArgMinFixture : public ArgMinMaxFixture
143{
144 ArgMinFixture() : ArgMinMaxFixture("ARG_MIN",
145 "[ 3, 2, 1, 4 ]",
146 "[ 2, 1, 4 ]",
147 "[ 0, 0, 0, 0 ]") {}
148};
149
150BOOST_FIXTURE_TEST_CASE(ParseArgMin, ArgMinFixture)
151{
152 RunTest<3, armnn::DataType::Float32, armnn::DataType::Signed32>(
153 0,
154 {{ "inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f,
155 8.0f, 7.0f, 6.0f, 6.0f,
156 100.0f, 20.0f, 300.0f, 40.0f,
157 500.0f, 476.0f, 450.0f, 426.0f,
158 50.0f, 60.0f, 70.0f, 80.0f,
159 10.0f, 200.0f, 30.0f, 400.0f } } },
160 {{ "outputTensor", { 0, 0, 0, 0,
161 0, 0, 0, 0 } } });
162}
163
164BOOST_AUTO_TEST_SUITE_END()