blob: f0d184ca46037381b34be07b98df3bb0be530cf8 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Won Jeon2c34b462024-02-06 18:37:00 +00002// Copyright (c) 2020-2024, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16/*
17 * Filename: src/arith_util.h
18 * Description:
19 * arithmetic utility macro, include:
20 * fp16 (float16_t ) type alias
21 * bitwise operation
22 * fix point arithmetic
23 * fp16 type conversion(in binary translation)
24 * fp16 arithmetic (disguised with fp32 now)
25 */
26
27#ifndef ARITH_UTIL_H
28#define ARITH_UTIL_H
29
30#include <fenv.h>
31#include <math.h>
32#define __STDC_LIMIT_MACROS //enable min/max of plain data type
Tai Lya4d748b2023-03-28 22:06:56 +000033#include "dtype.h"
James Ward24dbc422022-10-19 12:20:31 +010034#include "func_config.h"
James Wardee256692022-11-15 11:36:47 +000035#include "func_debug.h"
36#include "half.hpp"
Eric Kunzee5e26762020-10-13 16:11:07 -070037#include "inttypes.h"
James Wardee256692022-11-15 11:36:47 +000038#include <bitset>
Eric Kunzee5e26762020-10-13 16:11:07 -070039#include <cassert>
Eric Kunzee5e26762020-10-13 16:11:07 -070040#include <limits>
41#include <stdint.h>
42#include <typeinfo>
43
James Ward24dbc422022-10-19 12:20:31 +010044using namespace tosa;
Eric Kunzee5e26762020-10-13 16:11:07 -070045using namespace std;
Tai Lya4d748b2023-03-28 22:06:56 +000046using namespace TosaReference;
Eric Kunzee5e26762020-10-13 16:11:07 -070047
48inline size_t _count_one(uint64_t val)
49{
50 size_t count = 0;
51 for (; val; count++)
52 {
53 val &= val - 1;
54 }
55 return count;
56}
57
58template <typename T>
59inline size_t _integer_log2(T val)
60{
61 size_t result = 0;
62 while (val >>= 1)
63 {
64 ++result;
65 }
66 return result;
67}
68
69template <typename T>
70inline size_t _count_leading_zeros(T val)
71{
72 size_t size = sizeof(T) * 8;
73 size_t count = 0;
74 T msb = static_cast<T>(1) << (size - 1);
75 for (size_t i = 0; i < size; i++)
76 {
77 if (!((val << i) & msb))
78 count++;
79 else
80 break;
81 }
82 return count;
83}
84
85template <typename T>
86inline size_t _count_leading_ones(T val)
87{
88 size_t size = sizeof(T) * 8;
89 size_t count = 0;
90 T msb = static_cast<T>(1) << (size - 1);
91 for (size_t i = 0; i < size; i++)
92 {
93 if ((val << i) & msb)
94 count++;
95 else
96 break;
97 }
98 return count;
99}
100
101#define MAX(a, b) ((a) > (b) ? (a) : (b))
102#define MIN(a, b) ((a) < (b) ? (a) : (b))
103// Compute ceiling of (a/b)
104#define DIV_CEIL(a, b) ((a) % (b) ? ((a) / (b) + 1) : ((a) / (b)))
105
106// Returns a mask of 1's of this size
Jerry Gecf305db2023-03-06 13:07:36 -0800107#define ONES_MASK(SIZE) ((uint64_t)((SIZE) >= 64 ? UINT64_C(0xffffffffffffffff) : (UINT64_C(1) << (SIZE)) - 1))
Eric Kunzee5e26762020-10-13 16:11:07 -0700108
109// Returns a field of bits from HIGH_BIT to LOW_BIT, right-shifted
110// include both side, equivalent VAL[LOW_BIT:HIGH_BIT] in verilog
111
112#define BIT_FIELD(HIGH_BIT, LOW_BIT, VAL) (((uint64_t)(VAL) >> (LOW_BIT)) & ONES_MASK((HIGH_BIT) + 1 - (LOW_BIT)))
113
114// Returns a bit at a particular position
115#define BIT_EXTRACT(POS, VAL) (((uint64_t)(VAL) >> (POS)) & (1))
116
117// Use Brian Kernigahan's way: https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetKernighan
118// Does this need to support floating point type?
119// Not sure if static_cast is the right thing to do, try to be type safe first
120#define ONES_COUNT(VAL) (_count_one((uint64_t)(VAL)))
121
122#define SHIFT(SHF, VAL) (((SHF) > 0) ? ((VAL) << (SHF)) : ((SHF < 0) ? ((VAL) >> (-(SHF))) : (VAL)))
123#define ROUNDTO(A, B) ((A) % (B) == 0 ? (A) : ((A) / (B) + 1) * (B))
124#define ROUNDTOLOWER(A, B) (((A) / (B)) * (B))
125#define BIDIRECTIONAL_SHIFT(VAL, SHIFT) (((SHIFT) >= 0) ? ((VAL) << (SHIFT)) : ((VAL) >> (-(SHIFT))))
126#define ILOG2(VAL) (_integer_log2(VAL))
127
128// Get negative value (2's complement)
129#define NEGATIVE_8(VAL) ((uint8_t)(~(VAL) + 1))
130#define NEGATIVE_16(VAL) ((uint16_t)(~(VAL) + 1))
131#define NEGATIVE_32(VAL) ((uint32_t)(~(VAL) + 1))
132#define NEGATIVE_64(VAL) ((uint64_t)(~(VAL) + 1))
133// Convert a bit quanity to the minimum bytes required to hold those bits
134#define BITS_TO_BYTES(BITS) (ROUNDTO((BITS), 8) / 8)
135
136// Count leading zeros/ones for 8/16/32/64-bit operands
137// (I don't see an obvious way to collapse this into a size-independent set)
138// treated as unsigned
139#define LEADING_ZEROS_64(VAL) (_count_leading_zeros((uint64_t)(VAL)))
140#define LEADING_ZEROS_32(VAL) (_count_leading_zeros((uint32_t)(VAL)))
141#define LEADING_ZEROS_16(VAL) (_count_leading_zeros((uint16_t)(VAL)))
142#define LEADING_ZEROS_8(VAL) (_count_leading_zeros((uint8_t)(VAL)))
143#define LEADING_ZEROS(VAL) (_count_leading_zeros(VAL))
144
145#define LEADING_ONES_64(VAL) _count_leading_ones((uint64_t)(VAL))
146#define LEADING_ONES_32(VAL) _count_leading_ones((uint32_t)(VAL))
147#define LEADING_ONES_16(VAL) _count_leading_ones((uint16_t)(VAL))
148#define LEADING_ONES_8(VAL) _count_leading_ones((uint8_t)(VAL))
149#define LEADING_ONES(VAL) _count_leading_ones(VAL)
150// math operation
151// sign-extended for signed version
152// extend different return type (8, 16, 32) + (S, U)
153// Saturate a value at a certain bitwidth, signed and unsigned versions
154// Format is as followed: SATURATE_VAL_{saturation_sign}_{return_type}
155// for example
156// SATURATE_VAL_U_8U(8,300) will return uint8_t with value of 255(0xff)
157// SATURATE_VAL_S_32S(5,-48) will return int32_t with value of -16(0x10)
158// note that negative value can cast to unsigned return type using native uint(int) cast
159// so SATURATE_VAL_S_8U(5,-40) will have value 0'b1110000 which is in turn 224 in uint8_t
160
161template <typename T>
162constexpr T bitmask(const uint32_t width)
163{
164 ASSERT(width <= sizeof(T) * 8);
165 return width == sizeof(T) * 8 ? static_cast<T>(std::numeric_limits<uintmax_t>::max())
166 : (static_cast<T>(1) << width) - 1;
167}
168
169template <typename T>
170constexpr T minval(const uint32_t width)
171{
172 ASSERT(width <= sizeof(T) * 8);
173 return std::is_signed<T>::value ? -(static_cast<T>(1) << (width - 1)) : 0;
174}
175
176template <typename T>
177constexpr T maxval(const uint32_t width)
178{
179 ASSERT(width <= sizeof(T) * 8);
180 return bitmask<T>(width - std::is_signed<T>::value);
181}
182
183template <typename T>
184constexpr T saturate(const uint32_t width, const intmax_t value)
185{
186 // clang-format off
187 return static_cast<T>(
188 std::min(
189 std::max(
190 value,
191 static_cast<intmax_t>(minval<T>(width))
192 ),
193 static_cast<intmax_t>(maxval<T>(width))
194 )
195 );
196 // clang-format on
197}
198
James Ward24dbc422022-10-19 12:20:31 +0100199inline void float_trunc_bytes(float* src)
200{
201 /* Set the least significant two bytes to zero for the input float value.*/
Eric Kunze2d7e4b12024-01-23 17:56:35 -0800202 uint32_t* ptr = reinterpret_cast<uint32_t*>(src);
203 *ptr = *ptr & UINT32_C(0xffff0000);
James Ward24dbc422022-10-19 12:20:31 +0100204}
205
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000206inline void truncateFloatToBFloat(float* src, int64_t size)
207{
James Ward24dbc422022-10-19 12:20:31 +0100208 /* Set the least significant two bytes to zero for each float
209 value in the input src buffer. */
210 ASSERT_MEM(src);
211 ASSERT_MSG(size > 0, "Size of src (representing number of values in src) must be a positive integer.");
212 for (; size != 0; src++, size--)
213 {
214 float_trunc_bytes(src);
215 }
216}
217
218inline bool checkValidBFloat(float src)
219{
220 /* Checks if the least significant two bytes are zero. */
Eric Kunze2d7e4b12024-01-23 17:56:35 -0800221 uint32_t* ptr = reinterpret_cast<uint32_t*>(&src);
222 return (*ptr & UINT32_C(0x0000ffff)) == 0;
James Ward24dbc422022-10-19 12:20:31 +0100223}
224
Tai Lya4d748b2023-03-28 22:06:56 +0000225template <TOSA_REF_TYPE Dtype>
James Ward24dbc422022-10-19 12:20:31 +0100226float fpTrunc(float f_in)
227{
Tai Lya4d748b2023-03-28 22:06:56 +0000228 /* Truncates a float value based on the TOSA_REF_TYPE it represents.*/
James Ward24dbc422022-10-19 12:20:31 +0100229 switch (Dtype)
230 {
Tai Lya4d748b2023-03-28 22:06:56 +0000231 case TOSA_REF_TYPE_BF16:
James Ward24dbc422022-10-19 12:20:31 +0100232 truncateFloatToBFloat(&f_in, 1);
233 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000234 case TOSA_REF_TYPE_FP16:
James Wardee256692022-11-15 11:36:47 +0000235 // Cast to temporary float16 value before casting back to float32
236 {
237 half_float::half h = half_float::half_cast<half_float::half, float>(f_in);
238 f_in = half_float::half_cast<float, half_float::half>(h);
239 break;
240 }
Tai Lya4d748b2023-03-28 22:06:56 +0000241 case TOSA_REF_TYPE_FP32:
James Ward24dbc422022-10-19 12:20:31 +0100242 // No-op for fp32
243 break;
244 default:
Won Jeon2c34b462024-02-06 18:37:00 +0000245 ASSERT_MSG(false, "TOSA_REF_TYPE %s should not be float-cast.", EnumNameTOSAREFTYPE(Dtype));
James Ward24dbc422022-10-19 12:20:31 +0100246 }
247 return f_in;
248}
249
Eric Kunzee5e26762020-10-13 16:11:07 -0700250#endif /* _ARITH_UTIL_H */