blob: 9ce32ffec287681d0e8903288a2593dc07524846 [file] [log] [blame]
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +01001// Copyright (c) 2023, ARM Limited.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "generate_dot_product.h"
16#include "generate_utils.h"
17
Jeremy Johnson59b307d2023-10-04 14:17:26 +010018#include <cmath>
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010019#include <cstdint>
20
21namespace
22{
23
24// Input index global variables
25inline constexpr uint32_t P0 = 0;
26inline constexpr uint32_t P1 = 1;
Jeremy Johnson59b307d2023-10-04 14:17:26 +010027inline constexpr uint32_t P2 = 2;
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010028
29// Unused helper function
30template <typename... Args>
31inline void unused(Args&&...)
32{}
33
34// Primitive generator class
35//
36// Yields a new value on function operator access and increases the
37// index by one
38class PrimitiveGenerator
39{
40public:
41 PrimitiveGenerator(uint32_t S)
42 : _S(S)
43 , _m(0)
44 , _r(0)
45 , _index(0)
46 {
47 _m = (8 * _S + 1) * 0x705A5E75;
48 _r = _m + 1;
49 }
50
51 [[nodiscard]] float operator()()
52 {
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010053 float sign = (_r >> 31) == 0 ? +1 : -1;
54 float pseudo = sign * (float)(_r & 0x7FFFFFFF) / (float)(0x7FFFFFFF);
Jeremy Johnson59b307d2023-10-04 14:17:26 +010055
56 // Move index and calculate r value for the next index
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010057 ++_index;
Jeremy Johnson59b307d2023-10-04 14:17:26 +010058 _r = _r * _m + 1;
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010059
60 return pseudo;
61 }
62
63 uint32_t index()
64 {
65 return _index;
66 }
67
68private:
69 uint32_t _S;
70 uint32_t _m;
71 uint32_t _r;
72 uint32_t _index;
73};
74
75//----------------------------------------------------------------------------//
Jeremy Johnson59b307d2023-10-04 14:17:26 +010076// State generators - equivalent to tosa_mi_data() in the TOSA specification
77//
78// Each call to the generator returns the next generated value with an
79// auto incrementing index
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010080//----------------------------------------------------------------------------//
81
Jeremy Johnson59b307d2023-10-04 14:17:26 +010082// Test set 0 generator
83// The aim of this generator is to check that sum of products with zero gives zero result.
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010084class GeneratorS0 : public TosaReference::IDotProductGenerator
85{
86public:
87 GeneratorS0(uint32_t p)
88 : _p(p)
Jeremy Johnson59b307d2023-10-04 14:17:26 +010089 , _set_data0(2 * 0)
90 , _set_data1(2 * 0 + 1)
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010091 {}
92 float operator()(uint32_t k) override
93 {
94 unused(k);
Jeremy Johnson59b307d2023-10-04 14:17:26 +010095 const float s0 = _set_data0();
96 const float s1 = _set_data1();
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010097 if (_p == P0)
98 return s0 < 0.f ? 0.f : s1;
Jeremy Johnson59b307d2023-10-04 14:17:26 +010099 else if (_p == P1)
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100100 return s0 < 0.f ? s1 : 0.f;
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100101 else
102 return 0.f;
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100103 }
104
105private:
106 uint32_t _p;
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100107 PrimitiveGenerator _set_data0;
108 PrimitiveGenerator _set_data1;
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100109};
110
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100111// Test set 1 generator
112// The aim of this test set is to check values with large exponents.
113class GeneratorS1 : public TosaReference::IDotProductGenerator
114{
115public:
116 GeneratorS1(uint32_t p, uint32_t KS, float B)
117 : _p(p)
118 , _KS(KS)
119 , _B(B)
120 , _set_data(3 * 1 + p)
121 {}
122 float operator()(uint32_t k) override
123 {
124 unused(k);
125 const float s = _set_data();
126 float v = 0.75f + 0.25f * s;
127 if (_p != P2)
128 return (_B / std::sqrt(_KS + 1)) * v;
129 else
130 return (_B * _B / (_KS + 1)) * v;
131 }
132
133private:
134 uint32_t _p;
135 uint32_t _KS;
136 float _B;
137 PrimitiveGenerator _set_data;
138};
139
140// Test set 2 generator
141// The aim of this test set is to check rounding error when accumulating small values
142// onto a large value. In this case the small values are of similar magnitude. If the
143// implementation changes the order of the sum, then the test data must also be reordered
144// so that the largest values occur first in the sum.
145class GeneratorS2 : public TosaReference::IDotProductGenerator
146{
147public:
148 GeneratorS2(uint32_t p, uint32_t KS)
149 : _p(p)
150 , _KS(KS)
151 , _set_data(2 * 2 + p)
152 {}
153 float operator()(uint32_t k) override
154 {
155 const float s = _set_data();
156 if (_p != P2)
157 return k == 0 ? 1.f : s / std::sqrt(_KS);
158 else
159 return 0.f;
160 }
161
162private:
163 uint32_t _p;
164 uint32_t _KS;
165 PrimitiveGenerator _set_data;
166};
167
168// Test set 3 generator
169// The aim of this test set is to check rounding error when accumulating small values
170// onto a large value. In this case the small values are of varying magnitude. If the
171// implementation changes the order of the sum, then the test data must also be reordered
172// so that the largest values occur first in the sum.
173class GeneratorS3 : public TosaReference::IDotProductGenerator
174{
175public:
176 GeneratorS3(uint32_t p)
177 : _p(p)
178 , _set_data(2 * 3 + p)
179 {}
180 float operator()(uint32_t k) override
181 {
182 const float s0 = _set_data();
183 const float s1 = _set_data();
184 if (_p != P2)
185 return k == 0 ? 16.f : std::exp(2 * s0) * s1;
186 else
187 return 0.f;
188 }
189
190private:
191 uint32_t _p;
192 PrimitiveGenerator _set_data;
193};
194
195// Test set 4 generator
196// The aim of this test set is to check a mixture of zero and non-zero products.
197class GeneratorS4 : public TosaReference::IDotProductGenerator
198{
199public:
200 GeneratorS4(uint32_t p, uint32_t KS, float B)
201 : _p(p)
202 , _KS(KS)
203 , _B(B)
204 , _set_data0(2 * 4 + 0)
205 , _set_data1(2 * 4 + 1)
206 {}
207 float operator()(uint32_t k) override
208 {
209 const float s0 = _set_data0();
210 const float s1 = _set_data1();
211 if (_p == P0)
Jeremy Johnson0601f802023-11-08 16:28:09 +0000212 if (k == _KS / 2)
213 {
214 return s0 < 0 ? -0.5f : +0.5f;
215 }
216 else
217 {
218 return s0 < 0 ? 0.f : (_B / std::sqrt(_KS)) * s1;
219 }
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100220 else if (_p == P1)
Jeremy Johnson0601f802023-11-08 16:28:09 +0000221 if (k == _KS / 2)
222 {
223 return s0 < 0 ? +0.5f : -0.5f;
224 }
225 else
226 {
227 return s0 < 0 ? (_B / std::sqrt(_KS)) * s1 : 0.f;
228 }
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100229 else
230 return 0.f;
231 }
232
233private:
234 uint32_t _p;
235 uint32_t _KS;
236 float _B;
237 PrimitiveGenerator _set_data0;
238 PrimitiveGenerator _set_data1;
239};
240
241// Test set 5 generator
242// The aim of this test set is to check signed inputs of large range.
243class GeneratorS5 : public TosaReference::IDotProductGenerator
244{
245public:
246 GeneratorS5(uint32_t p, uint32_t KS, float B)
247 : _p(p)
248 , _KS(KS)
249 , _B(B)
250 , _set_data(3 * 5 + p)
251 {}
252 float operator()(uint32_t k) override
253 {
254 unused(k);
255 const float s = _set_data();
256 if (_p != P2)
257 return (_B / std::sqrt(_KS + 1)) * s;
258 else
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100259 return 0.f;
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100260 }
261
262private:
263 uint32_t _p;
264 uint32_t _KS;
265 float _B;
266 PrimitiveGenerator _set_data;
267};
268
269float getBoundParameter(const DType& dataType, const DType& accType)
270{
271 // Work out the bounds parameter value B for the given data and accumulator types
272 // Returns value > 0.f on success
273 float B = 0.f;
274 if (dataType == DType::DType_FP16)
275 {
276 if (accType == DType::DType_FP16)
277 B = 255.875f; // (1<<8) - (1/8);
278 else if (accType == DType::DType_FP32)
279 B = 65504.f; // (1<<16) - (1<<5);
280 }
281 else if (dataType == DType::DType_BF16)
282 {
283 if (accType == DType::DType_FP32)
284 B = 18374686479671623680.f; // (1<<64) - (1<<56)
285 }
286 else if (dataType == DType::DType_FP32)
287 {
288 if (accType == DType::DType_FP32)
289 B = 18446742974197923840.f; // (1<<64) - (1<<40)
290 }
291 return B;
292}
293
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100294} // namespace
295
296namespace TosaReference
297{
298
299std::unique_ptr<IDotProductGenerator> pickDotProductGenerator(const GenerateConfig& cfg)
300{
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100301 // Generators can only support 3 inputs
302 if (cfg.inputPos > 2)
303 return nullptr;
304
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100305 const DotProductInfo& dpinfo = cfg.dotProductInfo;
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100306
307 float B = getBoundParameter(cfg.dataType, dpinfo.accType);
308 if (B > 0.f)
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100309 {
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100310 // Create the generator
311 switch (dpinfo.s)
312 {
313 case 0:
314 return std::make_unique<GeneratorS0>(cfg.inputPos);
315 case 1:
316 return std::make_unique<GeneratorS1>(cfg.inputPos, dpinfo.ks, B);
317 case 2:
318 return std::make_unique<GeneratorS2>(cfg.inputPos, dpinfo.ks);
319 case 3:
320 return std::make_unique<GeneratorS3>(cfg.inputPos);
321 case 4:
322 return std::make_unique<GeneratorS4>(cfg.inputPos, dpinfo.ks, B);
323 case 5:
324 return std::make_unique<GeneratorS5>(cfg.inputPos, dpinfo.ks, B);
325 default:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100326 WARNING("[Generator][DP] Unsupported dot product test series for generator.");
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100327 return nullptr;
328 }
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100329 }
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100330 WARNING("[Generator][DP] Unsupported data types for generator.");
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100331 return nullptr;
332}
333
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100334} // namespace TosaReference