Remove invalid tests from test generator

 * Implemented InvalidValidator to remove existing invalid tests.
 * Removed invalid tests for resize, rescale, conv2d, depthwise_conv2d,
transpose_conv2d, avg_pool2d, and max_pool2d (note default avg/max_pool
never produced negative tests, but theoretically could).
 * Changed behaviour of computerMultiplierAndShift to produce the allowed
range of shift values.

Signed-off-by: Matthew Haddon <matthew.haddon@arm.com>
Change-Id: I5e7b11030deb5322e2ca08fd4f4467fb02b7740d
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 777c059..760ed06 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -123,10 +123,21 @@
             shift = shift + 1
 
         shift = (-shift) + scaleBits
-        # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
+        #print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
+
+        # Adjust multiplier such that shift is in allowed value range.
+        if shift == 0:
+            multiplier = multiplier // 4
+            shift = shift + 2
+        elif shift == 1:
+            multiplier = multiplier // 2
+            shift = shift + 1
+        elif shift == 63:
+            multiplier = multiplier * 2
+            shift = shift - 1
 
         assert multiplier <= (1 << scaleBits)
-        assert shift >= 0 and shift <= 63
+        assert shift >= 2 and shift <= 62
 
         return multiplier, shift
 
@@ -566,7 +577,7 @@
                             "st{}{}_kern{}{}_pad{}{}{}{}".format(
                                 s[0], s[1], k[0], k[1], p[0], p[1], p[2], p[3]
                             ),
-                            [k, s, p],
+                            [s, p, k],
                         )
                     )
         return arg_list
@@ -946,6 +957,126 @@
 
         return arg_list
 
+class TosaInvalidValidator:
+
+    @staticmethod
+    def ivWrongDataTypeOrModeResize(**kwargs):
+        input_dtype = kwargs["input_dtype"]
+        args = kwargs["args"]
+        mode = args[0]
+        stride = args[1]
+        stride_fp = args[4]
+        output_dtype = args[8]
+
+        if mode == ResizeMode.BILINEAR:
+            # Invalid output data type / Invalid input datatype
+            return (
+                not (input_dtype == DType.INT8 and output_dtype == DType.INT32) or
+                not (input_dtype == DType.INT16 and output_dtype == DType.INT48) or
+                not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT) or
+                (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
+            )
+        elif mode == ResizeMode.NEAREST:
+            # Invalid output data type / Invalid input datatype
+            return (
+                (input_dtype != output_dtype) or
+                (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
+            )
+        else:
+            # Invalid resize mode
+            return True
+
+    @staticmethod
+    def ivBadStride(**kwargs):
+        input_dtype = kwargs["input_dtype"]
+        args = kwargs["args"]
+        stride_x = args[1][0]
+        stride_y = args[1][1]
+        stride_fp_x = args[4][0]
+        stride_fp_y = args[4][1]
+
+        if input_dtype == DType.FLOAT:
+            if stride_fp_x <= 0 or stride_fp_y <= 0:
+                # Negative or zero stride
+                return True
+        else:
+            if stride_x <= 0 or stride_y <= 0:
+                # Negative or zero stride
+                return True
+        return False
+
+
+
+
+    @staticmethod
+    def ivHeightWidthSmallerZero(**kwargs):
+        opName = kwargs['opName']
+
+        inputShapes = kwargs['shapeList']
+        input = inputShapes[0]
+        if not opName.endswith("pool2d"):
+            filter = inputShapes[1]
+
+        args = kwargs['args']
+        strides = args[0]
+        padding = args[1]
+        dilations = args[2]
+        if opName.endswith("pool2d"):
+            kernel = args[2]
+
+        if opName.startswith('conv2d'):
+            h = (
+                input[1]
+                - filter[1]
+                - (filter[1] - 1) * (dilations[0] - 1)
+                + padding[0]
+                + padding[1]
+            ) // strides[0] + 1
+
+            w = (
+                input[2]
+                - filter[2]
+                - (filter[2] - 1) * (dilations[1] - 1)
+                + padding[2]
+                + padding[3]
+            ) // strides[1] + 1
+        elif opName.startswith("depthwise_conv2d"):
+            h = (
+                input[1]
+                - filter[0]
+                - (filter[0] - 1) * (dilations[0] - 1)
+                + padding[0]
+                + padding[1]
+            ) // strides[0] + 1
+
+            w = (
+                input[2]
+                - filter[1]
+                - (filter[1] - 1) * (dilations[1] - 1)
+                + padding[2]
+                + padding[3]
+            ) // strides[1] + 1
+        elif opName.endswith("pool2d"):
+            h = (input[1] + padding[0] + padding[1] + strides[0] - kernel[0]) // strides[0]
+            w = (input[2] + padding[2] + padding[3] + strides[1] - kernel[1]) // strides[1]
+        else:
+            assert False, "Unrecognized Op"
+
+        if h <= 0 or w <= 0:
+            # Invalid parameter combination
+            return True
+        return False
+
+    @staticmethod
+    def ivNonPositiveOutputShape(**kwargs):
+        args = kwargs['args']
+        output_shape = args[3]
+        if output_shape[1] <= 0 or output_shape[2] <= 0:
+            # Negative output shape
+            return True
+        return False
+
+
 
 class TosaTestGen:
     # Maximum rank of tensor supported by test generator.
@@ -1204,7 +1335,7 @@
         self.ser.addOperator(op, [a.name], [result_tens.name], attr)
         return result_tens
 
-    def build_pool2d(self, op, input, kernel, stride, pad, qinfo=None):
+    def build_pool2d(self, op, input, stride, pad, kernel, qinfo=None):
         result_tens = OutputShaper.pool2dOp(self.ser, input, kernel, stride, pad)
 
         attr = ts.TosaSerializerAttribute()
@@ -1538,7 +1669,7 @@
 
         if scale32:
             pass
-            # Cap the scaling at 2^15 - 1 for scale16
+            # Cap the scaling at 2^31 - 1 for scale32
             scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
         else:
             # Cap the scaling at 2^15 - 1 for scale16
@@ -1553,10 +1684,6 @@
             multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
                 scale_arr[i], scale32
             )
-            if shift_arr[i] < 2 or shift_arr[i] > 62:
-                self.ser.setExpectedReturnCode(
-                    TosaReturnCode.UNPREDICTABLE, "OpRescale: invalid shift value"
-                )
 
         # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
 
@@ -1780,6 +1907,19 @@
 
                             testList.append((opName, testStr, t, shapeList, args))
 
+        # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
+        if "invalid_test_validators" in op:
+            invalid_test_validators = op["invalid_test_validators"]
+            clean_testList = []
+            for test in testList:
+                for validator_fcn in invalid_test_validators:
+                    remove_test = False
+                    if validator_fcn(opName=test[0], input_dtype=test[2], shapeList=test[3], args=test[4]):
+                        remove_test = True
+                if not remove_test:
+                    clean_testList.append(test)
+            testList = clean_testList
+
         # Reset RNG so both positive and negative tests are reproducible
         self.resetRNG()
         # Negative test loop
@@ -2112,6 +2252,7 @@
             "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
             "qgen": TosaQuantGen.qgUnary,
             "types": TYPE_NARROW_INT_FP,
+            "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,)
         },
         # Templated operator.  Filled in by createDynamicOpLists
         "conv2d_TEMPLATE": {
@@ -2121,6 +2262,7 @@
             "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv2D),
             "qgen": TosaQuantGen.qgConv,
             "types": TYPE_CONV2D,
+            "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
             "template": True,
         },
         # Conv3d TBD
@@ -2137,6 +2279,7 @@
             ),
             "qgen": TosaQuantGen.qgConv,
             "types": TYPE_CONV2D,
+            "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
             "template": True,
         },
         "fully_connected": {
@@ -2161,6 +2304,7 @@
             "rank": (4, 4),
             "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
             "types": TYPE_NARROW_INT_FP,
+            "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,)
         },
         # Templated operator.  Filled in by createDynamicOpLists
         "transpose_conv2d_TEMPLATE": {
@@ -2174,6 +2318,7 @@
             ),
             "qgen": TosaQuantGen.qgConv,
             "types": TYPE_CONV2D,
+            "invalid_test_validators": (TosaInvalidValidator.ivNonPositiveOutputShape,),
             "template": True,
         },
         # Activation functions
@@ -2529,6 +2674,7 @@
             "rank": (4, 4),
             "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
             "types": [DType.INT8, DType.INT16, DType.FLOAT],
+            "invalid_test_validators": (TosaInvalidValidator.ivWrongDataTypeOrModeResize, TosaInvalidValidator.ivBadStride)
         },
         # Type conversion
         "cast": {
@@ -2691,14 +2837,6 @@
             + padding[3]
         ) // strides[1] + 1
 
-        if h <= 0 or w <= 0:
-            # Invalid test parameters?
-            h = 0
-            w = 0
-            ser.setExpectedReturnCode(
-                TosaReturnCode.UNPREDICTABLE, "Invalid combination of conv2d parameters"
-            )
-
         ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
 
         if ifm.dtype == DType.INT8:
@@ -2733,14 +2871,6 @@
             + padding[3]
         ) // strides[1] + 1
 
-        if h <= 0 or w <= 0:
-            # Invalid test parameters?
-            h = 0
-            w = 0
-            ser.setExpectedReturnCode(
-                TosaReturnCode.UNPREDICTABLE, "Invalid combination of conv2d parameters"
-            )
-
         ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
 
         if ifm.dtype == DType.INT8:
@@ -2760,14 +2890,6 @@
         h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
         w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
 
-        if h <= 0 or w <= 0:
-            # Invalid test parameters?
-            h = 0
-            w = 0
-            ser.setExpectedReturnCode(
-                TosaReturnCode.UNPREDICTABLE, "Invalid combination of pool2d parameters"
-            )
-
         ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
         return ser.addOutput(ofm_shape, ifm.dtype)
 
@@ -2928,62 +3050,6 @@
 
         output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
 
-        if input_dtype == DType.FLOAT:
-            if stride_fp[0] <= 0 or stride_fp[1] <= 0:
-                ser.setExpectedReturnCode(
-                    TosaReturnCode.ERROR, "Negative or zero stride"
-                )
-        else:
-            if stride[0] <= 0 or stride[1] <= 0:
-                ser.setExpectedReturnCode(
-                    TosaReturnCode.ERROR, "Negative or zero stride"
-                )
-
-        if mode == ResizeMode.BILINEAR:
-            if input_dtype == DType.INT8:
-                if output_dtype != DType.INT32:
-                    ser.setExpectedReturnCode(
-                        TosaReturnCode.ERROR, "Invalid output data type"
-                    )
-            elif input_dtype == DType.INT16:
-                if output_dtype != DType.INT48:
-                    ser.setExpectedReturnCode(
-                        TosaReturnCode.ERROR, "Invalid output data type"
-                    )
-            elif input_dtype == DType.FLOAT:
-                if output_dtype != DType.FLOAT:
-                    ser.setExpectedReturnCode(
-                        TosaReturnCode.ERROR, "Invalid output data type"
-                    )
-            else:
-                ser.setExpectedReturnCode(
-                    TosaReturnCode.ERROR, "Invalid input data type"
-                )
-
-        elif mode == ResizeMode.NEAREST:
-            if input_dtype == DType.INT8:
-                if output_dtype != DType.INT8:
-                    ser.setExpectedReturnCode(
-                        TosaReturnCode.ERROR, "Invalid output data type"
-                    )
-            elif input_dtype == DType.INT16:
-                if output_dtype != DType.INT16:
-                    ser.setExpectedReturnCode(
-                        TosaReturnCode.ERROR, "Invalid output data type"
-                    )
-            elif input_dtype == DType.FLOAT:
-                if output_dtype != DType.FLOAT:
-                    ser.setExpectedReturnCode(
-                        TosaReturnCode.ERROR, "Invalid output data type"
-                    )
-            else:
-                ser.setExpectedReturnCode(
-                    TosaReturnCode.ERROR, "Invalid input data type"
-                )
-
-        else:
-            ser.setExpectedReturnCode(TosaReturnCode.ERROR, "Invalid resize mode")
-
         return ser.addOutput(output_dims, output_dtype)
 
     @staticmethod
@@ -3001,9 +3067,4 @@
         else:
             raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
 
-        if output_shape[1] <= 0 or output_shape[2] <= 0:
-            ser.setExpectedReturnCode(
-                TosaReturnCode.UNPREDICTABLE, "Negative output shape"
-            )
-
         return ser.addOutput(output_shape, out_dtype)