Fix TransposeConv2d in operator API
- Change name of the TransposeConv2d attribute output_shape to out_shape
in generate_api.py to match with TOSA specification
- Fix serialization attributes mapping for operator TransposeConv2d
- Add a unit test for TransposeConv2d operator
Signed-off-by: Dmitrii Agibov <dmitrii.agibov@arm.com>
Change-Id: I6613c0d093aeea0af30012bcc1c8e5d26dec746c
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py
index 99639f4..d9077f0 100644
--- a/scripts/operator_api/generate_api.py
+++ b/scripts/operator_api/generate_api.py
@@ -64,7 +64,7 @@
"fully_connected": "FullyConnected",
"matmul": "MatMul",
"max_pool2d": "Pool",
- "transpose_conv2d": "Conv",
+ "transpose_conv2d": "TransposeConv",
"clamp": "Clamp",
"arithmetic_right_shift": "ArithmeticRightShift",
"mul": "Mul",
@@ -99,9 +99,16 @@
serLibOpAtts = copy.deepcopy(allSerialLibAtts[serLibOpType])
tosaArgsDict = {arg["name"]: arg for arg in tosaArgs}
serTosaTypeMap = {"ResizeMode": "tosa_mode"}
- # For reshape operator, change 'new_shape' to 'shape' to match tosa.xml
- if tosaOpName == "reshape":
- serLibOpAtts[0]["name"] = "shape"
+ serAttsToFix = {
+ "reshape": {"new_shape": "shape"},
+ "transpose_conv2d": {"output_shape": "out_shape"},
+ }
+ if tosaOpName in serAttsToFix:
+ # Fix attributes names to match with tosa.xml
+ for attDefName, tosaSpecName in serAttsToFix[tosaOpName].items():
+ for opAtts in serLibOpAtts:
+ if opAtts["name"] == attDefName:
+ opAtts["name"] = tosaSpecName
for att in serLibOpAtts:
attName = att["name"]
attType = att["dType"]