Jeremy Johnson | c833081 | 2024-01-18 16:57:28 +0000 | [diff] [blame] | 1 | // Copyright (c) 2023-2024, ARM Limited. |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 2 | // |
| 3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | // you may not use this file except in compliance with the License. |
| 5 | // You may obtain a copy of the License at |
| 6 | // |
| 7 | // http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | // |
| 9 | // Unless required by applicable law or agreed to in writing, software |
| 10 | // distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | // See the License for the specific language governing permissions and |
| 13 | // limitations under the License. |
| 14 | |
| 15 | #include "generate_dot_product.h" |
| 16 | #include "generate_utils.h" |
| 17 | |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame] | 18 | #include <cmath> |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 19 | #include <cstdint> |
| 20 | |
| 21 | namespace |
| 22 | { |
| 23 | |
| 24 | // Input index global variables |
| 25 | inline constexpr uint32_t P0 = 0; |
| 26 | inline constexpr uint32_t P1 = 1; |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame] | 27 | inline constexpr uint32_t P2 = 2; |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 28 | |
| 29 | // Unused helper function |
| 30 | template <typename... Args> |
| 31 | inline void unused(Args&&...) |
| 32 | {} |
| 33 | |
| 34 | // Primitive generator class |
| 35 | // |
| 36 | // Yields a new value on function operator access and increases the |
| 37 | // index by one |
| 38 | class PrimitiveGenerator |
| 39 | { |
| 40 | public: |
| 41 | PrimitiveGenerator(uint32_t S) |
| 42 | : _S(S) |
| 43 | , _m(0) |
| 44 | , _r(0) |
| 45 | , _index(0) |
| 46 | { |
| 47 | _m = (8 * _S + 1) * 0x705A5E75; |
| 48 | _r = _m + 1; |
| 49 | } |
| 50 | |
| 51 | [[nodiscard]] float operator()() |
| 52 | { |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 53 | float sign = (_r >> 31) == 0 ? +1 : -1; |
| 54 | float pseudo = sign * (float)(_r & 0x7FFFFFFF) / (float)(0x7FFFFFFF); |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame] | 55 | |
| 56 | // Move index and calculate r value for the next index |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 57 | ++_index; |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame] | 58 | _r = _r * _m + 1; |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 59 | |
| 60 | return pseudo; |
| 61 | } |
| 62 | |
Jeremy Johnson | c833081 | 2024-01-18 16:57:28 +0000 | [diff] [blame] | 63 | uint32_t nextIndex() |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 64 | { |
| 65 | return _index; |
| 66 | } |
| 67 | |
| 68 | private: |
| 69 | uint32_t _S; |
| 70 | uint32_t _m; |
| 71 | uint32_t _r; |
| 72 | uint32_t _index; |
| 73 | }; |
| 74 | |
| 75 | //----------------------------------------------------------------------------// |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame] | 76 | // State generators - equivalent to tosa_mi_data() in the TOSA specification |
| 77 | // |
| 78 | // Each call to the generator returns the next generated value with an |
| 79 | // auto incrementing index |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 80 | //----------------------------------------------------------------------------// |
| 81 | |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame] | 82 | // Test set 0 generator |
| 83 | // The aim of this generator is to check that sum of products with zero gives zero result. |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 84 | class GeneratorS0 : public TosaReference::IDotProductGenerator |
| 85 | { |
| 86 | public: |
| 87 | GeneratorS0(uint32_t p) |
| 88 | : _p(p) |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame] | 89 | , _set_data0(2 * 0) |
| 90 | , _set_data1(2 * 0 + 1) |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 91 | {} |
| 92 | float operator()(uint32_t k) override |
| 93 | { |
| 94 | unused(k); |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame] | 95 | const float s0 = _set_data0(); |
| 96 | const float s1 = _set_data1(); |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 97 | if (_p == P0) |
| 98 | return s0 < 0.f ? 0.f : s1; |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame] | 99 | else if (_p == P1) |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 100 | return s0 < 0.f ? s1 : 0.f; |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame] | 101 | else |
| 102 | return 0.f; |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 103 | } |
Tai Ly | 8ead6c4 | 2024-02-14 22:35:44 +0000 | [diff] [blame^] | 104 | uint32_t nextIndex() override |
Jeremy Johnson | c833081 | 2024-01-18 16:57:28 +0000 | [diff] [blame] | 105 | { |
| 106 | ASSERT_MSG(_set_data0.nextIndex() == _set_data1.nextIndex(), "Internal index inconsistency in GeneratorS0") |
| 107 | return _set_data0.nextIndex(); |
| 108 | } |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 109 | |
| 110 | private: |
| 111 | uint32_t _p; |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame] | 112 | PrimitiveGenerator _set_data0; |
| 113 | PrimitiveGenerator _set_data1; |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 114 | }; |
| 115 | |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame] | 116 | // Test set 1 generator |
| 117 | // The aim of this test set is to check values with large exponents. |
| 118 | class GeneratorS1 : public TosaReference::IDotProductGenerator |
| 119 | { |
| 120 | public: |
| 121 | GeneratorS1(uint32_t p, uint32_t KS, float B) |
| 122 | : _p(p) |
| 123 | , _KS(KS) |
| 124 | , _B(B) |
| 125 | , _set_data(3 * 1 + p) |
| 126 | {} |
| 127 | float operator()(uint32_t k) override |
| 128 | { |
| 129 | unused(k); |
| 130 | const float s = _set_data(); |
| 131 | float v = 0.75f + 0.25f * s; |
| 132 | if (_p != P2) |
| 133 | return (_B / std::sqrt(_KS + 1)) * v; |
| 134 | else |
| 135 | return (_B * _B / (_KS + 1)) * v; |
| 136 | } |
Tai Ly | 8ead6c4 | 2024-02-14 22:35:44 +0000 | [diff] [blame^] | 137 | uint32_t nextIndex() override |
Jeremy Johnson | c833081 | 2024-01-18 16:57:28 +0000 | [diff] [blame] | 138 | { |
| 139 | return _set_data.nextIndex(); |
| 140 | } |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame] | 141 | |
| 142 | private: |
| 143 | uint32_t _p; |
| 144 | uint32_t _KS; |
| 145 | float _B; |
| 146 | PrimitiveGenerator _set_data; |
| 147 | }; |
| 148 | |
| 149 | // Test set 2 generator |
| 150 | // The aim of this test set is to check rounding error when accumulating small values |
| 151 | // onto a large value. In this case the small values are of similar magnitude. If the |
| 152 | // implementation changes the order of the sum, then the test data must also be reordered |
| 153 | // so that the largest values occur first in the sum. |
| 154 | class GeneratorS2 : public TosaReference::IDotProductGenerator |
| 155 | { |
| 156 | public: |
| 157 | GeneratorS2(uint32_t p, uint32_t KS) |
| 158 | : _p(p) |
| 159 | , _KS(KS) |
| 160 | , _set_data(2 * 2 + p) |
| 161 | {} |
| 162 | float operator()(uint32_t k) override |
| 163 | { |
| 164 | const float s = _set_data(); |
| 165 | if (_p != P2) |
| 166 | return k == 0 ? 1.f : s / std::sqrt(_KS); |
| 167 | else |
| 168 | return 0.f; |
| 169 | } |
Tai Ly | 8ead6c4 | 2024-02-14 22:35:44 +0000 | [diff] [blame^] | 170 | uint32_t nextIndex() override |
Jeremy Johnson | c833081 | 2024-01-18 16:57:28 +0000 | [diff] [blame] | 171 | { |
| 172 | return _set_data.nextIndex(); |
| 173 | } |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame] | 174 | |
| 175 | private: |
| 176 | uint32_t _p; |
| 177 | uint32_t _KS; |
| 178 | PrimitiveGenerator _set_data; |
| 179 | }; |
| 180 | |
| 181 | // Test set 3 generator |
| 182 | // The aim of this test set is to check rounding error when accumulating small values |
| 183 | // onto a large value. In this case the small values are of varying magnitude. If the |
| 184 | // implementation changes the order of the sum, then the test data must also be reordered |
| 185 | // so that the largest values occur first in the sum. |
| 186 | class GeneratorS3 : public TosaReference::IDotProductGenerator |
| 187 | { |
| 188 | public: |
| 189 | GeneratorS3(uint32_t p) |
| 190 | : _p(p) |
| 191 | , _set_data(2 * 3 + p) |
| 192 | {} |
| 193 | float operator()(uint32_t k) override |
| 194 | { |
| 195 | const float s0 = _set_data(); |
| 196 | const float s1 = _set_data(); |
| 197 | if (_p != P2) |
| 198 | return k == 0 ? 16.f : std::exp(2 * s0) * s1; |
| 199 | else |
| 200 | return 0.f; |
| 201 | } |
Tai Ly | 8ead6c4 | 2024-02-14 22:35:44 +0000 | [diff] [blame^] | 202 | uint32_t nextIndex() override |
Jeremy Johnson | c833081 | 2024-01-18 16:57:28 +0000 | [diff] [blame] | 203 | { |
| 204 | return _set_data.nextIndex(); |
| 205 | } |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame] | 206 | |
| 207 | private: |
| 208 | uint32_t _p; |
| 209 | PrimitiveGenerator _set_data; |
| 210 | }; |
| 211 | |
| 212 | // Test set 4 generator |
| 213 | // The aim of this test set is to check a mixture of zero and non-zero products. |
| 214 | class GeneratorS4 : public TosaReference::IDotProductGenerator |
| 215 | { |
| 216 | public: |
| 217 | GeneratorS4(uint32_t p, uint32_t KS, float B) |
| 218 | : _p(p) |
| 219 | , _KS(KS) |
| 220 | , _B(B) |
| 221 | , _set_data0(2 * 4 + 0) |
| 222 | , _set_data1(2 * 4 + 1) |
| 223 | {} |
| 224 | float operator()(uint32_t k) override |
| 225 | { |
| 226 | const float s0 = _set_data0(); |
| 227 | const float s1 = _set_data1(); |
| 228 | if (_p == P0) |
Jeremy Johnson | 0601f80 | 2023-11-08 16:28:09 +0000 | [diff] [blame] | 229 | if (k == _KS / 2) |
| 230 | { |
| 231 | return s0 < 0 ? -0.5f : +0.5f; |
| 232 | } |
| 233 | else |
| 234 | { |
| 235 | return s0 < 0 ? 0.f : (_B / std::sqrt(_KS)) * s1; |
| 236 | } |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame] | 237 | else if (_p == P1) |
Jeremy Johnson | 0601f80 | 2023-11-08 16:28:09 +0000 | [diff] [blame] | 238 | if (k == _KS / 2) |
| 239 | { |
| 240 | return s0 < 0 ? +0.5f : -0.5f; |
| 241 | } |
| 242 | else |
| 243 | { |
| 244 | return s0 < 0 ? (_B / std::sqrt(_KS)) * s1 : 0.f; |
| 245 | } |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame] | 246 | else |
| 247 | return 0.f; |
| 248 | } |
Tai Ly | 8ead6c4 | 2024-02-14 22:35:44 +0000 | [diff] [blame^] | 249 | uint32_t nextIndex() override |
Jeremy Johnson | c833081 | 2024-01-18 16:57:28 +0000 | [diff] [blame] | 250 | { |
| 251 | ASSERT_MSG(_set_data0.nextIndex() == _set_data1.nextIndex(), "Internal index inconsistency in GeneratorS4") |
| 252 | return _set_data0.nextIndex(); |
| 253 | } |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame] | 254 | |
| 255 | private: |
| 256 | uint32_t _p; |
| 257 | uint32_t _KS; |
| 258 | float _B; |
| 259 | PrimitiveGenerator _set_data0; |
| 260 | PrimitiveGenerator _set_data1; |
| 261 | }; |
| 262 | |
| 263 | // Test set 5 generator |
| 264 | // The aim of this test set is to check signed inputs of large range. |
| 265 | class GeneratorS5 : public TosaReference::IDotProductGenerator |
| 266 | { |
| 267 | public: |
| 268 | GeneratorS5(uint32_t p, uint32_t KS, float B) |
| 269 | : _p(p) |
| 270 | , _KS(KS) |
| 271 | , _B(B) |
| 272 | , _set_data(3 * 5 + p) |
| 273 | {} |
| 274 | float operator()(uint32_t k) override |
| 275 | { |
| 276 | unused(k); |
| 277 | const float s = _set_data(); |
| 278 | if (_p != P2) |
| 279 | return (_B / std::sqrt(_KS + 1)) * s; |
| 280 | else |
Jeremy Johnson | d1a08ce | 2023-10-18 17:22:21 +0100 | [diff] [blame] | 281 | return 0.f; |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame] | 282 | } |
Tai Ly | 8ead6c4 | 2024-02-14 22:35:44 +0000 | [diff] [blame^] | 283 | uint32_t nextIndex() override |
Jeremy Johnson | c833081 | 2024-01-18 16:57:28 +0000 | [diff] [blame] | 284 | { |
| 285 | return _set_data.nextIndex(); |
| 286 | } |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame] | 287 | |
| 288 | private: |
| 289 | uint32_t _p; |
| 290 | uint32_t _KS; |
| 291 | float _B; |
| 292 | PrimitiveGenerator _set_data; |
| 293 | }; |
| 294 | |
| 295 | float getBoundParameter(const DType& dataType, const DType& accType) |
| 296 | { |
| 297 | // Work out the bounds parameter value B for the given data and accumulator types |
| 298 | // Returns value > 0.f on success |
| 299 | float B = 0.f; |
| 300 | if (dataType == DType::DType_FP16) |
| 301 | { |
| 302 | if (accType == DType::DType_FP16) |
| 303 | B = 255.875f; // (1<<8) - (1/8); |
| 304 | else if (accType == DType::DType_FP32) |
| 305 | B = 65504.f; // (1<<16) - (1<<5); |
| 306 | } |
| 307 | else if (dataType == DType::DType_BF16) |
| 308 | { |
| 309 | if (accType == DType::DType_FP32) |
| 310 | B = 18374686479671623680.f; // (1<<64) - (1<<56) |
| 311 | } |
| 312 | else if (dataType == DType::DType_FP32) |
| 313 | { |
| 314 | if (accType == DType::DType_FP32) |
| 315 | B = 18446742974197923840.f; // (1<<64) - (1<<40) |
| 316 | } |
| 317 | return B; |
| 318 | } |
| 319 | |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 320 | } // namespace |
| 321 | |
| 322 | namespace TosaReference |
| 323 | { |
| 324 | |
| 325 | std::unique_ptr<IDotProductGenerator> pickDotProductGenerator(const GenerateConfig& cfg) |
| 326 | { |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame] | 327 | // Generators can only support 3 inputs |
| 328 | if (cfg.inputPos > 2) |
| 329 | return nullptr; |
| 330 | |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 331 | const DotProductInfo& dpinfo = cfg.dotProductInfo; |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame] | 332 | |
| 333 | float B = getBoundParameter(cfg.dataType, dpinfo.accType); |
| 334 | if (B > 0.f) |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 335 | { |
Jeremy Johnson | c833081 | 2024-01-18 16:57:28 +0000 | [diff] [blame] | 336 | auto param = cfg.inputPos; |
| 337 | if (cfg.opType == Op_FFT2D) |
| 338 | { |
| 339 | // We only use param of zero for FFT2D tensors |
| 340 | param = 0; |
| 341 | } |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame] | 342 | // Create the generator |
| 343 | switch (dpinfo.s) |
| 344 | { |
| 345 | case 0: |
Jeremy Johnson | c833081 | 2024-01-18 16:57:28 +0000 | [diff] [blame] | 346 | return std::make_unique<GeneratorS0>(param); |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame] | 347 | case 1: |
Jeremy Johnson | c833081 | 2024-01-18 16:57:28 +0000 | [diff] [blame] | 348 | return std::make_unique<GeneratorS1>(param, dpinfo.ks, B); |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame] | 349 | case 2: |
Jeremy Johnson | c833081 | 2024-01-18 16:57:28 +0000 | [diff] [blame] | 350 | return std::make_unique<GeneratorS2>(param, dpinfo.ks); |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame] | 351 | case 3: |
Jeremy Johnson | c833081 | 2024-01-18 16:57:28 +0000 | [diff] [blame] | 352 | return std::make_unique<GeneratorS3>(param); |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame] | 353 | case 4: |
Jeremy Johnson | c833081 | 2024-01-18 16:57:28 +0000 | [diff] [blame] | 354 | return std::make_unique<GeneratorS4>(param, dpinfo.ks, B); |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame] | 355 | case 5: |
Jeremy Johnson | c833081 | 2024-01-18 16:57:28 +0000 | [diff] [blame] | 356 | return std::make_unique<GeneratorS5>(param, dpinfo.ks, B); |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame] | 357 | default: |
Jeremy Johnson | fc5e34e | 2023-10-24 14:45:12 +0100 | [diff] [blame] | 358 | WARNING("[Generator][DP] Unsupported dot product test series for generator."); |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame] | 359 | return nullptr; |
| 360 | } |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 361 | } |
Jeremy Johnson | fc5e34e | 2023-10-24 14:45:12 +0100 | [diff] [blame] | 362 | WARNING("[Generator][DP] Unsupported data types for generator."); |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 363 | return nullptr; |
| 364 | } |
| 365 | |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame] | 366 | } // namespace TosaReference |