blob: a50ab4362f6f596014fee852ae0ecf5e246b895a [file] [log] [blame]
//
// This confidential and proprietary software may be used only as
// authorised by a licensing agreement from ARM Limited
// (C) COPYRIGHT 2020-2024 ARM Limited
// ALL RIGHTS RESERVED
// The entire notice above must be reproduced on all authorised
// copies and copies may only be made to the extent permitted
// by a licensing agreement from ARM Limited.
in_t apply_add_s<in_t>(in_t a, in_t b) {
if (is_floating_point<in_t>()) return a + b;
int64_t c = sign_extend<int64_t>(a) + sign_extend<int64_t>(b);
REQUIRE(c >= minimum_s<in_t>() && c <= maximum_s<in_t>());
return static_cast<in_t>(c);
}
in_t apply_add_u<in_t>(in_t a, in_t b) {
if (is_floating_point<in_t>()) return a + b;
uint64_t c = zero_extend<uint64_t>(a) + zero_extend<uint64_t>(b);
REQUIRE(c >= minimum_u<in_u_t>() && c <= maximum_u<in_u_t>());
return truncate<in_t>(c);
}
in_t apply_arith_rshift<in_t>(in_t a, in_t b) {
int32_t c = sign_extend<int32_t>(a) >> sign_extend<int32_t>(b);
return static_cast<in_t>(c);
}
in_t apply_intdiv_s<in_t>(in_t a, in_t b) {
int64_t c = sign_extend<int64_t>(a) / sign_extend<int64_t>(b);
REQUIRE(c >= minimum_s<in_t>() && c <= maximum_s<in_t>());
return static_cast<in_t>(c);
}
// return input value rounded up to nearest integer
in_t apply_ceil<in_t>(in_t input);
in_t apply_clip_s<in_t>(in_t value, in_t min_val, in_t max_val) {
if (is_floating_point<in_t>()) {
REQUIRE(min_val <= max_val);
}
else {
REQUIRE(sign_extend<int64_t>(min_val) <= sign_extend<int64_t>(max_val));
}
value = apply_max_s<in_t>(value, min_val);
value = apply_min_s<in_t>(value, max_val);
return value;
}
in_t apply_clip_u<in_t>(in_t value, in_t min_val, in_t max_val) {
REQUIRE(zero_extend<int64_t>(min_val) <= zero_extend<int64_t>(max_val));
value = apply_max_u<in_t>(value, min_val);
value = apply_min_u<in_t>(value, max_val);
return value;
}
// return e to the power input
in_t apply_exp<in_t>(in_t input);
// return input value rounded down to nearest integer
in_t apply_floor<in_t>(in_t input);
// return the natural logarithm of input
in_t apply_log_positive_input<in_t>(in_t input);
in_t apply_log<in_t>(in_t input) {
if (input == 0) {
return -INFINITY;
}
else if (input < 0) {
return NaN;
}
return apply_log_positive_input(input);
}
in_t apply_logical_rshift<in_t>(in_t a, in_t b) {
uint64_t c = zero_extend<uint32_t>(a) >> zero_extend<uint32_t>(b);
return static_cast<in_t>(c);
}
in_t apply_max_s<in_t>(in_t a, in_t b) {
if (is_floating_point<in_t>()) {
if (isNaN(a) || isNaN(b)) {
return NaN;
}
if (a >= b) return a; else return b;
}
// Integer version
if (sign_extend<int64_t>(a) >= sign_extend<int64_t>(b)) return a; else return b;
}
in_t apply_max_u<in_t>(in_t a, in_t b) {
if (zero_extend<uint64_t>(a) >= zero_extend<int64_t>(b)) return a; else return b;
}
in_t apply_min_s<in_t>(in_t a, in_t b) {
if (is_floating_point<in_t>()) {
if (isNaN(a) || isNaN(b)) {
return NaN;
}
if (a < b) return a; else return b;
}
// Integer version
if (sign_extend<int64_t>(a) < sign_extend<int64_t>(b)) return a; else return b;
}
in_t apply_min_u<in_t>(in_t a, in_t b) {
if (zero_extend<int64_t>(a) < zero_extend<int64_t>(b)) return a; else return b;
}
in_t apply_mul_s<in_t>(in_t a, in_t b) {
if (is_floating_point<in_t>()) return a * b;
int64_t c = sign_extend<int64_t>(a) * sign_extend<int64_t>(b);
return static_cast<in_t>(c);
}
in_t apply_pow<in_t>(in_t a, in_t b) {
return a ** b; // a raised to the power b
}
// return the square root of input
in_t apply_sqrt<in_t>(in_t input);
in_t apply_sub_s<in_t>(in_t a, in_t b) {
if (is_floating_point<in_t>()) return a - b;
int64_t c = sign_extend<int64_t>(a) - sign_extend<int64_t>(b);
REQUIRE(c >= minimum_s<in_t>() && c <= maximum_s<in_t>());
return static_cast<in_t>(c);
}
in_t apply_sub_u<in_t>(in_t a, in_t b) {
uint64_t c = zero_extend<uint64_t>(a) - zero_extend<uint64_t>(b);
REQUIRE(c >= minimum_u<in_u_t>() && c <= maximum_u<in_u_t>());
return truncate<in_t>(c);
}
int32_t count_leading_zeros(int32_t a) {
int32_t acc = 32;
if (a != 0) {
uint32_t mask;
mask = 1 << (32 - 1); // width of int32_t - 1
acc = 0;
while ((mask & a) == 0) {
mask = mask >> 1;
acc = acc + 1;
}
}
return acc;
}