Update apply_scale_32()

Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: Ida8e3a17d74e5d6379b2244896ddf9e295d0ecc9
diff --git a/reference_model/src/quant_util.h b/reference_model/src/quant_util.h
index 3638b3b..3b58b66 100644
--- a/reference_model/src/quant_util.h
+++ b/reference_model/src/quant_util.h
@@ -26,18 +26,16 @@
 namespace TosaReference
 {
 
-template <DType AccDType>
 class QuantUtil
 {
 public:
-    using T = typename GetEigenType<AccDType>::type;
-
     static void reciprocal_scale(int32_t value,
                                  // Output
                                  int32_t& multiplier,
                                  int32_t& shift)
     {
-        ASSERT_MSG(value > 0, "AvgPool2d reciprocal_scale() error: # of elements should be > 1 but is %d", value);
+        ASSERT_MSG(value > 0,
+                   "AvgPool2d reciprocal_scale() error: # of elements should be > 1 but is %d", value);
         uint32_t value_u32 = (uint32_t)value;
         int32_t k          = 32 - LEADING_ZEROS_32(value_u32 - 1);    // (1<<k)/2 < value <= (1<<k)
         int64_t numerator  = ((1L << 30) + 1) << k;
@@ -45,33 +43,23 @@
         shift              = 30 + k;
     }
 
-    static int32_t apply_scale(T value, int32_t multiplier, int32_t shift, bool enabled_adjusted_rounding = true)
+    static int32_t apply_scale_32(int32_t value, int32_t multiplier, int32_t shift, bool double_round = true)
     {
-        if (AccDType == DType_FLOAT)
+        ASSERT_MSG(multiplier >= 0, "apply_scale_32() error: multiplier should >= 0 but is %d", multiplier);
+        ASSERT_MSG(shift >= 2 && shift <= 62, "apply_scale_32() error: shift should be within [2, 62] but is %d",
+                   shift);
+        int64_t round = 1L << (shift - 1);
+        if (double_round)
         {
-            return value;
-        }
-        ASSERT_MSG(multiplier >= 0, "apply_scale() error: multiplier should >= 0 but is %d", multiplier);
-        int64_t round = (shift > 0) ? (1L << (shift - 1)) : 0;
-        if (enabled_adjusted_rounding)
-        {
-            if (AccDType != DType_INT48)
-            {
-                if (shift > 31 && value >= 0)
-                    round += (1L << 30);
-                if (shift > 31 && value < 0)
-                    round -= (1L << 30);
-            }
-            else
-            {    // input data could be int16, which leads to 48 bits accumulator
-                ASSERT_MSG(multiplier < (1 << 15), "apply_scale() error: multiplier should <= %d in 48 bit mode",
-                           (1 << 15));
-            }
+            if (shift > 31 && value >= 0)
+                round += (1L << 30);
+            if (shift > 31 && value < 0)
+                round -= (1L << 30);
         }
         int64_t result = (int64_t)value * multiplier + round;
         result         = result >> shift;
         ASSERT_MSG(result >= -(1L << 31) && result < (1L << 31),
-                   "apply_scale() error: scaled result exceed int32 numeric range");
+                   "apply_scale_32() error: scaled result exceed int32 numeric range");
         return static_cast<int32_t>(result);
     }
 };