Add ERROR_IF checks for mismatched batch/channel

Change-Id: I7c670c5f9b97a18a6f586b16f31bc9fc301f6bc3
Signed-off-by: Matthew Haddon <matthew.haddon@arm.com>
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 2c13172..3cd1d69 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -1347,6 +1347,59 @@
         return info_dict
 
     @staticmethod
+    def evBatchMismatch(check=False, **kwargs):
+        error_name = ErrorIf.BatchMismatch
+        param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
+        error_result = False
+        error_reason = "Input batch size not equal to output batch size"
+
+        assert 'op' in kwargs
+        op = kwargs['op']
+        rmin, rmax = op['rank']
+        rank_range = range(rmin, rmax + 1)
+
+        if check:
+            input_shape = kwargs['input_shape'].shape
+            output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
+
+            if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
+                error_result = True
+
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs
+        }
+        return info_dict
+
+    @staticmethod
+    def evChannelMismatch(check=False, **kwargs):
+        error_name = ErrorIf.ChannelMismatch
+        param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
+        error_result = False
+        error_reason = "Input channel size not equal to output channel size"
+
+        assert 'op' in kwargs
+        op = kwargs['op']
+        rmin, rmax = op['rank']
+        rank_range = range(rmin, rmax + 1)
+
+        if check:
+            input_shape = kwargs['input_shape'].shape
+            output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
+            if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
+                error_result = True
+
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs
+        }
+        return info_dict
+
+    @staticmethod
     def evStrideSmallerEqualZero(check=False, **kwargs):
         error_name = ErrorIf.StrideSmallerEqualZero
         param_reqs = {"rank": None, "dtype": None, "shape": None}
@@ -2195,6 +2248,7 @@
     ):
         result_tens = OutputShaper.resizeOp(
             self.ser,
+            self.rng,
             input,
             mode,
             stride,
@@ -2232,6 +2286,7 @@
             stride_fp=stride_fp,
             input_list=input_list,
             output_list=output_list,
+            result_tensor=result_tens,
             num_operands=num_operands,
         )
 
@@ -3457,7 +3512,8 @@
             "error_if_validators": (TosaErrorValidator.evMaxDimExceeded, TosaErrorValidator.evStrideSmallerEqualZero, TosaErrorValidator.evStrideLargerDimension,
             TosaErrorValidator.evStrideLargerEqualMax, TosaErrorValidator.evOffsetSmallerEqualMin, TosaErrorValidator.evOffsetLargerEqualMax,
             TosaErrorValidator.evShiftNotZero, TosaErrorValidator.evShiftSmallerOne, TosaErrorValidator.evShiftLargerEleven, TosaErrorValidator.evWrongInputType,
-            TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+            TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList,
+            TosaErrorValidator.evBatchMismatch, TosaErrorValidator.evChannelMismatch)
         },
         # Type conversion
         "cast": {
@@ -3862,7 +3918,8 @@
 
     @staticmethod
     def resizeOp(
-        ser,
+        serializer,
+        rng,
         input,
         mode,
         stride,
@@ -3878,9 +3935,14 @@
         if error_name == ErrorIf.WrongRank:
             output_dims = [input.shape[0], output_dims[0], output_dims[0], input.shape[0]]
         else:
-            output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
+            if error_name == ErrorIf.BatchMismatch:
+                output_dims = [input.shape[0] + rng.integers(1, 10), output_dims[0], output_dims[1], input.shape[3]]
+            elif error_name == ErrorIf.ChannelMismatch:
+                output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3] + rng.integers(1, 10)]
+            else:
+                output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
 
-        return ser.addOutput(output_dims, output_dtype)
+        return serializer.addOutput(output_dims, output_dtype)
 
     @staticmethod
     def typeConversionOp(ser, val, out_dtype):