blob: 66a4b25fc955a06fa3d9df489f3986506306e481 [file] [log] [blame]
John Richardson70f946b2017-10-02 16:52:16 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "tests/validation/FixedPoint.h"
25
26#include "tests/Globals.h"
27#include "tests/framework/Asserts.h"
28#include "tests/framework/Macros.h"
29#include "tests/framework/datasets/Datasets.h"
30#include "tests/validation/Validation.h"
31
32namespace arm_compute
33{
34namespace test
35{
36namespace validation
37{
38namespace
39{
40const auto FuncNamesDataset = framework::dataset::make("FunctionNames", { FixedPointOp::ADD,
41 FixedPointOp::SUB,
42 FixedPointOp::MUL,
43 FixedPointOp::EXP,
44 FixedPointOp::LOG,
45 FixedPointOp::INV_SQRT
46 });
John Richardson66bd42a2017-10-12 12:15:23 +010047
48template <typename T>
49void load_array_from_numpy(const std::string &file, std::vector<unsigned long> &shape, std::vector<T> &data) // NOLINT
50{
51 try
52 {
53 npy::LoadArrayFromNumpy(file, shape, data);
54 }
55 catch(const std::runtime_error &e)
56 {
57 throw framework::FileNotFound("Could not load npy file: " + file + " (" + e.what() + ")");
58 }
59}
John Richardson70f946b2017-10-02 16:52:16 +010060} // namespace
61
62TEST_SUITE(UNIT)
63TEST_SUITE(FixedPoint)
64
65// *INDENT-OFF*
66// clang-format off
67DATA_TEST_CASE(FixedPointQS8Inputs, framework::DatasetMode::ALL, combine(
68 FuncNamesDataset,
69 framework::dataset::make("FractionalBits", 1, 7)),
70 func_name, frac_bits)
71// clang-format on
72// *INDENT-ON*
73{
74 std::vector<double> data;
75 std::vector<unsigned long> shape; //NOLINT
76
77 std::string func_name_lower = to_string(func_name);
78 std::transform(func_name_lower.begin(), func_name_lower.end(), func_name_lower.begin(), ::tolower);
79
80 const std::string inputs_file = library->path()
81 + "fixed_point/"
82 + func_name_lower
83 + "_Q8."
84 + support::cpp11::to_string(frac_bits)
85 + ".in.npy";
86
John Richardson66bd42a2017-10-12 12:15:23 +010087 load_array_from_numpy(inputs_file, shape, data);
John Richardson70f946b2017-10-02 16:52:16 +010088
89 // Values stored as doubles so reinterpret as floats
90 const auto *float_val = reinterpret_cast<float *>(&data[0]);
91 const size_t num_elements = data.size() * sizeof(double) / sizeof(float);
92
93 for(unsigned int i = 0; i < num_elements; ++i)
94 {
95 // Convert to fixed point
96 fixed_point_arithmetic::fixed_point<int8_t> in_val(float_val[i], frac_bits);
97
98 // Check that the value didn't change
99 ARM_COMPUTE_EXPECT(static_cast<float>(in_val) == float_val[i], framework::LogLevel::ERRORS);
100 }
101}
102
103//FIXME: Figure out how to handle expected failures properly
104// The last input argument specifies the expected number of failures for a
105// given combination of (function name, number of fractional bits) as defined
106// by the first two arguments.
107
108// *INDENT-OFF*
109// clang-format off
110DATA_TEST_CASE(FixedPointQS8Outputs, framework::DatasetMode::ALL, zip(combine(
111 FuncNamesDataset,
112 framework::dataset::make("FractionalBits", 1, 7)),
113 framework::dataset::make("ExpectedFailures", { 0, 0, 0, 0, 0, 0,
114 0, 0, 0, 0, 0, 0,
115 0, 0, 0, 0, 0, 0,
116 7, 8, 13, 2, 0, 0,
117 0, 0, 0, 0, 0, 0,
118 0, 0, 0, 5, 33, 96 })),
119 func_name, frac_bits, expected_failures)
120// clang-format on
121// *INDENT-ON*
122{
123 std::vector<double> in_data;
124 std::vector<unsigned long> in_shape; //NOLINT
125
126 std::vector<double> out_data;
127 std::vector<unsigned long> out_shape; //NOLINT
128
129 std::string func_name_lower = to_string(func_name);
130 std::transform(func_name_lower.begin(), func_name_lower.end(), func_name_lower.begin(), ::tolower);
131
132 const std::string base_file_name = library->path()
133 + "fixed_point/"
134 + func_name_lower
135 + "_Q8."
136 + support::cpp11::to_string(frac_bits);
137
138 const std::string inputs_file = base_file_name + ".in.npy";
139 const std::string reference_file = base_file_name + ".out.npy";
140
John Richardson66bd42a2017-10-12 12:15:23 +0100141 load_array_from_numpy(inputs_file, in_shape, in_data);
142 load_array_from_numpy(reference_file, out_shape, out_data);
John Richardson70f946b2017-10-02 16:52:16 +0100143
144 ARM_COMPUTE_EXPECT(in_shape.front() == out_shape.front(), framework::LogLevel::ERRORS);
145
146 const float step_size = std::pow(2.f, -frac_bits);
147 int64_t num_mismatches = 0;
148
149 // Values stored as doubles so reinterpret as floats
150 const auto *float_val = reinterpret_cast<float *>(&in_data[0]);
151 const auto *ref_val = reinterpret_cast<float *>(&out_data[0]);
152
153 const size_t num_elements = in_data.size() * sizeof(double) / sizeof(float);
154
155 for(unsigned int i = 0; i < num_elements; ++i)
156 {
157 fixed_point_arithmetic::fixed_point<int8_t> in_val(float_val[i], frac_bits);
158 fixed_point_arithmetic::fixed_point<int8_t> out_val(0.f, frac_bits);
159
160 float tolerance = 0.f;
161
162 if(func_name == FixedPointOp::ADD)
163 {
164 out_val = in_val + in_val;
165 }
166 else if(func_name == FixedPointOp::SUB)
167 {
168 out_val = in_val - in_val; //NOLINT
169 }
170 else if(func_name == FixedPointOp::MUL)
171 {
172 tolerance = 1.f * step_size;
173 out_val = in_val * in_val;
174 }
175 else if(func_name == FixedPointOp::EXP)
176 {
177 tolerance = 2.f * step_size;
178 out_val = fixed_point_arithmetic::exp(in_val);
179 }
180 else if(func_name == FixedPointOp::LOG)
181 {
182 tolerance = 4.f * step_size;
183 out_val = fixed_point_arithmetic::log(in_val);
184 }
185 else if(func_name == FixedPointOp::INV_SQRT)
186 {
187 tolerance = 5.f * step_size;
188 out_val = fixed_point_arithmetic::inv_sqrt(in_val);
189 }
190
191 if(std::abs(static_cast<float>(out_val) - ref_val[i]) > tolerance)
192 {
193 ARM_COMPUTE_TEST_INFO("input = " << in_val);
194 ARM_COMPUTE_TEST_INFO("output = " << out_val);
195 ARM_COMPUTE_TEST_INFO("reference = " << ref_val[i]);
196 ARM_COMPUTE_TEST_INFO("tolerance = " << tolerance);
197
198 ARM_COMPUTE_TEST_INFO((std::abs(static_cast<float>(out_val) - ref_val[i]) <= tolerance));
199
200 ++num_mismatches;
201 }
202 }
203
204 ARM_COMPUTE_EXPECT(num_mismatches == expected_failures, framework::LogLevel::ERRORS);
205}
206
207TEST_SUITE_END()
208TEST_SUITE_END()
209} // namespace validation
210} // namespace test
211} // namespace arm_compute