Add a parameter to tensor generation function to disable fuzzing
Signed-off-by: Won Jeon <won.jeon@arm.com>
Change-Id: Iff31b33b818a181371904915d5477a169513aa2e
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index 02ab8aa..12fff68 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -22,7 +22,6 @@
from frameworks.write_test_json import write_test_json # noqa: E402
from frameworks.arg_gen import ArgGen # noqa: E402
from frameworks.tensor_gen import TGen # noqa: E402
-from frameworks.tensor_gen import ElemSignedness # noqa: E402
from frameworks.test_builder import TBuilder # noqa: E402
from frameworks.test_gen_utils import ( # noqa: E402
QuantType,
@@ -958,9 +957,16 @@
# Get and seed a random number generator for this test
rng = np.random.default_rng(seed)
+ # For broadcast fuzzing, record the fuzzed index if fuzzing is already done.
+ fuzzed = []
+
# return placeholders=(str: name, np.array: value)
# consts=(str: name, np.array: value)
- placeholders, consts = tensor_gen_fcn(op, curr_shape, dtype, rng)
+ placeholders, consts = (
+ tensor_gen_fcn(op, curr_shape, dtype, rng, fuzzed)
+ if tensor_gen_fcn.__name__ == "tgBFuzz"
+ else tensor_gen_fcn(op, curr_shape, dtype, rng)
+ )
# if test doesn't have any placeholders/consts, terminated
if len(placeholders) == 0 and len(consts) == 0:
@@ -1157,18 +1163,19 @@
if tflite_inference_dtype == tf.int16:
converter.target_spec.supported_ops = [flag]
+ # Generator function for integer quantization of TFLiteConverter
+ # which generates a few hundred input samples with the same order, type, and shape as the inputs,
+ # to calibrate/estimate the range of the floating-point inputs.
+ # For broadcast fuzzing tests, fuzzing needs to be disabled, otherwise, it causes a mismatch of
+ # tensor shapes of inputs.
def input_stats():
- ## Rsqrt can only handle positive numbers
- elem_signedness = ElemSignedness.ALL_RANGE
- if op_name == "rsqrt":
- elem_signedness = ElemSignedness.POSITIVE
-
for i in range(0, args.num_samples):
- a = [
- TGen.getRand(shape, tf.float32, rng, elem_signedness)
- for shape in placeholder_shapes
- ]
- yield a
+ placeholders, _ = (
+ tensor_gen_fcn(op, placeholder_shapes[0], dtype, rng, fuzzed)
+ if tensor_gen_fcn == "tgBFuzz"
+ else tensor_gen_fcn(op, placeholder_shapes[0], dtype, rng)
+ )
+ yield [s[1] for s in placeholders]
converter.representative_dataset = input_stats
converter.inference_input_type = tflite_inference_dtype