Add conformance testing for shape operators

Signed-off-by: Won Jeon <won.jeon@arm.com>
Change-Id: Ie80570146601c470a3be7c04a9d6e1016a7c547c
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 159ee83..b9352ac 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -167,9 +167,10 @@
             rng = (-128, 128)
         elif dtype == DType.INT16:
             rng = (-32768, 32768)
-        elif dtype in (DType.INT32, DType.SHAPE):
-            # restricting too large value for SHAPE
+        elif dtype == DType.INT32:
             rng = (-(1 << 31), (1 << 31))
+        elif dtype == DType.SHAPE:
+            rng = tuple(self.args.tensor_shape_range[0:2])
         elif dtype == DType.INT48:
             rng = (-(1 << 47), (1 << 47))
         else:
@@ -190,7 +191,7 @@
 
         if dtype == DType.BOOL:
             return np.bool_(self.rng.choice(a=[False, True], size=shape))
-        elif dtype == DType.INT48:
+        elif dtype in (DType.INT48, DType.SHAPE):
             return np.int64(self.rng.integers(low=low, high=high, size=shape))
         elif dtype in (DType.FP16, DType.BF16, DType.FP32):
             f_tensor = self.rng.uniform(low=low, high=high, size=shape)
@@ -1399,7 +1400,10 @@
     def build_concat(
         self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
     ):
-        axis = args_dict["axis"]
+        if op["op"] == Op.CONCAT_SHAPE:
+            axis = 0
+        else:
+            axis = args_dict["axis"]
         if error_name != ErrorIf.WrongInputType:
             assert type(axis) == int
 
@@ -1438,9 +1442,12 @@
         ):
             return None
 
-        attr = ts.TosaSerializerAttribute()
-        attr.AxisAttribute(axis)
-
+        if op["op"] == Op.CONCAT:
+            attr = ts.TosaSerializerAttribute()
+            attr.AxisAttribute(axis)
+        else:
+            assert op["op"] == Op.CONCAT_SHAPE
+            attr = None
         self.ser.addOperator(op["op"], input_list, output_list, attr)
 
         compliance = self.tensorComplianceMetaData(
@@ -2512,6 +2519,52 @@
         self.ser.addOperator(op["op"], input_names, output_names, attr)
         return results
 
+    def build_shape_op(
+        self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
+    ):
+        assert len(inputs) == 2
+        a, b = inputs
+
+        result_tensor = OutputShaper.addShapeOp(self.ser, self.rng, a, b, error_name)
+
+        # Invalidate Input/Output list for error if checks.
+        input_list = [a.name, b.name]
+        output_list = [result_tensor.name]
+        pCount, cCount = op["operands"]
+        num_operands = pCount + cCount
+        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+            self, error_name, input_list, output_list
+        )
+
+        if not TosaErrorValidator.evValidateErrorIfs(
+            self.ser,
+            validator_fcns,
+            error_name,
+            op=op,
+            input1=a,
+            input2=b,
+            input_shape=a.shape,
+            input_dtype=a.dtype,
+            output_shape=result_tensor.shape,
+            output_dtype=result_tensor.dtype,
+            result_tensors=[result_tensor],
+            input_list=input_list,
+            output_list=output_list,
+            num_operands=num_operands,
+        ):
+            return None
+
+        self.ser.addOperator(
+            op["op"],
+            input_list,
+            output_list,
+        )
+        compliance = self.tensorComplianceMetaData(
+            op, a.dtype, args_dict, result_tensor, error_name
+        )
+
+        return TosaTestGen.BuildInfo(result_tensor, compliance)
+
     def create_filter_lists(
         self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
     ):
@@ -2725,12 +2778,12 @@
 
         if isinstance(dtype_or_dtypeList, list):
             dtypeList = dtype_or_dtypeList
-        elif op["op"] == Op.CONCAT:
+        elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
             dtypeList = [dtype_or_dtypeList] * len(shapeList)
         else:
             dtypeList = [dtype_or_dtypeList] * (num_operands)
 
-        if op["op"] != Op.CONCAT:
+        if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
             assert (
                 len(shapeList) == num_operands
             ), "shapeList length {} must match number of operands {}".format(
@@ -4605,6 +4658,78 @@
                 TosaErrorValidator.evFFTOutputShapeMismatch,
             ),
         },
+        # Shape
+        "add_shape": {
+            "op": Op.ADD_SHAPE,
+            "operands": (2, 0),
+            "build_fcn": (
+                build_shape_op,
+                TosaTensorGen.tgShape,
+                TosaTensorValuesGen.tvgAddSub,
+                TosaArgGen.agNone,
+            ),
+            "types": [DType.SHAPE],
+            "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
+        },
+        "sub_shape": {
+            "op": Op.SUB_SHAPE,
+            "operands": (2, 0),
+            "build_fcn": (
+                build_shape_op,
+                TosaTensorGen.tgShape,
+                TosaTensorValuesGen.tvgAddSub,
+                TosaArgGen.agNone,
+            ),
+            "types": [DType.SHAPE],
+            "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
+        },
+        "mul_shape": {
+            "op": Op.MUL_SHAPE,
+            "operands": (2, 0),
+            "build_fcn": (
+                build_shape_op,
+                TosaTensorGen.tgShape,
+                TosaTensorValuesGen.tvgMul,
+                TosaArgGen.agNone,
+            ),
+            "types": [DType.SHAPE],
+            "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
+        },
+        "div_shape": {
+            "op": Op.DIV_SHAPE,
+            "operands": (2, 0),
+            "build_fcn": (
+                build_shape_op,
+                TosaTensorGen.tgShape,
+                TosaTensorValuesGen.tvgIntDiv,
+                TosaArgGen.agNone,
+            ),
+            "types": [DType.SHAPE],
+            "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
+        },
+        "concat_shape": {
+            "op": Op.CONCAT_SHAPE,
+            "operands": (2, 0),
+            "build_fcn": (
+                build_concat,
+                TosaTensorGen.tgConcat,
+                TosaTensorValuesGen.tvgConcat,
+                TosaArgGen.agNone,
+            ),
+            "types": [DType.SHAPE],
+            "error_if_validators": (),
+        },
+        "const_shape": {
+            "op": Op.CONST_SHAPE,
+            "operands": (0, 1),
+            "build_fcn": (
+                build_const,
+                TosaTensorGen.tgBasic,
+                TosaTensorValuesGen.tvgDefault,
+                None,
+            ),
+            "types": [DType.SHAPE],
+        },
     }
 
 
@@ -5524,3 +5649,24 @@
         outputs.append(serializer.addOutput(output_shape, output_dtype))
         outputs.append(serializer.addOutput(output_shape, output_dtype))
         return outputs
+
+    @staticmethod
+    def addShapeOp(ser, rng, a, b, error_name=None):
+        if error_name != ErrorIf.RankMismatch:
+            assert len(a.shape) == len(b.shape)
+        assert a.dtype == b.dtype
+
+        shape = []
+        for i in range(len(a.shape)):
+            shape.append(a.shape[i])
+
+        fuzz_idx = rng.integers(0, len(a.shape))
+        if error_name == ErrorIf.DimensionMismatch:
+            shape[fuzz_idx] += 1
+
+        if error_name == ErrorIf.WrongOutputType:
+            wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
+            outputDType = rng.choice(wrong_dtypes)
+        else:
+            outputDType = DType.SHAPE
+        return ser.addOutput(shape, outputDType)