Add framework tests for tfl.real and tfl.imag

Change-Id: I665acac9b5171efd0c5a2b68b516609048f6e187
Signed-off-by: Luke Hutton <luke.hutton@arm.com>
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py
index f8d50a8..60e17ce 100644
--- a/verif/frameworks/tensor_gen.py
+++ b/verif/frameworks/tensor_gen.py
@@ -56,6 +56,10 @@
             return np.uint32(rng.integers(low=0, high=RAND_INT_MAX, size=shape))
         if dtype == tf.bool:
             return np.bool_(rng.choice(a=[False, True], size=shape))
+        if dtype == tf.complex64:
+            return TGen.getRand(shape, np.float32, rng) + 1j * TGen.getRand(
+                shape, np.float32, rng
+            )
 
         raise Exception("Unsupported type: {}".format(dtype))
 
@@ -305,5 +309,13 @@
         if len(shape) != 3:
             return [], []
 
-        tf_placeholders = [("placeholder_0", TGen.getRand(shape, dtype, rng))]
-        return tf_placeholders, []
+        return TGen.tgBasic(op, shape, dtype, rng)
+
+    @staticmethod
+    def tgComplexComponents(op, shape, dtype, rng):
+        # Temporarily require up to rank 3 shape, due to
+        # slice maximum rank limitiation.
+        if len(shape) > 3:
+            return [], []
+
+        return TGen.tgBasic(op, shape, dtype, rng)
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py
index 6302865..f872888 100644
--- a/verif/frameworks/test_builder.py
+++ b/verif/frameworks/test_builder.py
@@ -1225,3 +1225,17 @@
 
         def eval(self, a):
             return tf.signal.rfft2d(a, self.fft_length, name=self.result_name)
+
+    class Real:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.math.real(a, name=self.result_name)
+
+    class Imag:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.math.imag(a, name=self.result_name)
diff --git a/verif/frameworks/test_gen_utils.py b/verif/frameworks/test_gen_utils.py
index 2d8e5d6..6a59848 100644
--- a/verif/frameworks/test_gen_utils.py
+++ b/verif/frameworks/test_gen_utils.py
@@ -30,6 +30,8 @@
         shape_name = shape_name + "_qi16"
     elif dtype == tf.quint16:
         shape_name = shape_name + "_qu16"
+    elif dtype == tf.complex64:
+        shape_name = shape_name + "_c64"
     else:
         raise Exception("Unsupported type: {}".format(dtype))
 
diff --git a/verif/frameworks/tosa_verif_framework_compiler_runner.py b/verif/frameworks/tosa_verif_framework_compiler_runner.py
index c55864a..71723ae 100755
--- a/verif/frameworks/tosa_verif_framework_compiler_runner.py
+++ b/verif/frameworks/tosa_verif_framework_compiler_runner.py
@@ -384,7 +384,13 @@
             while len(list(ifm_np.shape)) < len(test_desc["ifm_shape"][i]):
                 ifm_np = np.expand_dims(ifm_np, axis=0)
 
-            assert list(ifm_np.shape) == test_desc["ifm_shape"][i]
+            # After legalization, complex tensors are expected to be represented
+            # as a single floating point tensor of shape [?, ..., ?, 2].
+            expected_shape = test_desc["ifm_shape"][i]
+            if test.endswith("c64"):
+                expected_shape.append(2)
+
+            assert list(ifm_np.shape) == expected_shape
 
             reference_runner_ifm_name.append(ifm_tensor_name)
 
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index 0741686..fffb842 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -841,6 +841,20 @@
             "tflite": TYPE_F,
         },
     },
+    "real": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.Real, TGen.tgComplexComponents, ArgGen.agNone),
+        "types": {
+            "tflite": [tf.complex64],
+        },
+    },
+    "imag": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.Imag, TGen.tgComplexComponents, ArgGen.agNone),
+        "types": {
+            "tflite": [tf.complex64],
+        },
+    },
 }
 
 # Shapes to be tested; default can be overwritten
@@ -1154,6 +1168,14 @@
             # 1. Saved out numpy array directly
             for idx, (name, val) in enumerate(placeholders):
                 placeholder_vals.append(tf.convert_to_tensor(val))
+
+                # Complex tensors are expected to be repsesented by a
+                # single floating point tensor of shape [?, ..., ?, 2].
+                if val.dtype == np.complex64:
+                    val_shape = val.shape + (2,)
+                    val = val.view(np.float32)
+                    val = val.reshape(val_shape)
+
                 np.save(
                     os.path.join(test_dir, placeholder_npy_filenames[idx]), val, False
                 )