Add a parameter to tensor generation function to disable fuzzing

Signed-off-by: Won Jeon <won.jeon@arm.com>
Change-Id: Iff31b33b818a181371904915d5477a169513aa2e
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py
index 170e5d8..d50bc74 100644
--- a/verif/frameworks/tensor_gen.py
+++ b/verif/frameworks/tensor_gen.py
@@ -91,7 +91,7 @@
         return tf_placeholders, tf_consts
 
     @staticmethod
-    def tgBFuzz(op, shape, dtype, rng):
+    def tgBFuzz(op, shape, dtype, rng, fuzzed=[]):
         # Build random tensor placeholder node args of a given shape, optionally
         # fuzzing the arguments with random 1's to force broadcasting
 
@@ -105,12 +105,14 @@
         tf_placeholders = []
         tf_consts = []
         for i in range(pl):
-            if i == fuzz_arg:
+            if not fuzzed and i == fuzz_arg:
                 # Insert the broadcast in one dimension index
                 s_fuzz = list(shape)
                 s_fuzz[fuzz_idx] = 1
                 s_fuzz = tuple(s_fuzz)
                 i_shape = s_fuzz
+                # Record the fuzzed index.
+                fuzzed.append(i)
             else:
                 i_shape = shape
             tf_placeholders.append(
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index 02ab8aa..12fff68 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -22,7 +22,6 @@
 from frameworks.write_test_json import write_test_json  # noqa: E402
 from frameworks.arg_gen import ArgGen  # noqa: E402
 from frameworks.tensor_gen import TGen  # noqa: E402
-from frameworks.tensor_gen import ElemSignedness  # noqa: E402
 from frameworks.test_builder import TBuilder  # noqa: E402
 from frameworks.test_gen_utils import (  # noqa: E402
     QuantType,
@@ -958,9 +957,16 @@
         # Get and seed a random number generator for this test
         rng = np.random.default_rng(seed)
 
+        # For broadcast fuzzing, record the fuzzed index if fuzzing is already done.
+        fuzzed = []
+
         # return placeholders=(str: name, np.array: value)
         # consts=(str: name, np.array: value)
-        placeholders, consts = tensor_gen_fcn(op, curr_shape, dtype, rng)
+        placeholders, consts = (
+            tensor_gen_fcn(op, curr_shape, dtype, rng, fuzzed)
+            if tensor_gen_fcn.__name__ == "tgBFuzz"
+            else tensor_gen_fcn(op, curr_shape, dtype, rng)
+        )
 
         # if test doesn't have any placeholders/consts, terminated
         if len(placeholders) == 0 and len(consts) == 0:
@@ -1157,18 +1163,19 @@
             if tflite_inference_dtype == tf.int16:
                 converter.target_spec.supported_ops = [flag]
 
+            # Generator function for integer quantization of TFLiteConverter
+            # which generates a few hundred input samples with the same order, type, and shape as the inputs,
+            # to calibrate/estimate the range of the floating-point inputs.
+            # For broadcast fuzzing tests, fuzzing needs to be disabled, otherwise, it causes a mismatch of
+            # tensor shapes of inputs.
             def input_stats():
-                ## Rsqrt can only handle positive numbers
-                elem_signedness = ElemSignedness.ALL_RANGE
-                if op_name == "rsqrt":
-                    elem_signedness = ElemSignedness.POSITIVE
-
                 for i in range(0, args.num_samples):
-                    a = [
-                        TGen.getRand(shape, tf.float32, rng, elem_signedness)
-                        for shape in placeholder_shapes
-                    ]
-                    yield a
+                    placeholders, _ = (
+                        tensor_gen_fcn(op, placeholder_shapes[0], dtype, rng, fuzzed)
+                        if tensor_gen_fcn == "tgBFuzz"
+                        else tensor_gen_fcn(op, placeholder_shapes[0], dtype, rng)
+                    )
+                    yield [s[1] for s in placeholders]
 
             converter.representative_dataset = input_stats
             converter.inference_input_type = tflite_inference_dtype