Add RFFT2d to the reference model

Includes:
* RFFT2d reference implementation
* TFLite framework tests
* Basic TOSA tests
* Serialization submodule upgrade with support for FFT/RFFT

Signed-off-by: Luke Hutton <luke.hutton@arm.com>
Change-Id: I2a687e9cf87fb62a26160ea52439ba9830bea36e
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index c29763b..fddf942 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -255,7 +255,7 @@
             input_dtype=a.dtype,
             output_dtype=result_tens.dtype,
             qinfo=qinfo,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -293,7 +293,7 @@
             input2=b,
             input_dtype=a.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -333,7 +333,7 @@
             input2=b,
             input_dtype=a.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -378,7 +378,7 @@
             input2=b,
             input_dtype=a.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -414,7 +414,7 @@
             input_shape=a.shape,
             input_dtype=a.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -448,7 +448,7 @@
             input_shape=a.shape,
             input_dtype=a.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -487,7 +487,7 @@
             input_dtype=a.dtype,
             output_shape=result_tens.shape,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -523,7 +523,7 @@
             input_dtype=a.dtype,
             output_shape=result_tens.shape,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -582,7 +582,7 @@
             stride=stride,
             pad=pad,
             qinfo=qinfo,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -938,7 +938,7 @@
             output_shape=result_tens.shape,
             output_dtype=result_tens.dtype,
             qinfo=qinfo,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -980,7 +980,7 @@
             output_shape=result_tens.shape,
             output_dtype=result_tens.dtype,
             qinfo=qinfo,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1016,7 +1016,7 @@
             output_shape=result_tens.shape,
             input_dtype=a.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1064,7 +1064,7 @@
             output_shape=result_tens.shape,
             input_dtype=a.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1122,7 +1122,7 @@
             output_shape=result_tens.shape,
             input_dtype=a.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1153,7 +1153,7 @@
             output_shape=result_tens.shape,
             input_dtype=a.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1199,7 +1199,7 @@
             input_dtype=a[0].dtype,
             output_dtype=result_tens.dtype,
             inputs=a,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1250,7 +1250,7 @@
             output_dtype=result_tens.dtype,
             pad=padding,
             qinfo=qinfo,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1283,7 +1283,7 @@
             output_shape=result_tens.shape,
             input_dtype=a.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1318,7 +1318,7 @@
             output_shape=result_tens.shape,
             input_dtype=a.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1356,7 +1356,7 @@
             perms=perms,
             input_dtype=a.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1391,7 +1391,7 @@
             output_dtype=result_tens.dtype,
             start=start,
             size=size,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1425,7 +1425,7 @@
             output_shape=result_tens.shape,
             input_dtype=a.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1474,7 +1474,7 @@
             output_shape=result_tens.shape,
             input_dtype=values.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1519,7 +1519,7 @@
             output_shape=result_tens.shape,
             input_dtype=values_in.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1580,7 +1580,7 @@
             border=border,
             input_list=input_list,
             output_list=output_list,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             num_operands=num_operands,
         ):
             return None
@@ -1628,7 +1628,7 @@
             output_shape=result_tens.shape,
             input_dtype=val.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1774,7 +1774,7 @@
             double_round=double_round,
             input_list=input_list,
             output_list=output_list,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             num_operands=num_operands,
         ):
             return None
@@ -2083,6 +2083,38 @@
 
         return acc_out
 
+    def build_rfft2d(self, op, val, validator_fcns=None, error_name=None):
+        results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
+
+        input_names = [val.name]
+        pCount, cCount = op["operands"]
+        num_operands = pCount + cCount
+
+        output_names = [res.name for res in results]
+        output_dtypes = [res.dtype for res in results]
+
+        input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+            self, error_name, input_names, output_names
+        )
+
+        if not TosaErrorValidator.evValidateErrorIfs(
+            self.ser,
+            validator_fcns,
+            error_name,
+            op=op,
+            input_shape=val.shape,
+            input_dtype=val.dtype,
+            output_dtype=output_dtypes,
+            result_tensors=results,
+            input_list=input_names,
+            output_list=output_names,
+            num_operands=num_operands,
+        ):
+            return None
+
+        self.ser.addOperator(op["op"], input_names, output_names)
+        return results
+
     def create_filter_lists(
         self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
     ):
@@ -3897,6 +3929,27 @@
                 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
             ),
         },
+        "rfft2d": {
+            "op": Op.RFFT2D,
+            "operands": (1, 0),
+            "rank": (3, 3),
+            "build_fcn": (
+                build_rfft2d,
+                TosaTensorGen.tgRFFT2d,
+                TosaTensorValuesGen.tvgDefault,
+                TosaArgGen.agNone,
+            ),
+            "types": [DType.FP32],
+            "error_if_validators": (
+                TosaErrorValidator.evWrongInputType,
+                TosaErrorValidator.evWrongOutputType,
+                TosaErrorValidator.evWrongInputList,
+                TosaErrorValidator.evWrongOutputList,
+                TosaErrorValidator.evWrongRank,
+                TosaErrorValidator.evBatchMismatch,
+                TosaErrorValidator.evKernelNotPowerOfTwo,
+            ),
+        },
     }
 
 
@@ -4717,3 +4770,26 @@
             out_dtype = rng.choice(wrong_dtypes)
 
         return ser.addOutput(output_shape, out_dtype)
+
+    @staticmethod
+    def rfft2dOp(serializer, rng, value, error_name=None):
+        outputs = []
+
+        input_shape = value.shape
+        if error_name != ErrorIf.WrongRank:
+            assert len(input_shape) == 3
+
+        output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
+
+        output_dtype = value.dtype
+        if error_name == ErrorIf.WrongOutputType:
+            excludes = [DType.FP32]
+            wrong_dtypes = list(usableDTypes(excludes=excludes))
+            output_dtype = rng.choice(wrong_dtypes)
+        elif error_name == ErrorIf.BatchMismatch:
+            incorrect_batch = input_shape[0] + rng.integers(1, 10)
+            output_shape = [incorrect_batch, *input_shape[1:]]
+
+        outputs.append(serializer.addOutput(output_shape, output_dtype))
+        outputs.append(serializer.addOutput(output_shape, output_dtype))
+        return outputs