Testing support for MUL with shift as input

Always create the shift as a tensor for all types in testing.
In the reference model, set the shift operand to be available for
all types, but only read in the shift tensor for i32.

Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Signed-off-by: TatWai Chong <tatwai.chong@arm.com>
Change-Id: Ia267cbf8b63ca0a9c97b38e8fb4db83eeb8c0538
diff --git a/verif/generator/datagenerator.py b/verif/generator/datagenerator.py
index 743475c..c63a2d5 100644
--- a/verif/generator/datagenerator.py
+++ b/verif/generator/datagenerator.py
@@ -82,6 +82,10 @@
             # Create buffer and initialize to zero
             buffer = (ct.c_int32 * size)(0)
             size_bytes = size * 4
+        elif dtype == "INT8":
+            size_bytes = size
+            # Create buffer of bytes and initialize to zero
+            buffer = (ct.c_ubyte * size_bytes)(0)
         else:
             raise GenerateError(f"Unsupported data type {dtype}")
 
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 0851aca..592c491 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -254,19 +254,16 @@
         return shape_list
 
     @staticmethod
-    def tgBroadcastFuzz(testGen, op, rank, error_name=None):
+    def _get_broadcast_shapes(testGen, num_shapes, rank, error_name=None):
         shape = testGen.makeShape(rank)
-
-        pl, const = op["operands"]
-
         shape_list = []
 
         # Choose one of the inputs to broadcast
         # Note: Simplifies OutputShaper code if we don't change first shape for errors
-        bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
+        bcast_idx = testGen.randInt(0 if error_name is None else 1, num_shapes)
         fuzz_idx = testGen.randInt(0, rank)
 
-        for i in range(pl + const):
+        for i in range(num_shapes):
             shape_bcast = shape.copy()
 
             # To test broadcasting, the chosen fuzz index dimension should not be 1
@@ -295,6 +292,22 @@
         return shape_list
 
     @staticmethod
+    def tgBroadcastFuzz(testGen, op, rank, error_name=None):
+        pl, const = op["operands"]
+        num_shapes = pl + const
+        return TosaTensorGen._get_broadcast_shapes(
+            testGen, num_shapes, rank, error_name
+        )
+
+    @staticmethod
+    def tgMul(testGen, op, rank, error_name=None):
+        # Get broadcast shapes for the first 2 inputs as the 3rd is shift
+        shape_list = TosaTensorGen._get_broadcast_shapes(testGen, 2, rank, error_name)
+        # Add a single dimension tensor for shift
+        shape_list.append([1])
+        return shape_list
+
+    @staticmethod
     def tgConv2D(testGen, op, rank, error_name=None):
         pl, const = op["operands"]
 
@@ -727,7 +740,12 @@
                 # Ignore lazy data gen option and create data array using any range limits
 
                 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
-                    arr = np.int64(argsDict["fixed_data"][idx])
+                    if dtype == DType.SHAPE:
+                        arr = np.int64(argsDict["fixed_data"][idx])
+                    elif dtype == DType.INT8:
+                        arr = np.int8(argsDict["fixed_data"][idx])
+                    else:
+                        assert False, "Unsupported fixed_data type"
                 else:
                     arr = testGen.getRandTensor(shape, dtype, data_range)
                 if roundMode:
@@ -1147,6 +1165,13 @@
             if data_range:
                 argsDict["data_range"] = data_range
 
+            if dtypeList[0] != DType.SHAPE:
+                # Need to supply shift tensor for MUL (not needed for MUL_SHAPE)
+                dtypeList[2] = DType.INT8
+                shapeList[2] = [1]
+                # Create a new list for the pre-generated data in argsDict["fixed_data"]
+                argsDict["fixed_data"] = [None, None, [argsDict["shift"]]]
+
             return TosaTensorValuesGen.tvgLazyGenDefault(
                 testGen, opName, dtypeList, shapeList, argsDict, error_name
             )
@@ -1154,9 +1179,6 @@
             # Integer test
             op = testGen.TOSA_OP_LIST[opName]
             pCount, cCount = op["operands"]
-            assert (
-                pCount == 2 and cCount == 0
-            ), "Op.MUL must have 2 placeholders, 0 consts"
 
             tens_ser_list = []
 
@@ -1213,6 +1235,7 @@
                 b_arr = b_arr // 2
 
             if dtypeList[0] == DType.SHAPE:
+                # MUL_SHAPE with 2 inputs
                 tens_ser_list.append(
                     testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr_64)
                 )
@@ -1220,12 +1243,16 @@
                     testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr_64)
                 )
             else:
+                # MUL with 3 inputs (3rd is shift)
                 tens_ser_list.append(
                     testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
                 )
                 tens_ser_list.append(
                     testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
                 )
+                tens_ser_list.append(
+                    testGen.ser.addPlaceholder([1], DType.INT8, np.int8([shift]))
+                )
 
             return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
 
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index ee45f0e..b472087 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -587,9 +587,9 @@
     def build_mul(
         self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
     ):
-        assert len(inputs) == 2
-        a, b = inputs
-        shift = args_dict["shift"]
+        # Note that mul is binary operator but it has a shift value tensor
+        assert len(inputs) == 3
+        a, b, s = inputs
 
         result_tensor = OutputShaper.binaryBroadcastOp(
             self.ser, self.rng, a, b, error_name
@@ -605,7 +605,7 @@
             result_tensor.setDtype(outputDType)
 
         # Invalidate Input/Output list for error if checks.
-        input_list = [a.name, b.name]
+        input_list = [a.name, b.name, s.name]
         output_list = [result_tensor.name]
         pCount, cCount = op["operands"]
         num_operands = pCount + cCount
@@ -629,10 +629,7 @@
         ):
             return None
 
-        attr = ts.TosaSerializerAttribute()
-        attr.MulAttribute(shift)
-
-        self.ser.addOperator(op["op"], input_list, output_list, attr)
+        self.ser.addOperator(op["op"], input_list, output_list)
 
         compliance = self.tensorComplianceMetaData(
             op, a.dtype, args_dict, result_tensor, error_name
@@ -3874,10 +3871,10 @@
         },
         "mul": {
             "op": Op.MUL,
-            "operands": (2, 0),
+            "operands": (3, 0),
             "build_fcn": (
                 build_mul,
-                TosaTensorGen.tgBroadcastFuzz,
+                TosaTensorGen.tgMul,
                 TosaTensorValuesGen.tvgMul,
                 TosaArgGen.agMul,
             ),