[ref model] Add acc_type to Conv Ops

This patch implements changes required by the new acc_type field in
ConvAttribute and TransposeConvAttribute

Signed-off-by: Tai Ly <tai.ly@arm.com>
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: Ib13dbeec4d8920e0ddbcca02b727e7277f2c8d62
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 83487a1..ffa3683 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1990,7 +1990,12 @@
         # Shape: (OFM channels), (KD), KH, KW, IFM channels
         filter_shape = shapeList[1]
 
-        accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
+        accum_dtypes = gtu.get_accum_dtypes_from_tgTypes(dtypes)
+
+        if error_name == ErrorIf.WrongAccumulatorType:
+            accum_dtypes = (
+                [DType.BF16] if gtu.dtypeIsFloat(dtypes[0]) else [DType.INT16]
+            )
 
         # Op type checks
         conv3d = opName.startswith("conv3d")
@@ -2110,88 +2115,91 @@
             sparsity = 1
 
         n = 0
-        for s in sorted(list(strides)):
-            for p in sorted(list(paddings)):
-                for d in sorted(list(dilations)):
-                    if (
-                        n % sparsity == 0
-                        # the padded shape must exceed the dilation * kernel to get a positive
-                        # sized output shape
-                        and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k_shape[0] - 1)
-                        and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k_shape[1] - 1)
-                        and (
-                            k_rank < 3
-                            or (
-                                (ifm_shape[3] - 1 + p[4] + p[5])
-                                > d[2] * (k_shape[2] - 1)
-                            )
-                        )
-                    ):
-                        remainders = []
-                        outputs = []
-                        for index in range(k_rank):
-                            pad_offset = index * 2
-                            partial = (
-                                ifm_shape[index + 1]
-                                - 1
-                                + p[pad_offset]
-                                + p[pad_offset + 1]
-                                - (k_shape[index] - 1) * d[index]
-                            )
-                            remainders.append(partial % s[index])
-                            outputs.append((partial // s[index]) + 1)
-
+        for a in accum_dtypes:
+            for s in sorted(list(strides)):
+                for p in sorted(list(paddings)):
+                    for d in sorted(list(dilations)):
                         if (
-                            # the parameters must produce integer exact output
-                            error_name != ErrorIf.ConvOutputShapeNonInteger
-                            and max(remainders) == 0
-                        ) or (
-                            error_name == ErrorIf.ConvOutputShapeNonInteger
-                            and max(remainders) > 0
-                        ):
-                            if (
-                                max_dim_size is not None
-                                and max(outputs) >= max_dim_size
-                            ):
-                                # Test will consume too much memory - skip it
-                                continue
-
-                            # Compliance - number of dot product calculations
-                            if depthwise:
-                                # N*OH*OW*C*M
-                                dots = gtu.product(
-                                    (ifm_shape[0], *outputs, *filter_shape[2:])
-                                )
-                            else:
-                                # N*OH*OW*OC or N*OD*OH*OW*OC
-                                dots = gtu.product(
-                                    (ifm_shape[0], *outputs, filter_shape[0])
-                                )
-                            args_dict = {
-                                "acc_type": accum_dtype,
-                                "stride": s,
-                                "pad": p,
-                                "dilation": d,
-                                "kernel": k_shape,
-                                "ks": k_size,
-                                "dot_products": dots,
-                                "shape": ifm_shape,
-                            }
-
-                            # Support for larger values than 9 needs different delimiter
-                            delim = "" if max(s + p + d) <= 9 else "x"
-                            arg_list.append(
-                                (
-                                    "acc{}_st{}_pad{}_dilat{}".format(
-                                        testGen.typeStr(accum_dtype),
-                                        delim.join([str(x) for x in s]),
-                                        delim.join([str(x) for x in p]),
-                                        delim.join([str(x) for x in d]),
-                                    ),
-                                    args_dict,
+                            n % sparsity == 0
+                            # the padded shape must exceed the dilation * kernel to get a positive
+                            # sized output shape
+                            and (ifm_shape[1] - 1 + p[0] + p[1])
+                            > d[0] * (k_shape[0] - 1)
+                            and (ifm_shape[2] - 1 + p[2] + p[3])
+                            > d[1] * (k_shape[1] - 1)
+                            and (
+                                k_rank < 3
+                                or (
+                                    (ifm_shape[3] - 1 + p[4] + p[5])
+                                    > d[2] * (k_shape[2] - 1)
                                 )
                             )
-                    n += 1
+                        ):
+                            remainders = []
+                            outputs = []
+                            for index in range(k_rank):
+                                pad_offset = index * 2
+                                partial = (
+                                    ifm_shape[index + 1]
+                                    - 1
+                                    + p[pad_offset]
+                                    + p[pad_offset + 1]
+                                    - (k_shape[index] - 1) * d[index]
+                                )
+                                remainders.append(partial % s[index])
+                                outputs.append((partial // s[index]) + 1)
+
+                            if (
+                                # the parameters must produce integer exact output
+                                error_name != ErrorIf.ConvOutputShapeNonInteger
+                                and max(remainders) == 0
+                            ) or (
+                                error_name == ErrorIf.ConvOutputShapeNonInteger
+                                and max(remainders) > 0
+                            ):
+                                if (
+                                    max_dim_size is not None
+                                    and max(outputs) >= max_dim_size
+                                ):
+                                    # Test will consume too much memory - skip it
+                                    continue
+
+                                # Compliance - number of dot product calculations
+                                if depthwise:
+                                    # N*OH*OW*C*M
+                                    dots = gtu.product(
+                                        (ifm_shape[0], *outputs, *filter_shape[2:])
+                                    )
+                                else:
+                                    # N*OH*OW*OC or N*OD*OH*OW*OC
+                                    dots = gtu.product(
+                                        (ifm_shape[0], *outputs, filter_shape[0])
+                                    )
+                                args_dict = {
+                                    "acc_type": a,
+                                    "stride": s,
+                                    "pad": p,
+                                    "dilation": d,
+                                    "kernel": k_shape,
+                                    "ks": k_size,
+                                    "dot_products": dots,
+                                    "shape": ifm_shape,
+                                }
+
+                                # Support for larger values than 9 needs different delimiter
+                                delim = "" if max(s + p + d) <= 9 else "x"
+                                arg_list.append(
+                                    (
+                                        "acc{}_st{}_pad{}_dilat{}".format(
+                                            testGen.typeStr(a),
+                                            delim.join([str(x) for x in s]),
+                                            delim.join([str(x) for x in p]),
+                                            delim.join([str(x) for x in d]),
+                                        ),
+                                        args_dict,
+                                    )
+                                )
+                        n += 1
 
         arg_list = TosaArgGen._add_data_generators(
             testGen,
@@ -2216,7 +2224,7 @@
             # Pick some potentially correct output dtype if input type is incorrect
             accum_dtype = DType.INT32
         else:
-            accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
+            accum_dtype = dtypes[-1]  # use output dtype as accum_dtype
 
         # Set up compliance info
         args_dict = {
@@ -2303,7 +2311,12 @@
         ifm_shape = shapeList[0]
         filter_shape = shapeList[1]
 
-        accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
+        accum_dtypes = gtu.get_accum_dtypes_from_tgTypes(dtypes)
+
+        if error_name == ErrorIf.WrongAccumulatorType:
+            accum_dtypes = (
+                [DType.BF16] if gtu.dtypeIsFloat(dtypes[0]) else [DType.INT16]
+            )
 
         # Must be rank 4
         if error_name != ErrorIf.WrongRank:
@@ -2400,41 +2413,42 @@
             sparsity = 1
 
         n = 0
-        for s in sorted(list(strides)):
-            for p in sorted(list(paddings)):
-                if n % sparsity == 0:
-                    # Determine the output shape
-                    oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
-                    ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
-                    os = [ifm_shape[0], oh, ow, filter_shape[0]]
+        for a in accum_dtypes:
+            for s in sorted(list(strides)):
+                for p in sorted(list(paddings)):
+                    if n % sparsity == 0:
+                        # Determine the output shape
+                        oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
+                        ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
+                        os = [ifm_shape[0], oh, ow, filter_shape[0]]
 
-                    # N*OH*OW*OC
-                    dots = gtu.product((ifm_shape[0], oh, ow, filter_shape[0]))
-                    args_dict = {
-                        "acc_type": accum_dtype,
-                        "stride": s,
-                        "pad": p,
-                        "kernel": k_shape,
-                        "ks": k_size,
-                        "dot_products": dots,
-                        "shape": ifm_shape,
-                        "out_shape": os,
-                    }
+                        # N*OH*OW*OC
+                        dots = gtu.product((ifm_shape[0], oh, ow, filter_shape[0]))
+                        args_dict = {
+                            "acc_type": a,
+                            "stride": s,
+                            "pad": p,
+                            "kernel": k_shape,
+                            "ks": k_size,
+                            "dot_products": dots,
+                            "shape": ifm_shape,
+                            "out_shape": os,
+                        }
 
-                    # Support for larger values than 9 needs different delimiter
-                    delim = "" if max(s + p) <= 9 else "x"
-                    arg_list.append(
-                        (
-                            "acc{}_st{}_pad{}_os{}".format(
-                                testGen.typeStr(accum_dtype),
-                                delim.join([str(x) for x in s]),
-                                delim.join([str(x) for x in p]),
-                                "x".join([str(x) for x in os]),
-                            ),
-                            args_dict,
+                        # Support for larger values than 9 needs different delimiter
+                        delim = "" if max(s + p) <= 9 else "x"
+                        arg_list.append(
+                            (
+                                "acc{}_st{}_pad{}_os{}".format(
+                                    testGen.typeStr(a),
+                                    delim.join([str(x) for x in s]),
+                                    delim.join([str(x) for x in p]),
+                                    "x".join([str(x) for x in os]),
+                                ),
+                                args_dict,
+                            )
                         )
-                    )
-                n += 1
+                    n += 1
 
         arg_list = TosaArgGen._add_data_generators(
             testGen,
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index e557f06..916b4f9 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -649,9 +649,9 @@
                     or input_dtype == DType.INT16
                     and output_dtype != DType.INT48
                     or input_dtype == DType.FP16
-                    and output_dtype not in (DType.FP16, DType.FP32)
+                    and output_dtype != DType.FP16
                     or input_dtype == DType.BF16
-                    and output_dtype != DType.FP32
+                    and output_dtype != DType.BF16
                     or input_dtype == DType.FP32
                     and output_dtype != DType.FP32
                     or input_dtype == DType.FP8E4M3
@@ -2682,6 +2682,36 @@
                 ):
                     error_result = True
 
+            elif op["op"] in {
+                Op.CONV2D,
+                Op.CONV3D,
+                Op.DEPTHWISE_CONV2D,
+                Op.TRANSPOSE_CONV2D,
+            }:
+                if input_dtype == DType.INT8 and accum_dtype != DType.INT32:
+                    error_result = True
+                elif input_dtype == DType.INT16 and accum_dtype != DType.INT48:
+                    error_result = True
+                elif (
+                    input_dtype
+                    in (
+                        DType.FP32,
+                        DType.BF16,
+                    )
+                    and accum_dtype != DType.FP32
+                ):
+                    error_result = True
+                elif input_dtype == DType.FP16 and accum_dtype not in (
+                    DType.FP16,
+                    DType.FP32,
+                ):
+                    error_result = True
+                elif (
+                    input_dtype in (DType.FP8E4M3, DType.FP8E5M2)
+                    and accum_dtype != DType.FP16
+                ):
+                    error_result = True
+
         info_dict = {
             "error_name": error_name,
             "error_result": error_result,
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 7702753..c867070 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -896,6 +896,7 @@
             input_shape=ifm.shape,
             weight_shape=filter.shape,
             output_shape=result_tensor.shape,
+            accum_dtype=accum_dtype,
         ):
             return None
 
@@ -903,7 +904,9 @@
         local_bound = False
 
         attr = ts.TosaSerializerAttribute()
-        attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
+        attr.ConvAttribute(
+            padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
+        )
 
         self.ser.addOperator(op["op"], input_list, output_list, attr)
 
@@ -981,6 +984,7 @@
             input_shape=ifm.shape,
             weight_shape=filter.shape,
             output_shape=result_tensor.shape,
+            accum_dtype=accum_dtype,
         ):
             return None
 
@@ -988,7 +992,9 @@
         local_bound = False
 
         attr = ts.TosaSerializerAttribute()
-        attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
+        attr.ConvAttribute(
+            padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
+        )
 
         self.ser.addOperator(op["op"], input_list, output_list, attr)
 
@@ -1057,6 +1063,7 @@
             input_shape=ifm.shape,
             weight_shape=filter.shape,
             output_shape=result_tensor.shape,
+            accum_dtype=accum_dtype,
         ):
             return None
 
@@ -1065,7 +1072,7 @@
 
         attr = ts.TosaSerializerAttribute()
         attr.TransposeConvAttribute(
-            out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound
+            out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound, accum_dtype
         )
 
         self.ser.addOperator(op["op"], input_list, output_list, attr)
@@ -1143,6 +1150,7 @@
             input_shape=ifm.shape,
             weight_shape=filter.shape,
             output_shape=result_tensor.shape,
+            accum_dtype=accum_dtype,
         ):
             return None
 
@@ -1150,7 +1158,9 @@
         local_bound = False
 
         attr = ts.TosaSerializerAttribute()
-        attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
+        attr.ConvAttribute(
+            padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
+        )
 
         self.ser.addOperator(op["op"], input_list, output_list, attr)
 
@@ -3385,6 +3395,7 @@
                 TosaErrorValidator.evWrongRank,
                 TosaErrorValidator.evConvOutputShapeMismatch,
                 TosaErrorValidator.evConvOutputShapeNonInteger,
+                TosaErrorValidator.evWrongAccumulatorType,
             ),
             "data_gen": {
                 "fp": (gtu.DataGenType.DOT_PRODUCT,),
@@ -3418,6 +3429,7 @@
                 TosaErrorValidator.evWrongRank,
                 TosaErrorValidator.evConvOutputShapeMismatch,
                 TosaErrorValidator.evConvOutputShapeNonInteger,
+                TosaErrorValidator.evWrongAccumulatorType,
             ),
             "data_gen": {
                 "fp": (gtu.DataGenType.DOT_PRODUCT,),
@@ -3452,6 +3464,7 @@
                 TosaErrorValidator.evWrongRank,
                 TosaErrorValidator.evConvOutputShapeMismatch,
                 TosaErrorValidator.evConvOutputShapeNonInteger,
+                TosaErrorValidator.evWrongAccumulatorType,
             ),
             "data_gen": {
                 "fp": (gtu.DataGenType.DOT_PRODUCT,),
@@ -3564,6 +3577,7 @@
                 TosaErrorValidator.evStrideSmallerOne,
                 TosaErrorValidator.evWrongRank,
                 TosaErrorValidator.evConvOutputShapeMismatch,
+                TosaErrorValidator.evWrongAccumulatorType,
             ),
             "data_gen": {
                 "fp": (gtu.DataGenType.DOT_PRODUCT,),
@@ -5290,6 +5304,18 @@
         return ser.addOutput(shape, outputDType)
 
     @staticmethod
+    def _get_conv_output_type(input_dtype):
+        if input_dtype in (DType.FP16, DType.BF16, DType.FP32):
+            return input_dtype
+        elif input_dtype in (DType.FP8E4M3, DType.FP8E5M2):
+            return DType.FP16
+        elif input_dtype in (DType.INT8, DType.INT4):
+            return DType.INT32
+        elif input_dtype in (DType.INT16,):
+            return DType.INT48
+        assert True, f"Unsupported convolution data type {input_dtype}"
+
+    @staticmethod
     def conv2dOp(
         ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
     ):
@@ -5329,7 +5355,7 @@
             # Pick some potentially correct output dtype if input type is incorrect
             out_dtype = DType.INT32
         else:
-            out_dtype = accum_dtype
+            out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
 
         if error_name == ErrorIf.WrongOutputType:
             if ifm.dtype == DType.FP16:
@@ -5393,7 +5419,7 @@
             # Pick some potentially correct output dtype if input type is incorrect
             out_dtype = DType.INT32
         else:
-            out_dtype = accum_dtype
+            out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
 
         if error_name == ErrorIf.WrongOutputType:
             if ifm.dtype == DType.FP16:
@@ -5444,7 +5470,7 @@
             # Pick some potentially correct output dtype if input type is incorrect
             out_dtype = DType.INT32
         else:
-            out_dtype = accum_dtype
+            out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
 
         if error_name == ErrorIf.WrongOutputType:
             if ifm.dtype == DType.FP16:
@@ -5958,7 +5984,7 @@
             # Pick some potentially correct output dtype if input type is incorrect
             out_dtype = DType.INT32
         else:
-            out_dtype = accum_dtype
+            out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
 
         if error_name == ErrorIf.WrongOutputType:
             if ifm.dtype == DType.FP16:
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index cfe7cc6..4a4f6bb 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -164,10 +164,18 @@
     return value
 
 
-def get_accum_dtype_from_tgTypes(dtypes):
-    # Get accumulate data-type from the test generator's defined types
+def get_accum_dtypes_from_tgTypes(dtypes):
+    # Get accumulate data-types from the test generator's defined types
     assert isinstance(dtypes, list) or isinstance(dtypes, tuple)
-    return dtypes[-1]
+    input_dtype = dtypes[0]
+    output_dtype = dtypes[-1]
+    # by default, accum_dtypes contains only output_dtype
+    accum_dtypes = [output_dtype]
+    if input_dtype == DType.FP16 and output_dtype == DType.FP16:
+        accum_dtypes = [DType.FP16, DType.FP32]
+    elif output_dtype == DType.BF16:
+        accum_dtypes = [DType.FP32]
+    return accum_dtypes
 
 
 def get_wrong_output_type(op_name, rng, input_dtype):