Add framework tests for tfl.real and tfl.imag
Change-Id: I665acac9b5171efd0c5a2b68b516609048f6e187
Signed-off-by: Luke Hutton <luke.hutton@arm.com>
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py
index f8d50a8..60e17ce 100644
--- a/verif/frameworks/tensor_gen.py
+++ b/verif/frameworks/tensor_gen.py
@@ -56,6 +56,10 @@
return np.uint32(rng.integers(low=0, high=RAND_INT_MAX, size=shape))
if dtype == tf.bool:
return np.bool_(rng.choice(a=[False, True], size=shape))
+ if dtype == tf.complex64:
+ return TGen.getRand(shape, np.float32, rng) + 1j * TGen.getRand(
+ shape, np.float32, rng
+ )
raise Exception("Unsupported type: {}".format(dtype))
@@ -305,5 +309,13 @@
if len(shape) != 3:
return [], []
- tf_placeholders = [("placeholder_0", TGen.getRand(shape, dtype, rng))]
- return tf_placeholders, []
+ return TGen.tgBasic(op, shape, dtype, rng)
+
+ @staticmethod
+ def tgComplexComponents(op, shape, dtype, rng):
+ # Temporarily require up to rank 3 shape, due to
+ # slice maximum rank limitiation.
+ if len(shape) > 3:
+ return [], []
+
+ return TGen.tgBasic(op, shape, dtype, rng)
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py
index 6302865..f872888 100644
--- a/verif/frameworks/test_builder.py
+++ b/verif/frameworks/test_builder.py
@@ -1225,3 +1225,17 @@
def eval(self, a):
return tf.signal.rfft2d(a, self.fft_length, name=self.result_name)
+
+ class Real:
+ def __init__(self, name):
+ self.result_name = name
+
+ def eval(self, a):
+ return tf.math.real(a, name=self.result_name)
+
+ class Imag:
+ def __init__(self, name):
+ self.result_name = name
+
+ def eval(self, a):
+ return tf.math.imag(a, name=self.result_name)
diff --git a/verif/frameworks/test_gen_utils.py b/verif/frameworks/test_gen_utils.py
index 2d8e5d6..6a59848 100644
--- a/verif/frameworks/test_gen_utils.py
+++ b/verif/frameworks/test_gen_utils.py
@@ -30,6 +30,8 @@
shape_name = shape_name + "_qi16"
elif dtype == tf.quint16:
shape_name = shape_name + "_qu16"
+ elif dtype == tf.complex64:
+ shape_name = shape_name + "_c64"
else:
raise Exception("Unsupported type: {}".format(dtype))
diff --git a/verif/frameworks/tosa_verif_framework_compiler_runner.py b/verif/frameworks/tosa_verif_framework_compiler_runner.py
index c55864a..71723ae 100755
--- a/verif/frameworks/tosa_verif_framework_compiler_runner.py
+++ b/verif/frameworks/tosa_verif_framework_compiler_runner.py
@@ -384,7 +384,13 @@
while len(list(ifm_np.shape)) < len(test_desc["ifm_shape"][i]):
ifm_np = np.expand_dims(ifm_np, axis=0)
- assert list(ifm_np.shape) == test_desc["ifm_shape"][i]
+ # After legalization, complex tensors are expected to be represented
+ # as a single floating point tensor of shape [?, ..., ?, 2].
+ expected_shape = test_desc["ifm_shape"][i]
+ if test.endswith("c64"):
+ expected_shape.append(2)
+
+ assert list(ifm_np.shape) == expected_shape
reference_runner_ifm_name.append(ifm_tensor_name)
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index 0741686..fffb842 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -841,6 +841,20 @@
"tflite": TYPE_F,
},
},
+ "real": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.Real, TGen.tgComplexComponents, ArgGen.agNone),
+ "types": {
+ "tflite": [tf.complex64],
+ },
+ },
+ "imag": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.Imag, TGen.tgComplexComponents, ArgGen.agNone),
+ "types": {
+ "tflite": [tf.complex64],
+ },
+ },
}
# Shapes to be tested; default can be overwritten
@@ -1154,6 +1168,14 @@
# 1. Saved out numpy array directly
for idx, (name, val) in enumerate(placeholders):
placeholder_vals.append(tf.convert_to_tensor(val))
+
+ # Complex tensors are expected to be repsesented by a
+ # single floating point tensor of shape [?, ..., ?, 2].
+ if val.dtype == np.complex64:
+ val_shape = val.shape + (2,)
+ val = val.view(np.float32)
+ val = val.reshape(val_shape)
+
np.save(
os.path.join(test_dir, placeholder_npy_filenames[idx]), val, False
)