Add FFT2d to the reference model

Includes:
* FFT2d reference implementation
* Basic TOSA tests

Change-Id: Ie79fcb713542345d550ec013646810c1e890e388
Signed-off-by: Luke Hutton <luke.hutton@arm.com>
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 5f9e2c1..2b762aa 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -213,6 +213,12 @@
         else:
             raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
 
+    def constrictBatchSize(self, shape):
+        # Limit the batch size unless an explicit target shape set
+        if self.args.max_batch_size and not self.args.target_shapes:
+            shape[0] = min(shape[0], self.args.max_batch_size)
+        return shape
+
     # Argument generators
     # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
     # Where the string descriptor is used to generate the test name and
@@ -2081,6 +2087,48 @@
 
         return acc_out
 
+    def build_fft2d(
+        self, op, val1, val2, inverse, validator_fcns=None, error_name=None
+    ):
+        results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
+
+        input_names = [val1.name, val2.name]
+        pCount, cCount = op["operands"]
+        num_operands = pCount + cCount
+
+        output_names = [res.name for res in results]
+        output_shapes = [res.shape for res in results]
+        output_dtypes = [res.dtype for res in results]
+
+        input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+            self, error_name, input_names, output_names
+        )
+
+        if not TosaErrorValidator.evValidateErrorIfs(
+            self.ser,
+            validator_fcns,
+            error_name,
+            op=op,
+            inverse=inverse,
+            input1=val1,
+            input2=val2,
+            input_shape=val1.shape,
+            input_dtype=val1.dtype,
+            output_shape=output_shapes,
+            output_dtype=output_dtypes,
+            result_tensors=results,
+            input_list=input_names,
+            output_list=output_names,
+            num_operands=num_operands,
+        ):
+            return None
+
+        attr = ts.TosaSerializerAttribute()
+        attr.FFTAttribute(inverse)
+
+        self.ser.addOperator(op["op"], input_names, output_names, attr)
+        return results
+
     def build_rfft2d(self, op, val, validator_fcns=None, error_name=None):
         results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
 
@@ -2089,6 +2137,7 @@
         num_operands = pCount + cCount
 
         output_names = [res.name for res in results]
+        output_shapes = [res.shape for res in results]
         output_dtypes = [res.dtype for res in results]
 
         input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
@@ -2102,6 +2151,7 @@
             op=op,
             input_shape=val.shape,
             input_dtype=val.dtype,
+            output_shape=output_shapes,
             output_dtype=output_dtypes,
             result_tensors=results,
             input_list=input_names,
@@ -3927,6 +3977,29 @@
                 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
             ),
         },
+        "fft2d": {
+            "op": Op.FFT2D,
+            "operands": (2, 0),
+            "rank": (3, 3),
+            "build_fcn": (
+                build_fft2d,
+                TosaTensorGen.tgFFT2d,
+                TosaTensorValuesGen.tvgDefault,
+                TosaArgGen.agFFT2d,
+            ),
+            "types": [DType.FP32],
+            "error_if_validators": (
+                TosaErrorValidator.evWrongInputType,
+                TosaErrorValidator.evWrongOutputType,
+                TosaErrorValidator.evWrongInputList,
+                TosaErrorValidator.evWrongOutputList,
+                TosaErrorValidator.evWrongRank,
+                TosaErrorValidator.evBatchMismatch,
+                TosaErrorValidator.evKernelNotPowerOfTwo,
+                TosaErrorValidator.evFFTInputShapeMismatch,
+                TosaErrorValidator.evFFTOutputShapeMismatch,
+            ),
+        },
         "rfft2d": {
             "op": Op.RFFT2D,
             "operands": (1, 0),
@@ -3946,6 +4019,7 @@
                 TosaErrorValidator.evWrongRank,
                 TosaErrorValidator.evBatchMismatch,
                 TosaErrorValidator.evKernelNotPowerOfTwo,
+                TosaErrorValidator.evFFTOutputShapeMismatch,
             ),
         },
     }
@@ -4770,6 +4844,37 @@
         return ser.addOutput(output_shape, out_dtype)
 
     @staticmethod
+    def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
+        outputs = []
+
+        assert ifm1.dtype == ifm2.dtype
+        input_dtype = ifm1.dtype
+
+        if error_name != ErrorIf.FFTInputShapeMismatch:
+            assert ifm1.shape == ifm2.shape
+
+        input_shape = ifm1.shape
+        if error_name != ErrorIf.WrongRank:
+            assert len(input_shape) == 3
+
+        output_shape = input_shape.copy()
+        output_dtype = input_dtype
+
+        if error_name == ErrorIf.WrongOutputType:
+            excludes = [DType.FP32]
+            wrong_dtypes = list(usableDTypes(excludes=excludes))
+            output_dtype = rng.choice(wrong_dtypes)
+        elif error_name == ErrorIf.BatchMismatch:
+            output_shape[0] += rng.integers(1, 10)
+        elif error_name == ErrorIf.FFTOutputShapeMismatch:
+            modify_dim = rng.choice([1, 2])
+            output_shape[modify_dim] += rng.integers(1, 10)
+
+        outputs.append(serializer.addOutput(output_shape, output_dtype))
+        outputs.append(serializer.addOutput(output_shape, output_dtype))
+        return outputs
+
+    @staticmethod
     def rfft2dOp(serializer, rng, value, error_name=None):
         outputs = []
 
@@ -4785,8 +4890,10 @@
             wrong_dtypes = list(usableDTypes(excludes=excludes))
             output_dtype = rng.choice(wrong_dtypes)
         elif error_name == ErrorIf.BatchMismatch:
-            incorrect_batch = input_shape[0] + rng.integers(1, 10)
-            output_shape = [incorrect_batch, *input_shape[1:]]
+            output_shape[0] += rng.integers(1, 10)
+        elif error_name == ErrorIf.FFTOutputShapeMismatch:
+            modify_dim = rng.choice([1, 2])
+            output_shape[modify_dim] += rng.integers(1, 10)
 
         outputs.append(serializer.addOutput(output_shape, output_dtype))
         outputs.append(serializer.addOutput(output_shape, output_dtype))