SPECIAL data gen mode for FP16 and FP32

Signed-off-by: evacha01 <evan.chandler@arm.com>
Change-Id: I5a9a1c63345bd83ca04bc6c2a99b0ef3612971ee
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 8d6c8d7..5957a33 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -264,6 +264,9 @@
             return [[]] * num_shapes
 
         shape = testGen.makeShape(rng, rank)
+        # Do not broadcast for some tests
+        if error_name is None and rng.randInt(high=100) < 10:
+            return [shape] * num_shapes
         shape_list = []
 
         # Choose any one of the inputs to broadcast
@@ -785,6 +788,10 @@
             "tensors": {},
         }
         dg_tens_meta = tens_data["tensors"]
+
+        fp_special_info = {}
+        fp_special_info["start_idx"] = int(rng.randInt())
+
         for idx, shape in enumerate(shapeList):
 
             tens_meta = {}
@@ -858,6 +865,8 @@
                     rng.randInt(0, gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["fullset"])
                 )
                 tens_meta["full_range_info"] = info
+            elif dg_type == gtu.DataGenType.FP_SPECIAL:
+                tens_meta["fp_special_info"] = fp_special_info
             else:
                 # TODO - other data gen type
                 assert False, "TODO: support other data gen types"
@@ -1862,16 +1871,12 @@
         for dg_type in dataGenTypesList:
             for arg_str, args_dict in arg_list:
                 gen_args_dict = args_dict.copy()
+                # Only create one test by default - no sets of tests
+                num_test_sets = 0
+
                 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
                     if error_name is None:
-                        num_test_sets = (
-                            args_dict["num_test_sets"]
-                            if "num_test_sets" in args_dict
-                            else 0
-                        )
-                    else:
-                        # Add single test for pseudo random
-                        num_test_sets = 0
+                        num_test_sets = args_dict.get("num_test_sets", 0)
 
                 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
                     # Extra tests for each dot product test set
@@ -1900,13 +1905,23 @@
                             f"Skipping {opName}{shape_info} as tensor data size too small for full range of values {tensor_size} < {gtu.DTYPE_ATTRIBUTES[dtype]['fullset']}"
                         )
                         continue
-                    # Large enough tensor data size for full range, add a single test
-                    num_test_sets = 0
+                    # Large enough tensor data size for full range, add full test
                     arg_str = f"{arg_str}_full" if arg_str else "full"
                     gen_args_dict["tags"] = args_dict.get("tags", []) + [
                         "non_finite_fp_data"
                     ]
 
+                elif dg_type == gtu.DataGenType.FP_SPECIAL:
+                    shapes_set = {tuple(x) for x in shapeList}
+                    if len(shapes_set) != 1:
+                        logger.info(
+                            f"Changing {opName} input shapes {shapes_set} - broadcasting incompatable with special test"
+                        )
+                        shapeList = [np.int32(np.broadcast_shapes(*shapeList))] * len(
+                            shapeList
+                        )
+                    arg_str = f"{arg_str}_fs" if arg_str else "fs"
+
                 gen_args_dict["dg_type"] = dg_type
                 if num_test_sets > 0:
                     for s in range(0, num_test_sets):
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 38ab3f4..40788a2 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -268,7 +268,7 @@
                 if "ksb" in argsDict
                 else int(argsDict["ks"]),
             }
-        elif argsDict["dg_type"] == gtu.DataGenType.SPECIAL:
+        elif argsDict["dg_type"] == gtu.DataGenType.FP_SPECIAL:
             mode = gtu.ComplianceMode.FP_SPECIAL
         elif "compliance" in op and "ulp" in op["compliance"]:
             mode = gtu.ComplianceMode.ULP
@@ -3352,7 +3352,11 @@
         DType.FP32: (gtu.DataGenType.DOT_PRODUCT,),
     }
     EW_UNARY_DATAGEN = {
-        DType.FP16: (gtu.DataGenType.PSEUDO_RANDOM, gtu.DataGenType.FULL_RANGE)
+        DType.FP16: (gtu.DataGenType.PSEUDO_RANDOM, gtu.DataGenType.FULL_RANGE),
+    }
+    PR_FS_DATAGEN = {
+        DType.FP16: (gtu.DataGenType.PSEUDO_RANDOM, gtu.DataGenType.FP_SPECIAL),
+        DType.FP32: (gtu.DataGenType.PSEUDO_RANDOM, gtu.DataGenType.FP_SPECIAL),
     }
 
     TOSA_OP_LIST = {
@@ -3716,7 +3720,7 @@
                 TosaErrorValidator.evDimensionMismatch,
                 TosaErrorValidator.evBroadcastShapesMismatch,
             ),
-            "data_gen": PSEUDO_RANDOM_DATAGEN,
+            "data_gen": PR_FS_DATAGEN,
             "compliance": {"ulp": 0.5},
         },
         "arithmetic_right_shift": {
@@ -3938,7 +3942,7 @@
                 TosaErrorValidator.evDimensionMismatch,
                 TosaErrorValidator.evBroadcastShapesMismatch,
             ),
-            "data_gen": PSEUDO_RANDOM_DATAGEN,
+            "data_gen": PR_FS_DATAGEN,
         },
         "minimum": {
             "op": Op.MINIMUM,
@@ -4330,7 +4334,7 @@
                 TosaErrorValidator.evDimensionMismatch,
                 TosaErrorValidator.evBroadcastShapesMismatch,
             ),
-            "data_gen": PSEUDO_RANDOM_DATAGEN,
+            "data_gen": PR_FS_DATAGEN,
         },
         "greater_equal": {
             "op": Op.GREATER_EQUAL,
@@ -4351,7 +4355,7 @@
                 TosaErrorValidator.evDimensionMismatch,
                 TosaErrorValidator.evBroadcastShapesMismatch,
             ),
-            "data_gen": PSEUDO_RANDOM_DATAGEN,
+            "data_gen": PR_FS_DATAGEN,
         },
         "greater": {
             "op": Op.GREATER,
@@ -4372,7 +4376,7 @@
                 TosaErrorValidator.evDimensionMismatch,
                 TosaErrorValidator.evBroadcastShapesMismatch,
             ),
-            "data_gen": PSEUDO_RANDOM_DATAGEN,
+            "data_gen": PR_FS_DATAGEN,
         },
         # Reduction operators
         "reduce_all": {
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index a8e321e..478190d 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -55,7 +55,7 @@
     DOT_PRODUCT = 1
     BOUNDARY = 2
     FULL_RANGE = 3
-    SPECIAL = 4
+    FP_SPECIAL = 4
     FIXED_DATA = 5