Add conformance testing for shape operators

Signed-off-by: Won Jeon <won.jeon@arm.com>
Change-Id: Ie80570146601c470a3be7c04a9d6e1016a7c547c
diff --git a/verif/conformance/test_select.py b/verif/conformance/test_select.py
index cebdf62..55eef58 100644
--- a/verif/conformance/test_select.py
+++ b/verif/conformance/test_select.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021-2023, ARM Limited.
+# Copyright (c) 2021-2024, ARM Limited.
 # SPDX-License-Identifier: Apache-2.0
 """Select generated tests."""
 import argparse
@@ -437,6 +437,12 @@
     name = "add"
 
 
+class AddShapeOperator(Operator):
+    """Test selector for the ADD_SHAPE operator."""
+
+    name = "add_shape"
+
+
 class ArgmaxOperator(Operator):
     """Test selector for the ARGMAX operator."""
 
@@ -507,6 +513,12 @@
     param_names = ["shape", "type", "axis"]
 
 
+class ConcatShapeOperator(Operator):
+    """Test selector for the CONCAT_SHAPE operator."""
+
+    name = "concat_shape"
+
+
 class CondIfOperator(Operator):
     """Test selector for the COND_IF operator."""
 
@@ -520,6 +532,12 @@
     name = "const"
 
 
+class ConstShapeOperator(Operator):
+    """Test selector for the CONST_SHAPE operator."""
+
+    name = "const_shape"
+
+
 class Conv2dOperator(Operator):
     """Test selector for the CONV2D operator."""
 
@@ -548,6 +566,12 @@
     param_names = ["shape", "type", "axis"]
 
 
+class DivShapeOperator(Operator):
+    """Test selector for the DIV_SHAPE operator."""
+
+    name = "div_shape"
+
+
 class EqualOperator(Operator):
     """Test selector for the EQUAL operator."""
 
@@ -696,6 +720,12 @@
     param_names = ["shape", "type", "perm", "shift"]
 
 
+class MulShapeOperator(Operator):
+    """Test selector for the MUL_SHAPE operator."""
+
+    name = "mul_shape"
+
+
 class NegateOperator(Operator):
     """Test selector for the Negate operator."""
 
@@ -849,6 +879,12 @@
     name = "sub"
 
 
+class SubShapeOperator(Operator):
+    """Test selector for the SUB_SHAPE operator."""
+
+    name = "sub_shape"
+
+
 class TableOperator(Operator):
     """Test selector for the TABLE operator."""
 
diff --git a/verif/conformance/tosa_base_profile_ops_info.json b/verif/conformance/tosa_base_profile_ops_info.json
index b186b06..ec51324 100644
--- a/verif/conformance/tosa_base_profile_ops_info.json
+++ b/verif/conformance/tosa_base_profile_ops_info.json
@@ -129,6 +129,35 @@
             }
         }
     },
+    "add_shape": {
+        "group": "shape",
+        "profile": [
+            "tosa-bi",
+            "tosa-mi"
+        ],
+        "generation": {
+            "standard": {
+                "generator_args": [
+                    [
+                        "--target-dtype",
+                        "shape",
+                        "--tensor-dim-range",
+                        "1,16",
+                        "--target-rank",
+                        "1"
+                    ]
+                ]
+            }
+        },
+        "selection": {
+            "default": {
+                "params": {},
+                "permutes": [
+                    "shape"
+                ]
+            }
+        }
+    },
     "argmax": {
         "group": "tensor",
         "profile": [
@@ -974,6 +1003,36 @@
             }
         }
     },
+    "concat_shape": {
+        "group": "shape",
+        "profile": [
+            "tosa-bi",
+            "tosa-mi"
+        ],
+        "generation": {
+            "standard": {
+                "generator_args": [
+                    [
+                        "--target-dtype",
+                        "shape",
+                        "--target-rank",
+                        "1",
+                        "--target-shape",
+                        "1",
+                        "--num-const-inputs-concat",
+                        "2"
+                    ]
+                ]
+            }
+        },
+        "selection": {
+            "default": {
+                "params": {},
+                "permutes": [
+                ]
+            }
+        }
+    },
     "cond_if": {
         "group": "control_flow",
         "profile": [
@@ -1080,6 +1139,35 @@
             }
         }
     },
+    "const_shape": {
+        "group": "shape",
+        "profile": [
+            "tosa-bi",
+            "tosa-mi"
+        ],
+        "generation": {
+            "standard": {
+                "no_negative_tests": "true",
+                "generator_args": [
+                    [
+                        "--target-dtype",
+                        "shape",
+                        "--target-rank",
+                        "1",
+                        "--target-shape",
+                        "1"
+                    ]
+                ]
+            }
+        },
+        "selection": {
+            "default": {
+                "params": {},
+                "permutes": [
+                ]
+            }
+        }
+    },
     "conv2d": {
         "group": "tensor",
         "profile": [
@@ -1374,6 +1462,35 @@
             }
         }
     },
+    "div_shape": {
+        "group": "shape",
+        "profile": [
+            "tosa-bi",
+            "tosa-mi"
+        ],
+        "generation": {
+            "standard": {
+                "generator_args": [
+                    [
+                        "--target-dtype",
+                        "shape",
+                        "--tensor-dim-range",
+                        "1,16",
+                        "--target-rank",
+                        "1"
+                    ]
+                ]
+            }
+        },
+        "selection": {
+            "default": {
+                "params": {},
+                "permutes": [
+                    "shape"
+                ]
+            }
+        }
+    },
     "equal": {
         "group": "comparison",
         "profile": [
@@ -2542,6 +2659,35 @@
             }
         }
     },
+    "mul_shape": {
+        "group": "shape",
+        "profile": [
+            "tosa-bi",
+            "tosa-mi"
+        ],
+        "generation": {
+            "standard": {
+                "generator_args": [
+                    [
+                        "--target-dtype",
+                        "shape",
+                        "--tensor-dim-range",
+                        "1,16",
+                        "--target-rank",
+                        "1"
+                    ]
+                ]
+            }
+        },
+        "selection": {
+            "default": {
+                "params": {},
+                "permutes": [
+                    "shape"
+                ]
+            }
+        }
+    },
     "negate": {
         "group": "ew_unary",
         "profile": [
@@ -3502,6 +3648,35 @@
             }
         }
     },
+    "sub_shape": {
+        "group": "shape",
+        "profile": [
+            "tosa-bi",
+            "tosa-mi"
+        ],
+        "generation": {
+            "standard": {
+                "generator_args": [
+                    [
+                        "--target-dtype",
+                        "shape",
+                        "--tensor-dim-range",
+                        "1,16",
+                        "--target-rank",
+                        "1"
+                    ]
+                ]
+            }
+        },
+        "selection": {
+            "default": {
+                "params": {},
+                "permutes": [
+                    "shape"
+                ]
+            }
+        }
+    },
     "table": {
         "group": "ew_binary",
         "profile": [
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index a655a50..f598377 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -622,6 +622,28 @@
 
         return new_shapeList
 
+    @staticmethod
+    def tgShape(testGen, opName, rank, error_name=None):
+        pl, const = opName["operands"]
+        shape = [rank]
+
+        # Constrict the overall size of the shape when creating ERROR_IF tests
+        if error_name:
+            shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
+
+        shape_list = []
+        for i in range(pl + const):
+            shape_list.append(shape.copy())
+
+            # Generates an input rank mismatch for operators with more than one input
+            if error_name == ErrorIf.RankMismatch:
+                if rank == 1 and i != 1:
+                    shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
+                elif i != 1:
+                    shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
+
+        return shape_list
+
 
 class TosaTensorValuesGen:
     """Tensor Value generators create the random data for each tensor in each test."""
@@ -891,7 +913,7 @@
 
     @staticmethod
     def tvgAddSub(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
-        if dtypeList[0] == DType.INT32 and error_name is None:
+        if dtypeList[0] in (DType.INT32, DType.SHAPE) and error_name is None:
             # Make sure the integer operation does not cause value saturation - where
             # the number wraps due to limited number of bits to store the answer
             op = testGen.TOSA_OP_LIST[opName]
@@ -900,9 +922,10 @@
                 pCount == 2 and cCount == 0
             ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
             tens_ser_list = []
-            add = op["op"] == Op.ADD
-            a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
-            b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
+            add = op["op"] in (Op.ADD, Op.ADD_SHAPE)
+            data_range = testGen.args.tensor_shape_range
+            a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0], data_range)
+            b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1], data_range)
             if add:
                 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
             else:
@@ -1138,12 +1161,15 @@
             tens_ser_list = []
 
             # Make sure multiply result in int32 range
-            shift = argsDict["shift"]
+            if dtypeList[0] == DType.SHAPE:
+                shift = 0
+            else:
+                shift = argsDict["shift"]
             if dtypeList[0] == DType.INT8:
                 num_bits = 8
             elif dtypeList[0] == DType.INT16:
                 num_bits = 16
-            elif dtypeList[0] == DType.INT32:
+            elif dtypeList[0] in (DType.INT32, DType.SHAPE):
                 num_bits = 32
             elif error_name == ErrorIf.WrongInputType:
                 num_bits = 8
@@ -1151,8 +1177,12 @@
                 raise Exception("OpMul: invalid input dtype")
 
             for idx, shape in enumerate(shapeList[:]):
-                low = -(2 ** (num_bits - 1))
-                high = (2 ** (num_bits - 1)) - 1
+                if dtypeList[idx] == DType.SHAPE:
+                    low = testGen.args.tensor_shape_range[0]
+                    high = testGen.args.tensor_shape_range[1]
+                else:
+                    low = -(2 ** (num_bits - 1))
+                    high = (2 ** (num_bits - 1)) - 1
 
                 a_arr = np.int32(
                     testGen.rng.integers(low=low, high=high, size=shapeList[0])
@@ -1182,12 +1212,20 @@
                 a_arr = a_arr // 2
                 b_arr = b_arr // 2
 
-            tens_ser_list.append(
-                testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
-            )
-            tens_ser_list.append(
-                testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
-            )
+            if dtypeList[0] == DType.SHAPE:
+                tens_ser_list.append(
+                    testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr_64)
+                )
+                tens_ser_list.append(
+                    testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr_64)
+                )
+            else:
+                tens_ser_list.append(
+                    testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
+                )
+                tens_ser_list.append(
+                    testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
+                )
 
             return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
 
@@ -1199,9 +1237,16 @@
         if testGen.args.num_const_inputs_concat == 0:
             count = len(shapeList)
 
-        shapeList = TosaTensorGen.tgConcatConstInput(
-            testGen, shapeList, argsDict["axis"], error_name
-        )
+        op = testGen.TOSA_OP_LIST[opName]
+        if op["op"] == Op.CONCAT_SHAPE:
+            # Set the axis to 0
+            shapeList = TosaTensorGen.tgConcatConstInput(
+                testGen, shapeList, 0, error_name
+            )
+        else:
+            shapeList = TosaTensorGen.tgConcatConstInput(
+                testGen, shapeList, argsDict["axis"], error_name
+            )
 
         # Override default pCount/cCount for operator
         argsDict["p_count"] = count
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index 7f719ee..5874123 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021-2023, ARM Limited.
+# Copyright (c) 2021-2024, ARM Limited.
 # SPDX-License-Identifier: Apache-2.0
 import math
 
@@ -595,6 +595,10 @@
                     error_result = True
                 # invalid input types are ignored, to avoid reporting multiple errors
 
+            elif op["op"] in {Op.ADD_SHAPE, Op.SUB_SHAPE, Op.MUL_SHAPE, Op.DIV_SHAPE}:
+                if output_dtype != DType.SHAPE:
+                    error_result = True
+
             else:
                 if output_dtype != input_dtype:
                     error_result = True
@@ -1109,7 +1113,13 @@
                 kwargs["input3"].shape if "input3" in kwargs else input2_shape
             )
 
-            if len(input1_shape) == len(input2_shape) == len(input3_shape):
+            op = kwargs["op"]
+            if op["op"] in (Op.ADD_SHAPE, Op.SUB_SHAPE, Op.MUL_SHAPE, Op.DIV_SHAPE):
+                output_shape = kwargs["result_tensors"][0].shape
+                if input1_shape != output_shape:
+                    error_result = True
+
+            elif len(input1_shape) == len(input2_shape) == len(input3_shape):
                 calculated_shape = TosaErrorValidator.calculateBroadcastShape(
                     input3_shape,
                     TosaErrorValidator.calculateBroadcastShape(
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 159ee83..b9352ac 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -167,9 +167,10 @@
             rng = (-128, 128)
         elif dtype == DType.INT16:
             rng = (-32768, 32768)
-        elif dtype in (DType.INT32, DType.SHAPE):
-            # restricting too large value for SHAPE
+        elif dtype == DType.INT32:
             rng = (-(1 << 31), (1 << 31))
+        elif dtype == DType.SHAPE:
+            rng = tuple(self.args.tensor_shape_range[0:2])
         elif dtype == DType.INT48:
             rng = (-(1 << 47), (1 << 47))
         else:
@@ -190,7 +191,7 @@
 
         if dtype == DType.BOOL:
             return np.bool_(self.rng.choice(a=[False, True], size=shape))
-        elif dtype == DType.INT48:
+        elif dtype in (DType.INT48, DType.SHAPE):
             return np.int64(self.rng.integers(low=low, high=high, size=shape))
         elif dtype in (DType.FP16, DType.BF16, DType.FP32):
             f_tensor = self.rng.uniform(low=low, high=high, size=shape)
@@ -1399,7 +1400,10 @@
     def build_concat(
         self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
     ):
-        axis = args_dict["axis"]
+        if op["op"] == Op.CONCAT_SHAPE:
+            axis = 0
+        else:
+            axis = args_dict["axis"]
         if error_name != ErrorIf.WrongInputType:
             assert type(axis) == int
 
@@ -1438,9 +1442,12 @@
         ):
             return None
 
-        attr = ts.TosaSerializerAttribute()
-        attr.AxisAttribute(axis)
-
+        if op["op"] == Op.CONCAT:
+            attr = ts.TosaSerializerAttribute()
+            attr.AxisAttribute(axis)
+        else:
+            assert op["op"] == Op.CONCAT_SHAPE
+            attr = None
         self.ser.addOperator(op["op"], input_list, output_list, attr)
 
         compliance = self.tensorComplianceMetaData(
@@ -2512,6 +2519,52 @@
         self.ser.addOperator(op["op"], input_names, output_names, attr)
         return results
 
+    def build_shape_op(
+        self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
+    ):
+        assert len(inputs) == 2
+        a, b = inputs
+
+        result_tensor = OutputShaper.addShapeOp(self.ser, self.rng, a, b, error_name)
+
+        # Invalidate Input/Output list for error if checks.
+        input_list = [a.name, b.name]
+        output_list = [result_tensor.name]
+        pCount, cCount = op["operands"]
+        num_operands = pCount + cCount
+        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+            self, error_name, input_list, output_list
+        )
+
+        if not TosaErrorValidator.evValidateErrorIfs(
+            self.ser,
+            validator_fcns,
+            error_name,
+            op=op,
+            input1=a,
+            input2=b,
+            input_shape=a.shape,
+            input_dtype=a.dtype,
+            output_shape=result_tensor.shape,
+            output_dtype=result_tensor.dtype,
+            result_tensors=[result_tensor],
+            input_list=input_list,
+            output_list=output_list,
+            num_operands=num_operands,
+        ):
+            return None
+
+        self.ser.addOperator(
+            op["op"],
+            input_list,
+            output_list,
+        )
+        compliance = self.tensorComplianceMetaData(
+            op, a.dtype, args_dict, result_tensor, error_name
+        )
+
+        return TosaTestGen.BuildInfo(result_tensor, compliance)
+
     def create_filter_lists(
         self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
     ):
@@ -2725,12 +2778,12 @@
 
         if isinstance(dtype_or_dtypeList, list):
             dtypeList = dtype_or_dtypeList
-        elif op["op"] == Op.CONCAT:
+        elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
             dtypeList = [dtype_or_dtypeList] * len(shapeList)
         else:
             dtypeList = [dtype_or_dtypeList] * (num_operands)
 
-        if op["op"] != Op.CONCAT:
+        if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
             assert (
                 len(shapeList) == num_operands
             ), "shapeList length {} must match number of operands {}".format(
@@ -4605,6 +4658,78 @@
                 TosaErrorValidator.evFFTOutputShapeMismatch,
             ),
         },
+        # Shape
+        "add_shape": {
+            "op": Op.ADD_SHAPE,
+            "operands": (2, 0),
+            "build_fcn": (
+                build_shape_op,
+                TosaTensorGen.tgShape,
+                TosaTensorValuesGen.tvgAddSub,
+                TosaArgGen.agNone,
+            ),
+            "types": [DType.SHAPE],
+            "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
+        },
+        "sub_shape": {
+            "op": Op.SUB_SHAPE,
+            "operands": (2, 0),
+            "build_fcn": (
+                build_shape_op,
+                TosaTensorGen.tgShape,
+                TosaTensorValuesGen.tvgAddSub,
+                TosaArgGen.agNone,
+            ),
+            "types": [DType.SHAPE],
+            "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
+        },
+        "mul_shape": {
+            "op": Op.MUL_SHAPE,
+            "operands": (2, 0),
+            "build_fcn": (
+                build_shape_op,
+                TosaTensorGen.tgShape,
+                TosaTensorValuesGen.tvgMul,
+                TosaArgGen.agNone,
+            ),
+            "types": [DType.SHAPE],
+            "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
+        },
+        "div_shape": {
+            "op": Op.DIV_SHAPE,
+            "operands": (2, 0),
+            "build_fcn": (
+                build_shape_op,
+                TosaTensorGen.tgShape,
+                TosaTensorValuesGen.tvgIntDiv,
+                TosaArgGen.agNone,
+            ),
+            "types": [DType.SHAPE],
+            "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
+        },
+        "concat_shape": {
+            "op": Op.CONCAT_SHAPE,
+            "operands": (2, 0),
+            "build_fcn": (
+                build_concat,
+                TosaTensorGen.tgConcat,
+                TosaTensorValuesGen.tvgConcat,
+                TosaArgGen.agNone,
+            ),
+            "types": [DType.SHAPE],
+            "error_if_validators": (),
+        },
+        "const_shape": {
+            "op": Op.CONST_SHAPE,
+            "operands": (0, 1),
+            "build_fcn": (
+                build_const,
+                TosaTensorGen.tgBasic,
+                TosaTensorValuesGen.tvgDefault,
+                None,
+            ),
+            "types": [DType.SHAPE],
+        },
     }
 
 
@@ -5524,3 +5649,24 @@
         outputs.append(serializer.addOutput(output_shape, output_dtype))
         outputs.append(serializer.addOutput(output_shape, output_dtype))
         return outputs
+
+    @staticmethod
+    def addShapeOp(ser, rng, a, b, error_name=None):
+        if error_name != ErrorIf.RankMismatch:
+            assert len(a.shape) == len(b.shape)
+        assert a.dtype == b.dtype
+
+        shape = []
+        for i in range(len(a.shape)):
+            shape.append(a.shape[i])
+
+        fuzz_idx = rng.integers(0, len(a.shape))
+        if error_name == ErrorIf.DimensionMismatch:
+            shape[fuzz_idx] += 1
+
+        if error_name == ErrorIf.WrongOutputType:
+            wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
+            outputDType = rng.choice(wrong_dtypes)
+        else:
+            outputDType = DType.SHAPE
+        return ser.addOutput(shape, outputDType)