Fix TransposeConv2d in operator API

- Change name of the TransposeConv2d attribute output_shape to out_shape
  in generate_api.py to match with TOSA specification
- Fix serialization attributes mapping for operator TransposeConv2d
- Add a unit test for TransposeConv2d operator

Signed-off-by: Dmitrii Agibov <dmitrii.agibov@arm.com>
Change-Id: I6613c0d093aeea0af30012bcc1c8e5d26dec746c
diff --git a/reference_model/include/operators.h b/reference_model/include/operators.h
index e56b882..1650ea4 100644
--- a/reference_model/include/operators.h
+++ b/reference_model/include/operators.h
@@ -112,13 +112,11 @@
     tosa_status_t tosa_run_transpose_conv2d(tosa_tensor_t client_input,
                                             tosa_tensor_t client_weight,
                                             tosa_tensor_t client_bias,
+                                            const int32_t client_out_pad[4],
                                             const int32_t client_stride[2],
+                                            const int32_t client_out_shape[4],
                                             const int32_t client_input_zp,
                                             const int32_t client_weight_zp,
-                                            const int32_t client_pad_len,
-                                            const int32_t client_pad[],
-                                            const int32_t client_dilation_len,
-                                            const int32_t client_dilation[],
                                             tosa_tensor_t client_output,
                                             const func_ctx_t& func_ctx);
 
diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc
index 842847e..9c7f9ef 100644
--- a/reference_model/src/operators.cc
+++ b/reference_model/src/operators.cc
@@ -465,21 +465,19 @@
     tosa_status_t tosa_run_transpose_conv2d(tosa_tensor_t client_input,
                                             tosa_tensor_t client_weight,
                                             tosa_tensor_t client_bias,
+                                            const int32_t client_out_pad[4],
                                             const int32_t client_stride[2],
+                                            const int32_t client_out_shape[4],
                                             const int32_t client_input_zp,
                                             const int32_t client_weight_zp,
-                                            const int32_t client_pad_len,
-                                            const int32_t client_pad[],
-                                            const int32_t client_dilation_len,
-                                            const int32_t client_dilation[],
                                             tosa_tensor_t client_output,
                                             const func_ctx_t& func_ctx)
     {
         // Create operator attributes
-        const std::vector<int32_t> pad(&client_pad[0], &client_pad[0] + client_pad_len);
+        const std::vector<int32_t> out_pad(&client_out_pad[0], &client_out_pad[4]);
         const std::vector<int32_t> stride(&client_stride[0], &client_stride[2]);
-        const std::vector<int32_t> dilation(&client_dilation[0], &client_dilation[0] + client_dilation_len);
-        TosaConvAttribute attr(pad, stride, dilation, client_input_zp, client_weight_zp);
+        const std::vector<int32_t> out_shape(&client_out_shape[0], &client_out_shape[4]);
+        TosaTransposeConvAttribute attr(out_pad, stride, out_shape, client_input_zp, client_weight_zp);
 
         // Create tensors
         tosa::TosaSerializationTensor* input  = translate_client_tensor(client_input, "input");
@@ -489,7 +487,7 @@
 
         // Create operator
         auto op = new tosa::TosaSerializationOperator(
-            tosa::Op::Op_TRANSPOSE_CONV2D, tosa::Attribute::Attribute_ConvAttribute, &attr,
+            tosa::Op::Op_TRANSPOSE_CONV2D, tosa::Attribute::Attribute_TransposeConvAttribute, &attr,
             { input->GetName(), weight->GetName(), bias->GetName() }, { output->GetName() });
 
         // Create a tosa single-op basic block
diff --git a/reference_model/test/model_runner_tests.cpp b/reference_model/test/model_runner_tests.cpp
index e838ea1..71e26c9 100644
--- a/reference_model/test/model_runner_tests.cpp
+++ b/reference_model/test/model_runner_tests.cpp
@@ -180,6 +180,66 @@
         compareOutput(dstData, expectedData, expectedData.size());
     }
 
+    TEST_CASE("op_entry_transpose_conv2d")
+    {
+        // Transpose Conv 2D parameters
+        const int32_t stride[2]    = { 1, 1 };
+        const int32_t out_pad[4]   = { 0, 0, 0, 0 };
+        const int32_t out_shape[4] = { 1, 32, 32, 16 };
+
+        // Inputs/Outputs
+        tosa_datatype_t dt                = tosa_datatype_fp32_t;
+        std::vector<int32_t> input_shape  = { 1, 32, 32, 8 };
+        std::vector<int32_t> output_shape = { 1, 32, 32, 16 };
+        std::vector<int32_t> weight_shape = { 16, 1, 1, 8 };
+        std::vector<int32_t> bias_shape   = { 16 };
+
+        std::vector<float> srcData(32 * 32 * 8, 1.0f);
+        std::vector<float> dstData(32 * 32 * 16, 0.f);
+        std::vector<float> biasData(16, 0.f);
+        std::vector<float> weightData(16 * 8, 1.0f);
+
+        tosa_tensor_t input;
+        input.shape     = input_shape.data();
+        input.num_dims  = input_shape.size();
+        input.data_type = dt;
+        input.data      = reinterpret_cast<uint8_t*>(srcData.data());
+        input.size      = srcData.size() * sizeof(float);
+
+        tosa_tensor_t weight;
+        weight.shape     = weight_shape.data();
+        weight.num_dims  = weight_shape.size();
+        weight.data_type = dt;
+        weight.data      = reinterpret_cast<uint8_t*>(weightData.data());
+        weight.size      = weightData.size() * sizeof(float);
+
+        tosa_tensor_t bias;
+        bias.shape     = bias_shape.data();
+        bias.num_dims  = bias_shape.size();
+        bias.data_type = dt;
+        bias.data      = reinterpret_cast<uint8_t*>(biasData.data());
+        bias.size      = biasData.size() * sizeof(float);
+
+        tosa_tensor_t output;
+        output.shape     = output_shape.data();
+        output.num_dims  = output_shape.size();
+        output.data_type = dt;
+        output.data      = reinterpret_cast<uint8_t*>(dstData.data());
+        output.size      = dstData.size() * sizeof(float);
+
+        const int32_t input_zp  = 0;
+        const int32_t weight_zp = 0;
+
+        // Execution
+        auto status =
+            tosa_run_transpose_conv2d(input, weight, bias, out_pad, stride, out_shape, input_zp, weight_zp, output, {});
+        CHECK((status == tosa_status_valid));
+
+        // Compare results
+        std::vector<float> expectedData(32 * 32 * 16, 8.0f);
+        compareOutput(dstData, expectedData, expectedData.size());
+    }
+
     TEST_CASE("op_entry_conv2d_abs_mode")
     {
         // Conv parameters
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py
index 99639f4..d9077f0 100644
--- a/scripts/operator_api/generate_api.py
+++ b/scripts/operator_api/generate_api.py
@@ -64,7 +64,7 @@
         "fully_connected": "FullyConnected",
         "matmul": "MatMul",
         "max_pool2d": "Pool",
-        "transpose_conv2d": "Conv",
+        "transpose_conv2d": "TransposeConv",
         "clamp": "Clamp",
         "arithmetic_right_shift": "ArithmeticRightShift",
         "mul": "Mul",
@@ -99,9 +99,16 @@
         serLibOpAtts = copy.deepcopy(allSerialLibAtts[serLibOpType])
         tosaArgsDict = {arg["name"]: arg for arg in tosaArgs}
         serTosaTypeMap = {"ResizeMode": "tosa_mode"}
-        # For reshape operator, change 'new_shape' to 'shape' to match tosa.xml
-        if tosaOpName == "reshape":
-            serLibOpAtts[0]["name"] = "shape"
+        serAttsToFix = {
+            "reshape": {"new_shape": "shape"},
+            "transpose_conv2d": {"output_shape": "out_shape"},
+        }
+        if tosaOpName in serAttsToFix:
+            # Fix attributes names to match with tosa.xml
+            for attDefName, tosaSpecName in serAttsToFix[tosaOpName].items():
+                for opAtts in serLibOpAtts:
+                    if opAtts["name"] == attDefName:
+                        opAtts["name"] = tosaSpecName
         for att in serLibOpAtts:
             attName = att["name"]
             attType = att["dType"]