blob: b6e7b8e82bd6ae7a03303a3878a3639d864f1da0 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Moritz Pflanzera09de0c2017-09-01 20:41:12 +010024#ifndef __ARM_COMPUTE_TEST_VALIDATION_H__
25#define __ARM_COMPUTE_TEST_VALIDATION_H__
Anthony Barbier6ff3b192017-09-04 18:44:23 +010026
Moritz Pflanzera09de0c2017-09-01 20:41:12 +010027#include "arm_compute/core/FixedPoint.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010028#include "arm_compute/core/Types.h"
Moritz Pflanzera09de0c2017-09-01 20:41:12 +010029#include "tests/IAccessor.h"
30#include "tests/SimpleTensor.h"
John Richardsonf89a49f2017-09-05 11:21:56 +010031#include "tests/Types.h"
Moritz Pflanzera09de0c2017-09-01 20:41:12 +010032#include "tests/Utils.h"
33#include "tests/framework/Asserts.h"
34#include "tests/framework/Exceptions.h"
Anthony Barbier2a07e182017-08-04 18:20:27 +010035#include "utils/TypePrinter.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010036
Moritz Pflanzera09de0c2017-09-01 20:41:12 +010037#include <iomanip>
38#include <ios>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010039#include <vector>
40
41namespace arm_compute
42{
Anthony Barbier6ff3b192017-09-04 18:44:23 +010043namespace test
44{
Anthony Barbier6ff3b192017-09-04 18:44:23 +010045namespace validation
46{
Moritz Pflanzera09de0c2017-09-01 20:41:12 +010047/** Class reprensenting an absolute tolerance value. */
Anthony Barbier6ff3b192017-09-04 18:44:23 +010048template <typename T>
Moritz Pflanzera09de0c2017-09-01 20:41:12 +010049class AbsoluteTolerance
50{
51public:
52 /** Underlying type. */
53 using value_type = T;
54
55 /* Default constructor.
56 *
57 * Initialises the tolerance to 0.
58 */
59 AbsoluteTolerance() = default;
60
61 /** Constructor.
62 *
63 * @param[in] value Absolute tolerance value.
64 */
65 explicit constexpr AbsoluteTolerance(T value)
66 : _value{ value }
67 {
68 }
69
70 /** Implicit conversion to the underlying type. */
71 constexpr operator T() const
72 {
73 return _value;
74 }
75
76private:
77 T _value{ std::numeric_limits<T>::epsilon() };
78};
79
80/** Class reprensenting a relative tolerance value. */
steniu013e05e4e2017-08-25 17:18:01 +010081template <typename T>
Moritz Pflanzera09de0c2017-09-01 20:41:12 +010082class RelativeTolerance
83{
84public:
85 /** Underlying type. */
steniu013e05e4e2017-08-25 17:18:01 +010086 using value_type = T;
Moritz Pflanzera09de0c2017-09-01 20:41:12 +010087
88 /* Default constructor.
89 *
90 * Initialises the tolerance to 0.
91 */
92 RelativeTolerance() = default;
93
94 /** Constructor.
95 *
96 * @param[in] value Relative tolerance value.
97 */
98 explicit constexpr RelativeTolerance(value_type value)
99 : _value{ value }
100 {
101 }
102
103 /** Implicit conversion to the underlying type. */
104 constexpr operator value_type() const
105 {
106 return _value;
107 }
108
109private:
steniu013e05e4e2017-08-25 17:18:01 +0100110 value_type _value{ std::numeric_limits<T>::epsilon() };
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100111};
112
113/** Print AbsoluteTolerance type. */
114template <typename T>
115inline ::std::ostream &operator<<(::std::ostream &os, const AbsoluteTolerance<T> &tolerance)
116{
117 os << static_cast<typename AbsoluteTolerance<T>::value_type>(tolerance);
118
119 return os;
120}
121
122/** Print RelativeTolerance type. */
steniu013e05e4e2017-08-25 17:18:01 +0100123template <typename T>
124inline ::std::ostream &operator<<(::std::ostream &os, const RelativeTolerance<T> &tolerance)
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100125{
steniu013e05e4e2017-08-25 17:18:01 +0100126 os << static_cast<typename RelativeTolerance<T>::value_type>(tolerance);
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100127
128 return os;
129}
130
131template <typename T>
132bool compare_dimensions(const Dimensions<T> &dimensions1, const Dimensions<T> &dimensions2)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100133{
134 if(dimensions1.num_dimensions() != dimensions2.num_dimensions())
135 {
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100136 return false;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100137 }
138
139 for(unsigned int i = 0; i < dimensions1.num_dimensions(); ++i)
140 {
141 if(dimensions1[i] != dimensions2[i])
142 {
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100143 return false;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100144 }
145 }
146
147 return true;
148}
149
150/** Validate valid regions.
151 *
152 * - Dimensionality has to be the same.
153 * - Anchors have to match.
154 * - Shapes have to match.
155 */
156void validate(const arm_compute::ValidRegion &region, const arm_compute::ValidRegion &reference);
157
158/** Validate padding.
159 *
160 * Padding on all sides has to be the same.
161 */
162void validate(const arm_compute::PaddingSize &padding, const arm_compute::PaddingSize &reference);
163
164/** Validate tensors.
165 *
166 * - Dimensionality has to be the same.
167 * - All values have to match.
168 *
169 * @note: wrap_range allows cases where reference tensor rounds up to the wrapping point, causing it to wrap around to
170 * zero while the test tensor stays at wrapping point to pass. This may permit true erroneous cases (difference between
171 * reference tensor and test tensor is multiple of wrap_range), but such errors would be detected by
172 * other test cases.
173 */
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100174template <typename T, typename U = AbsoluteTolerance<T>>
175void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, U tolerance_value = U(), float tolerance_number = 0.f);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100176
177/** Validate tensors with valid region.
178 *
179 * - Dimensionality has to be the same.
180 * - All values have to match.
181 *
182 * @note: wrap_range allows cases where reference tensor rounds up to the wrapping point, causing it to wrap around to
183 * zero while the test tensor stays at wrapping point to pass. This may permit true erroneous cases (difference between
184 * reference tensor and test tensor is multiple of wrap_range), but such errors would be detected by
185 * other test cases.
186 */
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100187template <typename T, typename U = AbsoluteTolerance<T>>
188void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, const ValidRegion &valid_region, U tolerance_value = U(), float tolerance_number = 0.f);
Isabella Gottardi62031532017-07-04 11:21:28 +0100189
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100190/** Validate tensors against constant value.
191 *
192 * - All values have to match.
193 */
194void validate(const IAccessor &tensor, const void *reference_value);
195
196/** Validate border against a constant value.
197 *
198 * - All border values have to match the specified value if mode is CONSTANT.
199 * - All border values have to be replicated if mode is REPLICATE.
200 * - Nothing is validated for mode UNDEFINED.
201 */
202void validate(const IAccessor &tensor, BorderSize border_size, const BorderMode &border_mode, const void *border_value);
203
204/** Validate classified labels against expected ones.
205 *
206 * - All values should match
207 */
208void validate(std::vector<unsigned int> classified_labels, std::vector<unsigned int> expected_labels);
steniu01960b0842017-06-23 11:44:34 +0100209
210/** Validate float value.
211 *
212 * - All values should match
213 */
Moritz Pflanzer7655a672017-09-23 11:57:33 +0100214template <typename T, typename U = AbsoluteTolerance<T>>
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100215void validate(T target, T reference, U tolerance = AbsoluteTolerance<T>());
steniu01960b0842017-06-23 11:44:34 +0100216
Michele Di Giorgioef4b4ae2017-07-04 17:19:43 +0100217template <typename T>
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100218struct compare_base
Michele Di Giorgioef4b4ae2017-07-04 17:19:43 +0100219{
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100220 compare_base(typename T::value_type target, typename T::value_type reference, T tolerance = T(0))
221 : _target{ target }, _reference{ reference }, _tolerance{ tolerance }
Michele Di Giorgioef4b4ae2017-07-04 17:19:43 +0100222 {
Michele Di Giorgioef4b4ae2017-07-04 17:19:43 +0100223 }
224
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100225 typename T::value_type _target{};
226 typename T::value_type _reference{};
227 T _tolerance{};
228};
Michele Di Giorgioef4b4ae2017-07-04 17:19:43 +0100229
Moritz Pflanzer5b61fd32017-09-12 15:51:33 +0100230template <typename T>
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100231struct compare;
232
233template <typename U>
Moritz Pflanzer5b61fd32017-09-12 15:51:33 +0100234struct compare<AbsoluteTolerance<U>> : public compare_base<AbsoluteTolerance<U>>
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100235{
236 using compare_base<AbsoluteTolerance<U>>::compare_base;
237
238 operator bool() const
239 {
240 if(!std::isfinite(this->_target) || !std::isfinite(this->_reference))
241 {
242 return false;
243 }
244 else if(this->_target == this->_reference)
245 {
246 return true;
247 }
248
Moritz Pflanzer5b61fd32017-09-12 15:51:33 +0100249 using comparison_type = typename std::conditional<std::is_integral<U>::value, int64_t, U>::type;
250
251 const comparison_type abs_difference(std::abs(static_cast<comparison_type>(this->_target) - static_cast<comparison_type>(this->_reference)));
252
253 return abs_difference <= static_cast<comparison_type>(this->_tolerance);
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100254 }
255};
256
257template <typename U>
Moritz Pflanzer5b61fd32017-09-12 15:51:33 +0100258struct compare<RelativeTolerance<U>> : public compare_base<RelativeTolerance<U>>
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100259{
steniu013e05e4e2017-08-25 17:18:01 +0100260 using compare_base<RelativeTolerance<U>>::compare_base;
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100261
262 operator bool() const
263 {
steniu013e05e4e2017-08-25 17:18:01 +0100264 if(!std::isfinite(this->_target) || !std::isfinite(this->_reference))
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100265 {
266 return false;
267 }
steniu013e05e4e2017-08-25 17:18:01 +0100268 else if(this->_target == this->_reference)
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100269 {
270 return true;
271 }
272
Moritz Pflanzerff1c3602017-09-22 12:41:25 +0100273 const U epsilon = (std::is_same<half, typename std::remove_cv<U>::type>::value || (this->_reference == 0)) ? static_cast<U>(0.01) : static_cast<U>(1e-05);
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100274
steniu013e05e4e2017-08-25 17:18:01 +0100275 if(std::abs(static_cast<double>(this->_reference) - static_cast<double>(this->_target)) <= epsilon)
276 {
277 return true;
278 }
279 else
280 {
281 if(static_cast<double>(this->_reference) == 0.0f) // We have checked whether _reference and _target is closing. If _reference is 0 but not closed to _target, it should return false
282 {
283 return false;
284 }
285
286 const double relative_change = std::abs(static_cast<double>(this->_target) - static_cast<double>(this->_reference)) / this->_reference;
287
288 return relative_change <= static_cast<U>(this->_tolerance);
289 }
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100290 }
291};
292
293template <typename T, typename U>
294void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, U tolerance_value, float tolerance_number)
295{
296 // Validate with valid region covering the entire shape
297 validate(tensor, reference, shape_to_valid_region(tensor.shape()), tolerance_value, tolerance_number);
298}
299
300template <typename T, typename U>
301void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, const ValidRegion &valid_region, U tolerance_value, float tolerance_number)
302{
303 int64_t num_mismatches = 0;
304 int64_t num_elements = 0;
305
306 ARM_COMPUTE_EXPECT_EQUAL(tensor.element_size(), reference.element_size(), framework::LogLevel::ERRORS);
307 ARM_COMPUTE_EXPECT_EQUAL(tensor.data_type(), reference.data_type(), framework::LogLevel::ERRORS);
308
309 if(reference.format() != Format::UNKNOWN)
310 {
311 ARM_COMPUTE_EXPECT_EQUAL(tensor.format(), reference.format(), framework::LogLevel::ERRORS);
312 }
313
314 ARM_COMPUTE_EXPECT_EQUAL(tensor.num_channels(), reference.num_channels(), framework::LogLevel::ERRORS);
315 ARM_COMPUTE_EXPECT(compare_dimensions(tensor.shape(), reference.shape()), framework::LogLevel::ERRORS);
316
317 const int min_elements = std::min(tensor.num_elements(), reference.num_elements());
318 const int min_channels = std::min(tensor.num_channels(), reference.num_channels());
319
320 // Iterate over all elements within valid region, e.g. U8, S16, RGB888, ...
321 for(int element_idx = 0; element_idx < min_elements; ++element_idx)
322 {
323 const Coordinates id = index2coord(reference.shape(), element_idx);
324
325 if(is_in_valid_region(valid_region, id))
326 {
327 // Iterate over all channels within one element
328 for(int c = 0; c < min_channels; ++c)
329 {
330 const T &target_value = reinterpret_cast<const T *>(tensor(id))[c];
331 const T &reference_value = reinterpret_cast<const T *>(reference(id))[c];
332
Moritz Pflanzer5b61fd32017-09-12 15:51:33 +0100333 if(!compare<U>(target_value, reference_value, tolerance_value))
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100334 {
335 ARM_COMPUTE_TEST_INFO("id = " << id);
336 ARM_COMPUTE_TEST_INFO("channel = " << c);
337 ARM_COMPUTE_TEST_INFO("target = " << std::setprecision(5) << framework::make_printable(target_value));
338 ARM_COMPUTE_TEST_INFO("reference = " << std::setprecision(5) << framework::make_printable(reference_value));
339 ARM_COMPUTE_TEST_INFO("tolerance = " << std::setprecision(5) << framework::make_printable(static_cast<typename U::value_type>(tolerance_value)));
steniu01172c58d2017-08-31 13:49:08 +0100340 framework::ARM_COMPUTE_PRINT_INFO();
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100341
342 ++num_mismatches;
343 }
344
345 ++num_elements;
346 }
347 }
348 }
349
350 if(num_elements > 0)
351 {
352 const int64_t absolute_tolerance_number = tolerance_number * num_elements;
353 const float percent_mismatches = static_cast<float>(num_mismatches) / num_elements * 100.f;
354
355 ARM_COMPUTE_TEST_INFO(num_mismatches << " values (" << std::fixed << std::setprecision(2) << percent_mismatches
356 << "%) mismatched (maximum tolerated " << std::setprecision(2) << tolerance_number << "%)");
357 ARM_COMPUTE_EXPECT(num_mismatches <= absolute_tolerance_number, framework::LogLevel::ERRORS);
Michele Di Giorgioef4b4ae2017-07-04 17:19:43 +0100358 }
359}
Giorgio Arenafc2817d2017-06-27 17:26:37 +0100360
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100361template <typename T, typename U>
362void validate(T target, T reference, U tolerance)
363{
364 ARM_COMPUTE_TEST_INFO("reference = " << std::setprecision(5) << framework::make_printable(reference));
365 ARM_COMPUTE_TEST_INFO("target = " << std::setprecision(5) << framework::make_printable(target));
366 ARM_COMPUTE_TEST_INFO("tolerance = " << std::setprecision(5) << framework::make_printable(static_cast<typename U::value_type>(tolerance)));
Moritz Pflanzer5b61fd32017-09-12 15:51:33 +0100367 ARM_COMPUTE_EXPECT((compare<U>(target, reference, tolerance)), framework::LogLevel::ERRORS);
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100368}
John Richardsonf89a49f2017-09-05 11:21:56 +0100369
370template <typename T, typename U>
371void validate_min_max_loc(const MinMaxLocationValues<T> &target, const MinMaxLocationValues<U> &reference)
372{
373 ARM_COMPUTE_EXPECT_EQUAL(target.min, reference.min, framework::LogLevel::ERRORS);
374 ARM_COMPUTE_EXPECT_EQUAL(target.max, reference.max, framework::LogLevel::ERRORS);
375
376 ARM_COMPUTE_EXPECT_EQUAL(target.min_loc.size(), reference.min_loc.size(), framework::LogLevel::ERRORS);
377 ARM_COMPUTE_EXPECT_EQUAL(target.max_loc.size(), reference.max_loc.size(), framework::LogLevel::ERRORS);
378
379 for(uint32_t i = 0; i < target.min_loc.size(); ++i)
380 {
381 const auto same_coords = std::find_if(reference.min_loc.begin(), reference.min_loc.end(), [&target, i](Coordinates2D coord)
382 {
383 return coord.x == target.min_loc.at(i).x && coord.y == target.min_loc.at(i).y;
384 });
385
386 ARM_COMPUTE_EXPECT(same_coords != reference.min_loc.end(), framework::LogLevel::ERRORS);
387 }
388
389 for(uint32_t i = 0; i < target.max_loc.size(); ++i)
390 {
391 const auto same_coords = std::find_if(reference.max_loc.begin(), reference.max_loc.end(), [&target, i](Coordinates2D coord)
392 {
393 return coord.x == target.max_loc.at(i).x && coord.y == target.max_loc.at(i).y;
394 });
395
396 ARM_COMPUTE_EXPECT(same_coords != reference.max_loc.end(), framework::LogLevel::ERRORS);
397 }
398}
399
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100400} // namespace validation
401} // namespace test
402} // namespace arm_compute
Anthony Barbierac69aa12017-07-03 17:39:37 +0100403#endif /* __ARM_COMPUTE_TEST_REFERENCE_VALIDATION_H__ */