blob: d437f4da7778b1214dc42d6edfbfe25122bf838f [file] [log] [blame]
Jan Eilers53ef7952021-06-02 12:01:25 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <reference/workloads/Decoders.hpp>
Jan Eilers53ef7952021-06-02 12:01:25 +01007
Teresa Charlin5306dc82023-10-30 22:29:58 +00008#include <armnn/utility/IgnoreUnused.hpp>
9
Jan Eilers53ef7952021-06-02 12:01:25 +010010#include <fmt/format.h>
11
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +010012#include <doctest/doctest.h>
Jan Eilers53ef7952021-06-02 12:01:25 +010013
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +010014#include <chrono>
Jan Eilers53ef7952021-06-02 12:01:25 +010015
16template<typename T>
17void CompareVector(std::vector<T> vec1, std::vector<T> vec2)
18{
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +010019 CHECK(vec1.size() == vec2.size());
Jan Eilers53ef7952021-06-02 12:01:25 +010020
21 bool mismatch = false;
Rob Hughes96fd98c2021-07-28 13:50:12 +010022 for (uint32_t i = 0; i < vec1.size(); ++i)
Jan Eilers53ef7952021-06-02 12:01:25 +010023 {
24 if (vec1[i] != vec2[i])
25 {
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +010026 MESSAGE(fmt::format("Vector value mismatch: index={} {} != {}",
27 i,
28 vec1[i],
29 vec2[i]));
30
Jan Eilers53ef7952021-06-02 12:01:25 +010031 mismatch = true;
32 }
33 }
34
35 if (mismatch)
36 {
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +010037 FAIL("Error in CompareVector. Vectors don't match.");
Jan Eilers53ef7952021-06-02 12:01:25 +010038 }
39}
40
41using namespace armnn;
42
43// Basically a per axis decoder but without any decoding/quantization
44class MockPerAxisIterator : public PerAxisIterator<const int8_t, Decoder<int8_t>>
45{
46public:
47 MockPerAxisIterator(const int8_t* data, const armnn::TensorShape& tensorShape, const unsigned int axis)
48 : PerAxisIterator(data, tensorShape, axis), m_NumElements(tensorShape.GetNumElements())
49 {}
50
51 int8_t Get() const override
52 {
53 return *m_Iterator;
54 }
55
56 virtual std::vector<float> DecodeTensor(const TensorShape &tensorShape,
57 bool isDepthwise = false) override
58 {
59 IgnoreUnused(tensorShape, isDepthwise);
60 return std::vector<float>{};
61 };
62
63 // Iterates over data using operator[] and returns vector
64 std::vector<int8_t> Loop()
65 {
66 std::vector<int8_t> vec;
67 for (uint32_t i = 0; i < m_NumElements; ++i)
68 {
69 this->operator[](i);
70 vec.emplace_back(Get());
71 }
72 return vec;
73 }
74
75 unsigned int GetAxisIndex()
76 {
77 return m_AxisIndex;
78 }
79 unsigned int m_NumElements;
80};
81
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +010082TEST_SUITE("RefPerAxisIterator")
83{
Jan Eilers53ef7952021-06-02 12:01:25 +010084// Test Loop (Equivalent to DecodeTensor) and Axis = 0
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +010085TEST_CASE("PerAxisIteratorTest1")
Jan Eilers53ef7952021-06-02 12:01:25 +010086{
87 std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
88 TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
89
90 // test axis=0
91 std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
92 auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 0);
93 std::vector<int8_t> output = iterator.Loop();
94 CompareVector(output, expOutput);
95
96 // Set iterator to index and check if the axis index is correct
97 iterator[5];
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +010098 CHECK(iterator.GetAxisIndex() == 1u);
Jan Eilers53ef7952021-06-02 12:01:25 +010099
100 iterator[1];
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100101 CHECK(iterator.GetAxisIndex() == 0u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100102
103 iterator[10];
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100104 CHECK(iterator.GetAxisIndex() == 2u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100105}
106
107// Test Axis = 1
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100108TEST_CASE("PerAxisIteratorTest2")
Jan Eilers53ef7952021-06-02 12:01:25 +0100109{
110 std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
111 TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
112
113 // test axis=1
114 std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
115 auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 1);
116 std::vector<int8_t> output = iterator.Loop();
117 CompareVector(output, expOutput);
118
119 // Set iterator to index and check if the axis index is correct
120 iterator[5];
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100121 CHECK(iterator.GetAxisIndex() == 0u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100122
123 iterator[1];
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100124 CHECK(iterator.GetAxisIndex() == 0u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100125
126 iterator[10];
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100127 CHECK(iterator.GetAxisIndex() == 0u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100128}
129
130// Test Axis = 2
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100131TEST_CASE("PerAxisIteratorTest3")
Jan Eilers53ef7952021-06-02 12:01:25 +0100132{
133 std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
134 TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
135
136 // test axis=2
137 std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
138 auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 2);
139 std::vector<int8_t> output = iterator.Loop();
140 CompareVector(output, expOutput);
141
142 // Set iterator to index and check if the axis index is correct
143 iterator[5];
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100144 CHECK(iterator.GetAxisIndex() == 0u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100145
146 iterator[1];
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100147 CHECK(iterator.GetAxisIndex() == 0u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100148
149 iterator[10];
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100150 CHECK(iterator.GetAxisIndex() == 1u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100151}
152
153// Test Axis = 3
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100154TEST_CASE("PerAxisIteratorTest4")
Jan Eilers53ef7952021-06-02 12:01:25 +0100155{
156 std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
157 TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
158
159 // test axis=3
160 std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
161 auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 3);
162 std::vector<int8_t> output = iterator.Loop();
163 CompareVector(output, expOutput);
164
165 // Set iterator to index and check if the axis index is correct
166 iterator[5];
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100167 CHECK(iterator.GetAxisIndex() == 1u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100168
169 iterator[1];
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100170 CHECK(iterator.GetAxisIndex() == 1u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100171
172 iterator[10];
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100173 CHECK(iterator.GetAxisIndex() == 0u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100174}
175
Jan Eilers53ef7952021-06-02 12:01:25 +0100176// Test Axis = 1. Different tensor shape
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100177TEST_CASE("PerAxisIteratorTest5")
Jan Eilers53ef7952021-06-02 12:01:25 +0100178{
179 using namespace armnn;
180 std::vector<int8_t> input =
181 {
182 0, 1, 2, 3,
183 4, 5, 6, 7,
184 8, 9, 10, 11,
185 12, 13, 14, 15
186 };
187
188 std::vector<int8_t> expOutput =
189 {
190 0, 1, 2, 3,
191 4, 5, 6, 7,
192 8, 9, 10, 11,
193 12, 13, 14, 15
194 };
195
196 TensorInfo tensorInfo ({2,2,2,2},DataType::QSymmS8);
197 auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 1);
198 std::vector<int8_t> output = iterator.Loop();
199 CompareVector(output, expOutput);
200
201 // Set iterator to index and check if the axis index is correct
202 iterator[5];
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100203 CHECK(iterator.GetAxisIndex() == 1u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100204
205 iterator[1];
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100206 CHECK(iterator.GetAxisIndex() == 0u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100207
208 iterator[10];
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100209 CHECK(iterator.GetAxisIndex() == 0u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100210}
211
212// Test the increment and decrement operator
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100213TEST_CASE("PerAxisIteratorTest7")
Jan Eilers53ef7952021-06-02 12:01:25 +0100214{
215 using namespace armnn;
216 std::vector<int8_t> input =
217 {
218 0, 1, 2, 3,
219 4, 5, 6, 7,
220 8, 9, 10, 11
221 };
222
223 std::vector<int8_t> expOutput =
224 {
225 0, 1, 2, 3,
226 4, 5, 6, 7,
227 8, 9, 10, 11
228 };
229
230 TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
231 auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 2);
232
233 iterator += 3;
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100234 CHECK(iterator.Get() == expOutput[3]);
235 CHECK(iterator.GetAxisIndex() == 1u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100236
237 iterator += 3;
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100238 CHECK(iterator.Get() == expOutput[6]);
239 CHECK(iterator.GetAxisIndex() == 1u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100240
241 iterator -= 2;
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100242 CHECK(iterator.Get() == expOutput[4]);
243 CHECK(iterator.GetAxisIndex() == 0u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100244
245 iterator -= 1;
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100246 CHECK(iterator.Get() == expOutput[3]);
247 CHECK(iterator.GetAxisIndex() == 1u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100248}
249
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100250}