Produce Concat tests with multiple input tensors

 * Concat tests now contain between 2 and 5 input tensors concatenated
together
 * Both input and const tensors are used as inputs to the operator
 * Option to add in const tensor inputs (this is slow), defaults to
original behaviour

Signed-off-by: Matthew Haddon <matthew.haddon@arm.com>
Change-Id: I2a0cc622d31aceab8d24521668d0aae040ba73b1
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 5c25f8e..5138e3f 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -330,6 +330,46 @@
 
         return [a_shape, b_shape]
 
+    @staticmethod
+    def tgConcat(testGen, opName, rank):
+        pl, const = opName["operands"]
+        shape = testGen.makeShape(rank)
+
+        # Create extra tensors to concat.
+        # Take into account value of pl when getting maximum number of concats
+        num_tensors = testGen.randInt(0, 4)
+        shape_list = []
+        for i in range(pl + const + num_tensors):
+            shape_list.append(shape.copy())
+
+        return shape_list
+
+    @staticmethod
+    def tgConcatConstInput(testGen, shapeList, axis):
+        # Split concat shape along axis to allow for multiple const inputs
+        # without making too many large tensors
+        shape = shapeList[0]
+        if len(shapeList) == 2 or shape[axis] < len(shapeList):
+            return shapeList
+
+        new_shapeList = [shape.copy()]
+        length_on_axis = shape[axis]
+        remaining_length = length_on_axis
+        for i in range(len(shapeList)-2):
+            # Calculate split on axis and remaining value
+            split_shape_val = int(shape[axis] / 2)
+            remaining_length = remaining_length - split_shape_val
+
+            # Append new shape, and set remaining shape
+            shape[axis] = split_shape_val
+            new_shapeList.append(shape.copy())
+            shape[axis] = remaining_length
+            if i == len(shapeList) - 3:
+                new_shapeList.append(shape.copy())
+
+        return new_shapeList
+
+
 
 class TosaArgGen:
     """Argument generators create exhaustive or random lists of attributes for operators that take
@@ -1263,13 +1303,23 @@
         self.ser.addOperator(op, [a.name], [result_tens.name])
         return result_tens
 
-    def build_concat(self, op, a, b, axis):
-        result_tens = OutputShaper.concatOp(self.ser, a, b, axis)
+    def build_concat(self, op, *a):
+        assert (type(a[-1]) == int)
+
+        # To store variable length list of input tensors we need to store axis along with it
+        axis = a[-1]
+        a = a[:-1]
+
+        result_tens = OutputShaper.concatOp(self.ser, axis, *a)
 
         attr = ts.TosaSerializerAttribute()
         attr.AxisAttribute(axis)
 
-        self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
+        input_tensor_names = []
+        for tensor in a:
+            input_tensor_names.append(tensor.name)
+
+        self.ser.addOperator(op, input_tensor_names, [result_tens.name], attr)
 
     def build_pad(self, op, a, padding, qinfo):
         result_tens = OutputShaper.padOp(self.ser, a, padding)
@@ -1708,19 +1758,22 @@
 
         if isinstance(dtype_or_dtypeList, list):
             dtypeList = dtype_or_dtypeList
+        elif op['op'] == Op.CONCAT:
+            dtypeList = [dtype_or_dtypeList] * len(shapeList)
         else:
             dtypeList = [dtype_or_dtypeList] * (num_operands)
 
-        assert (
-            len(shapeList) == num_operands
-        ), "shapeList length {} must match number of operands {}".format(
-            len(shapeList), num_operands
-        )
-        assert (
-            len(dtypeList) == num_operands
-        ), "dtypeList length {} must match number of operands {}".format(
-            len(dtypeList), num_operands
-        )
+        if op['op'] != Op.CONCAT:
+            assert (
+                len(shapeList) == num_operands
+            ), "shapeList length {} must match number of operands {}".format(
+                len(shapeList), num_operands
+            )
+            assert (
+                len(dtypeList) == num_operands
+            ), "dtypeList length {} must match number of operands {}".format(
+                len(dtypeList), num_operands
+            )
 
         try:
             qgen = op["qgen"]
@@ -1850,6 +1903,18 @@
                 )
 
                 tens.extend(placeholders)
+        elif op["op"] == Op.CONCAT:
+            count = len(shapeList) - self.args.num_const_inputs_concat
+            if count < 1:
+                count = 1
+            if self.args.num_const_inputs_concat == 0:
+                count = len(shapeList)
+
+            shapeList = TosaTensorGen.tgConcatConstInput(self, shapeList, testArgs[0])
+            tens.extend(
+                self.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
+            )
+            tens.extend(self.buildConstTensors(shapeList[count:], dtypeList[count:]))
         else:
             tens.extend(
                 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
@@ -2336,7 +2401,7 @@
         "concat": {
             "op": Op.CONCAT,
             "operands": (2, 0),
-            "build_fcn": (build_concat, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
+            "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
             "types": TYPE_FIB,
         },
         "pad": {
@@ -2694,12 +2759,18 @@
         return ser.addOutput(output_shape, out_dtype)
 
     @staticmethod
-    def concatOp(ser, a, b, axis):
+    def concatOp(ser, axis, *a):
+        input1 = a[0]
+        remaining_inputs = a[1:]
 
-        output_shape = a.shape.copy()
-        output_shape[axis] = a.shape[axis] + b.shape[axis]
+        output_shape = input1.shape.copy()
 
-        return ser.addOutput(output_shape, a.dtype)
+        output_shape[axis] = input1.shape[axis]
+
+        for tensor in remaining_inputs:
+            output_shape[axis] += tensor.shape[axis]
+
+        return ser.addOutput(output_shape, input1.dtype)
 
     @staticmethod
     def padOp(ser, a, padding):