Add a parameter to tensor generation function to disable fuzzing
Signed-off-by: Won Jeon <won.jeon@arm.com>
Change-Id: Iff31b33b818a181371904915d5477a169513aa2e
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py
index 170e5d8..d50bc74 100644
--- a/verif/frameworks/tensor_gen.py
+++ b/verif/frameworks/tensor_gen.py
@@ -91,7 +91,7 @@
return tf_placeholders, tf_consts
@staticmethod
- def tgBFuzz(op, shape, dtype, rng):
+ def tgBFuzz(op, shape, dtype, rng, fuzzed=[]):
# Build random tensor placeholder node args of a given shape, optionally
# fuzzing the arguments with random 1's to force broadcasting
@@ -105,12 +105,14 @@
tf_placeholders = []
tf_consts = []
for i in range(pl):
- if i == fuzz_arg:
+ if not fuzzed and i == fuzz_arg:
# Insert the broadcast in one dimension index
s_fuzz = list(shape)
s_fuzz[fuzz_idx] = 1
s_fuzz = tuple(s_fuzz)
i_shape = s_fuzz
+ # Record the fuzzed index.
+ fuzzed.append(i)
else:
i_shape = shape
tf_placeholders.append(