Add negative testing for avg_pool2d, max_pool2d

 * Negative tests for ERROR_IFs given in spec
 * Constrict dimension size of latter ranks if
rank is larger than 4

Signed-off-by: Matthew Haddon <matthew.haddon@arm.com>
Change-Id: Iffea1874e876dba83c8a7c63049283bf7b3ba74b
diff --git a/verif/tosa_error_if.py b/verif/tosa_error_if.py
index 8710885..2daeb9d 100644
--- a/verif/tosa_error_if.py
+++ b/verif/tosa_error_if.py
@@ -35,5 +35,10 @@
     AxisSmallerZero = "AxisSmallerZero"
     AxisLargerRank = "AxisLargerRank"
     ShapeOfAxisNotOne = "ShapeOfAxisNotOne"
+    KernelSmallerOne = "KernelSmallerOne"
+    StrideSmallerOne = "StrideSmallerOne"
+    PadSmallerZero = "PadSmallerZero"
+    PadLargerEqualKernel = "PadLargerEqualKernel"
+    PoolingOutputShapeMismatch = "PoolingOutputShapeMismatch"
 
 
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 43b188d..928ac0e 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -193,6 +193,9 @@
         # Constrict the batch size?
         if testGen.args.max_batch_size:
             shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
+        # Constrict dimension size for large ranks
+        if rank > 4:
+            shape[4] = 1
 
         shape_list = []
         for i in range(pl + const):
@@ -655,7 +658,8 @@
         arg_list = []
 
         shape = shapeList[0]
-        assert len(shape) == 4
+        if error_name != ErrorIf.WrongRank:
+            assert len(shape) == 4
 
         # Generate comprehensive argument lists
         p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
@@ -681,11 +685,31 @@
         for s in sorted(list(strides)):
             for p in sorted(list(paddings)):
                 for k in sorted(list(kernels)):
-                    if (n % sparsity == 0
+                    # Calculate output height to test for error_if conditions
+                    oh = (shape[1] + p[0] + p[1] + s[0] - k[0]) // s[0]
+                    ow = (shape[2] + p[2] + p[3] + s[1] - k[1]) // s[1]
+                    y = (oh * s[0]) - p[0] - p[1] - s[0] + k[0]
+                    x = (ow * s[1]) - p[2] - p[3] - s[1] + k[1]
+
+                    if error_name in [ErrorIf.StrideSmallerOne, ErrorIf.KernelSmallerOne, ErrorIf.PadSmallerZero, ErrorIf.PadLargerEqualKernel]:
+                        sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(testGen, error_name, s, p, k)
+                        if None not in [sNew, pNew, kNew] and n % sparsity == 0:
+                            arg_list.append(
+                                (
+                                    "st{}_kern{}_pad{}".format(
+                                        "".join([str(x) for x in sNew]),
+                                        "".join([str(x) for x in kNew]),
+                                        "".join([str(x) for x in pNew]),
+                                    ),
+                                    [sNew, pNew, kNew],
+                                )
+                            )
+                    elif (n % sparsity == 0
                         # padding must not exceed the kernel size
                         and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
                         # the padded shape must exceed the kernel size
                         and (shape[1] + p[0] + p[1]) > k[0] and (shape[2] + p[2] + p[3]) > k[1]
+                        and y < shape[1] and x < shape[2]
                     ):
                         arg_list.append(
                             (
@@ -1181,6 +1205,32 @@
         return shift, stride, stride_fp, offset, offset_fp, outputDType
 
     @staticmethod
+    def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
+        if (error_name == ErrorIf.StrideSmallerOne
+            # padding must not exceed the kernel size
+            and pad[0] < kernel[0] and pad[1] < kernel[0] and pad[2] < kernel[1] and pad[3] < kernel[1]):
+            wrongStride = (testGen.rng.choice([0, -1, -2, -3]), testGen.rng.choice([0, -1, -2, -3]))
+            return wrongStride, pad, kernel
+        elif error_name == ErrorIf.PadSmallerZero:
+            wrongPad = (testGen.rng.choice([-1, -2, -3]),
+                        testGen.rng.choice([-1, -2, -3]),
+                        testGen.rng.choice([-1, -2, -3]),
+                        testGen.rng.choice([-1, -2, -3]))
+            return stride, wrongPad, kernel
+        elif error_name == ErrorIf.KernelSmallerOne:
+            wrongKernel = (testGen.rng.choice([0, -1, -2, -3]), testGen.rng.choice([0, -1, -2, -3]))
+            return stride, pad, wrongKernel
+        elif error_name == ErrorIf.PadLargerEqualKernel:
+            wrongPad = (testGen.rng.choice([kernel[0], kernel[0]+1, kernel[0]+2]),
+                        testGen.rng.choice([kernel[0], kernel[0]+1, kernel[0]+2]),
+                        testGen.rng.choice([kernel[1], kernel[1]+1, kernel[1]+2]),
+                        testGen.rng.choice([kernel[1], kernel[1]+1, kernel[1]+2]))
+            return stride, wrongPad, kernel
+        else:
+            return None, None, None
+
+
+    @staticmethod
     def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
         # Mess up input/output tensors for ERROR_IF checks
         if error_name == "WrongInputList":
@@ -1294,8 +1344,10 @@
         rank_range = range(rmin, rmax + 1)
         incorrect_ranks = list(set(all_ranks) - set(rank_range))
         # Set minimum incorrect rank to 3 to avoid index error
-        if op['op'] == Op.RESIZE:
+        if op['op'] in [Op.RESIZE]:
             incorrect_ranks = [3, 5]
+        elif op['op'] in [Op.AVG_POOL2D, Op.MAX_POOL2D]:
+            incorrect_ranks = [5]
 
         error_name = ErrorIf.WrongRank
         param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
@@ -1304,7 +1356,7 @@
 
         if check:
             input_shape = kwargs['input_shape']
-            if op['op'] == Op.RESIZE and len(input_shape.shape) != 4:
+            if op['op'] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D] and len(input_shape) != 4:
                 error_result = True
 
         info_dict = {
@@ -1370,7 +1422,7 @@
         error_reason = "At least one maximum dimension is larger than 16384"
 
         if check:
-            input_shape = kwargs['input_shape'].shape
+            input_shape = kwargs['input_shape']
             output_shape = kwargs['output_shape'] # Note this is just (OH, OW)
             if ((input_shape[1] > 16384) or
                 (input_shape[2] > 16384) or
@@ -1399,7 +1451,7 @@
         rank_range = range(rmin, rmax + 1)
 
         if check:
-            input_shape = kwargs['input_shape'].shape
+            input_shape = kwargs['input_shape']
             output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
 
             if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
@@ -1426,7 +1478,7 @@
         rank_range = range(rmin, rmax + 1)
 
         if check:
-            input_shape = kwargs['input_shape'].shape
+            input_shape = kwargs['input_shape']
             output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
             if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
                 error_result = True
@@ -1503,7 +1555,7 @@
         error_reason = "Stride value larger than or equal to H/W dimension"
 
         if check:
-            shape = kwargs['input_shape'].shape
+            shape = kwargs['input_shape']
             input_dtype = kwargs['input_dtype']
             stride = kwargs['stride_fp']
 
@@ -1669,10 +1721,17 @@
 
     @staticmethod
     def evInputZeroPointNotZero(check=False, **kwargs):
+        op = kwargs['op']
+        inputDtypes = op['types'].copy()
+        if DType.INT8 in inputDtypes:
+            inputDtypes.remove(DType.INT8)
+        if DType.UINT8 in inputDtypes:
+            inputDtypes.remove(DType.UINT8)
+
         error_name = ErrorIf.InputZeroPointNotZero
         param_reqs = {
             "rank": None,
-            "dtype": [DType.INT16, DType.INT32, DType.FLOAT],
+            "dtype": inputDtypes,
             "shape": None
             }
         error_result = False
@@ -1697,21 +1756,28 @@
 
     @staticmethod
     def evOutputZeroPointNotZero(check=False, **kwargs):
+        op = kwargs['op']
+        inputDtypes = op['types'].copy()
+        if DType.INT8 in inputDtypes:
+            inputDtypes.remove(DType.INT8)
+        if DType.UINT8 in inputDtypes:
+            inputDtypes.remove(DType.UINT8)
+
         error_name = ErrorIf.OutputZeroPointNotZero
         param_reqs = {
             "rank": None,
-            "dtype": [DType.INT16, DType.INT32, DType.FLOAT],
+            "dtype": inputDtypes,
             "shape": None
             }
         error_result = False
         error_reason = "Output DType not INT8 and zero point not 0"
 
         if check:
-            output_dtype = kwargs['output_dtype']
+            input_dtype = kwargs['input_dtype']
             # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
             qinfo = kwargs['qinfo'].ints
             output_zero_point = qinfo[1][1]
-            if output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0:
+            if input_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0:
                 error_result = True
 
         info_dict = {
@@ -1787,6 +1853,135 @@
         return info_dict
 
 
+    @staticmethod
+    def evPadSmallerZero(check=False, **kwargs):
+        error_name = ErrorIf.PadSmallerZero
+        param_reqs = {"rank": None, "dtype": None, "shape": None}
+        error_result = False
+        error_reason = "At least one pad is smaller than zero"
+
+        if check:
+            pad = kwargs['pad']
+            if min(pad) < 0:
+                error_result = True
+
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs
+        }
+        return info_dict
+
+
+    @staticmethod
+    def evPadLargerEqualKernel(check=False, **kwargs):
+        error_name = ErrorIf.PadLargerEqualKernel
+        param_reqs = {"rank": None, "dtype": None, "shape": None}
+        error_result = False
+        error_reason = "At least one pad is larger than kernel dimension"
+
+        if check:
+            pad = kwargs['pad']
+            kernel = kwargs['kernel']
+            if min(pad) > 0 and min(kernel) > 1:
+                if pad[0] >= kernel[0] or pad[1] >= kernel[0] or pad[2] >= kernel[1] or pad[3] >= kernel[1]:
+                    error_result = True
+
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs
+        }
+        return info_dict
+
+    @staticmethod
+    def evPoolingOutputShapeMismatch(check=False, **kwargs):
+        error_name = ErrorIf.PoolingOutputShapeMismatch
+        param_reqs = {"rank": None, "dtype": None, "shape": None}
+        error_result = False
+        error_reason = "Mismatch between output shape provided and expected output shape"
+
+        if check:
+            pad = kwargs['pad']
+            pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
+
+            kernel = kwargs['kernel']
+            kernel_y, kernel_x = kernel[0], kernel[1]
+
+            input_shape = kwargs['input_shape']
+            IH, IW = input_shape[1], input_shape[2]
+
+            output_shape = kwargs['output_shape']
+            OH, OW = output_shape[1], output_shape[2]
+
+            stride = kwargs['stride']
+            stride_y, stride_x = stride[0], stride[1]
+
+            # calculate correct height, width dimensions
+            if stride_x != 0 and stride_y != 0:
+                y_correct = (IH + pad_top + pad_bottom + stride_y - kernel_y) // stride_y
+                x_correct = (IW + pad_left + pad_right + stride_x - kernel_x) // stride_x
+
+            # ensure parameters are valid
+            params_valid = (min(kernel) >= 1 and min(stride) >= 1 and min(pad) >= 0
+                and not (pad[0] >= kernel[0] or pad[1] >= kernel[0] or pad[2] >= kernel[1] or pad[3] >= kernel[1]))
+
+            if params_valid and (OH != y_correct or OW != x_correct):
+                error_result = True
+
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs
+        }
+        return info_dict
+
+
+    @staticmethod
+    def evKernelSmallerOne(check=False, **kwargs):
+        error_name = ErrorIf.KernelSmallerOne
+        param_reqs = {"rank": None, "dtype": None, "shape": None}
+        error_result = False
+        error_reason = "At least one kernel dimension is smaller than zero"
+
+        if check:
+            kernel = kwargs['kernel']
+            if min(kernel) < 1:
+                error_result = True
+
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs
+        }
+        return info_dict
+
+    @staticmethod
+    def evStrideSmallerOne(check=False, **kwargs):
+        error_name = ErrorIf.StrideSmallerOne
+        param_reqs = {"rank": None, "dtype": None, "shape": None}
+        error_result = False
+        error_reason = "At least one stride dimension is smaller than zero"
+
+        if check:
+            stride = kwargs['stride']
+            if min(stride) < 1:
+                error_result = True
+
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs
+        }
+        return info_dict
+
+
+
 class TosaInvalidValidator:
 
     @staticmethod
@@ -2225,13 +2420,47 @@
         self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
         return result_tens
 
-    def build_pool2d(self, op, input, stride, pad, kernel, qinfo=None):
-        result_tens = OutputShaper.pool2dOp(self.ser, input, kernel, stride, pad)
+    def build_pool2d(self, op, input, stride, pad, kernel, validator_fcns=None, error_name=None, qinfo=None):
+        result_tens = OutputShaper.pool2dOp(self.ser, self.rng, input, kernel, stride, pad, error_name)
+
+        # Ensure new output type has correct qinfo
+        if error_name == ErrorIf.WrongInputType:
+            if input.dtype not in [DType.INT8, DType.UINT8]:
+                qinfo = ts.TosaSerializerQuantInfo()
+                qinfo.UnaryQuantInfo(
+                TosaQuantGen.getQinfo(self, input.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
+                )
+
+        # Invalidate Input/Output list for error if checks.
+        input_list = [input.name]
+        output_list = [result_tens.name]
+        pCount, cCount = op["operands"]
+        num_operands = pCount + cCount
+        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+
+        TosaErrorValidator.evValidateErrorIfs(
+            self.ser,
+            validator_fcns,
+            error_name,
+            op=op,
+            input_shape=input.shape,
+            input_dtype=input.dtype,
+            output_shape=result_tens.shape,
+            output_dtype=result_tens.dtype,
+            kernel=kernel,
+            stride=stride,
+            pad=pad,
+            qinfo = qinfo,
+            result_tensor = result_tens,
+            input_list=input_list,
+            output_list=output_list,
+            num_operands=num_operands,
+        )
 
         attr = ts.TosaSerializerAttribute()
         attr.PoolAttribute(kernel, stride, pad)
 
-        self.ser.addOperator(op['op'], [input.name], [result_tens.name], attr, qinfo)
+        self.ser.addOperator(op['op'], input_list, output_list, attr, qinfo)
         return result_tens
 
     def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
@@ -2541,7 +2770,7 @@
             shift=shift,
             input_dtype=input_dtype,
             output_dtype=output_dtype,
-            input_shape=input,
+            input_shape=input.shape,
             output_shape=output_dims,
             offset=offset,
             offset_fp=offset_fp,
@@ -2796,6 +3025,7 @@
             cleanRankFilter = range(rmin, rmax + 1)
 
         dtypes = op["types"]
+
         if dtypeFilter is not None:
             cleanDtypeFilter = []
             # Create list of operator dtypes filtered by requested dtypes
@@ -3342,7 +3572,11 @@
             "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
             "qgen": TosaQuantGen.qgUnary,
             "types": TYPE_NARROW_INT_FP,
-            "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,)
+            "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
+            "error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero,
+            TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
+            TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero,
+            TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch)
         },
         # Templated operator.  Filled in by createDynamicOpLists
         "conv2d_TEMPLATE": {
@@ -3403,7 +3637,10 @@
             "rank": (4, 4),
             "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
             "types": TYPE_NARROW_INT_FP,
-            "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,)
+            "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
+            "error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero,
+            TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
+            TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch)
         },
         # Templated operator.  Filled in by createDynamicOpLists
         "transpose_conv2d_TEMPLATE": {
@@ -4119,13 +4356,31 @@
         return ser.addOutput(ofm_shape, out_dtype)
 
     @staticmethod
-    def pool2dOp(ser, ifm, kernel, stride, pad):
+    def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
         # input: NHWC
-        h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
-        w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
+        if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
+            # If an incorrect stride is used set dimensions to 0, test is invalid anyway.
+            h = 1
+            w = 1
+        else:
+            h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
+            w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
+
+        if error_name == ErrorIf.PoolingOutputShapeMismatch:
+            choices = [1, 2, 3, 4, 5]
+            h = h + rng.choice(choices)
+            w = w + rng.choice(choices)
 
         ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
-        return ser.addOutput(ofm_shape, ifm.dtype)
+
+        if error_name == ErrorIf.WrongOutputType:
+            all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
+            wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
+            outputDType = rng.choice(wrong_dtypes)
+        else:
+            outputDType = ifm.dtype
+
+        return ser.addOutput(ofm_shape, outputDType)
 
     @staticmethod
     def fullyConnectedOp(ser, input, filter):