Add 0-rank tensor support for concat in framework test

Signed-off-by: Won Jeon <won.jeon@arm.com>
Change-Id: Iff77091e4a57f487431ffbf7ac1c89301a153c8b
diff --git a/verif/frameworks/arg_gen.py b/verif/frameworks/arg_gen.py
index a25c205..c385274 100644
--- a/verif/frameworks/arg_gen.py
+++ b/verif/frameworks/arg_gen.py
@@ -45,6 +45,10 @@
     @staticmethod
     def agAxes(op, shapes, rng):
         axes = []
+        if shapes == ():
+            axes.append(["_axis_0", [0]])
+            return axes
+
         for i in range(-len(shapes), len(shapes), 1):
             if i >= 0:
                 axes.append(["_axis_{}".format(i), [i]])
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py
index d0c0a0b..4370215 100644
--- a/verif/frameworks/tensor_gen.py
+++ b/verif/frameworks/tensor_gen.py
@@ -41,8 +41,12 @@
             RAND_SHIFT_FACTOR = 0.5
 
         if dtype == tf.float32:
-            return np.float32(
-                (rng.random(size=shape) - RAND_SHIFT_FACTOR) * RAND_SCALE_FACTOR
+            return (
+                np.float32(
+                    (rng.random(size=shape) - RAND_SHIFT_FACTOR) * RAND_SCALE_FACTOR
+                )
+                if shape != ()
+                else np.float32(rng.random())
             )
         if dtype == tf.float16:
             return np.float16(
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py
index 3554e40..7b20cef 100644
--- a/verif/frameworks/test_builder.py
+++ b/verif/frameworks/test_builder.py
@@ -136,7 +136,11 @@
             self.result_name = name
 
         def eval(self, a, b):
-            return tf.concat([a, b], self.axis, name=self.result_name)
+            return (
+                tf.concat([a, b], self.axis, name=self.result_name)
+                if a.shape != ()
+                else tf.stack([a, b], name=self.result_name)
+            )
 
     class BitwiseAnd:
         def __init__(self, name):
@@ -767,7 +771,11 @@
             self.result_name = name
 
         def eval(self, a, b, c, d):
-            return tf.concat([a, b, c, d], axis=self.axis, name=self.result_name)
+            return (
+                tf.concat([a, b, c, d], axis=self.axis, name=self.result_name)
+                if a.shape != ()
+                else tf.stack([a, b, c, d], name=self.result_name)
+            )
 
     class Stack:
         def __init__(self, axis, name):
diff --git a/verif/frameworks/test_gen_utils.py b/verif/frameworks/test_gen_utils.py
index 6a59848..f31ac63 100644
--- a/verif/frameworks/test_gen_utils.py
+++ b/verif/frameworks/test_gen_utils.py
@@ -9,6 +9,9 @@
 # Get a string name for a given shape
 def get_shape_str(shape, dtype):
     shape_name = None
+    if len(shape) == 0:
+        shape_name = "0"
+
     for dim in shape:
         shape_name = (shape_name + "x" + str(dim)) if shape_name else str(dim)
 
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index ffe373b..9d666ab 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -174,6 +174,11 @@
         "operands": (2, 0),
         "build_fcn": (TBuilder.Concat, TGen.tgBasic, ArgGen.agAxes),
         "types": TYPE_FI,
+        "rank": (0, 4),
+        "custom_shapes": {
+            "custom_shape_only": False,
+            "shape_list": [()],
+        },
     },
     "bitwise_and": {
         "operands": (2, 0),
@@ -635,6 +640,11 @@
         "operands": (4, 0),
         "build_fcn": (TBuilder.Concatv2, TGen.tgBasic, ArgGen.agAxes),
         "types": TYPE_FI,
+        "rank": (0, 4),
+        "custom_shapes": {
+            "custom_shape_only": False,
+            "shape_list": [()],
+        },
     },
     "stack": {
         "operands": (4, 0),
@@ -1473,7 +1483,7 @@
             shape_list = custom_shapes["shape_list"]
         else:
             shape_list = shape_list.copy()
-            shape_list.append(custom_shapes["shape_list"])
+            shape_list.extend(custom_shapes["shape_list"])
     except KeyError:
         pass