Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 1 | // Copyright (c) 2023, ARM Limited. |
| 2 | // |
| 3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | // you may not use this file except in compliance with the License. |
| 5 | // You may obtain a copy of the License at |
| 6 | // |
| 7 | // http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | // |
| 9 | // Unless required by applicable law or agreed to in writing, software |
| 10 | // distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | // See the License for the specific language governing permissions and |
| 13 | // limitations under the License. |
| 14 | |
| 15 | #include "generate_dot_product.h" |
| 16 | #include "generate_utils.h" |
| 17 | |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame^] | 18 | #include <cmath> |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 19 | #include <cstdint> |
| 20 | |
| 21 | namespace |
| 22 | { |
| 23 | |
| 24 | // Input index global variables |
| 25 | inline constexpr uint32_t P0 = 0; |
| 26 | inline constexpr uint32_t P1 = 1; |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame^] | 27 | inline constexpr uint32_t P2 = 2; |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 28 | |
| 29 | // Unused helper function |
| 30 | template <typename... Args> |
| 31 | inline void unused(Args&&...) |
| 32 | {} |
| 33 | |
| 34 | // Primitive generator class |
| 35 | // |
| 36 | // Yields a new value on function operator access and increases the |
| 37 | // index by one |
| 38 | class PrimitiveGenerator |
| 39 | { |
| 40 | public: |
| 41 | PrimitiveGenerator(uint32_t S) |
| 42 | : _S(S) |
| 43 | , _m(0) |
| 44 | , _r(0) |
| 45 | , _index(0) |
| 46 | { |
| 47 | _m = (8 * _S + 1) * 0x705A5E75; |
| 48 | _r = _m + 1; |
| 49 | } |
| 50 | |
| 51 | [[nodiscard]] float operator()() |
| 52 | { |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 53 | float sign = (_r >> 31) == 0 ? +1 : -1; |
| 54 | float pseudo = sign * (float)(_r & 0x7FFFFFFF) / (float)(0x7FFFFFFF); |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame^] | 55 | |
| 56 | // Move index and calculate r value for the next index |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 57 | ++_index; |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame^] | 58 | _r = _r * _m + 1; |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 59 | |
| 60 | return pseudo; |
| 61 | } |
| 62 | |
| 63 | uint32_t index() |
| 64 | { |
| 65 | return _index; |
| 66 | } |
| 67 | |
| 68 | private: |
| 69 | uint32_t _S; |
| 70 | uint32_t _m; |
| 71 | uint32_t _r; |
| 72 | uint32_t _index; |
| 73 | }; |
| 74 | |
| 75 | //----------------------------------------------------------------------------// |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame^] | 76 | // State generators - equivalent to tosa_mi_data() in the TOSA specification |
| 77 | // |
| 78 | // Each call to the generator returns the next generated value with an |
| 79 | // auto incrementing index |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 80 | //----------------------------------------------------------------------------// |
| 81 | |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame^] | 82 | // Test set 0 generator |
| 83 | // The aim of this generator is to check that sum of products with zero gives zero result. |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 84 | class GeneratorS0 : public TosaReference::IDotProductGenerator |
| 85 | { |
| 86 | public: |
| 87 | GeneratorS0(uint32_t p) |
| 88 | : _p(p) |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame^] | 89 | , _set_data0(2 * 0) |
| 90 | , _set_data1(2 * 0 + 1) |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 91 | {} |
| 92 | float operator()(uint32_t k) override |
| 93 | { |
| 94 | unused(k); |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame^] | 95 | const float s0 = _set_data0(); |
| 96 | const float s1 = _set_data1(); |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 97 | if (_p == P0) |
| 98 | return s0 < 0.f ? 0.f : s1; |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame^] | 99 | else if (_p == P1) |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 100 | return s0 < 0.f ? s1 : 0.f; |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame^] | 101 | else |
| 102 | return 0.f; |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 103 | } |
| 104 | |
| 105 | private: |
| 106 | uint32_t _p; |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame^] | 107 | PrimitiveGenerator _set_data0; |
| 108 | PrimitiveGenerator _set_data1; |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 109 | }; |
| 110 | |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame^] | 111 | // Test set 1 generator |
| 112 | // The aim of this test set is to check values with large exponents. |
| 113 | class GeneratorS1 : public TosaReference::IDotProductGenerator |
| 114 | { |
| 115 | public: |
| 116 | GeneratorS1(uint32_t p, uint32_t KS, float B) |
| 117 | : _p(p) |
| 118 | , _KS(KS) |
| 119 | , _B(B) |
| 120 | , _set_data(3 * 1 + p) |
| 121 | {} |
| 122 | float operator()(uint32_t k) override |
| 123 | { |
| 124 | unused(k); |
| 125 | const float s = _set_data(); |
| 126 | float v = 0.75f + 0.25f * s; |
| 127 | if (_p != P2) |
| 128 | return (_B / std::sqrt(_KS + 1)) * v; |
| 129 | else |
| 130 | return (_B * _B / (_KS + 1)) * v; |
| 131 | } |
| 132 | |
| 133 | private: |
| 134 | uint32_t _p; |
| 135 | uint32_t _KS; |
| 136 | float _B; |
| 137 | PrimitiveGenerator _set_data; |
| 138 | }; |
| 139 | |
| 140 | // Test set 2 generator |
| 141 | // The aim of this test set is to check rounding error when accumulating small values |
| 142 | // onto a large value. In this case the small values are of similar magnitude. If the |
| 143 | // implementation changes the order of the sum, then the test data must also be reordered |
| 144 | // so that the largest values occur first in the sum. |
| 145 | class GeneratorS2 : public TosaReference::IDotProductGenerator |
| 146 | { |
| 147 | public: |
| 148 | GeneratorS2(uint32_t p, uint32_t KS) |
| 149 | : _p(p) |
| 150 | , _KS(KS) |
| 151 | , _set_data(2 * 2 + p) |
| 152 | {} |
| 153 | float operator()(uint32_t k) override |
| 154 | { |
| 155 | const float s = _set_data(); |
| 156 | if (_p != P2) |
| 157 | return k == 0 ? 1.f : s / std::sqrt(_KS); |
| 158 | else |
| 159 | return 0.f; |
| 160 | } |
| 161 | |
| 162 | private: |
| 163 | uint32_t _p; |
| 164 | uint32_t _KS; |
| 165 | PrimitiveGenerator _set_data; |
| 166 | }; |
| 167 | |
| 168 | // Test set 3 generator |
| 169 | // The aim of this test set is to check rounding error when accumulating small values |
| 170 | // onto a large value. In this case the small values are of varying magnitude. If the |
| 171 | // implementation changes the order of the sum, then the test data must also be reordered |
| 172 | // so that the largest values occur first in the sum. |
| 173 | class GeneratorS3 : public TosaReference::IDotProductGenerator |
| 174 | { |
| 175 | public: |
| 176 | GeneratorS3(uint32_t p) |
| 177 | : _p(p) |
| 178 | , _set_data(2 * 3 + p) |
| 179 | {} |
| 180 | float operator()(uint32_t k) override |
| 181 | { |
| 182 | const float s0 = _set_data(); |
| 183 | const float s1 = _set_data(); |
| 184 | if (_p != P2) |
| 185 | return k == 0 ? 16.f : std::exp(2 * s0) * s1; |
| 186 | else |
| 187 | return 0.f; |
| 188 | } |
| 189 | |
| 190 | private: |
| 191 | uint32_t _p; |
| 192 | PrimitiveGenerator _set_data; |
| 193 | }; |
| 194 | |
| 195 | // Test set 4 generator |
| 196 | // The aim of this test set is to check a mixture of zero and non-zero products. |
| 197 | class GeneratorS4 : public TosaReference::IDotProductGenerator |
| 198 | { |
| 199 | public: |
| 200 | GeneratorS4(uint32_t p, uint32_t KS, float B) |
| 201 | : _p(p) |
| 202 | , _KS(KS) |
| 203 | , _B(B) |
| 204 | , _set_data0(2 * 4 + 0) |
| 205 | , _set_data1(2 * 4 + 1) |
| 206 | {} |
| 207 | float operator()(uint32_t k) override |
| 208 | { |
| 209 | const float s0 = _set_data0(); |
| 210 | const float s1 = _set_data1(); |
| 211 | if (_p == P0) |
| 212 | return (k == _KS / 2) ? +0.5f : s0 < 0 ? 0.f : (_B / std::sqrt(_KS)) * s1; |
| 213 | else if (_p == P1) |
| 214 | return (k == _KS / 2) ? -0.5f : s0 < 0 ? (_B / std::sqrt(_KS)) * s1 : 0.f; |
| 215 | else |
| 216 | return 0.f; |
| 217 | } |
| 218 | |
| 219 | private: |
| 220 | uint32_t _p; |
| 221 | uint32_t _KS; |
| 222 | float _B; |
| 223 | PrimitiveGenerator _set_data0; |
| 224 | PrimitiveGenerator _set_data1; |
| 225 | }; |
| 226 | |
| 227 | // Test set 5 generator |
| 228 | // The aim of this test set is to check signed inputs of large range. |
| 229 | class GeneratorS5 : public TosaReference::IDotProductGenerator |
| 230 | { |
| 231 | public: |
| 232 | GeneratorS5(uint32_t p, uint32_t KS, float B) |
| 233 | : _p(p) |
| 234 | , _KS(KS) |
| 235 | , _B(B) |
| 236 | , _set_data(3 * 5 + p) |
| 237 | {} |
| 238 | float operator()(uint32_t k) override |
| 239 | { |
| 240 | unused(k); |
| 241 | const float s = _set_data(); |
| 242 | if (_p != P2) |
| 243 | return (_B / std::sqrt(_KS + 1)) * s; |
| 244 | else |
| 245 | return (_B * _B / (_KS + 1)) * s; |
| 246 | } |
| 247 | |
| 248 | private: |
| 249 | uint32_t _p; |
| 250 | uint32_t _KS; |
| 251 | float _B; |
| 252 | PrimitiveGenerator _set_data; |
| 253 | }; |
| 254 | |
| 255 | float getBoundParameter(const DType& dataType, const DType& accType) |
| 256 | { |
| 257 | // Work out the bounds parameter value B for the given data and accumulator types |
| 258 | // Returns value > 0.f on success |
| 259 | float B = 0.f; |
| 260 | if (dataType == DType::DType_FP16) |
| 261 | { |
| 262 | if (accType == DType::DType_FP16) |
| 263 | B = 255.875f; // (1<<8) - (1/8); |
| 264 | else if (accType == DType::DType_FP32) |
| 265 | B = 65504.f; // (1<<16) - (1<<5); |
| 266 | } |
| 267 | else if (dataType == DType::DType_BF16) |
| 268 | { |
| 269 | if (accType == DType::DType_FP32) |
| 270 | B = 18374686479671623680.f; // (1<<64) - (1<<56) |
| 271 | } |
| 272 | else if (dataType == DType::DType_FP32) |
| 273 | { |
| 274 | if (accType == DType::DType_FP32) |
| 275 | B = 18446742974197923840.f; // (1<<64) - (1<<40) |
| 276 | } |
| 277 | return B; |
| 278 | } |
| 279 | |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 280 | } // namespace |
| 281 | |
| 282 | namespace TosaReference |
| 283 | { |
| 284 | |
| 285 | std::unique_ptr<IDotProductGenerator> pickDotProductGenerator(const GenerateConfig& cfg) |
| 286 | { |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame^] | 287 | // Generators can only support 3 inputs |
| 288 | if (cfg.inputPos > 2) |
| 289 | return nullptr; |
| 290 | |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 291 | const DotProductInfo& dpinfo = cfg.dotProductInfo; |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame^] | 292 | |
| 293 | float B = getBoundParameter(cfg.dataType, dpinfo.accType); |
| 294 | if (B > 0.f) |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 295 | { |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame^] | 296 | // Create the generator |
| 297 | switch (dpinfo.s) |
| 298 | { |
| 299 | case 0: |
| 300 | return std::make_unique<GeneratorS0>(cfg.inputPos); |
| 301 | case 1: |
| 302 | return std::make_unique<GeneratorS1>(cfg.inputPos, dpinfo.ks, B); |
| 303 | case 2: |
| 304 | return std::make_unique<GeneratorS2>(cfg.inputPos, dpinfo.ks); |
| 305 | case 3: |
| 306 | return std::make_unique<GeneratorS3>(cfg.inputPos); |
| 307 | case 4: |
| 308 | return std::make_unique<GeneratorS4>(cfg.inputPos, dpinfo.ks, B); |
| 309 | case 5: |
| 310 | return std::make_unique<GeneratorS5>(cfg.inputPos, dpinfo.ks, B); |
| 311 | default: |
| 312 | return nullptr; |
| 313 | } |
Jeremy Johnson | b20b0c9 | 2023-10-04 14:17:55 +0100 | [diff] [blame] | 314 | } |
| 315 | return nullptr; |
| 316 | } |
| 317 | |
Jeremy Johnson | 59b307d | 2023-10-04 14:17:26 +0100 | [diff] [blame^] | 318 | } // namespace TosaReference |