Support 16-bit Rescale

Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: Ifc80b83c1abcd08e1b7f8e50f647b74c861bc933
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index 3a610ea..d988c57 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -71,9 +71,9 @@
     int32_t output_zp               = attribute->output_zp();
     std::vector<int32_t> multiplier = attribute->multiplier();
     std::vector<int32_t> shift      = attribute->shift();
-    //bool scale32 = attribute->scale32();
-    bool double_round = attribute->double_round();
-    bool per_channel  = attribute->per_channel();
+    bool scale32                    = attribute->scale32();
+    bool double_round               = attribute->double_round();
+    bool per_channel                = attribute->per_channel();
 
     // reshape [d0, d1, ..., dn] into [d0 * d1 ..., dn]
     Eigen::array<Eigen::Index, 2> shape_2d;
@@ -94,7 +94,6 @@
 
     ETensor2<OutEigenType> output_2d(shape_2d);
 
-    // TODO: pass scale32 in when 16-bit mode implemented
     if (per_channel)
     {
         ETensor2<InEigenType> curr_channel_slice_prescaled;
@@ -110,10 +109,15 @@
             channel_shift                = shift[i];
             curr_channel_slice_postscaled =
                 curr_channel_slice_prescaled.unaryExpr([input_zp, output_zp, channel_multiplier, channel_shift,
-                                                        double_round](InEigenType in_val) -> OutEigenType {
+                                                        double_round, scale32](InEigenType in_val) -> OutEigenType {
                     InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
-                    int32_t scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, channel_multiplier,
-                                                                              channel_shift, double_round);
+                    int32_t scaled;
+                    if (scale32)
+                        scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, channel_multiplier,
+                                                                          channel_shift, double_round);
+                    else
+                        scaled = TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, channel_multiplier,
+                                                                          channel_shift);
                     OutEigenType out_val = (OutEigenType)(scaled + output_zp);
                     out_val              = std::max<OutEigenType>(out_val, QMin);
                     out_val              = std::min<OutEigenType>(out_val, QMax);
@@ -130,16 +134,20 @@
     {
         int32_t tensor_multiplier = multiplier[0];
         int32_t tensor_shift      = shift[0];
-        output_2d                 = input_reshaped.unaryExpr(
-            [input_zp, output_zp, tensor_multiplier, tensor_shift, double_round](InEigenType in_val) -> OutEigenType {
-                InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
-                int32_t scaled       = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier,
-                                                                          tensor_shift, double_round);
-                OutEigenType out_val = (OutEigenType)(scaled + output_zp);
-                out_val              = std::max<OutEigenType>(out_val, QMin);
-                out_val              = std::min<OutEigenType>(out_val, QMax);
-                return out_val;
-            });
+        output_2d = input_reshaped.unaryExpr([input_zp, output_zp, tensor_multiplier, tensor_shift, double_round,
+                                              scale32](InEigenType in_val) -> OutEigenType {
+            InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
+            int32_t scaled;
+            if (scale32)
+                scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier, tensor_shift,
+                                                                  double_round);
+            else
+                scaled = TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, tensor_multiplier, tensor_shift);
+            OutEigenType out_val = (OutEigenType)(scaled + output_zp);
+            out_val              = std::max<OutEigenType>(out_val, QMin);
+            out_val              = std::min<OutEigenType>(out_val, QMax);
+            return out_val;
+        });
     }
 
     // reshape [d0 * d1 ..., dn] back to [d0, d1, ..., dn]
diff --git a/reference_model/src/quant_util.h b/reference_model/src/quant_util.h
index 1784493..f07dd10 100644
--- a/reference_model/src/quant_util.h
+++ b/reference_model/src/quant_util.h
@@ -61,6 +61,19 @@
                    "apply_scale_32() error: scaled result exceed int32 numeric range");
         return static_cast<int32_t>(result);
     }
+
+    static int32_t apply_scale_16(int64_t value, int16_t multiplier, int32_t shift)
+    {
+        ASSERT_MSG(multiplier >= 0, "apply_scale_16() error: multiplier should >= 0 but is %d", multiplier);
+        ASSERT_MSG(value >= -(static_cast<int64_t>(1) << 47) && value < (static_cast<int64_t>(1) << 47),
+                   "apply_scale_16() error: value should be within [-(1^47), 1^47]");
+        int64_t round  = 1L << (shift - 1);
+        int64_t result = value * (int64_t)multiplier + round;
+        result         = result >> shift;
+        ASSERT_MSG(result >= -(1L << 31) && result < (1L << 31),
+                   "apply_scale_16() error: scaled result exceed int32 numeric range");
+        return static_cast<int32_t>(result);
+    }
 };
 
 class TypeChecker
@@ -68,8 +81,8 @@
 public:
     static bool is_integer(DType dtype)
     {
-        if (dtype == DType_INT4 || dtype == DType_INT8 || dtype == DType_UINT8 ||
-            dtype == DType_INT16 || dtype == DType_INT32 || dtype == DType_INT48)
+        if (dtype == DType_INT4 || dtype == DType_INT8 || dtype == DType_UINT8 || dtype == DType_INT16 ||
+            dtype == DType_INT32 || dtype == DType_INT48)
         {
             return true;
         }