blob: 91b17145be66a6d7628838fc776c9e0445b96a8a [file] [log] [blame]
Moritz Pflanzerc7d15032017-07-18 16:21:16 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef __ARM_COMPUTE_TEST_VALIDATION_H__
25#define __ARM_COMPUTE_TEST_VALIDATION_H__
26
27#include "SimpleTensor.h"
28#include "arm_compute/core/FixedPoint.h"
29#include "arm_compute/core/Types.h"
30#include "framework/Asserts.h"
Moritz Pflanzer2ac50402017-07-24 15:52:54 +010031#include "framework/Exceptions.h"
Moritz Pflanzerc7d15032017-07-18 16:21:16 +010032#include "tests/IAccessor.h"
33#include "tests/TypePrinter.h"
34#include "tests/Utils.h"
35
36#include <iomanip>
Moritz Pflanzerb9e9cff2017-07-27 15:15:55 +010037#include <ios>
Moritz Pflanzerc7d15032017-07-18 16:21:16 +010038#include <vector>
39
40namespace arm_compute
41{
42namespace test
43{
44namespace validation
45{
46template <typename T>
47bool compare_dimensions(const Dimensions<T> &dimensions1, const Dimensions<T> &dimensions2)
48{
49 if(dimensions1.num_dimensions() != dimensions2.num_dimensions())
50 {
51 return false;
52 }
53
54 for(unsigned int i = 0; i < dimensions1.num_dimensions(); ++i)
55 {
56 if(dimensions1[i] != dimensions2[i])
57 {
58 return false;
59 }
60 }
61
62 return true;
63}
64
65/** Validate valid regions.
66 *
67 * - Dimensionality has to be the same.
68 * - Anchors have to match.
69 * - Shapes have to match.
70 */
71void validate(const arm_compute::ValidRegion &region, const arm_compute::ValidRegion &reference);
72
73/** Validate padding.
74 *
75 * Padding on all sides has to be the same.
76 */
77void validate(const arm_compute::PaddingSize &padding, const arm_compute::PaddingSize &reference);
78
79/** Validate tensors.
80 *
81 * - Dimensionality has to be the same.
82 * - All values have to match.
83 *
84 * @note: wrap_range allows cases where reference tensor rounds up to the wrapping point, causing it to wrap around to
85 * zero while the test tensor stays at wrapping point to pass. This may permit true erroneous cases (difference between
86 * reference tensor and test tensor is multiple of wrap_range), but such errors would be detected by
87 * other test cases.
88 */
Moritz Pflanzere49e2662017-07-21 15:55:28 +010089template <typename T, typename U = T>
90void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, U tolerance_value = U(0), float tolerance_number = 0.f);
Moritz Pflanzerc7d15032017-07-18 16:21:16 +010091
92/** Validate tensors with valid region.
93 *
94 * - Dimensionality has to be the same.
95 * - All values have to match.
96 *
97 * @note: wrap_range allows cases where reference tensor rounds up to the wrapping point, causing it to wrap around to
98 * zero while the test tensor stays at wrapping point to pass. This may permit true erroneous cases (difference between
99 * reference tensor and test tensor is multiple of wrap_range), but such errors would be detected by
100 * other test cases.
101 */
Moritz Pflanzere49e2662017-07-21 15:55:28 +0100102template <typename T, typename U = T>
103void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, const ValidRegion &valid_region, U tolerance_value = U(0), float tolerance_number = 0.f);
Moritz Pflanzerc7d15032017-07-18 16:21:16 +0100104
105/** Validate tensors against constant value.
106 *
107 * - All values have to match.
108 */
109void validate(const IAccessor &tensor, const void *reference_value);
110
111/** Validate border against a constant value.
112 *
113 * - All border values have to match the specified value if mode is CONSTANT.
114 * - All border values have to be replicated if mode is REPLICATE.
115 * - Nothing is validated for mode UNDEFINED.
116 */
117void validate(const IAccessor &tensor, BorderSize border_size, const BorderMode &border_mode, const void *border_value);
118
119/** Validate classified labels against expected ones.
120 *
121 * - All values should match
122 */
123void validate(std::vector<unsigned int> classified_labels, std::vector<unsigned int> expected_labels);
124
125/** Validate float value.
126 *
127 * - All values should match
128 */
129template <typename T, typename U = T>
130void validate(T target, T ref, U tolerance_abs_error = std::numeric_limits<T>::epsilon(), double tolerance_relative_error = 0.0001f);
131
132template <typename T, typename U = T>
133bool is_equal(T target, T ref, U max_absolute_error = std::numeric_limits<T>::epsilon(), double max_relative_error = 0.0001f)
134{
135 if(!std::isfinite(target) || !std::isfinite(ref))
136 {
137 return false;
138 }
139
140 // No need further check if they are equal
141 if(ref == target)
142 {
143 return true;
144 }
145
146 // Need this check for the situation when the two values close to zero but have different sign
147 if(std::abs(std::abs(ref) - std::abs(target)) <= max_absolute_error)
148 {
149 return true;
150 }
151
152 double relative_error = 0;
153
154 if(std::abs(target) > std::abs(ref))
155 {
156 relative_error = std::abs(static_cast<double>(target - ref) / target);
157 }
158 else
159 {
160 relative_error = std::abs(static_cast<double>(ref - target) / ref);
161 }
162
163 return relative_error <= max_relative_error;
164}
165
166template <typename T, typename U>
167void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, U tolerance_value, float tolerance_number)
168{
169 // Validate with valid region covering the entire shape
170 validate(tensor, reference, shape_to_valid_region(tensor.shape()), tolerance_value, tolerance_number);
171}
172
173template <typename T, typename U>
174void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, const ValidRegion &valid_region, U tolerance_value, float tolerance_number)
175{
176 int64_t num_mismatches = 0;
177 int64_t num_elements = 0;
178
Moritz Pflanzer2ac50402017-07-24 15:52:54 +0100179 ARM_COMPUTE_EXPECT_EQUAL(tensor.element_size(), reference.element_size(), framework::LogLevel::ERRORS);
180 ARM_COMPUTE_EXPECT_EQUAL(tensor.format(), reference.format(), framework::LogLevel::ERRORS);
181 ARM_COMPUTE_EXPECT_EQUAL(tensor.data_type(), reference.data_type(), framework::LogLevel::ERRORS);
182 ARM_COMPUTE_EXPECT_EQUAL(tensor.num_channels(), reference.num_channels(), framework::LogLevel::ERRORS);
183 ARM_COMPUTE_EXPECT(compare_dimensions(tensor.shape(), reference.shape()), framework::LogLevel::ERRORS);
Moritz Pflanzerc7d15032017-07-18 16:21:16 +0100184
185 const int min_elements = std::min(tensor.num_elements(), reference.num_elements());
186 const int min_channels = std::min(tensor.num_channels(), reference.num_channels());
187
188 // Iterate over all elements within valid region, e.g. U8, S16, RGB888, ...
189 for(int element_idx = 0; element_idx < min_elements; ++element_idx)
190 {
191 const Coordinates id = index2coord(reference.shape(), element_idx);
192
193 if(is_in_valid_region(valid_region, id))
194 {
195 // Iterate over all channels within one element
196 for(int c = 0; c < min_channels; ++c)
197 {
198 const T &target_value = reinterpret_cast<const T *>(tensor(id))[c];
199 const T &reference_value = reinterpret_cast<const T *>(reference(id))[c];
200
201 if(!is_equal(target_value, reference_value, tolerance_value))
202 {
203 ARM_COMPUTE_TEST_INFO("id = " << id);
204 ARM_COMPUTE_TEST_INFO("channel = " << c);
205 ARM_COMPUTE_TEST_INFO("target = " << std::setprecision(5) << target_value);
206 ARM_COMPUTE_TEST_INFO("reference = " << std::setprecision(5) << reference_value);
Moritz Pflanzerb9e9cff2017-07-27 15:15:55 +0100207 ARM_COMPUTE_TEST_INFO("tolerance = " << std::setprecision(5) << tolerance_value);
Moritz Pflanzer2ac50402017-07-24 15:52:54 +0100208 ARM_COMPUTE_EXPECT_EQUAL(target_value, reference_value, framework::LogLevel::DEBUG);
Moritz Pflanzerc7d15032017-07-18 16:21:16 +0100209
210 ++num_mismatches;
211 }
212
213 ++num_elements;
214 }
215 }
216 }
217
218 if(num_elements > 0)
219 {
220 const int64_t absolute_tolerance_number = tolerance_number * num_elements;
221 const float percent_mismatches = static_cast<float>(num_mismatches) / num_elements * 100.f;
222
Moritz Pflanzerb9e9cff2017-07-27 15:15:55 +0100223 ARM_COMPUTE_TEST_INFO(num_mismatches << " values (" << std::fixed << std::setprecision(2) << percent_mismatches
Moritz Pflanzerc7d15032017-07-18 16:21:16 +0100224 << "%) mismatched (maximum tolerated " << std::setprecision(2) << tolerance_number << "%)");
Moritz Pflanzer2ac50402017-07-24 15:52:54 +0100225 ARM_COMPUTE_EXPECT(num_mismatches <= absolute_tolerance_number, framework::LogLevel::ERRORS);
Moritz Pflanzerc7d15032017-07-18 16:21:16 +0100226 }
227}
228
229template <typename T, typename U>
230void validate(T target, T ref, U tolerance_abs_error, double tolerance_relative_error)
231{
232 const bool equal = is_equal(target, ref, tolerance_abs_error, tolerance_relative_error);
233
234 ARM_COMPUTE_TEST_INFO("reference = " << std::setprecision(5) << ref);
235 ARM_COMPUTE_TEST_INFO("target = " << std::setprecision(5) << target);
Moritz Pflanzerb9e9cff2017-07-27 15:15:55 +0100236 ARM_COMPUTE_TEST_INFO("tolerance = " << std::setprecision(5) << tolerance_abs_error);
Moritz Pflanzer2ac50402017-07-24 15:52:54 +0100237 ARM_COMPUTE_EXPECT(equal, framework::LogLevel::ERRORS);
Moritz Pflanzerc7d15032017-07-18 16:21:16 +0100238}
239} // namespace validation
240} // namespace test
241} // namespace arm_compute
242#endif /* __ARM_COMPUTE_TEST_REFERENCE_VALIDATION_H__ */