blob: cab43ba00a6bd91177f41c6c091f34d070117d1f [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Jerry Gecf305db2023-03-06 13:07:36 -08002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16/*
17 * Filename: src/arith_util.h
18 * Description:
19 * arithmetic utility macro, include:
20 * fp16 (float16_t ) type alias
21 * bitwise operation
22 * fix point arithmetic
23 * fp16 type conversion(in binary translation)
24 * fp16 arithmetic (disguised with fp32 now)
25 */
26
27#ifndef ARITH_UTIL_H
28#define ARITH_UTIL_H
29
30#include <fenv.h>
31#include <math.h>
32#define __STDC_LIMIT_MACROS //enable min/max of plain data type
Tai Lya4d748b2023-03-28 22:06:56 +000033#include "dtype.h"
James Ward24dbc422022-10-19 12:20:31 +010034#include "func_config.h"
James Wardee256692022-11-15 11:36:47 +000035#include "func_debug.h"
36#include "half.hpp"
Eric Kunzee5e26762020-10-13 16:11:07 -070037#include "inttypes.h"
James Wardee256692022-11-15 11:36:47 +000038#include <Eigen/Core>
39#include <bitset>
Eric Kunzee5e26762020-10-13 16:11:07 -070040#include <cassert>
41#include <iostream>
42#include <limits>
43#include <stdint.h>
44#include <typeinfo>
45
James Ward24dbc422022-10-19 12:20:31 +010046using namespace tosa;
Eric Kunzee5e26762020-10-13 16:11:07 -070047using namespace std;
Tai Lya4d748b2023-03-28 22:06:56 +000048using namespace TosaReference;
Eric Kunzee5e26762020-10-13 16:11:07 -070049
50inline size_t _count_one(uint64_t val)
51{
52 size_t count = 0;
53 for (; val; count++)
54 {
55 val &= val - 1;
56 }
57 return count;
58}
59
60template <typename T>
61inline size_t _integer_log2(T val)
62{
63 size_t result = 0;
64 while (val >>= 1)
65 {
66 ++result;
67 }
68 return result;
69}
70
71template <typename T>
72inline size_t _count_leading_zeros(T val)
73{
74 size_t size = sizeof(T) * 8;
75 size_t count = 0;
76 T msb = static_cast<T>(1) << (size - 1);
77 for (size_t i = 0; i < size; i++)
78 {
79 if (!((val << i) & msb))
80 count++;
81 else
82 break;
83 }
84 return count;
85}
86
87template <typename T>
88inline size_t _count_leading_ones(T val)
89{
90 size_t size = sizeof(T) * 8;
91 size_t count = 0;
92 T msb = static_cast<T>(1) << (size - 1);
93 for (size_t i = 0; i < size; i++)
94 {
95 if ((val << i) & msb)
96 count++;
97 else
98 break;
99 }
100 return count;
101}
102
103#define MAX(a, b) ((a) > (b) ? (a) : (b))
104#define MIN(a, b) ((a) < (b) ? (a) : (b))
105// Compute ceiling of (a/b)
106#define DIV_CEIL(a, b) ((a) % (b) ? ((a) / (b) + 1) : ((a) / (b)))
107
108// Returns a mask of 1's of this size
Jerry Gecf305db2023-03-06 13:07:36 -0800109#define ONES_MASK(SIZE) ((uint64_t)((SIZE) >= 64 ? UINT64_C(0xffffffffffffffff) : (UINT64_C(1) << (SIZE)) - 1))
Eric Kunzee5e26762020-10-13 16:11:07 -0700110
111// Returns a field of bits from HIGH_BIT to LOW_BIT, right-shifted
112// include both side, equivalent VAL[LOW_BIT:HIGH_BIT] in verilog
113
114#define BIT_FIELD(HIGH_BIT, LOW_BIT, VAL) (((uint64_t)(VAL) >> (LOW_BIT)) & ONES_MASK((HIGH_BIT) + 1 - (LOW_BIT)))
115
116// Returns a bit at a particular position
117#define BIT_EXTRACT(POS, VAL) (((uint64_t)(VAL) >> (POS)) & (1))
118
119// Use Brian Kernigahan's way: https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetKernighan
120// Does this need to support floating point type?
121// Not sure if static_cast is the right thing to do, try to be type safe first
122#define ONES_COUNT(VAL) (_count_one((uint64_t)(VAL)))
123
124#define SHIFT(SHF, VAL) (((SHF) > 0) ? ((VAL) << (SHF)) : ((SHF < 0) ? ((VAL) >> (-(SHF))) : (VAL)))
125#define ROUNDTO(A, B) ((A) % (B) == 0 ? (A) : ((A) / (B) + 1) * (B))
126#define ROUNDTOLOWER(A, B) (((A) / (B)) * (B))
127#define BIDIRECTIONAL_SHIFT(VAL, SHIFT) (((SHIFT) >= 0) ? ((VAL) << (SHIFT)) : ((VAL) >> (-(SHIFT))))
128#define ILOG2(VAL) (_integer_log2(VAL))
129
130// Get negative value (2's complement)
131#define NEGATIVE_8(VAL) ((uint8_t)(~(VAL) + 1))
132#define NEGATIVE_16(VAL) ((uint16_t)(~(VAL) + 1))
133#define NEGATIVE_32(VAL) ((uint32_t)(~(VAL) + 1))
134#define NEGATIVE_64(VAL) ((uint64_t)(~(VAL) + 1))
135// Convert a bit quanity to the minimum bytes required to hold those bits
136#define BITS_TO_BYTES(BITS) (ROUNDTO((BITS), 8) / 8)
137
138// Count leading zeros/ones for 8/16/32/64-bit operands
139// (I don't see an obvious way to collapse this into a size-independent set)
140// treated as unsigned
141#define LEADING_ZEROS_64(VAL) (_count_leading_zeros((uint64_t)(VAL)))
142#define LEADING_ZEROS_32(VAL) (_count_leading_zeros((uint32_t)(VAL)))
143#define LEADING_ZEROS_16(VAL) (_count_leading_zeros((uint16_t)(VAL)))
144#define LEADING_ZEROS_8(VAL) (_count_leading_zeros((uint8_t)(VAL)))
145#define LEADING_ZEROS(VAL) (_count_leading_zeros(VAL))
146
147#define LEADING_ONES_64(VAL) _count_leading_ones((uint64_t)(VAL))
148#define LEADING_ONES_32(VAL) _count_leading_ones((uint32_t)(VAL))
149#define LEADING_ONES_16(VAL) _count_leading_ones((uint16_t)(VAL))
150#define LEADING_ONES_8(VAL) _count_leading_ones((uint8_t)(VAL))
151#define LEADING_ONES(VAL) _count_leading_ones(VAL)
152// math operation
153// sign-extended for signed version
154// extend different return type (8, 16, 32) + (S, U)
155// Saturate a value at a certain bitwidth, signed and unsigned versions
156// Format is as followed: SATURATE_VAL_{saturation_sign}_{return_type}
157// for example
158// SATURATE_VAL_U_8U(8,300) will return uint8_t with value of 255(0xff)
159// SATURATE_VAL_S_32S(5,-48) will return int32_t with value of -16(0x10)
160// note that negative value can cast to unsigned return type using native uint(int) cast
161// so SATURATE_VAL_S_8U(5,-40) will have value 0'b1110000 which is in turn 224 in uint8_t
162
163template <typename T>
164constexpr T bitmask(const uint32_t width)
165{
166 ASSERT(width <= sizeof(T) * 8);
167 return width == sizeof(T) * 8 ? static_cast<T>(std::numeric_limits<uintmax_t>::max())
168 : (static_cast<T>(1) << width) - 1;
169}
170
171template <typename T>
172constexpr T minval(const uint32_t width)
173{
174 ASSERT(width <= sizeof(T) * 8);
175 return std::is_signed<T>::value ? -(static_cast<T>(1) << (width - 1)) : 0;
176}
177
178template <typename T>
179constexpr T maxval(const uint32_t width)
180{
181 ASSERT(width <= sizeof(T) * 8);
182 return bitmask<T>(width - std::is_signed<T>::value);
183}
184
185template <typename T>
186constexpr T saturate(const uint32_t width, const intmax_t value)
187{
188 // clang-format off
189 return static_cast<T>(
190 std::min(
191 std::max(
192 value,
193 static_cast<intmax_t>(minval<T>(width))
194 ),
195 static_cast<intmax_t>(maxval<T>(width))
196 )
197 );
198 // clang-format on
199}
200
James Ward24dbc422022-10-19 12:20:31 +0100201inline void float_trunc_bytes(float* src)
202{
203 /* Set the least significant two bytes to zero for the input float value.*/
204 char src_as_bytes[sizeof(float)];
205 memcpy(src_as_bytes, src, sizeof(float));
206
207 if (g_func_config.float_is_big_endian)
208 {
209 src_as_bytes[2] = '\000';
210 src_as_bytes[3] = '\000';
211 }
212 else
213 {
214 src_as_bytes[0] = '\000';
215 src_as_bytes[1] = '\000';
216 }
217
218 memcpy(src, &src_as_bytes, sizeof(float));
219}
220
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000221inline void truncateFloatToBFloat(float* src, int64_t size)
222{
James Ward24dbc422022-10-19 12:20:31 +0100223 /* Set the least significant two bytes to zero for each float
224 value in the input src buffer. */
225 ASSERT_MEM(src);
226 ASSERT_MSG(size > 0, "Size of src (representing number of values in src) must be a positive integer.");
227 for (; size != 0; src++, size--)
228 {
229 float_trunc_bytes(src);
230 }
231}
232
233inline bool checkValidBFloat(float src)
234{
235 /* Checks if the least significant two bytes are zero. */
James Ward24dbc422022-10-19 12:20:31 +0100236 char src_as_bytes[sizeof(float)];
237 memcpy(src_as_bytes, &src, sizeof(float));
238
239 if (g_func_config.float_is_big_endian)
240 {
241 return (src_as_bytes[2] == '\000' && src_as_bytes[3] == '\000');
242 }
243 else
244 {
245 return (src_as_bytes[0] == '\000' && src_as_bytes[1] == '\000');
246 }
247}
248
249inline bool float_is_big_endian()
250{
251 /* Compares float values 1.0 and -1.0 by checking whether the
252 negation causes the first or the last byte to change.
253 First byte changing would indicate the float representation
254 is big-endian.*/
255 float f = 1.0;
256 char f_as_bytes[sizeof(float)];
257 memcpy(f_as_bytes, &f, sizeof(float));
258 f = -f;
259 char f_neg_as_bytes[sizeof(float)];
260 memcpy(f_neg_as_bytes, &f, sizeof(float));
261 return f_as_bytes[0] != f_neg_as_bytes[0];
262}
263
Tai Lya4d748b2023-03-28 22:06:56 +0000264template <TOSA_REF_TYPE Dtype>
James Ward24dbc422022-10-19 12:20:31 +0100265float fpTrunc(float f_in)
266{
Tai Lya4d748b2023-03-28 22:06:56 +0000267 /* Truncates a float value based on the TOSA_REF_TYPE it represents.*/
James Ward24dbc422022-10-19 12:20:31 +0100268 switch (Dtype)
269 {
Tai Lya4d748b2023-03-28 22:06:56 +0000270 case TOSA_REF_TYPE_BF16:
James Ward24dbc422022-10-19 12:20:31 +0100271 truncateFloatToBFloat(&f_in, 1);
272 break;
Tai Lya4d748b2023-03-28 22:06:56 +0000273 case TOSA_REF_TYPE_FP16:
James Wardee256692022-11-15 11:36:47 +0000274 // Cast to temporary float16 value before casting back to float32
275 {
276 half_float::half h = half_float::half_cast<half_float::half, float>(f_in);
277 f_in = half_float::half_cast<float, half_float::half>(h);
278 break;
279 }
Tai Lya4d748b2023-03-28 22:06:56 +0000280 case TOSA_REF_TYPE_FP32:
James Ward24dbc422022-10-19 12:20:31 +0100281 // No-op for fp32
282 break;
283 default:
Tai Lya4d748b2023-03-28 22:06:56 +0000284 ASSERT_MSG(false, "TOSA_REF_TYPE %s should not be float-truncated.", EnumNameTOSAREFTYPE(Dtype));
James Ward24dbc422022-10-19 12:20:31 +0100285 }
286 return f_in;
287}
288
Eric Kunzee5e26762020-10-13 16:11:07 -0700289#endif /* _ARITH_UTIL_H */