blob: fee9fef7f10b20182f096dd7fd9f729cf3505917 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Won Jeon2c34b462024-02-06 18:37:00 +00002// Copyright (c) 2020-2024, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16/*
17 * Filename: src/arith_util.h
18 * Description:
19 * arithmetic utility macro, include:
20 * fp16 (float16_t ) type alias
21 * bitwise operation
22 * fix point arithmetic
23 * fp16 type conversion(in binary translation)
24 * fp16 arithmetic (disguised with fp32 now)
TatWai Chong08fe7a52024-03-21 14:34:33 -070025 * and include the arithmetic helpers listed in Section 4.3.1. of the spec
Eric Kunzee5e26762020-10-13 16:11:07 -070026 */
27
28#ifndef ARITH_UTIL_H
29#define ARITH_UTIL_H
30
31#include <fenv.h>
32#include <math.h>
33#define __STDC_LIMIT_MACROS //enable min/max of plain data type
Tai Lya4d748b2023-03-28 22:06:56 +000034#include "dtype.h"
James Ward24dbc422022-10-19 12:20:31 +010035#include "func_config.h"
James Wardee256692022-11-15 11:36:47 +000036#include "func_debug.h"
37#include "half.hpp"
Eric Kunzee5e26762020-10-13 16:11:07 -070038#include "inttypes.h"
TatWai Chong08fe7a52024-03-21 14:34:33 -070039#include "ops/template_types.h"
James Wardee256692022-11-15 11:36:47 +000040#include <bitset>
Eric Kunzee5e26762020-10-13 16:11:07 -070041#include <cassert>
Eric Kunzee5e26762020-10-13 16:11:07 -070042#include <limits>
43#include <stdint.h>
44#include <typeinfo>
45
James Ward24dbc422022-10-19 12:20:31 +010046using namespace tosa;
Eric Kunzee5e26762020-10-13 16:11:07 -070047using namespace std;
Tai Lya4d748b2023-03-28 22:06:56 +000048using namespace TosaReference;
Eric Kunzee5e26762020-10-13 16:11:07 -070049
50inline size_t _count_one(uint64_t val)
51{
52 size_t count = 0;
53 for (; val; count++)
54 {
55 val &= val - 1;
56 }
57 return count;
58}
59
60template <typename T>
61inline size_t _integer_log2(T val)
62{
63 size_t result = 0;
64 while (val >>= 1)
65 {
66 ++result;
67 }
68 return result;
69}
70
71template <typename T>
72inline size_t _count_leading_zeros(T val)
73{
74 size_t size = sizeof(T) * 8;
75 size_t count = 0;
76 T msb = static_cast<T>(1) << (size - 1);
77 for (size_t i = 0; i < size; i++)
78 {
79 if (!((val << i) & msb))
80 count++;
81 else
82 break;
83 }
84 return count;
85}
86
87template <typename T>
88inline size_t _count_leading_ones(T val)
89{
90 size_t size = sizeof(T) * 8;
91 size_t count = 0;
92 T msb = static_cast<T>(1) << (size - 1);
93 for (size_t i = 0; i < size; i++)
94 {
95 if ((val << i) & msb)
96 count++;
97 else
98 break;
99 }
100 return count;
101}
102
103#define MAX(a, b) ((a) > (b) ? (a) : (b))
104#define MIN(a, b) ((a) < (b) ? (a) : (b))
105// Compute ceiling of (a/b)
106#define DIV_CEIL(a, b) ((a) % (b) ? ((a) / (b) + 1) : ((a) / (b)))
107
108// Returns a mask of 1's of this size
Jerry Gecf305db2023-03-06 13:07:36 -0800109#define ONES_MASK(SIZE) ((uint64_t)((SIZE) >= 64 ? UINT64_C(0xffffffffffffffff) : (UINT64_C(1) << (SIZE)) - 1))
Eric Kunzee5e26762020-10-13 16:11:07 -0700110
111// Returns a field of bits from HIGH_BIT to LOW_BIT, right-shifted
112// include both side, equivalent VAL[LOW_BIT:HIGH_BIT] in verilog
113
114#define BIT_FIELD(HIGH_BIT, LOW_BIT, VAL) (((uint64_t)(VAL) >> (LOW_BIT)) & ONES_MASK((HIGH_BIT) + 1 - (LOW_BIT)))
115
116// Returns a bit at a particular position
117#define BIT_EXTRACT(POS, VAL) (((uint64_t)(VAL) >> (POS)) & (1))
118
119// Use Brian Kernigahan's way: https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetKernighan
120// Does this need to support floating point type?
121// Not sure if static_cast is the right thing to do, try to be type safe first
122#define ONES_COUNT(VAL) (_count_one((uint64_t)(VAL)))
123
124#define SHIFT(SHF, VAL) (((SHF) > 0) ? ((VAL) << (SHF)) : ((SHF < 0) ? ((VAL) >> (-(SHF))) : (VAL)))
125#define ROUNDTO(A, B) ((A) % (B) == 0 ? (A) : ((A) / (B) + 1) * (B))
126#define ROUNDTOLOWER(A, B) (((A) / (B)) * (B))
127#define BIDIRECTIONAL_SHIFT(VAL, SHIFT) (((SHIFT) >= 0) ? ((VAL) << (SHIFT)) : ((VAL) >> (-(SHIFT))))
128#define ILOG2(VAL) (_integer_log2(VAL))
129
130// Get negative value (2's complement)
131#define NEGATIVE_8(VAL) ((uint8_t)(~(VAL) + 1))
132#define NEGATIVE_16(VAL) ((uint16_t)(~(VAL) + 1))
133#define NEGATIVE_32(VAL) ((uint32_t)(~(VAL) + 1))
134#define NEGATIVE_64(VAL) ((uint64_t)(~(VAL) + 1))
135// Convert a bit quanity to the minimum bytes required to hold those bits
136#define BITS_TO_BYTES(BITS) (ROUNDTO((BITS), 8) / 8)
137
138// Count leading zeros/ones for 8/16/32/64-bit operands
139// (I don't see an obvious way to collapse this into a size-independent set)
140// treated as unsigned
141#define LEADING_ZEROS_64(VAL) (_count_leading_zeros((uint64_t)(VAL)))
142#define LEADING_ZEROS_32(VAL) (_count_leading_zeros((uint32_t)(VAL)))
143#define LEADING_ZEROS_16(VAL) (_count_leading_zeros((uint16_t)(VAL)))
144#define LEADING_ZEROS_8(VAL) (_count_leading_zeros((uint8_t)(VAL)))
145#define LEADING_ZEROS(VAL) (_count_leading_zeros(VAL))
146
147#define LEADING_ONES_64(VAL) _count_leading_ones((uint64_t)(VAL))
148#define LEADING_ONES_32(VAL) _count_leading_ones((uint32_t)(VAL))
149#define LEADING_ONES_16(VAL) _count_leading_ones((uint16_t)(VAL))
150#define LEADING_ONES_8(VAL) _count_leading_ones((uint8_t)(VAL))
151#define LEADING_ONES(VAL) _count_leading_ones(VAL)
152// math operation
153// sign-extended for signed version
154// extend different return type (8, 16, 32) + (S, U)
155// Saturate a value at a certain bitwidth, signed and unsigned versions
156// Format is as followed: SATURATE_VAL_{saturation_sign}_{return_type}
157// for example
158// SATURATE_VAL_U_8U(8,300) will return uint8_t with value of 255(0xff)
159// SATURATE_VAL_S_32S(5,-48) will return int32_t with value of -16(0x10)
160// note that negative value can cast to unsigned return type using native uint(int) cast
161// so SATURATE_VAL_S_8U(5,-40) will have value 0'b1110000 which is in turn 224 in uint8_t
162
163template <typename T>
164constexpr T bitmask(const uint32_t width)
165{
166 ASSERT(width <= sizeof(T) * 8);
167 return width == sizeof(T) * 8 ? static_cast<T>(std::numeric_limits<uintmax_t>::max())
168 : (static_cast<T>(1) << width) - 1;
169}
170
171template <typename T>
172constexpr T minval(const uint32_t width)
173{
174 ASSERT(width <= sizeof(T) * 8);
175 return std::is_signed<T>::value ? -(static_cast<T>(1) << (width - 1)) : 0;
176}
177
178template <typename T>
179constexpr T maxval(const uint32_t width)
180{
181 ASSERT(width <= sizeof(T) * 8);
182 return bitmask<T>(width - std::is_signed<T>::value);
183}
184
185template <typename T>
186constexpr T saturate(const uint32_t width, const intmax_t value)
187{
188 // clang-format off
189 return static_cast<T>(
190 std::min(
191 std::max(
192 value,
193 static_cast<intmax_t>(minval<T>(width))
194 ),
195 static_cast<intmax_t>(maxval<T>(width))
196 )
197 );
198 // clang-format on
199}
200
James Ward24dbc422022-10-19 12:20:31 +0100201inline void float_trunc_bytes(float* src)
202{
203 /* Set the least significant two bytes to zero for the input float value.*/
Eric Kunze2d7e4b12024-01-23 17:56:35 -0800204 uint32_t* ptr = reinterpret_cast<uint32_t*>(src);
205 *ptr = *ptr & UINT32_C(0xffff0000);
James Ward24dbc422022-10-19 12:20:31 +0100206}
207
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000208inline void truncateFloatToBFloat(float* src, int64_t size)
209{
James Ward24dbc422022-10-19 12:20:31 +0100210 /* Set the least significant two bytes to zero for each float
211 value in the input src buffer. */
212 ASSERT_MEM(src);
213 ASSERT_MSG(size > 0, "Size of src (representing number of values in src) must be a positive integer.");
214 for (; size != 0; src++, size--)
215 {
216 float_trunc_bytes(src);
217 }
218}
219
220inline bool checkValidBFloat(float src)
221{
222 /* Checks if the least significant two bytes are zero. */
Eric Kunze2d7e4b12024-01-23 17:56:35 -0800223 uint32_t* ptr = reinterpret_cast<uint32_t*>(&src);
224 return (*ptr & UINT32_C(0x0000ffff)) == 0;
James Ward24dbc422022-10-19 12:20:31 +0100225}
226
Tai Lya4d748b2023-03-28 22:06:56 +0000227template <TOSA_REF_TYPE Dtype>
James Ward24dbc422022-10-19 12:20:31 +0100228float fpTrunc(float f_in)
229{
Tai Lya4d748b2023-03-28 22:06:56 +0000230 /* Truncates a float value based on the TOSA_REF_TYPE it represents.*/
James Ward24dbc422022-10-19 12:20:31 +0100231 switch (Dtype)
232 {
Tai Lya4d748b2023-03-28 22:06:56 +0000233 case TOSA_REF_TYPE_BF16:
James Ward24dbc422022-10-19 12:20:31 +0100234 truncateFloatToBFloat(&f_in, 1);
235 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000236 case TOSA_REF_TYPE_FP16:
James Wardee256692022-11-15 11:36:47 +0000237 // Cast to temporary float16 value before casting back to float32
238 {
239 half_float::half h = half_float::half_cast<half_float::half, float>(f_in);
240 f_in = half_float::half_cast<float, half_float::half>(h);
241 break;
242 }
Tai Lya4d748b2023-03-28 22:06:56 +0000243 case TOSA_REF_TYPE_FP32:
James Ward24dbc422022-10-19 12:20:31 +0100244 // No-op for fp32
245 break;
246 default:
Won Jeon2c34b462024-02-06 18:37:00 +0000247 ASSERT_MSG(false, "TOSA_REF_TYPE %s should not be float-cast.", EnumNameTOSAREFTYPE(Dtype));
James Ward24dbc422022-10-19 12:20:31 +0100248 }
249 return f_in;
250}
251
TatWai Chong08fe7a52024-03-21 14:34:33 -0700252// return the maximum value when interpreting type T as a signed value.
253template <TOSA_REF_TYPE Dtype>
254int32_t getSignedMaximum()
255{
256 if (Dtype == TOSA_REF_TYPE_INT8 || Dtype == TOSA_REF_TYPE_UINT8)
257 return GetQMax<TOSA_REF_TYPE_INT8>::value;
258
259 if (Dtype == TOSA_REF_TYPE_INT16 || Dtype == TOSA_REF_TYPE_UINT16)
260 return GetQMax<TOSA_REF_TYPE_INT16>::value;
261
262 if (Dtype == TOSA_REF_TYPE_INT32)
263 return GetQMax<TOSA_REF_TYPE_INT32>::value;
264
265 FATAL_ERROR("Get maximum_s for the dtype input is not supported");
266 return 0;
267}
268
269// return the minimum value when interpreting type T as a signed value.
270template <TOSA_REF_TYPE Dtype>
271int32_t getSignedMinimum()
272{
273 if (Dtype == TOSA_REF_TYPE_INT8 || Dtype == TOSA_REF_TYPE_UINT8)
274 return GetQMin<TOSA_REF_TYPE_INT8>::value;
275
276 if (Dtype == TOSA_REF_TYPE_INT16 || Dtype == TOSA_REF_TYPE_UINT16)
277 return GetQMin<TOSA_REF_TYPE_INT16>::value;
278
279 if (Dtype == TOSA_REF_TYPE_INT32)
280 return GetQMin<TOSA_REF_TYPE_INT32>::value;
281
282 FATAL_ERROR("Get minimum_s for the dtype input is not supported");
283 return 0;
284}
285
286// return the maximum value when interpreting type T as an unsigned value.
287template <TOSA_REF_TYPE Dtype>
288int32_t getUnsignedMaximum()
289{
290 if (Dtype == TOSA_REF_TYPE_INT8 || Dtype == TOSA_REF_TYPE_UINT8)
291 return GetQMax<TOSA_REF_TYPE_UINT8>::value;
292
293 if (Dtype == TOSA_REF_TYPE_INT16 || Dtype == TOSA_REF_TYPE_UINT16)
294 return GetQMax<TOSA_REF_TYPE_UINT16>::value;
295
296 if (Dtype == TOSA_REF_TYPE_INT32)
297 return std::numeric_limits<uint32_t>::max();
298
299 FATAL_ERROR("Get maximum_u for the dtype input is not supported");
300 return 0;
301}
302
303// return the minimum value when interpreting type T as an unsigned value.
304template <TOSA_REF_TYPE Dtype>
305int32_t getUnsignedMinimum()
306{
307 if (Dtype == TOSA_REF_TYPE_INT8 || Dtype == TOSA_REF_TYPE_UINT8)
308 return GetQMin<TOSA_REF_TYPE_UINT8>::value;
309
310 if (Dtype == TOSA_REF_TYPE_INT16 || Dtype == TOSA_REF_TYPE_UINT16)
311 return GetQMin<TOSA_REF_TYPE_UINT16>::value;
312
313 if (Dtype == TOSA_REF_TYPE_INT32)
314 return std::numeric_limits<uint32_t>::min();
315
316 FATAL_ERROR("Get minimum_u for the dtype input is not supported");
317 return 0;
318}
319
Eric Kunzee5e26762020-10-13 16:11:07 -0700320#endif /* _ARITH_UTIL_H */