blob: 13ae1008cd552d1b734c79bf973160d87fc4bc0d [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5#pragma once
6
7#include <armnn/TensorFwd.hpp>
8#include <boost/test/unit_test.hpp>
9#include <boost/multi_array.hpp>
10#include <vector>
11#include <array>
12
13#include <boost/assert.hpp>
14#include <boost/test/tools/floating_point_comparison.hpp>
15#include <boost/random/uniform_real_distribution.hpp>
16#include <boost/random/mersenne_twister.hpp>
17#include <boost/numeric/conversion/cast.hpp>
18
19#include "armnn/Tensor.hpp"
20
21#include "backends/test/QuantizeHelper.hpp"
22
23#include <cmath>
24
surmeh013537c2c2018-05-18 16:31:43 +010025constexpr float g_FloatCloseToZeroTolerance = 1.0e-6f;
telsoa014fcda012018-03-09 14:13:49 +000026
27template<typename T, bool isQuantized = true>
28struct SelectiveComparer
29{
30 static bool Compare(T a, T b)
31 {
32 return (std::max(a, b) - std::min(a, b)) <= 1;
33 }
34
35};
36
37template<typename T>
38struct SelectiveComparer<T, false>
39{
40 static bool Compare(T a, T b)
41 {
telsoa01c577f2c2018-08-31 09:22:23 +010042 // If a or b is zero, percent_tolerance does an exact match, so compare to a small, constant tolerance instead.
telsoa014fcda012018-03-09 14:13:49 +000043 if (a == 0.0f || b == 0.0f)
44 {
45 return std::abs(a - b) <= g_FloatCloseToZeroTolerance;
46 }
Francis Murtagh8c5e3dc2018-08-30 17:18:37 +010047
48 if (std::isinf(a) && a == b)
49 {
50 return true;
51 }
52
53 if (std::isnan(a) && std::isnan(b))
54 {
55 return true;
56 }
57
telsoa014fcda012018-03-09 14:13:49 +000058 // For unquantized floats we use a tolerance of 1%.
59 boost::math::fpc::close_at_tolerance<float> comparer(boost::math::fpc::percent_tolerance(1.0f));
60 return comparer(a, b);
61 }
62};
63
64template<typename T>
65bool SelectiveCompare(T a, T b)
66{
67 return SelectiveComparer<T, armnn::IsQuantizedType<T>()>::Compare(a, b);
68};
69
70
71
72template <typename T, std::size_t n>
73boost::test_tools::predicate_result CompareTensors(const boost::multi_array<T, n>& a,
74 const boost::multi_array<T, n>& b)
75{
telsoa01c577f2c2018-08-31 09:22:23 +010076 // Checks they are same shape.
telsoa014fcda012018-03-09 14:13:49 +000077 for (unsigned int i=0; i<n; i++)
78 {
79 if (a.shape()[i] != b.shape()[i])
80 {
81 boost::test_tools::predicate_result res(false);
82 res.message() << "Different shapes ["
83 << a.shape()[i]
84 << "!="
85 << b.shape()[i]
86 << "]";
87 return res;
88 }
89 }
90
telsoa01c577f2c2018-08-31 09:22:23 +010091 // Now compares element-wise.
telsoa014fcda012018-03-09 14:13:49 +000092
telsoa01c577f2c2018-08-31 09:22:23 +010093 // Fun iteration over n dimensions.
telsoa014fcda012018-03-09 14:13:49 +000094 std::array<unsigned int, n> indices;
95 for (unsigned int i = 0; i < n; i++)
96 {
97 indices[i] = 0;
98 }
99
100 std::stringstream errorString;
101 int numFailedElements = 0;
102 constexpr int maxReportedDifferences = 3;
103
104 while (true)
105 {
106 bool comparison = SelectiveCompare(a(indices), b(indices));
107 if (!comparison)
108 {
109 ++numFailedElements;
110
111 if (numFailedElements <= maxReportedDifferences)
112 {
113 if (numFailedElements >= 2)
114 {
115 errorString << ", ";
116 }
117 errorString << "[";
118 for (unsigned int i = 0; i < n; ++i)
119 {
120 errorString << indices[i];
121 if (i != n - 1)
122 {
123 errorString << ",";
124 }
125 }
126 errorString << "]";
127
128 errorString << " (" << +a(indices) << " != " << +b(indices) << ")";
129 }
130 }
131
132 ++indices[n - 1];
133 for (unsigned int i=n-1; i>0; i--)
134 {
135 if (indices[i] == a.shape()[i])
136 {
137 indices[i] = 0;
138 ++indices[i - 1];
139 }
140 }
141
142 if (indices[0] == a.shape()[0])
143 {
144 break;
145 }
146 }
147
148 boost::test_tools::predicate_result comparisonResult(true);
149 if (numFailedElements > 0)
150 {
151 comparisonResult = false;
152 comparisonResult.message() << numFailedElements << " different values at: ";
153 if (numFailedElements > maxReportedDifferences)
154 {
155 errorString << ", ... (and " << (numFailedElements - maxReportedDifferences) << " other differences)";
156 }
157 comparisonResult.message() << errorString.str();
158 }
159
160 return comparisonResult;
161}
162
163
telsoa01c577f2c2018-08-31 09:22:23 +0100164// Creates a boost::multi_array with the shape defined by the given TensorInfo.
telsoa014fcda012018-03-09 14:13:49 +0000165template <typename T, std::size_t n>
166boost::multi_array<T, n> MakeTensor(const armnn::TensorInfo& tensorInfo)
167{
168 std::array<unsigned int, n> shape;
169
170 for (unsigned int i = 0; i < n; i++)
171 {
172 shape[i] = tensorInfo.GetShape()[i];
173 }
174
175 return boost::multi_array<T, n>(shape);
176}
177
telsoa01c577f2c2018-08-31 09:22:23 +0100178// Creates a boost::multi_array with the shape defined by the given TensorInfo and contents defined by the given vector.
telsoa014fcda012018-03-09 14:13:49 +0000179template <typename T, std::size_t n>
180boost::multi_array<T, n> MakeTensor(const armnn::TensorInfo& tensorInfo, const std::vector<T>& flat)
181{
182 BOOST_ASSERT_MSG(flat.size() == tensorInfo.GetNumElements(), "Wrong number of components supplied to tensor");
183
184 std::array<unsigned int, n> shape;
185
186 for (unsigned int i = 0; i < n; i++)
187 {
188 shape[i] = tensorInfo.GetShape()[i];
189 }
190
191 boost::const_multi_array_ref<T, n> arrayRef(&flat[0], shape);
192 return boost::multi_array<T, n>(arrayRef);
193}
194
195template <typename T, std::size_t n>
196boost::multi_array<T, n> MakeRandomTensor(const armnn::TensorInfo& tensorInfo,
197 unsigned int seed,
198 float min = -10.0f,
199 float max = 10.0f)
200{
201 boost::random::mt19937 gen(seed);
202 boost::random::uniform_real_distribution<float> dist(min, max);
203
204 std::vector<float> init(tensorInfo.GetNumElements());
205 for (unsigned int i = 0; i < init.size(); i++)
206 {
207 init[i] = dist(gen);
208 }
209 float qScale = tensorInfo.GetQuantizationScale();
210 int32_t qOffset = tensorInfo.GetQuantizationOffset();
211 return MakeTensor<T, n>(tensorInfo, QuantizedVector<T>(qScale, qOffset, init));
212}