Add DIM operator to reference model

Signed-off-by: Won Jeon <won.jeon@arm.com>
Change-Id: Iea11ee5d3d98773e9c5e9b827593c05afb41ce3b
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index b5e71ac..8c18e67 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -88,7 +88,9 @@
             return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
         elif dtype == DType.UINT16:
             return np.int32(self.rng.integers(low=0, high=65536, size=shape))
-        elif dtype == DType.INT32:
+        elif (
+            dtype == DType.INT32 or dtype == DType.SHAPE
+        ):  # restricting too large value for SHAPE
             return np.int32(
                 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
             )
@@ -181,7 +183,9 @@
             low, high = (-128, 128)
         elif dtype == DType.INT16:
             low, high = (-32768, 32768)
-        elif dtype == DType.INT32:
+        elif (
+            dtype == DType.INT32 or dtype == DType.SHAPE
+        ):  # restricting too large value for SHAPE
             low, high = (-(1 << 31), (1 << 31))
         elif dtype == DType.INT48:
             low, high = (-(1 << 47), (1 << 47))
@@ -1310,6 +1314,49 @@
         self.ser.addOperator(op["op"], input_list, output_list, attr)
         return result_tens
 
+    def build_dim(
+        self,
+        op,
+        a,
+        axis,
+        validator_fcns=None,
+        error_name=None,
+        qinfo=None,
+    ):
+        result_tens = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
+
+        # Invalidate Input/Output list for error if checks.
+        input_list = [a.name]
+        output_list = [result_tens.name]
+        pCount, cCount = op["operands"]
+        num_operands = pCount + cCount
+        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+            self, error_name, input_list, output_list
+        )
+
+        if not TosaErrorValidator.evValidateErrorIfs(
+            self.ser,
+            validator_fcns,
+            error_name,
+            op=op,
+            axis=axis,
+            input_shape=a.shape,
+            input_dtype=a.dtype,
+            output_shape=result_tens.shape,
+            output_dtype=result_tens.dtype,
+            result_tensors=[result_tens],
+            input_list=input_list,
+            output_list=output_list,
+            num_operands=num_operands,
+        ):
+            return None
+
+        attr = ts.TosaSerializerAttribute()
+        attr.AxisAttribute(axis)
+
+        self.ser.addOperator(op["op"], input_list, output_list, attr)
+        return result_tens
+
     def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
         result_tens = OutputShaper.reshapeOp(
             self.ser, self.rng, a, newShape, error_name
@@ -3749,6 +3796,25 @@
                 TosaErrorValidator.evWrongRank,
             ),
         },
+        "dim": {
+            "op": Op.DIM,
+            "operands": (1, 0),
+            "build_fcn": (
+                build_dim,
+                TosaTensorGen.tgBasic,
+                TosaTensorValuesGen.tvgDefault,
+                TosaArgGen.agAxis,
+            ),
+            "types": TYPE_FIB,
+            "error_if_validators": (
+                TosaErrorValidator.evAxisLargerRank,
+                TosaErrorValidator.evAxisSmallerZero,
+                TosaErrorValidator.evWrongInputType,
+                TosaErrorValidator.evWrongInputList,
+                TosaErrorValidator.evWrongOutputList,
+                TosaErrorValidator.evWrongRank,
+            ),
+        },
         "reshape": {
             "op": Op.RESHAPE,
             "operands": (1, 0),
@@ -4665,6 +4731,27 @@
         return ser.addOutput(output_shape, outputDType)
 
     @staticmethod
+    def dimOp(ser, rng, a, axis, error_name=None):
+        output_shape = [1]
+
+        if error_name == ErrorIf.WrongOutputType:
+            all_dtypes = [
+                DType.INT8,
+                DType.INT16,
+                DType.INT32,
+                DType.INT48,
+                DType.FP32,
+                DType.FP16,
+                DType.BF16,
+            ]
+            wrong_dtypes = list(set(all_dtypes))
+            outputDType = rng.choice(wrong_dtypes)
+        else:
+            outputDType = DType.SHAPE
+
+        return ser.addOutput(output_shape, outputDType)
+
+    @staticmethod
     def reshapeOp(ser, rng, a, shape, error_name=None):
         output_shape = shape.copy()