blob: 53bef3af56701f8bee4809e8f54f51d6ca916bcd [file] [log] [blame]
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +01001// Copyright (c) 2023, ARM Limited.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "generate_dot_product.h"
16#include "generate_utils.h"
17
Jeremy Johnson59b307d2023-10-04 14:17:26 +010018#include <cmath>
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010019#include <cstdint>
20
21namespace
22{
23
24// Input index global variables
25inline constexpr uint32_t P0 = 0;
26inline constexpr uint32_t P1 = 1;
Jeremy Johnson59b307d2023-10-04 14:17:26 +010027inline constexpr uint32_t P2 = 2;
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010028
29// Unused helper function
30template <typename... Args>
31inline void unused(Args&&...)
32{}
33
34// Primitive generator class
35//
36// Yields a new value on function operator access and increases the
37// index by one
38class PrimitiveGenerator
39{
40public:
41 PrimitiveGenerator(uint32_t S)
42 : _S(S)
43 , _m(0)
44 , _r(0)
45 , _index(0)
46 {
47 _m = (8 * _S + 1) * 0x705A5E75;
48 _r = _m + 1;
49 }
50
51 [[nodiscard]] float operator()()
52 {
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010053 float sign = (_r >> 31) == 0 ? +1 : -1;
54 float pseudo = sign * (float)(_r & 0x7FFFFFFF) / (float)(0x7FFFFFFF);
Jeremy Johnson59b307d2023-10-04 14:17:26 +010055
56 // Move index and calculate r value for the next index
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010057 ++_index;
Jeremy Johnson59b307d2023-10-04 14:17:26 +010058 _r = _r * _m + 1;
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010059
60 return pseudo;
61 }
62
63 uint32_t index()
64 {
65 return _index;
66 }
67
68private:
69 uint32_t _S;
70 uint32_t _m;
71 uint32_t _r;
72 uint32_t _index;
73};
74
75//----------------------------------------------------------------------------//
Jeremy Johnson59b307d2023-10-04 14:17:26 +010076// State generators - equivalent to tosa_mi_data() in the TOSA specification
77//
78// Each call to the generator returns the next generated value with an
79// auto incrementing index
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010080//----------------------------------------------------------------------------//
81
Jeremy Johnson59b307d2023-10-04 14:17:26 +010082// Test set 0 generator
83// The aim of this generator is to check that sum of products with zero gives zero result.
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010084class GeneratorS0 : public TosaReference::IDotProductGenerator
85{
86public:
87 GeneratorS0(uint32_t p)
88 : _p(p)
Jeremy Johnson59b307d2023-10-04 14:17:26 +010089 , _set_data0(2 * 0)
90 , _set_data1(2 * 0 + 1)
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010091 {}
92 float operator()(uint32_t k) override
93 {
94 unused(k);
Jeremy Johnson59b307d2023-10-04 14:17:26 +010095 const float s0 = _set_data0();
96 const float s1 = _set_data1();
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010097 if (_p == P0)
98 return s0 < 0.f ? 0.f : s1;
Jeremy Johnson59b307d2023-10-04 14:17:26 +010099 else if (_p == P1)
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100100 return s0 < 0.f ? s1 : 0.f;
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100101 else
102 return 0.f;
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100103 }
104
105private:
106 uint32_t _p;
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100107 PrimitiveGenerator _set_data0;
108 PrimitiveGenerator _set_data1;
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100109};
110
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100111// Test set 1 generator
112// The aim of this test set is to check values with large exponents.
113class GeneratorS1 : public TosaReference::IDotProductGenerator
114{
115public:
116 GeneratorS1(uint32_t p, uint32_t KS, float B)
117 : _p(p)
118 , _KS(KS)
119 , _B(B)
120 , _set_data(3 * 1 + p)
121 {}
122 float operator()(uint32_t k) override
123 {
124 unused(k);
125 const float s = _set_data();
126 float v = 0.75f + 0.25f * s;
127 if (_p != P2)
128 return (_B / std::sqrt(_KS + 1)) * v;
129 else
130 return (_B * _B / (_KS + 1)) * v;
131 }
132
133private:
134 uint32_t _p;
135 uint32_t _KS;
136 float _B;
137 PrimitiveGenerator _set_data;
138};
139
140// Test set 2 generator
141// The aim of this test set is to check rounding error when accumulating small values
142// onto a large value. In this case the small values are of similar magnitude. If the
143// implementation changes the order of the sum, then the test data must also be reordered
144// so that the largest values occur first in the sum.
145class GeneratorS2 : public TosaReference::IDotProductGenerator
146{
147public:
148 GeneratorS2(uint32_t p, uint32_t KS)
149 : _p(p)
150 , _KS(KS)
151 , _set_data(2 * 2 + p)
152 {}
153 float operator()(uint32_t k) override
154 {
155 const float s = _set_data();
156 if (_p != P2)
157 return k == 0 ? 1.f : s / std::sqrt(_KS);
158 else
159 return 0.f;
160 }
161
162private:
163 uint32_t _p;
164 uint32_t _KS;
165 PrimitiveGenerator _set_data;
166};
167
168// Test set 3 generator
169// The aim of this test set is to check rounding error when accumulating small values
170// onto a large value. In this case the small values are of varying magnitude. If the
171// implementation changes the order of the sum, then the test data must also be reordered
172// so that the largest values occur first in the sum.
173class GeneratorS3 : public TosaReference::IDotProductGenerator
174{
175public:
176 GeneratorS3(uint32_t p)
177 : _p(p)
178 , _set_data(2 * 3 + p)
179 {}
180 float operator()(uint32_t k) override
181 {
182 const float s0 = _set_data();
183 const float s1 = _set_data();
184 if (_p != P2)
185 return k == 0 ? 16.f : std::exp(2 * s0) * s1;
186 else
187 return 0.f;
188 }
189
190private:
191 uint32_t _p;
192 PrimitiveGenerator _set_data;
193};
194
195// Test set 4 generator
196// The aim of this test set is to check a mixture of zero and non-zero products.
197class GeneratorS4 : public TosaReference::IDotProductGenerator
198{
199public:
200 GeneratorS4(uint32_t p, uint32_t KS, float B)
201 : _p(p)
202 , _KS(KS)
203 , _B(B)
204 , _set_data0(2 * 4 + 0)
205 , _set_data1(2 * 4 + 1)
206 {}
207 float operator()(uint32_t k) override
208 {
209 const float s0 = _set_data0();
210 const float s1 = _set_data1();
211 if (_p == P0)
212 return (k == _KS / 2) ? +0.5f : s0 < 0 ? 0.f : (_B / std::sqrt(_KS)) * s1;
213 else if (_p == P1)
214 return (k == _KS / 2) ? -0.5f : s0 < 0 ? (_B / std::sqrt(_KS)) * s1 : 0.f;
215 else
216 return 0.f;
217 }
218
219private:
220 uint32_t _p;
221 uint32_t _KS;
222 float _B;
223 PrimitiveGenerator _set_data0;
224 PrimitiveGenerator _set_data1;
225};
226
227// Test set 5 generator
228// The aim of this test set is to check signed inputs of large range.
229class GeneratorS5 : public TosaReference::IDotProductGenerator
230{
231public:
232 GeneratorS5(uint32_t p, uint32_t KS, float B)
233 : _p(p)
234 , _KS(KS)
235 , _B(B)
236 , _set_data(3 * 5 + p)
237 {}
238 float operator()(uint32_t k) override
239 {
240 unused(k);
241 const float s = _set_data();
242 if (_p != P2)
243 return (_B / std::sqrt(_KS + 1)) * s;
244 else
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100245 return 0.f;
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100246 }
247
248private:
249 uint32_t _p;
250 uint32_t _KS;
251 float _B;
252 PrimitiveGenerator _set_data;
253};
254
255float getBoundParameter(const DType& dataType, const DType& accType)
256{
257 // Work out the bounds parameter value B for the given data and accumulator types
258 // Returns value > 0.f on success
259 float B = 0.f;
260 if (dataType == DType::DType_FP16)
261 {
262 if (accType == DType::DType_FP16)
263 B = 255.875f; // (1<<8) - (1/8);
264 else if (accType == DType::DType_FP32)
265 B = 65504.f; // (1<<16) - (1<<5);
266 }
267 else if (dataType == DType::DType_BF16)
268 {
269 if (accType == DType::DType_FP32)
270 B = 18374686479671623680.f; // (1<<64) - (1<<56)
271 }
272 else if (dataType == DType::DType_FP32)
273 {
274 if (accType == DType::DType_FP32)
275 B = 18446742974197923840.f; // (1<<64) - (1<<40)
276 }
277 return B;
278}
279
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100280} // namespace
281
282namespace TosaReference
283{
284
285std::unique_ptr<IDotProductGenerator> pickDotProductGenerator(const GenerateConfig& cfg)
286{
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100287 // Generators can only support 3 inputs
288 if (cfg.inputPos > 2)
289 return nullptr;
290
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100291 const DotProductInfo& dpinfo = cfg.dotProductInfo;
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100292
293 float B = getBoundParameter(cfg.dataType, dpinfo.accType);
294 if (B > 0.f)
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100295 {
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100296 // Create the generator
297 switch (dpinfo.s)
298 {
299 case 0:
300 return std::make_unique<GeneratorS0>(cfg.inputPos);
301 case 1:
302 return std::make_unique<GeneratorS1>(cfg.inputPos, dpinfo.ks, B);
303 case 2:
304 return std::make_unique<GeneratorS2>(cfg.inputPos, dpinfo.ks);
305 case 3:
306 return std::make_unique<GeneratorS3>(cfg.inputPos);
307 case 4:
308 return std::make_unique<GeneratorS4>(cfg.inputPos, dpinfo.ks, B);
309 case 5:
310 return std::make_unique<GeneratorS5>(cfg.inputPos, dpinfo.ks, B);
311 default:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100312 WARNING("[Generator][DP] Unsupported dot product test series for generator.");
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100313 return nullptr;
314 }
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100315 }
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100316 WARNING("[Generator][DP] Unsupported data types for generator.");
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100317 return nullptr;
318}
319
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100320} // namespace TosaReference