Improve Avg_Pool2D ErrorIf Testing

* Add test for invalid accumulator dtype

Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: I506e2047623372670b82db6e9c0010fa89802851
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 4630f35..33e74b5 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -2485,6 +2485,9 @@
             # incorrect input data-type
             accum_dtypes = [DType.INT32]
 
+        if error_name == ErrorIf.WrongAccumulatorType:
+            accum_dtypes = list(gtu.usableDTypes(excludes=accum_dtypes))
+
         if not test_level8k:
             if testGen.args.oversize:
                 # add some oversize argument values
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index 5fd647a..9a88acb 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -84,6 +84,7 @@
     ReshapeOutputSizeMultiInference = "ReshapeOutputSizeMultiInference"
     ReshapeOutputSizeNonInteger = "ReshapeOutputSizeNonInteger"
     BroadcastShapesMismatch = "BroadcastShapesMismatch"
+    WrongAccumulatorType = "WrongAccumulatorType"
 
 
 class TosaErrorIfArgGen:
@@ -2580,6 +2581,49 @@
         }
         return info_dict
 
+    def evWrongAccumulatorType(check=False, **kwargs):
+        error_name = ErrorIf.WrongAccumulatorType
+        param_reqs = {"rank": None, "dtype": None, "shape": None}
+        error_result = False
+        error_reason = "An unsupported accumulator data type was requested"
+
+        if check:
+            op = kwargs["op"]
+            input_dtype = kwargs["input_dtype"]
+            accum_dtype = kwargs["accum_dtype"]
+            if op["op"] == Op.AVG_POOL2D:
+                if (
+                    input_dtype
+                    in (
+                        DType.INT8,
+                        DType.INT16,
+                    )
+                    and accum_dtype != DType.INT32
+                ):
+                    error_result = True
+                elif (
+                    input_dtype
+                    in (
+                        DType.FP32,
+                        DType.BF16,
+                    )
+                    and accum_dtype != DType.FP32
+                ):
+                    error_result = True
+                elif input_dtype == DType.FP16 and accum_dtype not in (
+                    DType.FP16,
+                    DType.FP32,
+                ):
+                    error_result = True
+
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs,
+        }
+        return info_dict
+
 
 class TosaInvalidValidator:
     @staticmethod
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index f5eca18..2d471c0 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -835,6 +835,7 @@
             input_dtype=input.dtype,
             output_shape=result_tensor.shape,
             output_dtype=result_tensor.dtype,
+            accum_dtype=accum_dtype,
             kernel=kernel,
             stride=stride,
             pad=pad,
@@ -3218,6 +3219,7 @@
                 TosaErrorValidator.evPadLargerEqualKernel,
                 TosaErrorValidator.evPoolingOutputShapeMismatch,
                 TosaErrorValidator.evPoolingOutputShapeNonInteger,
+                TosaErrorValidator.evWrongAccumulatorType,
             ),
             "data_gen": {
                 "fp": (gtu.DataGenType.DOT_PRODUCT,),