[ref model] Change RescaleOp attrs to inputs

This patch implements changes required for RescaleOp's
multiplier and shift changing from attributes to inputs

Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: I178919727e3220c749dad0ebce141e695868fee0
diff --git a/examples/test_add_1x4x4x4_f32/model.pb b/examples/test_add_1x4x4x4_f32/model.pb
index f1af89b..026dcea 100644
--- a/examples/test_add_1x4x4x4_f32/model.pb
+++ b/examples/test_add_1x4x4x4_f32/model.pb
@@ -92,5 +92,5 @@
   }
 }
 versions {
-  producer: 1597
+  producer: 1581
 }
diff --git a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa
index e3bc565..20e1333 100644
--- a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa
+++ b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa
Binary files differ
diff --git a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa
index 90e8e86..d55d5d6 100644
--- a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa
+++ b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa
Binary files differ
diff --git a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/model.pb b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/model.pb
index bf7ee3a..ce59e78 100644
--- a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/model.pb
+++ b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/model.pb
@@ -137,5 +137,5 @@
   }
 }
 versions {
-  producer: 1597
+  producer: 1581
 }
diff --git a/examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa b/examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa
index e9fb643..87bafd1 100644
--- a/examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa
+++ b/examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa
Binary files differ
diff --git a/examples/test_lstm_stateful_13x21x3_f32/flatbuffer-tflite/test_lstm_stateful_13x21x3_f32.tosa b/examples/test_lstm_stateful_13x21x3_f32/flatbuffer-tflite/test_lstm_stateful_13x21x3_f32.tosa
index f9606a1..deaca6e 100644
--- a/examples/test_lstm_stateful_13x21x3_f32/flatbuffer-tflite/test_lstm_stateful_13x21x3_f32.tosa
+++ b/examples/test_lstm_stateful_13x21x3_f32/flatbuffer-tflite/test_lstm_stateful_13x21x3_f32.tosa
Binary files differ
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index 5dbc7bd..d58cfeb 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -35,7 +35,7 @@
 OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
     : GraphNode(sgt_, Op_RESCALE, id_)
 {
-    setRequiredOperands(1, 1);
+    setRequiredOperands(3, 1);
     INIT_ATTRIBUTE(Rescale);
 }
 
@@ -68,6 +68,20 @@
 
     ASSERT_MEM(in && out);
 
+    multiplierI32 = dynamic_cast<TosaReference::TensorTemplate<TMultiplierI32>*>(inputs[1]);
+    multiplierI16 = dynamic_cast<TosaReference::TensorTemplate<TMultiplierI16>*>(inputs[1]);
+    shift         = dynamic_cast<TosaReference::TensorTemplate<TShift>*>(inputs[2]);
+    ASSERT_MEM(shift);
+
+    if (attribute->scale32())
+    {
+        ASSERT_MEM(multiplierI32);
+    }
+    else
+    {
+        ASSERT_MEM(multiplierI16);
+    }
+
     if ((InDtype != TOSA_REF_TYPE_INT8) && (InDtype != TOSA_REF_TYPE_UINT8) && (InDtype != TOSA_REF_TYPE_UINT16) &&
         (attribute->input_zp() != 0))
     {
@@ -124,15 +138,15 @@
 template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
 int OpRescale<Rank, InDtype, OutDtype>::eval()
 {
-    int32_t input_zp                = attribute->input_zp();
-    int32_t output_zp               = attribute->output_zp();
-    std::vector<int32_t> multiplier = attribute->multiplier();
-    std::vector<int32_t> shift      = attribute->shift();
-    bool scale32                    = attribute->scale32();
-    bool double_round               = attribute->double_round();
-    bool per_channel                = attribute->per_channel();
-    bool input_unsigned             = attribute->input_unsigned();
-    bool output_unsigned            = attribute->output_unsigned();
+    int32_t input_zp  = attribute->input_zp();
+    int32_t output_zp = attribute->output_zp();
+    std::vector<int32_t> multiplier;
+    std::vector<int32_t> shift;
+    bool scale32         = attribute->scale32();
+    bool double_round    = attribute->double_round();
+    bool per_channel     = attribute->per_channel();
+    bool input_unsigned  = attribute->input_unsigned();
+    bool output_unsigned = attribute->output_unsigned();
 
     // reshape [d0, d1, ..., dn] into [d0 * d1 ..., dn]
     Eigen::array<Eigen::Index, 2> shape_2d;
@@ -153,6 +167,28 @@
 
     ETensor2<OutEigenType> output_2d(shape_2d);
 
+    if (scale32)
+    {
+        auto multiplier_val = this->multiplierI32->getTensor();
+        for (int i = 0; i < multiplier_val.size(); i++)
+        {
+            multiplier.push_back(static_cast<int32_t>(multiplier_val(i)));
+        }
+    }
+    else
+    {
+        auto multiplier_val = this->multiplierI16->getTensor();
+        for (int i = 0; i < multiplier_val.size(); i++)
+        {
+            multiplier.push_back(static_cast<int32_t>(multiplier_val(i)));
+        }
+    }
+    auto shift_val = this->shift->getTensor();
+    for (int i = 0; i < shift_val.size(); i++)
+    {
+        shift.push_back(static_cast<int32_t>(shift_val(i)));
+    }
+
     if (per_channel)
     {
         ETensor2<InEigenType> curr_channel_slice_prescaled;
diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h
index 75f244d..a06dccc 100644
--- a/reference_model/src/ops/type_conversion.h
+++ b/reference_model/src/ops/type_conversion.h
@@ -32,10 +32,16 @@
     virtual int checkTensorAttributes() final;
     virtual int eval() final;
 
-    using InEigenType  = typename GetEigenType<InDtype>::type;
-    using OutEigenType = typename GetEigenType<OutDtype>::type;
-    using TIn          = Eigen::Tensor<InEigenType, Rank>;
-    using TOut         = Eigen::Tensor<OutEigenType, Rank>;
+    using InEigenType    = typename GetEigenType<InDtype>::type;
+    using OutEigenType   = typename GetEigenType<OutDtype>::type;
+    using TIn            = Eigen::Tensor<InEigenType, Rank>;
+    using TOut           = Eigen::Tensor<OutEigenType, Rank>;
+    using I8EigenType    = typename GetEigenType<TOSA_REF_TYPE::TOSA_REF_TYPE_INT8>::type;
+    using I16EigenType   = typename GetEigenType<TOSA_REF_TYPE::TOSA_REF_TYPE_INT16>::type;
+    using I32EigenType   = typename GetEigenType<TOSA_REF_TYPE::TOSA_REF_TYPE_INT32>::type;
+    using TMultiplierI16 = Eigen::Tensor<I16EigenType, 1>;
+    using TMultiplierI32 = Eigen::Tensor<I32EigenType, 1>;
+    using TShift         = Eigen::Tensor<I8EigenType, 1>;
 
     static constexpr int32_t QMin = GetQMin<OutDtype>::value;
     static constexpr int32_t QMax = GetQMax<OutDtype>::value;
@@ -44,6 +50,9 @@
     TosaRescaleAttribute* attribute;
     TosaReference::TensorTemplate<TIn>* in;
     TosaReference::TensorTemplate<TOut>* out;
+    TosaReference::TensorTemplate<TMultiplierI16>* multiplierI16;
+    TosaReference::TensorTemplate<TMultiplierI32>* multiplierI32;
+    TosaReference::TensorTemplate<TShift>* shift;
 };
 
 template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 592c491..cbfffae 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -744,6 +744,10 @@
                         arr = np.int64(argsDict["fixed_data"][idx])
                     elif dtype == DType.INT8:
                         arr = np.int8(argsDict["fixed_data"][idx])
+                    elif dtype == DType.INT16:
+                        arr = np.int16(argsDict["fixed_data"][idx])
+                    elif dtype == DType.INT32:
+                        arr = np.int32(argsDict["fixed_data"][idx])
                     else:
                         assert False, "Unsupported fixed_data type"
                 else:
@@ -1060,6 +1064,26 @@
         )
 
     @staticmethod
+    def tvgRescale(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
+        scale32 = argsDict["scale"]
+        multiplier_arr = argsDict["multiplier"]
+        shift_arr = argsDict["shift"]
+
+        if scale32:
+            dtypeList[1] = DType.INT32
+        else:
+            dtypeList[1] = DType.INT16
+        shapeList[1] = [len(multiplier_arr)]
+        dtypeList[2] = DType.INT8
+        shapeList[2] = [len(shift_arr)]
+        # Create a new list for the pre-generated data in argsDict["fixed_data"]
+        argsDict["fixed_data"] = [None, multiplier_arr, shift_arr]
+
+        return TosaTensorValuesGen.tvgLazyGenDefault(
+            testGen, opName, dtypeList, shapeList, argsDict, error_name
+        )
+
+    @staticmethod
     def tvgPad(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
         # argsDict["pad"] is 2D array, need to flatten it to get list of values
         pad_values = argsDict["pad"].flatten()
@@ -2842,6 +2866,43 @@
                             # Illegal condition.  ERROR_IF(!scale32 && double_round)
                             continue
 
+                        if per_channel:
+                            nc = shapeList[0][-1]
+                        else:
+                            nc = 1
+
+                        in_type_width = gtu.dtypeWidth(inDtype)
+                        out_type_width = gtu.dtypeWidth(outDtype)
+
+                        # Calculate scale based on:
+                        # scale = a *(2^output_width)/(2^input_width))
+
+                        a = np.float32(testGen.rng.random(size=[nc]))
+                        scale_arr = a * np.float32(
+                            (1 << out_type_width) / (1 << in_type_width)
+                        )
+
+                        if scale32:
+                            # Cap the scaling at 2^31 - 1 for scale32
+                            scale_arr = np.clip(
+                                scale_arr, 1.0 / (1 << 31), (1 << 31) - 1
+                            )
+                        else:
+                            # Cap the scaling at 2^15 - 1 for scale16
+                            scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
+
+                        # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
+
+                        multiplier_arr = np.int32(np.zeros(shape=[nc]))
+                        shift_arr = np.int32(np.zeros(shape=[nc]))
+                        for i in range(nc):
+                            (
+                                multiplier_arr[i],
+                                shift_arr[i],
+                            ) = TosaQuantGen.computeMultiplierAndShift(
+                                scale_arr[i], scale32
+                            )
+
                         arg_list.append(
                             (
                                 "out{}_sc{}_dr{}_pc{}".format(
@@ -2855,6 +2916,8 @@
                                     "scale": scale32,
                                     "double_round": double_round,
                                     "per_channel": per_channel,
+                                    "multiplier": multiplier_arr,
+                                    "shift": shift_arr,
                                 },
                             )
                         )
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 978e735..415858c 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -317,13 +317,6 @@
                     "Unknown dtype, cannot convert to string: {}".format(dtype)
                 )
 
-    def typeWidth(self, dtype):
-        """Get the datatype width for data types"""
-        if dtype in gtu.DTYPE_ATTRIBUTES:
-            return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
-        else:
-            raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
-
     def constrictBatchSize(self, shape):
         # Limit the batch size unless an explicit target shape set
         if self.args.max_batch_size and not self.args.target_shapes:
@@ -2130,12 +2123,15 @@
         error_name=None,
         qinfo=None,
     ):
-        assert len(inputs) == 1
+        assert len(inputs) == 3
         val = inputs[0]
+        multiplier_val = inputs[1]
+        shift_val = inputs[2]
         out_dtype = args_dict["output_dtype"]
         scale32 = args_dict["scale"]
         double_round = args_dict["double_round"]
         per_channel = args_dict["per_channel"]
+        shift_arr = args_dict["shift"]
 
         result_tensor = OutputShaper.typeConversionOp(
             self.ser, self.rng, val, out_dtype, error_name
@@ -2146,8 +2142,8 @@
         else:
             nc = 1
 
-        in_type_width = self.typeWidth(val.dtype)
-        out_type_width = self.typeWidth(out_dtype)
+        in_type_width = gtu.dtypeWidth(val.dtype)
+        out_type_width = gtu.dtypeWidth(out_dtype)
 
         input_unsigned = False
         output_unsigned = False
@@ -2198,31 +2194,10 @@
         else:
             output_zp = 0
 
-        # Calculate scale based on:
-        # scale = a *(2^output_width)/(2^input_width))
-
-        a = np.float32(self.rng.random(size=[nc]))
-        scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
-
-        if scale32:
-            pass
-            # Cap the scaling at 2^31 - 1 for scale32
-            scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
-        else:
-            # Cap the scaling at 2^15 - 1 for scale16
-            scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
-
-        # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
-
-        multiplier_arr = np.int32(np.zeros(shape=[nc]))
-        shift_arr = np.int32(np.zeros(shape=[nc]))
         min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
         max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
 
         for i in range(nc):
-            multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
-                scale_arr[i], scale32
-            )
             min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
             max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
 
@@ -2256,7 +2231,7 @@
                 )
 
         # Invalidate Input/Output list for error if checks.
-        input_list = [val.name]
+        input_list = [val.name, multiplier_val.name, shift_val.name]
         output_list = [result_tensor.name]
         pCount, cCount = op["operands"]
         num_operands = pCount + cCount
@@ -2287,8 +2262,8 @@
         attr.RescaleAttribute(
             input_zp,
             output_zp,
-            multiplier_arr,
-            shift_arr,
+            [],
+            [],
             scale32,
             double_round,
             per_channel,
@@ -4809,11 +4784,11 @@
         },
         "rescale": {
             "op": Op.RESCALE,
-            "operands": (1, 0),
+            "operands": (3, 0),
             "build_fcn": (
                 build_rescale,
                 TosaTensorGen.tgBasic,
-                TosaTensorValuesGen.tvgLazyGenDefault,
+                TosaTensorValuesGen.tvgRescale,
                 TosaArgGen.agRescale,
             ),
             "types": [
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index 31a0ff0..384463f 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -55,6 +55,14 @@
     FIXED_DATA = 5
 
 
+def dtypeWidth(dtype):
+    """Get the datatype width for data types"""
+    if dtype in DTYPE_ATTRIBUTES:
+        return DTYPE_ATTRIBUTES[dtype]["width"]
+    else:
+        raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
+
+
 def dtypeIsSupportedByCompliance(dtype):
     """Types supported by the new data generation and compliance flow."""
     if isinstance(dtype, list) or isinstance(dtype, tuple):