blob: 831ad74b1924b3f3c6c736f0e263501e60dc5928 [file] [log] [blame]
Tai Lyce911a22024-03-21 17:01:14 +00001// Copyright (c) 2024, ARM Limited.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#ifndef TOSA_FLOAT_UTILS_H_
16#define TOSA_FLOAT_UTILS_H_
17
18#include <algorithm>
19#include <cstdint>
20#include <limits>
21#include <type_traits>
22#if defined(__cpp_lib_bit_cast)
23#include <bit>
24#endif // defined(__cpp_lib_bit_cast)
25
26namespace tosa
27{
28
29namespace float_support
30{
31
32struct hidden
33{};
34
35#if defined(__cpp_lib_bit_cast)
36#define BITCAST_CONSTEXPR constexpr inline
37
38constexpr inline int32_t get_bits(const float& f)
39{
40 return std::bit_cast<int32_t>(f);
41}
42constexpr inline float from_bits(const int32_t& i)
43{
44 return std::bit_cast<float>(i);
45}
46
47#else
48#define BITCAST_CONSTEXPR inline
49
50union ufloat32
51{
52 constexpr ufloat32(const float& x)
53 : f(x)
54 {}
55 constexpr ufloat32(const int32_t& x)
56 : i(x)
57 {}
58
59 float f;
60 int32_t i;
61};
62
63inline int32_t get_bits(const float& f)
64{
65 return ufloat32(f).i;
66}
67inline float from_bits(const int32_t& i)
68{
69 return ufloat32(i).f;
70}
71#endif
72
73} // namespace float_support
74
75template <typename storage_t,
76 size_t n_exp_bits,
77 bool has_nan,
78 bool with_denorm,
79 bool with_infinity,
80 std::enable_if_t<(n_exp_bits + 1 < sizeof(storage_t) * 8), bool> = true>
81class float_t
82{
83 storage_t m_data = 0;
84
85public:
86 static constexpr size_t n_exponent_bits = n_exp_bits;
87 static constexpr size_t n_significand_bits = sizeof(storage_t) * 8 - 1 - n_exp_bits;
88 static constexpr int64_t exponent_bias = (1 << (n_exp_bits - 1)) - 1;
89
90 /// \brief Construct a floating point type with the given bit
91 /// representation.
92 static constexpr float_t from_bits(storage_t bits)
93 {
94 return float_t(float_support::hidden(), bits);
95 }
96
97 /// \brief Construct a float from the given sign, exponent and significand
98 /// bits.
99 static constexpr float_t from_bits(bool pm, storage_t e, storage_t s)
100 {
101 storage_t bits = pm ? 1 : 0;
102
103 bits <<= n_exp_bits;
104 bits |= e;
105
106 bits <<= n_significand_bits;
107 if (with_denorm || e)
108 bits |= s;
109
110 return float_t(float_support::hidden(), bits);
111 }
112
113 /// \brief (Hidden) Construct a float type from a given bit pattern
114 constexpr float_t(const float_support::hidden&, storage_t bits)
115 : m_data(bits)
116 {}
117
118 constexpr float_t()
119 : m_data(0)
120 {}
121 constexpr float_t(const float_t& other)
122 : m_data(other.m_data)
123 {}
124
125 /// \brief Cast to a different floating point representation.
126 template <typename other_storage_t,
127 size_t other_n_exp_bits,
128 bool other_has_nan,
129 bool other_has_denorm,
130 bool other_has_infinity>
131 constexpr inline
132 operator float_t<other_storage_t, other_n_exp_bits, other_has_nan, other_has_denorm, other_has_infinity>() const
133 {
134 using other_float_t =
135 float_t<other_storage_t, other_n_exp_bits, other_has_nan, other_has_denorm, other_has_infinity>;
136
137 // Shortcut for types which are fundamentally similar (e.g., bf16 ->
138 // fp32)
139 if constexpr (n_exp_bits == other_n_exp_bits && sizeof(other_storage_t) >= sizeof(storage_t) &&
140 has_nan == other_has_nan)
141 {
142 return other_float_t::from_bits(static_cast<other_storage_t>(m_data)
143 << (sizeof(other_storage_t) - sizeof(storage_t)) * 8);
144 }
145
146 // Get initial values for the new floating point type
147 const bool sign_bit = m_data < 0;
148 int64_t new_exponent_bits = 0;
149 uint64_t new_significand = 0;
150
151 if (is_nan() || is_infinity())
152 {
153 new_exponent_bits = (1 << other_n_exp_bits) - 1;
154
155 if (is_nan())
156 {
157 if constexpr (other_has_infinity)
158 {
159 // Copy across the `not_quiet bit`; set the LSB. Don't
160 // attempt to copy across any of the rest of the payload.
161 new_significand =
162 0x1 | (((significand() >> (n_significand_bits - 1)) & 1) << other_float_t::n_significand_bits);
163 }
164 else
165 {
166 new_significand = (1ul << other_float_t::n_significand_bits) - 1;
167 }
168 }
169 else if constexpr (!other_has_infinity)
170 {
171 new_significand = (1ul << other_float_t::n_significand_bits) - (other_has_nan ? 2 : 1);
172 }
173 }
174 else if (!is_zero())
175 {
176 const int64_t this_exponent_bits = exponent_bits();
177 {
178 constexpr int64_t exponent_rebias = other_float_t::exponent_bias - exponent_bias;
179 new_exponent_bits = std::max(this_exponent_bits + exponent_rebias, exponent_rebias + 1);
180 }
181 new_significand = this->significand() << (64 - n_significand_bits);
182
183 // Normalise subnormals
184 if (this_exponent_bits == 0)
185 {
186 // Shift the most-significant 1 out of the magnitude to convert
187 // it to a significand. Modify the exponent accordingly.
188 uint8_t shift = __builtin_clzl(new_significand) + 1;
189 new_exponent_bits -= shift;
190 new_significand <<= shift;
191 }
192
193 // Align the significand for the output type
194 uint32_t shift = 64 - other_float_t::n_significand_bits;
195 const bool other_is_subnormal = new_exponent_bits <= 0;
196 if (other_is_subnormal)
197 {
198 shift += 1 - new_exponent_bits;
199 new_exponent_bits = 0;
200 }
201
202 const uint64_t shift_out = shift == 64 ? new_significand : new_significand & ((1ll << shift) - 1);
203 new_significand = shift == 64 ? 0 : new_significand >> shift;
204
205 // Reinsert the most-significant-one if this is a subnormal in the
206 // output type.
207 new_significand |= (other_is_subnormal ? 1ll : 0) << (64 - shift);
208
209 // Apply rounding based on the bits shifted out of the significand
210 const uint64_t shift_half = 1ll << (shift - 1);
211 if (shift_out > shift_half || (shift_out == shift_half && (new_significand & 1)))
212 {
213 new_significand += 1;
214
215 // Handle the case that the significand overflowed due to
216 // rounding
217 constexpr uint64_t max_significand = (1ll << other_float_t::n_significand_bits) - 1;
218 if (new_significand > max_significand)
219 {
220 new_significand = 0;
221 new_exponent_bits++;
222 }
223 }
224
225 // Saturate to infinity if the exponent is larger than can be
226 // represented in the output type. This can only occur if the size
227 // of the exponent of the new type is not greater than the exponent
228 // of the old type.
229 if constexpr (other_n_exp_bits <= n_exp_bits)
230 {
231 constexpr int64_t inf_exp_bits = (1ll << other_n_exp_bits) - 1;
232 if (new_exponent_bits >= inf_exp_bits)
233 {
234 new_exponent_bits = inf_exp_bits;
235 new_significand =
236 other_has_infinity ? 0 : (1ul << other_float_t::n_significand_bits) - (other_has_nan ? 2 : 1);
237 }
238 }
239 }
240
241 return other_float_t::from_bits(sign_bit, new_exponent_bits, new_significand);
242 }
243
244 /// \brief Convert from a 32-bit floating point value
245 BITCAST_CONSTEXPR
246 float_t(const float& f)
247 {
248 // If this format exactly represents the binary32 format then get
249 // the bits from the provided float; otherwise get a binary32
250 // representation and then convert to this format.
251 if constexpr (represents_binary32())
252 m_data = float_support::get_bits(f);
253 else
254 m_data = static_cast<float_t<storage_t, n_exp_bits, has_nan, with_denorm, with_infinity>>(
255 static_cast<float_t<int32_t, 8, true, true, true>>(f))
256 .m_data;
257 }
258
259 /// \brief Cast to a 32-bit floating point value
260 BITCAST_CONSTEXPR operator float() const
261 {
262 // If this format exactly represents the binary32 format then return
263 // a float; otherwise get a binary32 representation and then return
264 // a float.
265 if constexpr (represents_binary32())
266 return float_support::from_bits(m_data);
267 else
268 return static_cast<float>(this->operator float_t<int32_t, 8, true, true, true>());
269 }
270
271 /// \brief Return whether this type represents the IEEE754 binary32
272 /// format
273 constexpr static inline bool represents_binary32()
274 {
275 return std::is_same_v<storage_t, int32_t> && n_exp_bits == 8 && has_nan && with_denorm && with_infinity;
276 }
277
278 constexpr auto operator-() const
279 {
280 return from_bits(m_data ^ (1ll << (sizeof(storage_t) * 8 - 1)));
281 }
282
283 constexpr bool is_subnormal() const
284 {
285 return exponent_bits() == 0 && significand() != 0;
286 }
287
288 constexpr bool is_zero() const
289 {
290 return exponent_bits() == 0 && significand() == 0;
291 }
292
293 constexpr bool is_nan() const
294 {
295 return has_nan && (exponent_bits() == (1ul << n_exponent_bits) - 1) &&
296 ((with_infinity && significand()) ||
297 (!with_infinity && significand() == (1ul << n_significand_bits) - 1));
298 }
299
300 constexpr bool is_infinity() const
301 {
302 return with_infinity && ((exponent_bits() == (1ul << n_exponent_bits) - 1) && !significand());
303 }
304
305 constexpr inline const storage_t& bits() const
306 {
307 return m_data;
308 }
309
310 /// \brief Get the exponent
311 constexpr inline int64_t exponent() const
312 {
313 return std::max<int64_t>(exponent_bits(), 1ul) - exponent_bias;
314 }
315
316 /// \brief Get the bits from the exponent
317 constexpr inline uint64_t exponent_bits() const
318 {
319 constexpr uint64_t mask = (1ul << n_exp_bits) - 1;
320 return (m_data >> n_significand_bits) & mask;
321 }
322
323 constexpr inline uint64_t significand() const
324 {
325 return m_data & ((1ul << n_significand_bits) - 1);
326 }
327
328 constexpr inline bool operator==(const float_t& other) const
329 {
330 return !is_nan() && !other.is_nan() && ((is_zero() && other.is_zero()) || bits() == other.bits());
331 }
332
333 constexpr inline float_t& operator+=(const float_t& rhs)
334 {
335 this->m_data = static_cast<float_t>(static_cast<float>(*this) + static_cast<float>(rhs)).bits();
336 return *this;
337 }
338};
339
340// This should probably be exported so we can use it elsewhere
341#undef BITCAST_CONSTEXPR
342
343namespace float_support
344{
345
346// Pre-C++23 these can't be computed as constexpr, so have to hardcode them
347
348template <int>
349struct digits10; // floor(log10(2) * (digits - 1)
350template <int>
351struct max_digits10; // ceil(log10(2) * digits + 1)
352template <int>
353struct min_exponent10; // floor(log10(2) * min_exponent)
354template <int>
355struct max_exponent10; // floor(log10(2) * max_exponent)
356
357template <>
358struct digits10<8>
359{
360 constexpr static inline int value = 2;
361};
362
363template <>
364struct max_digits10<8>
365{
366 constexpr static inline int value = 4;
367};
368
369template <>
370struct digits10<10>
371{
372 constexpr static inline int value = 2;
373};
374
375template <>
376struct max_digits10<10>
377{
378 constexpr static inline int value = 5;
379};
380
381template <>
382struct digits10<24>
383{
384 constexpr static inline int value = 6;
385};
386
387template <>
388struct max_digits10<24>
389{
390 constexpr static inline int value = 9;
391};
392
393template <>
394struct min_exponent10<-13>
395{
396 constexpr static inline int value = -3;
397};
398
399template <>
400struct max_exponent10<16>
401{
402 constexpr static inline int value = 4;
403};
404
405template <>
406struct min_exponent10<-125>
407{
408 constexpr static inline int value = -37;
409};
410
411template <>
412struct max_exponent10<128>
413{
414 constexpr static inline int value = 38;
415};
416
417template <int d>
418inline constexpr int digits10_v = digits10<d>::value;
419template <int d>
420inline constexpr int max_digits10_v = max_digits10<d>::value;
421
422template <int e>
423inline constexpr int min_exponent10_v = min_exponent10<e>::value;
424
425template <int e>
426inline constexpr int max_exponent10_v = max_exponent10<e>::value;
427
428} // namespace float_support
429
430} // namespace tosa
431
432namespace std
433{
434
435template <typename storage_t, size_t n_exp_bits, bool has_nan, bool has_denorm, bool has_inf>
436struct is_floating_point<tosa::float_t<storage_t, n_exp_bits, has_nan, has_denorm, has_inf>>
437 : std::integral_constant<bool, true>
438{};
439
440template <typename storage_t, size_t n_exp_bits, bool has_nan, bool with_denorm, bool with_inf>
441class numeric_limits<tosa::float_t<storage_t, n_exp_bits, has_nan, with_denorm, with_inf>>
442{
443 using this_float_t = tosa::float_t<storage_t, n_exp_bits, has_nan, with_denorm, with_inf>;
444
445public:
446 static constexpr bool is_specialized = true;
447
448 static constexpr auto min() noexcept
449 {
450 return this_float_t::from_bits(false, 1, 0);
451 }
452
453 static constexpr auto max() noexcept
454 {
455 return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 2,
456 (1 << this_float_t::n_significand_bits) - 1);
457 }
458
459 static constexpr auto lowest() noexcept
460 {
461 return -max();
462 }
463
464 static constexpr int digits = this_float_t::n_significand_bits + 1;
465 static constexpr int digits10 = tosa::float_support::digits10_v<digits>;
466 static constexpr int max_digits10 = tosa::float_support::max_digits10_v<digits>;
467
468 static constexpr bool is_signed = true;
469 static constexpr bool is_integer = false;
470 static constexpr bool is_exact = false;
471 static constexpr int radix = 2;
472
473 static constexpr auto epsilon() noexcept
474 {
475 return this_float_t::from_bits(false, this_float_t::exponent_bias - this_float_t::n_significand_bits, 0);
476 }
477
478 static constexpr auto round_error() noexcept
479 {
480 return this_float_t::from_bits(0, this_float_t::exponent_bias - 1, 0);
481 }
482
483 static constexpr int min_exponent = (1 - this_float_t::exponent_bias) + 1;
484 static constexpr int min_exponent10 = tosa::float_support::min_exponent10_v<min_exponent>;
485 static constexpr int max_exponent = this_float_t::exponent_bias + 1;
486 static constexpr int max_exponent10 = tosa::float_support::max_exponent10_v<max_exponent>;
487
488 static constexpr bool has_infinity = with_inf;
489 static constexpr bool has_quiet_NaN = has_nan;
490 static constexpr bool has_signaling_NaN = true;
491 static constexpr float_denorm_style has_denorm = with_denorm ? denorm_present : denorm_absent;
492 static constexpr bool has_denorm_loss = false;
493
494 static constexpr auto infinity() noexcept
495 {
496 if constexpr (with_inf)
497 {
498 return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, 0);
499 }
500 else
501 {
502 return this_float_t::from_bits(false, 0, 0);
503 }
504 }
505
506 static constexpr auto quiet_NaN() noexcept
507 {
508 return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1,
509 1 << (this_float_t::n_significand_bits - 1) | 1);
510 }
511
512 static constexpr auto signaling_NaN() noexcept
513 {
514 return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, 1);
515 }
516
517 static constexpr auto denorm_min() noexcept
518 {
519 return this_float_t::from_bits(false, 0, 1);
520 }
521
522 static constexpr bool is_iec559 = false;
523 static constexpr bool is_bounded = false;
524 static constexpr bool is_modulo = false;
525
526 static constexpr bool traps = false;
527 static constexpr bool tinyness_before = false;
528 static constexpr float_round_style round_style = round_to_nearest;
529};
530
531} // namespace std
532
533#endif // TOSA_FLOAT_UTILS_H_