Add Broadcast DimensionMismatch errors

Add RankMismatch and DimensionMismatch support for SELECT
Update RankMismatch ops to also support DimensionMismatch
Update POW op to have proper broadcast testing
A few other broadcastable ops missing Rank/Dimension testing

Change-Id: I6566f45a7a0db4f9f008456ea7a8e23d4192f4f9
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
diff --git a/verif/tosa_error_if.py b/verif/tosa_error_if.py
index c3a9068..eb67ea8 100644
--- a/verif/tosa_error_if.py
+++ b/verif/tosa_error_if.py
@@ -30,6 +30,7 @@
     BatchMismatch = "BatchMismatch"
     ChannelMismatch = "ChannelMismatch"
     RankMismatch = "RankMismatch"
+    DimensionMismatch = "DimensionMismatch"
     InputZeroPointNotZero = "InputZeroPointNotZero"
     WeightZeroPointNotZero = "WeightZeroPointNotZero"
     OutputZeroPointNotZero = "OutputZeroPointNotZero"
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 80ccff3..db44328 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -270,21 +270,26 @@
         shape_list = []
 
         # Choose one of the inputs to broadcast
-        bcast_idx = testGen.randInt(0, pl + const)
+        # Note: Simplifies OutputShaper code if we don't change first shape for errors
+        bcast_idx = testGen.randInt(0 if error_name == None else 1, pl + const)
         for i in range(pl + const):
             shape_bcast = shape.copy()
 
-            if error_name == ErrorIf.RankMismatch:
-                bcast_idx = -1 # Turn off broadcast because we are not testing it
-                if rank == 1 and i != 1:
-                    shape_bcast = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
-                elif i != 1:
-                    shape_bcast = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
-
             # If the chosen input, pick a random index to broadcast
             if i == bcast_idx:
                 fuzz_idx = testGen.randInt(0, rank)
-                shape_bcast[fuzz_idx] = 1
+                if error_name == ErrorIf.DimensionMismatch:
+                    shape_bcast[fuzz_idx] += 1
+                elif error_name == ErrorIf.RankMismatch:
+                    # Add one rank to the shape (or more for rank of 1)
+                    extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
+                    shape_bcast = np.concatenate((shape_bcast, testGen.makeShape(extra_ranks)))
+                    if rank != 1:
+                        # Either keep the extra rank, or remove it
+                        new_len = testGen.rng.choice([-2, len(shape_bcast)])
+                        shape_bcast = shape_bcast[:new_len]
+                else:
+                    shape_bcast[fuzz_idx] = 1
 
             shape_list.append(shape_bcast)
 
@@ -2001,8 +2006,14 @@
         if check:
             input1_shape = kwargs['input1'].shape
             input2_shape = kwargs['input2'].shape
+            # In case of SELECT op
+            input3_shape = kwargs['input3'].shape if 'input3' in kwargs else input2_shape
             output_shape = kwargs['result_tensor'].shape
-            if (len(input1_shape) != len(output_shape)) or (len(input2_shape) != len(output_shape)):
+            if (
+                (len(input1_shape) != len(output_shape)) or
+                (len(input2_shape) != len(output_shape)) or
+                (len(input3_shape) != len(output_shape))
+                ):
                 error_result = True
 
         info_dict = {
@@ -2014,6 +2025,35 @@
         return info_dict
 
     @staticmethod
+    def evDimensionMismatch(check=False, **kwargs):
+        error_name = ErrorIf.DimensionMismatch
+        param_reqs = {"rank": None, "dtype": None, "shape": None}
+        error_result = False
+        error_reason = "Input Dimensions do not match output"
+
+        if check:
+            input1_shape = kwargs['input1'].shape
+            input2_shape = kwargs['input2'].shape
+            # In case of SELECT op
+            input3_shape = kwargs['input3'].shape if 'input3' in kwargs else input2_shape
+            output_shape = kwargs['result_tensor'].shape
+            for i in range(min(len(input1_shape), len(input2_shape), len(input3_shape))):
+                if (
+                    (input1_shape[i] != 1 and input1_shape[i] != output_shape[i]) or
+                    (input2_shape[i] != 1 and input2_shape[i] != output_shape[i]) or
+                    (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
+                    ):
+                    error_result = True
+
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs
+        }
+        return info_dict
+
+    @staticmethod
     def evInputZeroPointNotZero(check=False, **kwargs):
         op = kwargs['op']
         inputDtypes = op['types'].copy()
@@ -3492,6 +3532,9 @@
             validator_fcns,
             error_name,
             op=op,
+            input1 = cond,
+            input2 = a,
+            input3 = b,
             input_shape = a.shape,
             input_dtype = a.dtype,
             output_dtype = result_tens.dtype,
@@ -3519,6 +3562,8 @@
             validator_fcns,
             error_name,
             op=op,
+            input1 = a,
+            input2 = b,
             input_shape = a.shape,
             input_dtype = a.dtype,
             output_shape = result_tens.shape,
@@ -5019,7 +5064,7 @@
             )
 
             tens.extend(placeholders)
-        elif op["op"] == Op.MUL:
+        elif op["op"] == Op.MUL and error_name == None:
             assert (
                 pCount == 2 and cCount == 0
             ), "Op.MUL must have 2 placeholders, 0 consts"
@@ -5363,7 +5408,7 @@
             "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
             "types": TYPE_FI32,
             "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
-            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
         },
         "arithmetic_right_shift": {
             "op": Op.ARITHMETIC_RIGHT_SHIFT,
@@ -5374,8 +5419,8 @@
                 TosaArgGen.agArithmeticRightShift,
             ),
             "types": TYPE_INT,
-            "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
-            TosaErrorValidator.evWrongOutputList)
+            "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
         },
         "bitwise_and": {
             "op": Op.BITWISE_AND,
@@ -5383,7 +5428,7 @@
             "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
             "types": TYPE_INT,
             "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
-            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
         },
         "bitwise_or": {
             "op": Op.BITWISE_OR,
@@ -5391,7 +5436,7 @@
             "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
             "types": TYPE_INT,
             "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
-            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
         },
         "bitwise_xor": {
             "op": Op.BITWISE_XOR,
@@ -5399,7 +5444,7 @@
             "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
             "types": TYPE_INT,
             "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
-            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
         },
         "intdiv": {
             "op": Op.INTDIV,
@@ -5407,7 +5452,7 @@
             "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
             "types": [DType.INT32],
             "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
-            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
         },
         "logical_and": {
             "op": Op.LOGICAL_AND,
@@ -5415,7 +5460,7 @@
             "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
             "types": TYPE_BOOL,
             "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
-            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
         },
         "logical_left_shift": {
             "op": Op.LOGICAL_LEFT_SHIFT,
@@ -5423,7 +5468,7 @@
             "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
             "types": TYPE_INT,
             "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
-            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
         },
         "logical_right_shift": {
             "op": Op.LOGICAL_RIGHT_SHIFT,
@@ -5431,7 +5476,7 @@
             "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
             "types": TYPE_INT,
             "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
-            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
         },
         "logical_or": {
             "op": Op.LOGICAL_OR,
@@ -5439,7 +5484,7 @@
             "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
             "types": TYPE_BOOL,
             "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
-            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
         },
         "logical_xor": {
             "op": Op.LOGICAL_XOR,
@@ -5447,7 +5492,7 @@
             "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
             "types": TYPE_BOOL,
             "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
-            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
         },
         "maximum": {
             "op": Op.MAXIMUM,
@@ -5455,7 +5500,7 @@
             "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
             "types": TYPE_FI32,
             "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
-            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
         },
         "minimum": {
             "op": Op.MINIMUM,
@@ -5463,7 +5508,7 @@
             "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
             "types": TYPE_FI32,
             "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
-            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
         },
         "mul": {
             "op": Op.MUL,
@@ -5471,15 +5516,15 @@
             "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
             "types": TYPE_INT_FP,
             "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
-            TosaErrorValidator.evWrongOutputList)
+            TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evRankMismatch, TosaErrorValidator.evDimensionMismatch)
         },
         "pow": {
             "op": Op.POW,
             "operands": (2, 0),
-            "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBasic, None),
+            "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
             "types": TYPE_FP,
             "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
-            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
         },
         "sub": {
             "op": Op.SUB,
@@ -5487,7 +5532,7 @@
             "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
             "types": TYPE_FI32,
             "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
-            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
         },
         "table": {
             "op": Op.TABLE,
@@ -5597,8 +5642,8 @@
             "operands": (3, 0),
             "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
             "types": TYPE_FIB,
-            "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
-            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+            "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
         },
         # Comparison operators
         "equal": {
@@ -5606,24 +5651,24 @@
             "operands": (2, 0),
             "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
             "types": TYPE_FI32,
-            "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
-            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+            "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
         },
         "greater_equal": {
             "op": Op.GREATER_EQUAL,
             "operands": (2, 0),
             "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
             "types": TYPE_FI32,
-            "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
-            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+            "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
         },
         "greater": {
             "op": Op.GREATER,
             "operands": (2, 0),
             "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
             "types": TYPE_FI32,
-            "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
-            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+            "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
         },
         # Reduction operators
         "reduce_all": {
@@ -5916,12 +5961,16 @@
 
     @staticmethod
     def selectOp(ser, rng, cond, a, b, error_name=None):
-        assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
+        if error_name != ErrorIf.RankMismatch:
+            assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
         assert a.dtype == b.dtype
 
         shape = []
-        for i in range(len(a.shape)):
-            shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
+        for i in range(len(cond.shape)):
+            if cond.shape[i] == 1 and error_name == None:
+                shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
+            else:
+                shape.append(cond.shape[i])
 
         if error_name == ErrorIf.WrongOutputType:
             all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
@@ -5934,7 +5983,8 @@
 
     @staticmethod
     def binaryComparisonOp(ser, rng, a, b , error_name=None):
-        assert len(a.shape) == len(b.shape)
+        if error_name != ErrorIf.RankMismatch:
+            assert len(a.shape) == len(b.shape)
         assert a.dtype == b.dtype
 
         # Do broadcast