blob: 4b435cab8c62a5db70189bf15798f08a1adc65ba [file] [log] [blame]
Jeremy Johnsonc8330812024-01-18 16:57:28 +00001// Copyright (c) 2023-2024, ARM Limited.
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +01002//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "generate_dot_product.h"
16#include "generate_utils.h"
17
Jeremy Johnson59b307d2023-10-04 14:17:26 +010018#include <cmath>
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010019#include <cstdint>
20
21namespace
22{
23
24// Input index global variables
25inline constexpr uint32_t P0 = 0;
26inline constexpr uint32_t P1 = 1;
Jeremy Johnson59b307d2023-10-04 14:17:26 +010027inline constexpr uint32_t P2 = 2;
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010028
29// Unused helper function
30template <typename... Args>
31inline void unused(Args&&...)
32{}
33
34// Primitive generator class
35//
36// Yields a new value on function operator access and increases the
37// index by one
38class PrimitiveGenerator
39{
40public:
41 PrimitiveGenerator(uint32_t S)
42 : _S(S)
43 , _m(0)
44 , _r(0)
45 , _index(0)
46 {
47 _m = (8 * _S + 1) * 0x705A5E75;
48 _r = _m + 1;
49 }
50
51 [[nodiscard]] float operator()()
52 {
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010053 float sign = (_r >> 31) == 0 ? +1 : -1;
54 float pseudo = sign * (float)(_r & 0x7FFFFFFF) / (float)(0x7FFFFFFF);
Jeremy Johnson59b307d2023-10-04 14:17:26 +010055
56 // Move index and calculate r value for the next index
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010057 ++_index;
Jeremy Johnson59b307d2023-10-04 14:17:26 +010058 _r = _r * _m + 1;
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010059
60 return pseudo;
61 }
62
Jeremy Johnsonc8330812024-01-18 16:57:28 +000063 uint32_t nextIndex()
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010064 {
65 return _index;
66 }
67
68private:
69 uint32_t _S;
70 uint32_t _m;
71 uint32_t _r;
72 uint32_t _index;
73};
74
75//----------------------------------------------------------------------------//
Jeremy Johnson59b307d2023-10-04 14:17:26 +010076// State generators - equivalent to tosa_mi_data() in the TOSA specification
77//
78// Each call to the generator returns the next generated value with an
79// auto incrementing index
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010080//----------------------------------------------------------------------------//
81
Jeremy Johnson59b307d2023-10-04 14:17:26 +010082// Test set 0 generator
83// The aim of this generator is to check that sum of products with zero gives zero result.
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010084class GeneratorS0 : public TosaReference::IDotProductGenerator
85{
86public:
87 GeneratorS0(uint32_t p)
88 : _p(p)
Jeremy Johnson59b307d2023-10-04 14:17:26 +010089 , _set_data0(2 * 0)
90 , _set_data1(2 * 0 + 1)
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010091 {}
92 float operator()(uint32_t k) override
93 {
94 unused(k);
Jeremy Johnson59b307d2023-10-04 14:17:26 +010095 const float s0 = _set_data0();
96 const float s1 = _set_data1();
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010097 if (_p == P0)
98 return s0 < 0.f ? 0.f : s1;
Jeremy Johnson59b307d2023-10-04 14:17:26 +010099 else if (_p == P1)
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100100 return s0 < 0.f ? s1 : 0.f;
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100101 else
102 return 0.f;
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100103 }
Tai Ly8ead6c42024-02-14 22:35:44 +0000104 uint32_t nextIndex() override
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000105 {
106 ASSERT_MSG(_set_data0.nextIndex() == _set_data1.nextIndex(), "Internal index inconsistency in GeneratorS0")
107 return _set_data0.nextIndex();
108 }
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100109
110private:
111 uint32_t _p;
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100112 PrimitiveGenerator _set_data0;
113 PrimitiveGenerator _set_data1;
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100114};
115
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100116// Test set 1 generator
117// The aim of this test set is to check values with large exponents.
118class GeneratorS1 : public TosaReference::IDotProductGenerator
119{
120public:
121 GeneratorS1(uint32_t p, uint32_t KS, float B)
122 : _p(p)
123 , _KS(KS)
124 , _B(B)
125 , _set_data(3 * 1 + p)
126 {}
127 float operator()(uint32_t k) override
128 {
129 unused(k);
130 const float s = _set_data();
131 float v = 0.75f + 0.25f * s;
132 if (_p != P2)
133 return (_B / std::sqrt(_KS + 1)) * v;
134 else
135 return (_B * _B / (_KS + 1)) * v;
136 }
Tai Ly8ead6c42024-02-14 22:35:44 +0000137 uint32_t nextIndex() override
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000138 {
139 return _set_data.nextIndex();
140 }
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100141
142private:
143 uint32_t _p;
144 uint32_t _KS;
145 float _B;
146 PrimitiveGenerator _set_data;
147};
148
149// Test set 2 generator
150// The aim of this test set is to check rounding error when accumulating small values
151// onto a large value. In this case the small values are of similar magnitude. If the
152// implementation changes the order of the sum, then the test data must also be reordered
153// so that the largest values occur first in the sum.
154class GeneratorS2 : public TosaReference::IDotProductGenerator
155{
156public:
157 GeneratorS2(uint32_t p, uint32_t KS)
158 : _p(p)
159 , _KS(KS)
160 , _set_data(2 * 2 + p)
161 {}
162 float operator()(uint32_t k) override
163 {
164 const float s = _set_data();
165 if (_p != P2)
166 return k == 0 ? 1.f : s / std::sqrt(_KS);
167 else
168 return 0.f;
169 }
Tai Ly8ead6c42024-02-14 22:35:44 +0000170 uint32_t nextIndex() override
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000171 {
172 return _set_data.nextIndex();
173 }
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100174
175private:
176 uint32_t _p;
177 uint32_t _KS;
178 PrimitiveGenerator _set_data;
179};
180
181// Test set 3 generator
182// The aim of this test set is to check rounding error when accumulating small values
183// onto a large value. In this case the small values are of varying magnitude. If the
184// implementation changes the order of the sum, then the test data must also be reordered
185// so that the largest values occur first in the sum.
186class GeneratorS3 : public TosaReference::IDotProductGenerator
187{
188public:
189 GeneratorS3(uint32_t p)
190 : _p(p)
191 , _set_data(2 * 3 + p)
192 {}
193 float operator()(uint32_t k) override
194 {
195 const float s0 = _set_data();
196 const float s1 = _set_data();
197 if (_p != P2)
198 return k == 0 ? 16.f : std::exp(2 * s0) * s1;
199 else
200 return 0.f;
201 }
Tai Ly8ead6c42024-02-14 22:35:44 +0000202 uint32_t nextIndex() override
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000203 {
204 return _set_data.nextIndex();
205 }
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100206
207private:
208 uint32_t _p;
209 PrimitiveGenerator _set_data;
210};
211
212// Test set 4 generator
213// The aim of this test set is to check a mixture of zero and non-zero products.
214class GeneratorS4 : public TosaReference::IDotProductGenerator
215{
216public:
217 GeneratorS4(uint32_t p, uint32_t KS, float B)
218 : _p(p)
219 , _KS(KS)
220 , _B(B)
221 , _set_data0(2 * 4 + 0)
222 , _set_data1(2 * 4 + 1)
223 {}
224 float operator()(uint32_t k) override
225 {
226 const float s0 = _set_data0();
227 const float s1 = _set_data1();
228 if (_p == P0)
Jeremy Johnson0601f802023-11-08 16:28:09 +0000229 if (k == _KS / 2)
230 {
231 return s0 < 0 ? -0.5f : +0.5f;
232 }
233 else
234 {
235 return s0 < 0 ? 0.f : (_B / std::sqrt(_KS)) * s1;
236 }
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100237 else if (_p == P1)
Jeremy Johnson0601f802023-11-08 16:28:09 +0000238 if (k == _KS / 2)
239 {
240 return s0 < 0 ? +0.5f : -0.5f;
241 }
242 else
243 {
244 return s0 < 0 ? (_B / std::sqrt(_KS)) * s1 : 0.f;
245 }
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100246 else
247 return 0.f;
248 }
Tai Ly8ead6c42024-02-14 22:35:44 +0000249 uint32_t nextIndex() override
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000250 {
251 ASSERT_MSG(_set_data0.nextIndex() == _set_data1.nextIndex(), "Internal index inconsistency in GeneratorS4")
252 return _set_data0.nextIndex();
253 }
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100254
255private:
256 uint32_t _p;
257 uint32_t _KS;
258 float _B;
259 PrimitiveGenerator _set_data0;
260 PrimitiveGenerator _set_data1;
261};
262
263// Test set 5 generator
264// The aim of this test set is to check signed inputs of large range.
265class GeneratorS5 : public TosaReference::IDotProductGenerator
266{
267public:
268 GeneratorS5(uint32_t p, uint32_t KS, float B)
269 : _p(p)
270 , _KS(KS)
271 , _B(B)
272 , _set_data(3 * 5 + p)
273 {}
274 float operator()(uint32_t k) override
275 {
276 unused(k);
277 const float s = _set_data();
278 if (_p != P2)
279 return (_B / std::sqrt(_KS + 1)) * s;
280 else
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100281 return 0.f;
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100282 }
Tai Ly8ead6c42024-02-14 22:35:44 +0000283 uint32_t nextIndex() override
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000284 {
285 return _set_data.nextIndex();
286 }
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100287
288private:
289 uint32_t _p;
290 uint32_t _KS;
291 float _B;
292 PrimitiveGenerator _set_data;
293};
294
295float getBoundParameter(const DType& dataType, const DType& accType)
296{
297 // Work out the bounds parameter value B for the given data and accumulator types
298 // Returns value > 0.f on success
299 float B = 0.f;
300 if (dataType == DType::DType_FP16)
301 {
302 if (accType == DType::DType_FP16)
303 B = 255.875f; // (1<<8) - (1/8);
304 else if (accType == DType::DType_FP32)
305 B = 65504.f; // (1<<16) - (1<<5);
306 }
307 else if (dataType == DType::DType_BF16)
308 {
309 if (accType == DType::DType_FP32)
310 B = 18374686479671623680.f; // (1<<64) - (1<<56)
311 }
312 else if (dataType == DType::DType_FP32)
313 {
314 if (accType == DType::DType_FP32)
315 B = 18446742974197923840.f; // (1<<64) - (1<<40)
316 }
317 return B;
318}
319
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100320} // namespace
321
322namespace TosaReference
323{
324
325std::unique_ptr<IDotProductGenerator> pickDotProductGenerator(const GenerateConfig& cfg)
326{
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100327 // Generators can only support 3 inputs
328 if (cfg.inputPos > 2)
329 return nullptr;
330
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100331 const DotProductInfo& dpinfo = cfg.dotProductInfo;
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100332
333 float B = getBoundParameter(cfg.dataType, dpinfo.accType);
334 if (B > 0.f)
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100335 {
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000336 auto param = cfg.inputPos;
337 if (cfg.opType == Op_FFT2D)
338 {
339 // We only use param of zero for FFT2D tensors
340 param = 0;
341 }
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100342 // Create the generator
343 switch (dpinfo.s)
344 {
345 case 0:
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000346 return std::make_unique<GeneratorS0>(param);
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100347 case 1:
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000348 return std::make_unique<GeneratorS1>(param, dpinfo.ks, B);
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100349 case 2:
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000350 return std::make_unique<GeneratorS2>(param, dpinfo.ks);
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100351 case 3:
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000352 return std::make_unique<GeneratorS3>(param);
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100353 case 4:
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000354 return std::make_unique<GeneratorS4>(param, dpinfo.ks, B);
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100355 case 5:
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000356 return std::make_unique<GeneratorS5>(param, dpinfo.ks, B);
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100357 default:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100358 WARNING("[Generator][DP] Unsupported dot product test series for generator.");
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100359 return nullptr;
360 }
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100361 }
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100362 WARNING("[Generator][DP] Unsupported data types for generator.");
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100363 return nullptr;
364}
365
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100366} // namespace TosaReference