blob: 73816f858843a823281275b674216d7f627a7e21 [file] [log] [blame]
Tracy Narine944fb502023-07-04 15:08:57 +01001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ParserFlatbuffersSerializeFixture.hpp"
7#include <armnnDeserializer/IDeserializer.hpp>
8
9#include <doctest/doctest.h>
10
11#include <string>
12
13TEST_SUITE("Deserializer_ReverseV2")
14{
15 struct ReverseV2Fixture : public ParserFlatbuffersSerializeFixture
16 {
17 explicit ReverseV2Fixture(const std::string& inputShape,
18 const std::string& outputShape,
19 const std::string& dataType,
20 const std::string& axis)
21 {
22 m_JsonString = R"(
23 {
24 inputIds: [0],
25 outputIds: [2],
26 layers: [
27 {
28 layer_type: "InputLayer",
29 layer: {
30 base: {
31 layerBindingId: 0,
32 base: {
33 index: 0,
34 layerName: "InputLayer",
35 layerType: "Input",
36 inputSlots: [{
37 index: 0,
38 connection: {sourceLayerIndex:0, outputSlotIndex:0 },
39 }],
40 outputSlots: [{
41 index: 0,
42 tensorInfo: {
43 dimensions: )" + inputShape + R"(,
44 dataType: )" + dataType + R"(
45 }
46 }]
47 }
48 }
49 }
50 },
51 {
52 layer_type: "ReverseV2Layer",
53 layer: {
54 base: {
55 index: 1,
56 layerName: "ReverseV2Layer",
57 layerType: "ReverseV2",
58 inputSlots: [{
59 index: 0,
60 connection: {sourceLayerIndex:0, outputSlotIndex:0 },
61 }],
62 outputSlots: [{
63 index: 0,
64 tensorInfo: {
65 dimensions: )" + outputShape + R"(,
66 dataType: )" + dataType + R"(
67 }
68 }]
69 },
70 descriptor: {
71 axis: )" + axis + R"(
72 }
73 }
74 },
75 {
76 layer_type: "OutputLayer",
77 layer: {
78 base:{
79 layerBindingId: 2,
80 base: {
81 index: 2,
82 layerName: "OutputLayer",
83 layerType: "Output",
84 inputSlots: [{
85 index: 0,
86 connection: {sourceLayerIndex:1, outputSlotIndex:0 },
87 }],
88 outputSlots: [{
89 index: 0,
90 tensorInfo: {
91 dimensions: )" + outputShape + R"(,
92 dataType: )" + dataType + R"(
93 },
94 }],
95 }
96 }
97 }
98 }
99 ]
100 }
101 )";
102
103 SetupSingleInputSingleOutput("InputLayer", "OutputLayer");
104 }
105 };
106
107 // Test cases
108
109 struct SimpleReverseV2FixtureFloat32 : ReverseV2Fixture
110 {
111 SimpleReverseV2FixtureFloat32()
112 : ReverseV2Fixture("[ 2, 2 ]",
113 "[ 2, 2 ]",
114 "Float32",
115 "[1]")
116 {}
117 };
118
119 TEST_CASE_FIXTURE(SimpleReverseV2FixtureFloat32, "SimpleReverseV2TestFloat32")
120 {
121 RunTest<4, armnn::DataType::Float32>(
122 0,
123 { 1.0f, 2.0f,
124 3.0f, 4.0f },
125 { 2.0f, 1.0f,
126 4.0f, 3.0f }
127 );
128 }
129
130 struct SimpleReverseV2FixtureFloat32ZeroAxis : ReverseV2Fixture
131 {
132 SimpleReverseV2FixtureFloat32ZeroAxis()
133 : ReverseV2Fixture("[ 2, 2 ]",
134 "[ 2, 2 ]",
135 "Float32",
136 "[0]")
137 {}
138 };
139
140 TEST_CASE_FIXTURE(SimpleReverseV2FixtureFloat32ZeroAxis, "SimpleReverseV2TestFloat32ZeroAxis")
141 {
142 RunTest<4, armnn::DataType::Float32>(
143 0,
144 { 1.0f, 2.0f,
145 3.0f, 4.0f },
146 { 3.0f, 4.0f,
147 1.0f, 2.0f }
148 );
149 }
150
151 struct SimpleReverseV2FixtureFloat32NegativeAxis : ReverseV2Fixture
152 {
153 SimpleReverseV2FixtureFloat32NegativeAxis()
154 : ReverseV2Fixture("[ 3, 3 ]",
155 "[ 3, 3 ]",
156 "Float32",
157 "[-1]")
158 {}
159 };
160
161 TEST_CASE_FIXTURE(SimpleReverseV2FixtureFloat32NegativeAxis, "SimpleReverseV2TestFloat32NegativeAxis")
162 {
163 RunTest<4, armnn::DataType::Float32>(
164 0,
165 { 1.0f, 2.0f, 3.0f,
166 4.0f, 5.0f, 6.0f,
167 7.0f, 8.0f, 9.0f },
168 { 3.0f, 2.0f, 1.0f,
169 6.0f, 5.0f, 4.0f,
170 9.0f, 8.0f, 7.0f }
171 );
172 }
173
174 struct SimpleReverseV2FixtureFloat32ThreeAxis : ReverseV2Fixture
175 {
176 SimpleReverseV2FixtureFloat32ThreeAxis()
177 : ReverseV2Fixture("[ 3, 3, 3 ]",
178 "[ 3, 3, 3 ]",
179 "Float32",
180 "[0, 2, 1]")
181 {}
182 };
183
184 TEST_CASE_FIXTURE(SimpleReverseV2FixtureFloat32ThreeAxis, "SimpleReverseV2TestFloat32ThreeAxis")
185 {
186 RunTest<4, armnn::DataType::Float32>(
187 0,
188 { 1.0f, 2.0f, 3.0f,
189 4.0f, 5.0f, 6.0f,
190 7.0f, 8.0f, 9.0f,
191
192 11.0f, 12.0f, 13.0f,
193 14.0f, 15.0f, 16.0f,
194 17.0f, 18.0f, 19.0f,
195
196 21.0f, 22.0f, 23.0f,
197 24.0f, 25.0f, 26.0f,
198 27.0f, 28.0f, 29.0f },
199 { 29.0f, 28.0f, 27.0f,
200 26.0f, 25.0f, 24.0f,
201 23.0f, 22.0f, 21.0f,
202
203 19.0f, 18.0f, 17.0f,
204 16.0f, 15.0f, 14.0f,
205 13.0f, 12.0f, 11.0f,
206
207 9.0f, 8.0f, 7.0f,
208 6.0f, 5.0f, 4.0f,
209 3.0f, 2.0f, 1.0f }
210 );
211 }
212
213 struct SimpleReverseV2FixtureQuantisedAsymm8ThreeAxis : ReverseV2Fixture
214 {
215 SimpleReverseV2FixtureQuantisedAsymm8ThreeAxis()
216 : ReverseV2Fixture("[ 3, 3, 3 ]",
217 "[ 3, 3, 3 ]",
218 "QuantisedAsymm8",
219 "[0, 2, 1]")
220 {}
221 };
222
223 TEST_CASE_FIXTURE(SimpleReverseV2FixtureQuantisedAsymm8ThreeAxis, "SimpleReverseV2TestQuantisedAsymm8ThreeAxis")
224 {
225 RunTest<4, armnn::DataType::QAsymmU8>(
226 0,
227 { 1, 2, 3,
228 4, 5, 6,
229 7, 8, 9,
230
231 11, 12, 13,
232 14, 15, 16,
233 17, 18, 19,
234
235 21, 22, 23,
236 24, 25, 26,
237 27, 28, 29 },
238 { 29, 28, 27,
239 26, 25, 24,
240 23, 22, 21,
241
242 19, 18, 17,
243 16, 15, 14,
244 13, 12, 11,
245
246 9, 8, 7,
247 6, 5, 4,
248 3, 2, 1 }
249 );
250 }
251}