blob: 92b828e067e19626a7d3b8d099b757d7325dcdac [file] [log] [blame]
Jan Eilers53ef7952021-06-02 12:01:25 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <reference/workloads/Decoders.hpp>
Jan Eilers53ef7952021-06-02 12:01:25 +01007
8#include <fmt/format.h>
9
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +010010#include <doctest/doctest.h>
Jan Eilers53ef7952021-06-02 12:01:25 +010011
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +010012#include <chrono>
Jan Eilers53ef7952021-06-02 12:01:25 +010013
14template<typename T>
15void CompareVector(std::vector<T> vec1, std::vector<T> vec2)
16{
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +010017 CHECK(vec1.size() == vec2.size());
Jan Eilers53ef7952021-06-02 12:01:25 +010018
19 bool mismatch = false;
Rob Hughes96fd98c2021-07-28 13:50:12 +010020 for (uint32_t i = 0; i < vec1.size(); ++i)
Jan Eilers53ef7952021-06-02 12:01:25 +010021 {
22 if (vec1[i] != vec2[i])
23 {
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +010024 MESSAGE(fmt::format("Vector value mismatch: index={} {} != {}",
25 i,
26 vec1[i],
27 vec2[i]));
28
Jan Eilers53ef7952021-06-02 12:01:25 +010029 mismatch = true;
30 }
31 }
32
33 if (mismatch)
34 {
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +010035 FAIL("Error in CompareVector. Vectors don't match.");
Jan Eilers53ef7952021-06-02 12:01:25 +010036 }
37}
38
39using namespace armnn;
40
41// Basically a per axis decoder but without any decoding/quantization
42class MockPerAxisIterator : public PerAxisIterator<const int8_t, Decoder<int8_t>>
43{
44public:
45 MockPerAxisIterator(const int8_t* data, const armnn::TensorShape& tensorShape, const unsigned int axis)
46 : PerAxisIterator(data, tensorShape, axis), m_NumElements(tensorShape.GetNumElements())
47 {}
48
49 int8_t Get() const override
50 {
51 return *m_Iterator;
52 }
53
54 virtual std::vector<float> DecodeTensor(const TensorShape &tensorShape,
55 bool isDepthwise = false) override
56 {
57 IgnoreUnused(tensorShape, isDepthwise);
58 return std::vector<float>{};
59 };
60
61 // Iterates over data using operator[] and returns vector
62 std::vector<int8_t> Loop()
63 {
64 std::vector<int8_t> vec;
65 for (uint32_t i = 0; i < m_NumElements; ++i)
66 {
67 this->operator[](i);
68 vec.emplace_back(Get());
69 }
70 return vec;
71 }
72
73 unsigned int GetAxisIndex()
74 {
75 return m_AxisIndex;
76 }
77 unsigned int m_NumElements;
78};
79
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +010080TEST_SUITE("RefPerAxisIterator")
81{
Jan Eilers53ef7952021-06-02 12:01:25 +010082// Test Loop (Equivalent to DecodeTensor) and Axis = 0
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +010083TEST_CASE("PerAxisIteratorTest1")
Jan Eilers53ef7952021-06-02 12:01:25 +010084{
85 std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
86 TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
87
88 // test axis=0
89 std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
90 auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 0);
91 std::vector<int8_t> output = iterator.Loop();
92 CompareVector(output, expOutput);
93
94 // Set iterator to index and check if the axis index is correct
95 iterator[5];
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +010096 CHECK(iterator.GetAxisIndex() == 1u);
Jan Eilers53ef7952021-06-02 12:01:25 +010097
98 iterator[1];
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +010099 CHECK(iterator.GetAxisIndex() == 0u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100100
101 iterator[10];
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100102 CHECK(iterator.GetAxisIndex() == 2u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100103}
104
105// Test Axis = 1
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100106TEST_CASE("PerAxisIteratorTest2")
Jan Eilers53ef7952021-06-02 12:01:25 +0100107{
108 std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
109 TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
110
111 // test axis=1
112 std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
113 auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 1);
114 std::vector<int8_t> output = iterator.Loop();
115 CompareVector(output, expOutput);
116
117 // Set iterator to index and check if the axis index is correct
118 iterator[5];
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100119 CHECK(iterator.GetAxisIndex() == 0u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100120
121 iterator[1];
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100122 CHECK(iterator.GetAxisIndex() == 0u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100123
124 iterator[10];
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100125 CHECK(iterator.GetAxisIndex() == 0u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100126}
127
128// Test Axis = 2
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100129TEST_CASE("PerAxisIteratorTest3")
Jan Eilers53ef7952021-06-02 12:01:25 +0100130{
131 std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
132 TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
133
134 // test axis=2
135 std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
136 auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 2);
137 std::vector<int8_t> output = iterator.Loop();
138 CompareVector(output, expOutput);
139
140 // Set iterator to index and check if the axis index is correct
141 iterator[5];
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100142 CHECK(iterator.GetAxisIndex() == 0u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100143
144 iterator[1];
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100145 CHECK(iterator.GetAxisIndex() == 0u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100146
147 iterator[10];
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100148 CHECK(iterator.GetAxisIndex() == 1u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100149}
150
151// Test Axis = 3
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100152TEST_CASE("PerAxisIteratorTest4")
Jan Eilers53ef7952021-06-02 12:01:25 +0100153{
154 std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
155 TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
156
157 // test axis=3
158 std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
159 auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 3);
160 std::vector<int8_t> output = iterator.Loop();
161 CompareVector(output, expOutput);
162
163 // Set iterator to index and check if the axis index is correct
164 iterator[5];
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100165 CHECK(iterator.GetAxisIndex() == 1u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100166
167 iterator[1];
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100168 CHECK(iterator.GetAxisIndex() == 1u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100169
170 iterator[10];
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100171 CHECK(iterator.GetAxisIndex() == 0u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100172}
173
Jan Eilers53ef7952021-06-02 12:01:25 +0100174// Test Axis = 1. Different tensor shape
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100175TEST_CASE("PerAxisIteratorTest5")
Jan Eilers53ef7952021-06-02 12:01:25 +0100176{
177 using namespace armnn;
178 std::vector<int8_t> input =
179 {
180 0, 1, 2, 3,
181 4, 5, 6, 7,
182 8, 9, 10, 11,
183 12, 13, 14, 15
184 };
185
186 std::vector<int8_t> expOutput =
187 {
188 0, 1, 2, 3,
189 4, 5, 6, 7,
190 8, 9, 10, 11,
191 12, 13, 14, 15
192 };
193
194 TensorInfo tensorInfo ({2,2,2,2},DataType::QSymmS8);
195 auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 1);
196 std::vector<int8_t> output = iterator.Loop();
197 CompareVector(output, expOutput);
198
199 // Set iterator to index and check if the axis index is correct
200 iterator[5];
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100201 CHECK(iterator.GetAxisIndex() == 1u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100202
203 iterator[1];
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100204 CHECK(iterator.GetAxisIndex() == 0u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100205
206 iterator[10];
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100207 CHECK(iterator.GetAxisIndex() == 0u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100208}
209
210// Test the increment and decrement operator
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100211TEST_CASE("PerAxisIteratorTest7")
Jan Eilers53ef7952021-06-02 12:01:25 +0100212{
213 using namespace armnn;
214 std::vector<int8_t> input =
215 {
216 0, 1, 2, 3,
217 4, 5, 6, 7,
218 8, 9, 10, 11
219 };
220
221 std::vector<int8_t> expOutput =
222 {
223 0, 1, 2, 3,
224 4, 5, 6, 7,
225 8, 9, 10, 11
226 };
227
228 TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
229 auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 2);
230
231 iterator += 3;
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100232 CHECK(iterator.Get() == expOutput[3]);
233 CHECK(iterator.GetAxisIndex() == 1u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100234
235 iterator += 3;
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100236 CHECK(iterator.Get() == expOutput[6]);
237 CHECK(iterator.GetAxisIndex() == 1u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100238
239 iterator -= 2;
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100240 CHECK(iterator.Get() == expOutput[4]);
241 CHECK(iterator.GetAxisIndex() == 0u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100242
243 iterator -= 1;
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100244 CHECK(iterator.Get() == expOutput[3]);
245 CHECK(iterator.GetAxisIndex() == 1u);
Jan Eilers53ef7952021-06-02 12:01:25 +0100246}
247
Matthew Sloyan7a00eaa2021-06-20 18:45:05 +0100248}