Add RFFT2d to the reference model

Includes:
* RFFT2d reference implementation
* TFLite framework tests
* Basic TOSA tests
* Serialization submodule upgrade with support for FFT/RFFT

Signed-off-by: Luke Hutton <luke.hutton@arm.com>
Change-Id: I2a687e9cf87fb62a26160ea52439ba9830bea36e
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index 0121ccf..0d56161 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -1,5 +1,5 @@
 
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
 //
 //    Licensed under the Apache License, Version 2.0 (the "License");
 //    you may not use this file except in compliance with the License.
@@ -113,6 +113,9 @@
             DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT8);
             DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16);
             break;
+        case Op_RFFT2D:
+            DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP32);
+            break;
         case Op_TRANSPOSE_CONV2D:
             DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP16);
             DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP32);
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index b9ac94a..dff9e08 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -1,5 +1,5 @@
 
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
 //
 //    Licensed under the Apache License, Version 2.0 (the "License");
 //    you may not use this file except in compliance with the License.
@@ -1453,6 +1453,140 @@
     return GraphNode::eval();
 }
 
+template <DType Dtype>
+OpRFFT2d<Dtype>::OpRFFT2d(SubgraphTraverser* sgt_,
+                          TosaAttributeBase* attribute_,
+                          uint64_t id_)
+    : GraphNode(sgt_, Op_RFFT2D, id_)
+{
+    setRequiredOperands(1, 2);
+    setRequiredRank(3);
+}
+
+template <DType Dtype>
+OpRFFT2d<Dtype>::~OpRFFT2d() {}
+
+
+template <DType Dtype>
+int OpRFFT2d<Dtype>::checkTensorAttributes()
+{
+    if (validateRequiredOperands())
+        return 1;
+
+    if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]) ||
+    validateRequiredRank(outputs[1]))
+    {
+        return 1;
+    }
+
+    if (inputs[0]->matchType(*outputs[0]) || inputs[0]->matchType(*outputs[1]))
+    {
+        printNodeValidationError("OpRFFT2d: input and output tensor type mismatch");
+        return 1;
+    }
+
+    in  = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+    out_real = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+    out_imag = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[1]);
+
+    ASSERT_MEM(in && out_real && out_imag);
+
+    auto is_power_of_two = [](int32_t n) -> bool
+    {
+        return (n & (n-1)) == 0 && n > 0;
+    };
+
+    // Input shape: [N, H, W]
+    if (!is_power_of_two(in->getShape()[1]) || !is_power_of_two(in->getShape()[2]))
+    {
+        printNodeValidationError("OpRFFT2d: input height and width must be a power of two");
+        return 1;
+    }
+
+    // Output shape: [N, H, W / 2 + 1]
+    bool output_check = true;
+    for (int32_t i = 0; i < out_real->getRank(); i++)
+    {
+        if (out_real->getShape()[i] != out_imag->getShape()[i])
+        {
+            output_check = false;
+            break;
+        }
+    }
+    if (!output_check)
+    {
+        printNodeValidationError(
+            "OpRFFT2d: Mismatch between real output shape and imaginary output shape");
+        return 1;
+    }
+
+    if (in->getShape()[0] != out_real->getShape()[0]) {
+        printNodeValidationError("OpRFFT2d: input and output batch size don't match");
+        return 1;
+    }
+    if (in->getShape()[1] != out_real->getShape()[1]) {
+        printNodeValidationError("OpRFFT2d: input and output height don't match");
+        return 1;
+    }
+    if (in->getShape()[2] / 2 + 1 != out_real->getShape()[2]) {
+        printNodeValidationError("OpRFFT2d:  output width is expected to match input width / 2 + 1");
+        return 1;
+    }
+
+    return 0;
+}
+
+template <DType Dtype>
+int OpRFFT2d<Dtype>::eval()
+{
+    int32_t in_batch = in->getShape()[0];
+    int32_t in_height = in->getShape()[1];
+    int32_t in_width = in->getShape()[2];
+
+    int32_t out_real_batch = out_real->getShape()[0];
+    int32_t out_real_height = out_real->getShape()[1];
+    int32_t out_real_width = out_real->getShape()[2];
+
+    int32_t out_imag_batch = out_imag->getShape()[0];
+    int32_t out_imag_height = out_imag->getShape()[1];
+    int32_t out_imag_width = out_imag->getShape()[2];
+
+    DEBUG_INFO(OP,
+               "perform OpRFFT2d, input.shape=[%d,%d,%d], output_real.shape=[%d,%d,%d], "
+               "output_imag.shape=[%d,%d,%d]",
+               in_batch, in_height, in_width,
+               out_real_batch, out_real_height, out_real_width,
+               out_imag_batch, out_imag_height, out_imag_width);
+
+    OutEigenType sum_real, sum_imag, a;
+
+    for (int n = 0; n < in_batch; n++)
+    {
+        for (int oy = 0; oy < out_real_height; oy++)
+        {
+            for (int ox = 0; ox < out_real_width; ox++)
+            {
+                sum_real = 0.0;
+                sum_imag = 0.0;
+                for (int iy = 0; iy < in_height; iy++)
+                {
+                    for (int ix = 0; ix < in_width; ix++)
+                    {
+                        // Use explicit cast to ensure intermmediate calculations are completed using OutEigenType
+                        a = 2 * M_PI * ((iy * (OutEigenType)oy) / in_height + (ix * (OutEigenType)ox) / in_width);
+                        sum_real += this->in->getTensor()(n, iy, ix) * cos(a);
+                        sum_imag += -this->in->getTensor()(n, iy, ix) * sin(a);
+                    }
+                }
+                this->out_real->getTensor()(n, oy, ox) = sum_real;
+                this->out_imag->getTensor()(n, oy, ox) = sum_imag;
+            }
+        }
+    }
+
+    return GraphNode::eval();
+}
+
 template <DType InDtype, DType WeightDtype, DType AccDtype>
 OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
                                                            TosaAttributeBase* attribute_,
@@ -1738,6 +1872,8 @@
 DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8);
 DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16);
 
+DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32);
+
 DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP16);
 DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP32);
 DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, BF16, BF16, FP32);
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h
index fd6dd25..ed9a55c 100644
--- a/reference_model/src/ops/tensor_ops.h
+++ b/reference_model/src/ops/tensor_ops.h
@@ -1,5 +1,5 @@
 
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
 //
 //    Licensed under the Apache License, Version 2.0 (the "License");
 //    you may not use this file except in compliance with the License.
@@ -248,6 +248,27 @@
     tosa::TosaPoolAttribute* attribute;
 };
 
+template <DType Dtype>
+class OpRFFT2d : public GraphNode
+{
+public:
+    OpRFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
+    virtual ~OpRFFT2d();
+
+    virtual int checkTensorAttributes() final;
+    virtual int eval() final;
+
+    using InEigenType   = typename GetEigenType<Dtype>::type;
+    using OutEigenType  = typename GetEigenType<Dtype>::type;
+    using TIn           = Eigen::Tensor<InEigenType, 3>;
+    using TOut          = Eigen::Tensor<OutEigenType, 3>;
+
+protected:
+    TosaReference::TensorTemplate<TIn>* in;
+    TosaReference::TensorTemplate<TOut>* out_real;
+    TosaReference::TensorTemplate<TOut>* out_imag;
+};
+
 template <DType InDtype, DType WeightDtype, DType AccDtype>
 class OpTransposeConv2d : public GraphNode
 {
diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib
index e36f4f7..c15f7d5 160000
--- a/thirdparty/serialization_lib
+++ b/thirdparty/serialization_lib
@@ -1 +1 @@
-Subproject commit e36f4f70b51c03712db96ea284e6e54b3e60a74c
+Subproject commit c15f7d52aa4f360eba2344449baa418b7608ac7c
diff --git a/verif/frameworks/arg_gen.py b/verif/frameworks/arg_gen.py
index d81c3dd..61a1de0 100644
--- a/verif/frameworks/arg_gen.py
+++ b/verif/frameworks/arg_gen.py
@@ -1,5 +1,7 @@
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
 # SPDX-License-Identifier: Apache-2.0
+import math
+
 import numpy as np
 
 
@@ -851,3 +853,29 @@
             else:
                 axes.append(["_axis_m{}".format(-i), [i]])
         return axes
+
+    def agRFFT2d(op, shape, rng):
+        args = []
+
+        # Must be rank 3 input tensor
+        if len(shape) != 3:
+            return []
+
+        # Check rfft2d with enforced fft_length
+        for fft_length_h in [2, 32]:
+            for fft_length_w in [2, 8, 16]:
+                fft_length = [fft_length_h, fft_length_w]
+                args.append(["_fft_length_{}x{}".format(*fft_length), [fft_length]])
+
+        # Check rfft2d with no fft_length provided (fft_length=None).
+        # In this case, the height and width of the input should be
+        # used for the calculation. Therefore, we need to check that
+        # the input shape is already a power of two.
+        def is_power_of_two(x):
+            return math.log(x, 2).is_integer()
+
+        height, width = shape[1:3]
+        if is_power_of_two(height) and is_power_of_two(width):
+            args.append(["_fft_length_None", [None]])
+
+        return args
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py
index 767989e..c534a58 100644
--- a/verif/frameworks/tensor_gen.py
+++ b/verif/frameworks/tensor_gen.py
@@ -274,3 +274,12 @@
             )
 
         return tf_placeholders, tf_consts
+
+    @staticmethod
+    def tgRFFT2d(op, shape, dtype, rng):
+        # Require rank 3 shape
+        if len(shape) != 3:
+            return [], []
+
+        tf_placeholders = [("placeholder_0", TGen.getRand(shape, dtype, rng))]
+        return tf_placeholders, []
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py
index 8870f41..6e7b6a5 100644
--- a/verif/frameworks/test_builder.py
+++ b/verif/frameworks/test_builder.py
@@ -1243,3 +1243,11 @@
 
         def eval(self, a):
             return self.dense(a)
+
+    class RFFT2d:
+        def __init__(self, fft_length, name):
+            self.fft_length = fft_length
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.signal.rfft2d(a, self.fft_length, name=self.result_name)
diff --git a/verif/frameworks/tosa_verif_framework_compiler_runner.py b/verif/frameworks/tosa_verif_framework_compiler_runner.py
index 3597f2a..c55864a 100755
--- a/verif/frameworks/tosa_verif_framework_compiler_runner.py
+++ b/verif/frameworks/tosa_verif_framework_compiler_runner.py
@@ -1,5 +1,5 @@
 #!/usr/bin/env python3
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
 # SPDX-License-Identifier: Apache-2.0
 import argparse
 import glob
@@ -483,6 +483,20 @@
         except KeyError:
             assert 0, "fail to load tflite result numpy"
 
+    # TOSA has no notion of complex datatypes, it represents complex values using two
+    # fp32 output tensors representing real and imaginary values. When legalizing
+    # complex operations from frameworks, these two output tensors are combined into
+    # a single tensor of shape [?, ..., ?, 2] whereby each inner pair of values
+    # represents the real and imaginary parts of a complex value. This is completed
+    # by inserting reshape and concatenate TOSA operations during the legalization to
+    # maintain a one-to-one correspondance with framework outputs, thus simplifying
+    # legalization. Here tf_result should also match this format before being
+    # compared to the ref model output.
+    if tf_result.dtype == np.complex64:
+        ifm_shape = tf_result.shape + (2,)
+        tf_result = tf_result.view(np.float32)
+        tf_result = tf_result.reshape(ifm_shape)
+
     # Generate test descriptor per flatbuffer generation
     # Input .npy will be shared across different frameworks
     # Output .npy will be generated in its corresponding flatbuffer
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index 5b8856d..36ddda5 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -1,5 +1,5 @@
 #!/usr/bin/env python3
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
 # SPDX-License-Identifier: Apache-2.0
 import argparse
 import os
@@ -839,6 +839,13 @@
             ]
         },
     },
+    "rfft2d": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.RFFT2d, TGen.tgRFFT2d, ArgGen.agRFFT2d),
+        "types": {
+            "tflite": TYPE_F,
+        },
+    },
 }
 
 # Shapes to be tested; default can be overwritten
@@ -847,6 +854,7 @@
     (64,),
     (14, 19),
     (13, 21, 3),
+    (1, 8, 16),
     (1, 4, 4, 4),
     (1, 8, 4, 17),
     (1, 4, 8, 19),
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 4e15b06..fed91f6 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021-2022, ARM Limited.
+# Copyright (c) 2021-2023, ARM Limited.
 # SPDX-License-Identifier: Apache-2.0
 import itertools
 import math
@@ -417,6 +417,41 @@
         return [ifm_shape, filter_shape, bias_shape]
 
     @staticmethod
+    def tgRFFT2d(testGen, op, rank, error_name=None):
+        pl, const = op["operands"]
+
+        if error_name != ErrorIf.WrongRank:
+            assert rank == 3
+        assert pl == 1 and const == 0
+
+        # IFM dimensions are NHW
+        ifm_shape = testGen.makeShape(rank)
+
+        # Select nearest lower power of two from input height and width
+        ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
+        ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
+
+        # Constrict the overall size of the shape when creating ERROR_IF tests
+        if error_name:
+            ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
+
+        # Generate an invalid kernel that is not a power of two
+        if error_name == ErrorIf.KernelNotPowerOfTwo:
+            # We must increment by 2 if current size is 1
+            inc_h = 2 if ifm_shape[1] == 1 else 1
+            inc_w = 2 if ifm_shape[2] == 1 else 1
+            inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
+            selected_inc = testGen.rng.choice(inc_choices)
+            ifm_shape[1] += selected_inc[0]
+            ifm_shape[2] += selected_inc[1]
+
+        # Constrict the batch size
+        if testGen.args.max_batch_size:
+            ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
+
+        return [ifm_shape]
+
+    @staticmethod
     def tgFullyConnected(testGen, op, rank, error_name=None):
         pl, const = op["operands"]
 
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index c9d35c7..40c5d13 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -1,5 +1,7 @@
-# Copyright (c) 2021-2022, ARM Limited.
+# Copyright (c) 2021-2023, ARM Limited.
 # SPDX-License-Identifier: Apache-2.0
+import math
+
 import numpy as np
 from generator.tosa_utils import MAX_RESIZE_DIMENSION
 from generator.tosa_utils import product
@@ -76,6 +78,7 @@
     CondIfCondNotMatchingBool = "CondIfCondNotMatchingBool"
     CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne"
     CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne"
+    KernelNotPowerOfTwo = "KernelNotPowerOfTwo"
 
 
 class TosaErrorIfArgGen:
@@ -548,6 +551,10 @@
                 ):
                     error_result = True
 
+            elif op["op"] == Op.RFFT2D:
+                if not all([ty == input_dtype for ty in output_dtype]):
+                    error_result = True
+
             elif op["op"] in {
                 Op.CONV2D,
                 Op.CONV3D,
@@ -665,9 +672,13 @@
         error_reason = "Op output list does not match expected output"
 
         if check:
+            op = kwargs["op"]
             output_list = kwargs["output_list"]
-            # Note this will be incorrect if an operator returns more than one output
-            if len(output_list) != 1:
+            expected_length = 1
+            if op["op"] == Op.RFFT2D:
+                expected_length = 2
+
+            if len(output_list) != expected_length:
                 error_result = True
 
         info_dict = {
@@ -711,7 +722,7 @@
     @staticmethod
     def evBatchMismatch(check=False, **kwargs):
         error_name = ErrorIf.BatchMismatch
-        param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
+        param_reqs = {"rank": None, "dtype": None, "shape": None}
         error_result = False
         error_reason = "Input batch size not equal to output batch size"
 
@@ -722,12 +733,15 @@
 
         if check:
             input_shape = kwargs["input_shape"]
-            output_shape = kwargs[
-                "result_tensor"
-            ].shape  # Note this is just (N, OH, OW, C)
 
-            if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
-                error_result = True
+            for output in kwargs["result_tensors"]:
+                output_shape = (
+                    output.shape
+                )  # Note batch is expected to be the first dim
+                if (len(input_shape) in rank_range) and (
+                    input_shape[0] != output_shape[0]
+                ):
+                    error_result = True
 
         info_dict = {
             "error_name": error_name,
@@ -751,11 +765,12 @@
 
         if check:
             input_shape = kwargs["input_shape"]
-            output_shape = kwargs[
-                "result_tensor"
-            ].shape  # Note this is just (N, OH, OW, C)
-            if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
-                error_result = True
+            for output in kwargs["result_tensors"]:
+                output_shape = output.shape  # Note this is just (N, OH, OW, C)
+                if (len(input_shape) in rank_range) and (
+                    input_shape[3] != output_shape[3]
+                ):
+                    error_result = True
 
         info_dict = {
             "error_name": error_name,
@@ -1044,13 +1059,15 @@
             input3_shape = (
                 kwargs["input3"].shape if "input3" in kwargs else input2_shape
             )
-            output_shape = kwargs["result_tensor"].shape
-            if (
-                (len(input1_shape) != len(output_shape))
-                or (len(input2_shape) != len(output_shape))
-                or (len(input3_shape) != len(output_shape))
-            ):
-                error_result = True
+
+            for output in kwargs["result_tensors"]:
+                output_shape = output.shape
+                if (
+                    (len(input1_shape) != len(output_shape))
+                    or (len(input2_shape) != len(output_shape))
+                    or (len(input3_shape) != len(output_shape))
+                ):
+                    error_result = True
 
         info_dict = {
             "error_name": error_name,
@@ -1074,16 +1091,18 @@
             input3_shape = (
                 kwargs["input3"].shape if "input3" in kwargs else input2_shape
             )
-            output_shape = kwargs["result_tensor"].shape
-            for i in range(
-                min(len(input1_shape), len(input2_shape), len(input3_shape))
-            ):
-                if (
-                    (input1_shape[i] != 1 and input1_shape[i] != output_shape[i])
-                    or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i])
-                    or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
+
+            for output in kwargs["result_tensors"]:
+                output_shape = output.shape
+                for i in range(
+                    min(len(input1_shape), len(input2_shape), len(input3_shape))
                 ):
-                    error_result = True
+                    if (
+                        (input1_shape[i] != 1 and input1_shape[i] != output_shape[i])
+                        or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i])
+                        or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
+                    ):
+                        error_result = True
 
         info_dict = {
             "error_name": error_name,
@@ -2392,6 +2411,30 @@
         }
         return info_dict
 
+    @staticmethod
+    def evKernelNotPowerOfTwo(check=False, **kwargs):
+        error_name = ErrorIf.KernelNotPowerOfTwo
+        param_reqs = {"rank": None, "dtype": None, "shape": None}
+        error_result = False
+        error_reason = "kernel height and/or width not a power of two"
+
+        def is_power_of_two(x):
+            return math.log(x, 2).is_integer()
+
+        if check:
+            shape = kwargs["input_shape"]
+            if len(shape) == 3:
+                valid_kernel = is_power_of_two(shape[1]) and is_power_of_two(shape[2])
+                error_result = not valid_kernel
+
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs,
+        }
+        return info_dict
+
 
 class TosaInvalidValidator:
     @staticmethod
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index c29763b..fddf942 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -255,7 +255,7 @@
             input_dtype=a.dtype,
             output_dtype=result_tens.dtype,
             qinfo=qinfo,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -293,7 +293,7 @@
             input2=b,
             input_dtype=a.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -333,7 +333,7 @@
             input2=b,
             input_dtype=a.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -378,7 +378,7 @@
             input2=b,
             input_dtype=a.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -414,7 +414,7 @@
             input_shape=a.shape,
             input_dtype=a.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -448,7 +448,7 @@
             input_shape=a.shape,
             input_dtype=a.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -487,7 +487,7 @@
             input_dtype=a.dtype,
             output_shape=result_tens.shape,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -523,7 +523,7 @@
             input_dtype=a.dtype,
             output_shape=result_tens.shape,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -582,7 +582,7 @@
             stride=stride,
             pad=pad,
             qinfo=qinfo,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -938,7 +938,7 @@
             output_shape=result_tens.shape,
             output_dtype=result_tens.dtype,
             qinfo=qinfo,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -980,7 +980,7 @@
             output_shape=result_tens.shape,
             output_dtype=result_tens.dtype,
             qinfo=qinfo,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1016,7 +1016,7 @@
             output_shape=result_tens.shape,
             input_dtype=a.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1064,7 +1064,7 @@
             output_shape=result_tens.shape,
             input_dtype=a.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1122,7 +1122,7 @@
             output_shape=result_tens.shape,
             input_dtype=a.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1153,7 +1153,7 @@
             output_shape=result_tens.shape,
             input_dtype=a.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1199,7 +1199,7 @@
             input_dtype=a[0].dtype,
             output_dtype=result_tens.dtype,
             inputs=a,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1250,7 +1250,7 @@
             output_dtype=result_tens.dtype,
             pad=padding,
             qinfo=qinfo,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1283,7 +1283,7 @@
             output_shape=result_tens.shape,
             input_dtype=a.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1318,7 +1318,7 @@
             output_shape=result_tens.shape,
             input_dtype=a.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1356,7 +1356,7 @@
             perms=perms,
             input_dtype=a.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1391,7 +1391,7 @@
             output_dtype=result_tens.dtype,
             start=start,
             size=size,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1425,7 +1425,7 @@
             output_shape=result_tens.shape,
             input_dtype=a.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1474,7 +1474,7 @@
             output_shape=result_tens.shape,
             input_dtype=values.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1519,7 +1519,7 @@
             output_shape=result_tens.shape,
             input_dtype=values_in.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1580,7 +1580,7 @@
             border=border,
             input_list=input_list,
             output_list=output_list,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             num_operands=num_operands,
         ):
             return None
@@ -1628,7 +1628,7 @@
             output_shape=result_tens.shape,
             input_dtype=val.dtype,
             output_dtype=result_tens.dtype,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1774,7 +1774,7 @@
             double_round=double_round,
             input_list=input_list,
             output_list=output_list,
-            result_tensor=result_tens,
+            result_tensors=[result_tens],
             num_operands=num_operands,
         ):
             return None
@@ -2083,6 +2083,38 @@
 
         return acc_out
 
+    def build_rfft2d(self, op, val, validator_fcns=None, error_name=None):
+        results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
+
+        input_names = [val.name]
+        pCount, cCount = op["operands"]
+        num_operands = pCount + cCount
+
+        output_names = [res.name for res in results]
+        output_dtypes = [res.dtype for res in results]
+
+        input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+            self, error_name, input_names, output_names
+        )
+
+        if not TosaErrorValidator.evValidateErrorIfs(
+            self.ser,
+            validator_fcns,
+            error_name,
+            op=op,
+            input_shape=val.shape,
+            input_dtype=val.dtype,
+            output_dtype=output_dtypes,
+            result_tensors=results,
+            input_list=input_names,
+            output_list=output_names,
+            num_operands=num_operands,
+        ):
+            return None
+
+        self.ser.addOperator(op["op"], input_names, output_names)
+        return results
+
     def create_filter_lists(
         self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
     ):
@@ -3897,6 +3929,27 @@
                 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
             ),
         },
+        "rfft2d": {
+            "op": Op.RFFT2D,
+            "operands": (1, 0),
+            "rank": (3, 3),
+            "build_fcn": (
+                build_rfft2d,
+                TosaTensorGen.tgRFFT2d,
+                TosaTensorValuesGen.tvgDefault,
+                TosaArgGen.agNone,
+            ),
+            "types": [DType.FP32],
+            "error_if_validators": (
+                TosaErrorValidator.evWrongInputType,
+                TosaErrorValidator.evWrongOutputType,
+                TosaErrorValidator.evWrongInputList,
+                TosaErrorValidator.evWrongOutputList,
+                TosaErrorValidator.evWrongRank,
+                TosaErrorValidator.evBatchMismatch,
+                TosaErrorValidator.evKernelNotPowerOfTwo,
+            ),
+        },
     }
 
 
@@ -4717,3 +4770,26 @@
             out_dtype = rng.choice(wrong_dtypes)
 
         return ser.addOutput(output_shape, out_dtype)
+
+    @staticmethod
+    def rfft2dOp(serializer, rng, value, error_name=None):
+        outputs = []
+
+        input_shape = value.shape
+        if error_name != ErrorIf.WrongRank:
+            assert len(input_shape) == 3
+
+        output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
+
+        output_dtype = value.dtype
+        if error_name == ErrorIf.WrongOutputType:
+            excludes = [DType.FP32]
+            wrong_dtypes = list(usableDTypes(excludes=excludes))
+            output_dtype = rng.choice(wrong_dtypes)
+        elif error_name == ErrorIf.BatchMismatch:
+            incorrect_batch = input_shape[0] + rng.integers(1, 10)
+            output_shape = [incorrect_batch, *input_shape[1:]]
+
+        outputs.append(serializer.addOutput(output_shape, output_dtype))
+        outputs.append(serializer.addOutput(output_shape, output_dtype))
+        return outputs