Fix Transpose WrongRank test and add new test for Concat

 * Transpose WrongRank tests now use ranks 7, 8
 * Concat ERROR_IF checks now test for inaccurate summation
of output shape tensor dimension

Change-Id: If32f43a4dbd872d0ef7625fa3d4969c863a11b8c
Signed-off-by: Matthew Haddon <matthew.haddon@arm.com>
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Signed-off-by: Les Bell <les.bell@arm.com>
diff --git a/verif/tosa_error_if.py b/verif/tosa_error_if.py
index f0e752f..c3a9068 100644
--- a/verif/tosa_error_if.py
+++ b/verif/tosa_error_if.py
@@ -56,6 +56,7 @@
     MaxSmallerMin = "MaxSmallerMin"
     ConcatInputRankMismatch = "ConcatInputRankMismatch"
     ConcatInputDimMismatch = "ConcatInputDimMismatch"
+    ConcatShapeSumMismatch = "ConcatShapeSumMismatch"
     CondIfInputListThenGraphMismatch = "CondIfInputListThenGraphMismatch"
     CondIfInputListElseGraphMismatch = "CondIfInputListElseGraphMismatch"
     CondIfOutputListThenGraphMismatch = "CondIfOutputListThenGraphMismatch"
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 4e944ea..80ccff3 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -1629,6 +1629,8 @@
         # Set minimum incorrect rank to 3 to avoid index error
         if op['op'] in [Op.RESIZE]:
             incorrect_ranks = [3, 5]
+        if op['op'] in [Op.TRANSPOSE]:
+            incorrect_ranks = [7, 8]
 
         error_name = ErrorIf.WrongRank
         param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
@@ -2714,6 +2716,44 @@
         return info_dict
 
     @staticmethod
+    def evConcatShapeSumMismatch(check=False, **kwargs):
+        error_name = ErrorIf.ConcatShapeSumMismatch
+        param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
+        error_result = False
+        error_reason = "Sum of dimensions on axis not equal to output dimension"
+
+        if check:
+            inputs = kwargs['inputs']
+            input_shape = kwargs['input_shape']
+            output_shape = kwargs['output_shape']
+            axis = kwargs['axis']
+
+            # Ensure rank is valid before checking dims.
+            valid_params = True
+            for input in inputs:
+                if len(input.shape) != len(input_shape):
+                    valid_params = False
+            if axis < 0 or axis > len(input_shape):
+                valid_params = False
+
+            if valid_params:
+                axis_dim_sum = 0
+                for input in inputs:
+                    axis_dim_sum += input.shape[axis]
+
+                if axis_dim_sum != output_shape[axis]:
+                    error_result = True
+
+
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs
+        }
+        return info_dict
+
+    @staticmethod
     def evInputListThenGraphMismatch(check=False, **kwargs):
         error_name = ErrorIf.CondIfInputListThenGraphMismatch
         param_reqs = {"rank": None, "dtype": None, "shape": None}
@@ -5647,7 +5687,8 @@
             "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
             "types": TYPE_FIB,
             "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evConcatInputRankMismatch,
-            TosaErrorValidator.evConcatInputDimMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongOutputList)
+            TosaErrorValidator.evConcatShapeSumMismatch, TosaErrorValidator.evConcatInputDimMismatch, TosaErrorValidator.evWrongInputType,
+            TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongOutputList)
         },
         "pad": {
             "op": Op.PAD,
@@ -6173,12 +6214,13 @@
             error_name == ErrorIf.ConcatInputRankMismatch
             # unable to concat tensors along an invalid axis
             or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
-            # unable to concat tensors of different dimensions
-            or error_name == ErrorIf.ConcatInputDimMismatch
         ):
             for tensor in remaining_inputs:
                 output_shape[axis] += tensor.shape[axis]
 
+        if error_name == ErrorIf.ConcatShapeSumMismatch:
+            output_shape[axis] += rng.integers(5, 10)
+
         if error_name == ErrorIf.WrongOutputType:
             all_dtypes = {DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT}
             wrong_dtypes = list(all_dtypes - set([input1.dtype]))