Add FFT2d to the reference model

Includes:
* FFT2d reference implementation
* Basic TOSA tests

Change-Id: Ie79fcb713542345d550ec013646810c1e890e388
Signed-off-by: Luke Hutton <luke.hutton@arm.com>
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index b1a405a..8d84135 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -89,6 +89,9 @@
             DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
             DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
             break;
+        case Op_FFT2D:
+            DEF_FACTORY_ONE_TYPE(OpFFT2d, FP32);
+            break;
         case Op_FULLY_CONNECTED:
             DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16);
             DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP16, FP16, FP32);
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index 4663c47..af808e8 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -238,6 +238,86 @@
     return 0;
 }
 
+int check_fft_shape(const std::vector<int32_t>& in_real,
+                    const std::vector<int32_t>& in_imag,
+                    const std::vector<int32_t>& out_real,
+                    const std::vector<int32_t>& out_imag,
+                    std::string& msg) {
+    const bool is_rfft = in_imag.empty();
+    auto is_power_of_two = [](int32_t n) -> bool
+    {
+        return (n & (n-1)) == 0 && n > 0;
+    };
+
+    if (!is_power_of_two(in_real[1]) || !is_power_of_two(in_real[2]))
+    {
+        msg = "Input height and width must be a power of two";
+        return 1;
+    }
+
+    // RFFT does not have a second input
+    if (!is_rfft)
+    {
+        bool input_check = true;
+        for (size_t i = 0; i < in_real.size(); i++)
+        {
+            if (in_real[i] != in_imag[i])
+            {
+                input_check = false;
+                break;
+            }
+        }
+        if (!input_check)
+        {
+            msg = "Mismatch between real input shape and imaginary input shape";
+            return 1;
+        }
+    }
+
+    bool output_check = true;
+    for (size_t i = 0; i < out_real.size(); i++)
+    {
+        if (out_real[i] != out_imag[i])
+        {
+            output_check = false;
+            break;
+        }
+    }
+    if (!output_check)
+    {
+        msg = "Mismatch between real output shape and imaginary output shape";
+        return 1;
+    }
+
+    if (in_real[0] != out_real[0])
+    {
+        msg = "Input and output batch size don't match";
+        return 1;
+    }
+    if (in_real[1] != out_real[1])
+    {
+        msg = "Input and output height don't match";
+        return 1;
+    }
+
+    if (is_rfft)
+    {
+        if (in_real[2] / 2 + 1 != out_real[2])
+        {
+            msg = "Output width is expected to match input width / 2 + 1";
+            return 1;
+        }
+    } else {
+        if (in_real[2] != out_real[2])
+        {
+            msg = "Input and output width don't match";
+            return 1;
+        }
+    }
+
+    return 0;
+}
+
 template <int Rank, DType Dtype>
 OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_,
                                 TosaAttributeBase* attribute_,
@@ -1448,6 +1528,124 @@
 }
 
 template <DType Dtype>
+OpFFT2d<Dtype>::OpFFT2d(SubgraphTraverser* sgt_,
+                        TosaAttributeBase* attribute_,
+                        uint64_t id_)
+    : GraphNode(sgt_, Op_FFT2D, id_)
+{
+    setRequiredOperands(2, 2);
+    setRequiredRank(3);
+
+    INIT_ATTRIBUTE(FFT);
+}
+
+template <DType Dtype>
+OpFFT2d<Dtype>::~OpFFT2d() {
+    if (attribute)
+        delete attribute;
+}
+
+
+template <DType Dtype>
+int OpFFT2d<Dtype>::checkTensorAttributes()
+{
+    if (validateRequiredOperands())
+        return 1;
+
+    if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) ||
+        validateRequiredRank(outputs[0]) || validateRequiredRank(outputs[1]))
+    {
+        return 1;
+    }
+
+    if (inputs[0]->matchType(*outputs[0]) || inputs[1]->matchType(*outputs[1]) ||
+        inputs[0]->matchType(*inputs[1]))
+    {
+        printNodeValidationError("OpFFT2d: input and output tensor type mismatch");
+        return 1;
+    }
+
+    in_real  = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+    in_imag  = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
+    out_real = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+    out_imag = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[1]);
+
+    ASSERT_MEM(in_real && in_imag && out_real && out_imag);
+
+    std::string msg;
+    if (check_fft_shape(in_real->getShape(), in_imag->getShape(),
+                  out_real->getShape(), out_imag->getShape(), msg))
+    {
+        msg = "OpFFT2d: " + msg;
+        printNodeValidationError(msg.c_str());
+        return 1;
+    }
+
+    return 0;
+}
+
+template <DType Dtype>
+int OpFFT2d<Dtype>::eval()
+{
+    int in_real_batch = this->in_real->getShape()[0];
+    int in_real_height = this->in_real->getShape()[1];
+    int in_real_width = this->in_real->getShape()[2];
+
+    int in_imag_batch = this->in_imag->getShape()[0];
+    int in_imag_height = this->in_imag->getShape()[1];
+    int in_imag_width = this->in_imag->getShape()[2];
+
+    int out_real_batch = this->out_real->getShape()[0];
+    int out_real_height = this->out_real->getShape()[1];
+    int out_real_width = this->out_real->getShape()[2];
+
+    int out_imag_batch = this->out_imag->getShape()[0];
+    int out_imag_height = this->out_imag->getShape()[1];
+    int out_imag_width = this->out_imag->getShape()[2];
+
+    DEBUG_INFO(OP,
+               "perform OpFFT2d, input.shapes=[[%d,%d,%d],[%d,%d,%d]], output.shapes=[[%d,%d,%d],[%d,%d,%d]]",
+               in_real_batch, in_real_height, in_real_width,
+               in_imag_batch, in_imag_height, in_imag_width,
+               out_real_batch, out_real_height, out_real_width,
+               out_imag_batch, out_imag_height, out_imag_width);
+
+    OutEigenType sum_real, sum_imag, a, sign_val = 1.0;
+
+    if (attribute->inverse()) {
+        sign_val = -1.0;
+    }
+
+    for (int n = 0; n < in_real_batch; n++)
+    {
+        for (int oy = 0; oy < out_real_height; oy++)
+        {
+            for (int ox = 0; ox < out_real_width; ox++)
+            {
+                sum_real = 0.0;
+                sum_imag = 0.0;
+                for (int iy = 0; iy < in_real_height; iy++)
+                {
+                    for (int ix = 0; ix < in_real_width; ix++)
+                    {
+                        OutEigenType val_real = this->in_real->getTensor()(n, iy, ix);
+                        OutEigenType val_imag = this->in_imag->getTensor()(n, iy, ix);
+                        // Use explicit cast to ensure intermmediate calculations are completed using OutEigenType
+                        a = sign_val * 2 * M_PI * ((iy * (OutEigenType)oy) / in_real_height + (ix * (OutEigenType)ox) / in_real_width);
+                        sum_real += val_real * cos(a) + val_imag * sin(a);
+                        sum_imag += -val_real * sin(a) + val_imag * cos(a);
+                    }
+                }
+                this->out_real->getTensor()(n, oy, ox) = sum_real;
+                this->out_imag->getTensor()(n, oy, ox) = sum_imag;
+            }
+        }
+    }
+
+    return GraphNode::eval();
+}
+
+template <DType Dtype>
 OpRFFT2d<Dtype>::OpRFFT2d(SubgraphTraverser* sgt_,
                           TosaAttributeBase* attribute_,
                           uint64_t id_)
@@ -1485,45 +1683,12 @@
 
     ASSERT_MEM(in && out_real && out_imag);
 
-    auto is_power_of_two = [](int32_t n) -> bool
+    std::string msg;
+    if (check_fft_shape(in->getShape(), {},
+                  out_real->getShape(), out_imag->getShape(), msg))
     {
-        return (n & (n-1)) == 0 && n > 0;
-    };
-
-    // Input shape: [N, H, W]
-    if (!is_power_of_two(in->getShape()[1]) || !is_power_of_two(in->getShape()[2]))
-    {
-        printNodeValidationError("OpRFFT2d: input height and width must be a power of two");
-        return 1;
-    }
-
-    // Output shape: [N, H, W / 2 + 1]
-    bool output_check = true;
-    for (int32_t i = 0; i < out_real->getRank(); i++)
-    {
-        if (out_real->getShape()[i] != out_imag->getShape()[i])
-        {
-            output_check = false;
-            break;
-        }
-    }
-    if (!output_check)
-    {
-        printNodeValidationError(
-            "OpRFFT2d: Mismatch between real output shape and imaginary output shape");
-        return 1;
-    }
-
-    if (in->getShape()[0] != out_real->getShape()[0]) {
-        printNodeValidationError("OpRFFT2d: input and output batch size don't match");
-        return 1;
-    }
-    if (in->getShape()[1] != out_real->getShape()[1]) {
-        printNodeValidationError("OpRFFT2d: input and output height don't match");
-        return 1;
-    }
-    if (in->getShape()[2] / 2 + 1 != out_real->getShape()[2]) {
-        printNodeValidationError("OpRFFT2d:  output width is expected to match input width / 2 + 1");
+        msg = "OpRFFT2d: " + msg;
+        printNodeValidationError(msg.c_str());
         return 1;
     }
 
@@ -1843,6 +2008,8 @@
 DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
 DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
 
+DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP32);
+
 DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16);
 DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP32);
 DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, BF16, BF16, FP32);
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h
index 0d2b3eb..9ef4a58 100644
--- a/reference_model/src/ops/tensor_ops.h
+++ b/reference_model/src/ops/tensor_ops.h
@@ -249,6 +249,29 @@
 };
 
 template <DType Dtype>
+class OpFFT2d : public GraphNode
+{
+public:
+    OpFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
+    virtual ~OpFFT2d();
+
+    virtual int checkTensorAttributes() final;
+    virtual int eval() final;
+
+    using InEigenType   = typename GetEigenType<Dtype>::type;
+    using OutEigenType  = typename GetEigenType<Dtype>::type;
+    using TIn           = Eigen::Tensor<InEigenType, 3>;
+    using TOut          = Eigen::Tensor<OutEigenType, 3>;
+
+protected:
+    TosaReference::TensorTemplate<TIn>* in_real;
+    TosaReference::TensorTemplate<TIn>* in_imag;
+    TosaReference::TensorTemplate<TOut>* out_real;
+    TosaReference::TensorTemplate<TOut>* out_imag;
+    tosa::TosaFFTAttribute* attribute;
+};
+
+template <DType Dtype>
 class OpRFFT2d : public GraphNode
 {
 public:
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 05a7d2b..370570c 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -417,6 +417,45 @@
         return [ifm_shape, filter_shape, bias_shape]
 
     @staticmethod
+    def tgFFT2d(testGen, op, rank, error_name=None):
+        pl, const = op["operands"]
+
+        if error_name != ErrorIf.WrongRank:
+            assert rank == 3
+        assert pl == 2 and const == 0
+
+        # IFM dimensions are NHW
+        ifm_shape = testGen.makeShape(rank)
+
+        # Select nearest lower power of two from input height and width
+        ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
+        ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
+
+        # Constrict the overall size of the shape when creating ERROR_IF tests
+        if error_name:
+            ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
+
+        # Generate an invalid kernel that is not a power of two
+        if error_name == ErrorIf.KernelNotPowerOfTwo:
+            inc_h = 2 if ifm_shape[1] == 1 else 1
+            inc_w = 2 if ifm_shape[2] == 1 else 1
+            inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
+            selected_inc = testGen.rng.choice(inc_choices)
+            ifm_shape[1] += selected_inc[0]
+            ifm_shape[2] += selected_inc[1]
+
+        ifm_shape = testGen.constrictBatchSize(ifm_shape)
+
+        ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
+        if error_name == ErrorIf.FFTInputShapeMismatch:
+            modify_shape = testGen.rng.choice([0, 1])
+            # Only modify kernel (H, W)
+            modify_dim = testGen.rng.choice([1, 2])
+            ifm_shapes[modify_shape][modify_dim] *= 2
+
+        return [ifm_shapes[0], ifm_shapes[1]]
+
+    @staticmethod
     def tgRFFT2d(testGen, op, rank, error_name=None):
         pl, const = op["operands"]
 
@@ -1613,6 +1652,15 @@
 
         return arg_list
 
+    @staticmethod
+    def agFFT2d(testGen, opName, shapeList, dtype, error_name=None):
+        arg_list = []
+
+        arg_list.append(("inverseTrue", [True]))
+        arg_list.append(("inverseFalse", [False]))
+
+        return arg_list
+
     # Helper function for reshape.  Gets some factors of a larger number.
     @staticmethod
     def getFactors(val, start=1):
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index 93f975d..ee227b3 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -79,6 +79,8 @@
     CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne"
     CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne"
     KernelNotPowerOfTwo = "KernelNotPowerOfTwo"
+    FFTInputShapeMismatch = "FFTInputShapeMismatch"
+    FFTOutputShapeMismatch = "FFTOutputShapeMismatch"
 
 
 class TosaErrorIfArgGen:
@@ -562,7 +564,7 @@
                 ):
                     error_result = True
 
-            elif op["op"] == Op.RFFT2D:
+            elif op["op"] in [Op.FFT2D, Op.RFFT2D]:
                 if not all([ty == input_dtype for ty in output_dtype]):
                     error_result = True
 
@@ -686,7 +688,7 @@
             op = kwargs["op"]
             output_list = kwargs["output_list"]
             expected_length = 1
-            if op["op"] == Op.RFFT2D:
+            if op["op"] in [Op.FFT2D, Op.RFFT2D]:
                 expected_length = 2
 
             if len(output_list) != expected_length:
@@ -2446,6 +2448,64 @@
         }
         return info_dict
 
+    @staticmethod
+    def evFFTInputShapeMismatch(check=False, **kwargs):
+        error_name = ErrorIf.FFTInputShapeMismatch
+        param_reqs = {"rank": None, "dtype": None, "shape": None}
+        error_result = False
+        error_reason = "Mismatch between real and imaginary input shapes"
+
+        if check:
+            input1 = kwargs["input1"]
+            input2 = kwargs["input2"]
+
+            if input1.shape != input2.shape:
+                error_result = True
+
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs,
+        }
+        return info_dict
+
+    @staticmethod
+    def evFFTOutputShapeMismatch(check=False, **kwargs):
+        error_name = ErrorIf.FFTOutputShapeMismatch
+        param_reqs = {"rank": None, "dtype": None, "shape": None}
+        error_result = False
+        error_reason = (
+            "Mismatch between provided and expected output kernel (H, W) shape"
+        )
+
+        if check:
+            op = kwargs["op"]
+            input_shape = kwargs["input_shape"]
+
+            if len(input_shape) == 3:
+                output_shapes = kwargs["output_shape"]
+
+                # Ignoring batch size (N) from input shape
+                expected_shape = input_shape[1:]
+                if op["op"] == Op.RFFT2D:
+                    expected_shape[1] = expected_shape[1] // 2 + 1
+
+                # Ignoring batch size (N) from output shapes
+                output_shape_0 = output_shapes[0][1:]
+                output_shape_1 = output_shapes[1][1:]
+                # Ensure sure the kernel sizes (H, W) of both outputs match the expected
+                if output_shape_0 != output_shape_1 or output_shape_0 != expected_shape:
+                    error_result = True
+
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs,
+        }
+        return info_dict
+
 
 class TosaInvalidValidator:
     @staticmethod
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 5f9e2c1..2b762aa 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -213,6 +213,12 @@
         else:
             raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
 
+    def constrictBatchSize(self, shape):
+        # Limit the batch size unless an explicit target shape set
+        if self.args.max_batch_size and not self.args.target_shapes:
+            shape[0] = min(shape[0], self.args.max_batch_size)
+        return shape
+
     # Argument generators
     # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
     # Where the string descriptor is used to generate the test name and
@@ -2081,6 +2087,48 @@
 
         return acc_out
 
+    def build_fft2d(
+        self, op, val1, val2, inverse, validator_fcns=None, error_name=None
+    ):
+        results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
+
+        input_names = [val1.name, val2.name]
+        pCount, cCount = op["operands"]
+        num_operands = pCount + cCount
+
+        output_names = [res.name for res in results]
+        output_shapes = [res.shape for res in results]
+        output_dtypes = [res.dtype for res in results]
+
+        input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+            self, error_name, input_names, output_names
+        )
+
+        if not TosaErrorValidator.evValidateErrorIfs(
+            self.ser,
+            validator_fcns,
+            error_name,
+            op=op,
+            inverse=inverse,
+            input1=val1,
+            input2=val2,
+            input_shape=val1.shape,
+            input_dtype=val1.dtype,
+            output_shape=output_shapes,
+            output_dtype=output_dtypes,
+            result_tensors=results,
+            input_list=input_names,
+            output_list=output_names,
+            num_operands=num_operands,
+        ):
+            return None
+
+        attr = ts.TosaSerializerAttribute()
+        attr.FFTAttribute(inverse)
+
+        self.ser.addOperator(op["op"], input_names, output_names, attr)
+        return results
+
     def build_rfft2d(self, op, val, validator_fcns=None, error_name=None):
         results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
 
@@ -2089,6 +2137,7 @@
         num_operands = pCount + cCount
 
         output_names = [res.name for res in results]
+        output_shapes = [res.shape for res in results]
         output_dtypes = [res.dtype for res in results]
 
         input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
@@ -2102,6 +2151,7 @@
             op=op,
             input_shape=val.shape,
             input_dtype=val.dtype,
+            output_shape=output_shapes,
             output_dtype=output_dtypes,
             result_tensors=results,
             input_list=input_names,
@@ -3927,6 +3977,29 @@
                 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
             ),
         },
+        "fft2d": {
+            "op": Op.FFT2D,
+            "operands": (2, 0),
+            "rank": (3, 3),
+            "build_fcn": (
+                build_fft2d,
+                TosaTensorGen.tgFFT2d,
+                TosaTensorValuesGen.tvgDefault,
+                TosaArgGen.agFFT2d,
+            ),
+            "types": [DType.FP32],
+            "error_if_validators": (
+                TosaErrorValidator.evWrongInputType,
+                TosaErrorValidator.evWrongOutputType,
+                TosaErrorValidator.evWrongInputList,
+                TosaErrorValidator.evWrongOutputList,
+                TosaErrorValidator.evWrongRank,
+                TosaErrorValidator.evBatchMismatch,
+                TosaErrorValidator.evKernelNotPowerOfTwo,
+                TosaErrorValidator.evFFTInputShapeMismatch,
+                TosaErrorValidator.evFFTOutputShapeMismatch,
+            ),
+        },
         "rfft2d": {
             "op": Op.RFFT2D,
             "operands": (1, 0),
@@ -3946,6 +4019,7 @@
                 TosaErrorValidator.evWrongRank,
                 TosaErrorValidator.evBatchMismatch,
                 TosaErrorValidator.evKernelNotPowerOfTwo,
+                TosaErrorValidator.evFFTOutputShapeMismatch,
             ),
         },
     }
@@ -4770,6 +4844,37 @@
         return ser.addOutput(output_shape, out_dtype)
 
     @staticmethod
+    def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
+        outputs = []
+
+        assert ifm1.dtype == ifm2.dtype
+        input_dtype = ifm1.dtype
+
+        if error_name != ErrorIf.FFTInputShapeMismatch:
+            assert ifm1.shape == ifm2.shape
+
+        input_shape = ifm1.shape
+        if error_name != ErrorIf.WrongRank:
+            assert len(input_shape) == 3
+
+        output_shape = input_shape.copy()
+        output_dtype = input_dtype
+
+        if error_name == ErrorIf.WrongOutputType:
+            excludes = [DType.FP32]
+            wrong_dtypes = list(usableDTypes(excludes=excludes))
+            output_dtype = rng.choice(wrong_dtypes)
+        elif error_name == ErrorIf.BatchMismatch:
+            output_shape[0] += rng.integers(1, 10)
+        elif error_name == ErrorIf.FFTOutputShapeMismatch:
+            modify_dim = rng.choice([1, 2])
+            output_shape[modify_dim] += rng.integers(1, 10)
+
+        outputs.append(serializer.addOutput(output_shape, output_dtype))
+        outputs.append(serializer.addOutput(output_shape, output_dtype))
+        return outputs
+
+    @staticmethod
     def rfft2dOp(serializer, rng, value, error_name=None):
         outputs = []
 
@@ -4785,8 +4890,10 @@
             wrong_dtypes = list(usableDTypes(excludes=excludes))
             output_dtype = rng.choice(wrong_dtypes)
         elif error_name == ErrorIf.BatchMismatch:
-            incorrect_batch = input_shape[0] + rng.integers(1, 10)
-            output_shape = [incorrect_batch, *input_shape[1:]]
+            output_shape[0] += rng.integers(1, 10)
+        elif error_name == ErrorIf.FFTOutputShapeMismatch:
+            modify_dim = rng.choice([1, 2])
+            output_shape[modify_dim] += rng.integers(1, 10)
 
         outputs.append(serializer.addOutput(output_shape, output_dtype))
         outputs.append(serializer.addOutput(output_shape, output_dtype))