Take into account of `output_unsigned` in rescale operation

Set QMin and QMax based on the value of attribute `output_unsigned`.

Change-Id: I7f21f3edd7311295285fb3988b3c800de114777a
Signed-off-by: TatWai Chong <tatwai.chong@arm.com>
diff --git a/reference_model/include/dtype.h b/reference_model/include/dtype.h
index 3e8bdf5..a283f39 100644
--- a/reference_model/include/dtype.h
+++ b/reference_model/include/dtype.h
@@ -145,6 +145,35 @@
     return TOSA_REF_TYPE_UNKNOWN;
 }
 
+template <TOSA_REF_TYPE Dtype>
+bool IsSignedInt()
+{
+    switch (Dtype)
+    {
+        case TOSA_REF_TYPE_INT4:
+        case TOSA_REF_TYPE_INT8:
+        case TOSA_REF_TYPE_INT16:
+        case TOSA_REF_TYPE_INT32:
+        case TOSA_REF_TYPE_INT48:
+            return true;
+
+        case TOSA_REF_TYPE_UINT8:
+        case TOSA_REF_TYPE_UINT16:
+            return false;
+
+        case TOSA_REF_TYPE_BOOL:
+        case TOSA_REF_TYPE_FP32:
+        case TOSA_REF_TYPE_FP16:
+        case TOSA_REF_TYPE_BF16:
+        case TOSA_REF_TYPE_SHAPE:
+        case TOSA_REF_TYPE_FP8E4M3:
+        case TOSA_REF_TYPE_FP8E5M2:
+        default:
+            FATAL_ERROR("dtype is not an integer type");
+            break;
+    }
+}
+
 };    // namespace TosaReference
 
 #endif
diff --git a/reference_model/src/arith_util.h b/reference_model/src/arith_util.h
index f0d184c..fee9fef 100644
--- a/reference_model/src/arith_util.h
+++ b/reference_model/src/arith_util.h
@@ -22,6 +22,7 @@
  *      fix point arithmetic
  *      fp16 type conversion(in binary translation)
  *      fp16 arithmetic (disguised with fp32 now)
+ *    and include the arithmetic helpers listed in Section 4.3.1. of the spec
  */
 
 #ifndef ARITH_UTIL_H
@@ -35,6 +36,7 @@
 #include "func_debug.h"
 #include "half.hpp"
 #include "inttypes.h"
+#include "ops/template_types.h"
 #include <bitset>
 #include <cassert>
 #include <limits>
@@ -247,4 +249,72 @@
     return f_in;
 }
 
+// return the maximum value when interpreting type T as a signed value.
+template <TOSA_REF_TYPE Dtype>
+int32_t getSignedMaximum()
+{
+    if (Dtype == TOSA_REF_TYPE_INT8 || Dtype == TOSA_REF_TYPE_UINT8)
+        return GetQMax<TOSA_REF_TYPE_INT8>::value;
+
+    if (Dtype == TOSA_REF_TYPE_INT16 || Dtype == TOSA_REF_TYPE_UINT16)
+        return GetQMax<TOSA_REF_TYPE_INT16>::value;
+
+    if (Dtype == TOSA_REF_TYPE_INT32)
+        return GetQMax<TOSA_REF_TYPE_INT32>::value;
+
+    FATAL_ERROR("Get maximum_s for the dtype input is not supported");
+    return 0;
+}
+
+// return the minimum value when interpreting type T as a signed value.
+template <TOSA_REF_TYPE Dtype>
+int32_t getSignedMinimum()
+{
+    if (Dtype == TOSA_REF_TYPE_INT8 || Dtype == TOSA_REF_TYPE_UINT8)
+        return GetQMin<TOSA_REF_TYPE_INT8>::value;
+
+    if (Dtype == TOSA_REF_TYPE_INT16 || Dtype == TOSA_REF_TYPE_UINT16)
+        return GetQMin<TOSA_REF_TYPE_INT16>::value;
+
+    if (Dtype == TOSA_REF_TYPE_INT32)
+        return GetQMin<TOSA_REF_TYPE_INT32>::value;
+
+    FATAL_ERROR("Get minimum_s for the dtype input is not supported");
+    return 0;
+}
+
+// return the maximum value when interpreting type T as an unsigned value.
+template <TOSA_REF_TYPE Dtype>
+int32_t getUnsignedMaximum()
+{
+    if (Dtype == TOSA_REF_TYPE_INT8 || Dtype == TOSA_REF_TYPE_UINT8)
+        return GetQMax<TOSA_REF_TYPE_UINT8>::value;
+
+    if (Dtype == TOSA_REF_TYPE_INT16 || Dtype == TOSA_REF_TYPE_UINT16)
+        return GetQMax<TOSA_REF_TYPE_UINT16>::value;
+
+    if (Dtype == TOSA_REF_TYPE_INT32)
+        return std::numeric_limits<uint32_t>::max();
+
+    FATAL_ERROR("Get maximum_u for the dtype input is not supported");
+    return 0;
+}
+
+// return the minimum value when interpreting type T as an unsigned value.
+template <TOSA_REF_TYPE Dtype>
+int32_t getUnsignedMinimum()
+{
+    if (Dtype == TOSA_REF_TYPE_INT8 || Dtype == TOSA_REF_TYPE_UINT8)
+        return GetQMin<TOSA_REF_TYPE_UINT8>::value;
+
+    if (Dtype == TOSA_REF_TYPE_INT16 || Dtype == TOSA_REF_TYPE_UINT16)
+        return GetQMin<TOSA_REF_TYPE_UINT16>::value;
+
+    if (Dtype == TOSA_REF_TYPE_INT32)
+        return std::numeric_limits<uint32_t>::min();
+
+    FATAL_ERROR("Get minimum_u for the dtype input is not supported");
+    return 0;
+}
+
 #endif /* _ARITH_UTIL_H */
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index d58cfeb..835b656 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -148,6 +148,9 @@
     bool input_unsigned  = attribute->input_unsigned();
     bool output_unsigned = attribute->output_unsigned();
 
+    int32_t QMin = output_unsigned ? getUnsignedMinimum<OutDtype>() : getSignedMinimum<OutDtype>();
+    int32_t QMax = output_unsigned ? getUnsignedMaximum<OutDtype>() : getSignedMaximum<OutDtype>();
+
     // reshape [d0, d1, ..., dn] into [d0 * d1 ..., dn]
     Eigen::array<Eigen::Index, 2> shape_2d;
     shape_2d[0] = 1;
@@ -200,13 +203,12 @@
         {
             for (int32_t i = 0; i < shape_2d[1]; i++)
             {
-                begin                         = Eigen::array<Eigen::Index, 2>({ 0, i });
-                curr_channel_slice_prescaled  = input_reshaped.slice(begin, size);
-                channel_multiplier            = multiplier[i];
-                channel_shift                 = shift[i];
-                curr_channel_slice_postscaled = curr_channel_slice_prescaled.unaryExpr(
-                    [input_zp, output_zp, channel_multiplier, channel_shift, double_round, scale32, input_unsigned,
-                     output_unsigned](InEigenType in_val) -> OutEigenType {
+                begin                        = Eigen::array<Eigen::Index, 2>({ 0, i });
+                curr_channel_slice_prescaled = input_reshaped.slice(begin, size);
+                channel_multiplier           = multiplier[i];
+                channel_shift                = shift[i];
+                curr_channel_slice_postscaled =
+                    curr_channel_slice_prescaled.unaryExpr([=](InEigenType in_val) -> OutEigenType {
                         int64_t input_zp_shifted;
                         if (input_unsigned)
                         {
@@ -293,78 +295,79 @@
         int32_t tensor_shift      = shift[0];
         try
         {
-            output_2d =
-                input_reshaped.unaryExpr([input_zp, output_zp, tensor_multiplier, tensor_shift, double_round, scale32,
-                                          input_unsigned, output_unsigned](InEigenType in_val) -> OutEigenType {
-                    int64_t input_zp_shifted;
-                    if (input_unsigned)
+            output_2d = input_reshaped.unaryExpr([=](InEigenType in_val) -> OutEigenType {
+                int64_t input_zp_shifted;
+                if (input_unsigned)
+                {
+                    int64_t in_val64;
+                    int64_t in_zp64;
+                    switch (GetNumBits<InDtype>::value)
                     {
-                        int64_t in_val64;
-                        int64_t in_zp64;
-                        switch (GetNumBits<InDtype>::value)
-                        {
-                            case 8:
-                                in_val64 = zero_extend(static_cast<int8_t>(in_val));
-                                in_zp64  = zero_extend(static_cast<int8_t>(input_zp));
-                                break;
-                            case 16:
-                                in_val64 = zero_extend(static_cast<int16_t>(in_val));
-                                in_zp64  = zero_extend(static_cast<int16_t>(input_zp));
-                                break;
-                            default:
-                                in_val64 = static_cast<int64_t>(in_val);
-                                in_zp64  = static_cast<int64_t>(input_zp);
-                                break;
-                        }
-                        input_zp_shifted = in_val64 - in_zp64;
+                        case 8:
+                            in_val64 = zero_extend(static_cast<int8_t>(in_val));
+                            in_zp64  = zero_extend(static_cast<int8_t>(input_zp));
+                            break;
+                        case 16:
+                            in_val64 = zero_extend(static_cast<int16_t>(in_val));
+                            in_zp64  = zero_extend(static_cast<int16_t>(input_zp));
+                            break;
+                        default:
+                            in_val64 = static_cast<int64_t>(in_val);
+                            in_zp64  = static_cast<int64_t>(input_zp);
+                            break;
                     }
-                    else
-                    {
-                        input_zp_shifted = in_val - input_zp;
-                    }
-                    int32_t scaled;
-                    if (scale32)
-                        scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier,
-                                                                          tensor_shift, double_round);
-                    else
-                        scaled =
-                            TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, tensor_multiplier, tensor_shift);
+                    input_zp_shifted = in_val64 - in_zp64;
+                }
+                else
+                {
+                    input_zp_shifted = in_val - input_zp;
+                }
+                int32_t scaled;
+                if (scale32)
+                    scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier, tensor_shift,
+                                                                      double_round);
+                else
+                    scaled =
+                        TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, tensor_multiplier, tensor_shift);
 
-                    int64_t output_zp_extended;
-                    if (output_unsigned)
+                int64_t output_zp_extended;
+                if (output_unsigned)
+                {
+                    switch (GetNumBits<OutDtype>::value)
                     {
-                        switch (GetNumBits<OutDtype>::value)
-                        {
-                            case 8:
-                                output_zp_extended = zero_extend(static_cast<int8_t>(output_zp));
-                                break;
-                            case 16:
-                                output_zp_extended = zero_extend(static_cast<int16_t>(output_zp));
-                                break;
-                            default:
-                                output_zp_extended = static_cast<int64_t>(output_zp);
-                                break;
-                        }
+                        case 8:
+                            output_zp_extended = zero_extend(static_cast<int8_t>(output_zp));
+                            break;
+                        case 16:
+                            output_zp_extended = zero_extend(static_cast<int16_t>(output_zp));
+                            break;
+                        default:
+                            output_zp_extended = static_cast<int64_t>(output_zp);
+                            break;
                     }
-                    else
-                    {
-                        output_zp_extended = static_cast<int64_t>(output_zp);
-                    }
-                    int64_t res_in_64     = static_cast<int64_t>(scaled) + output_zp_extended;
-                    int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
-                    int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min());
-                    if (res_in_64 > i32_max_in_64 || res_in_64 < i32_min_in_64)
-                    {
-                        std::string desc = "scaling result [" + std::to_string(scaled) + "] plus output_zp [" +
-                                           std::to_string(output_zp) + "] not in i32 range";
-                        throw desc;
-                    }
+                }
+                else
+                {
+                    output_zp_extended = static_cast<int64_t>(output_zp);
+                }
+                int64_t res_in_64     = static_cast<int64_t>(scaled) + output_zp_extended;
+                int64_t i32_max_in_64 = IsSignedInt<OutDtype>()
+                                            ? static_cast<int64_t>(std::numeric_limits<int32_t>::max())
+                                            : static_cast<int64_t>(std::numeric_limits<uint32_t>::max());
+                int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min());
 
-                    OutEigenType out_val = static_cast<OutEigenType>(res_in_64);
-                    out_val              = std::max<OutEigenType>(out_val, QMin);
-                    out_val              = std::min<OutEigenType>(out_val, QMax);
-                    return out_val;
-                });
+                if (res_in_64 > i32_max_in_64 || res_in_64 < i32_min_in_64)
+                {
+                    std::string desc = "scaling result [" + std::to_string(scaled) + "] plus output_zp [" +
+                                       std::to_string(output_zp) + "] not in i32 range";
+                    throw desc;
+                }
+
+                OutEigenType out_val = static_cast<OutEigenType>(res_in_64);
+                out_val              = std::max<OutEigenType>(out_val, QMin);
+                out_val              = std::min<OutEigenType>(out_val, QMax);
+                return out_val;
+            });
         }
         catch (std::string desc)
         {
diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h
index a06dccc..da5537e 100644
--- a/reference_model/src/ops/type_conversion.h
+++ b/reference_model/src/ops/type_conversion.h
@@ -43,9 +43,6 @@
     using TMultiplierI32 = Eigen::Tensor<I32EigenType, 1>;
     using TShift         = Eigen::Tensor<I8EigenType, 1>;
 
-    static constexpr int32_t QMin = GetQMin<OutDtype>::value;
-    static constexpr int32_t QMax = GetQMax<OutDtype>::value;
-
 protected:
     TosaRescaleAttribute* attribute;
     TosaReference::TensorTemplate<TIn>* in;