TOSA: Added Depthwise support
This is mainly to add support for depthwise conv2d
with dephmultiplier = 1.
(But there are no testcases suited, all I have sourced
has depth_multiplier set to 2, which is not supported.)
-Added support for depthwise conv2d.
-Added support for removing Transpose of constant data
-Added support for removing reshape
Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I143e6246becfa78fd9f7510af0bf0d6b3fbbf2c7
diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py
index eb31716..268d43c 100644
--- a/ethosu/vela/tosa_reader.py
+++ b/ethosu/vela/tosa_reader.py
@@ -33,9 +33,11 @@
from .tensor import shape_num_elements
from .tensor import Tensor
from .tflite_mapping import DataType
+from .tosa.Op import Op as TosaOp
from .tosa.TosaGraph import TosaGraph as TG
from .tosa_mapping import datatype_map
from .tosa_mapping import datatype_map_numpy
+from .tosa_mapping import TOSA_IFM_INDICES
from .tosa_mapping import tosa_operator_map
from .tosa_mapping import unsupported_tosa_operators
@@ -89,7 +91,7 @@
op_code = op_data.Op()
if op_code in unsupported_tosa_operators:
print("Unsupported Operator", op_code)
- assert False
+ return
op_type, attr_serializer, quant_serializer, indices = tosa_operator_map[op_code]
inputs = []
@@ -104,6 +106,15 @@
outputs.append(output_tens)
assert output_tens is not None
+ # Permutation attribute for TRANSPOSE is an input tensor in TOSA
+ # TODO In order to optimise Depthwise spawning from TFLite Support for removing
+ # Transpose of constant data.
+ # Moving permutation to an attribute, to match internal graph representation for now
+ perms = None
+ if op_code == TosaOp.TRANSPOSE:
+ perms = perms = inputs.pop(1)
+ indices = TOSA_IFM_INDICES
+
name = "unknown_op_name"
if len(outputs):
name = outputs[0].name
@@ -148,6 +159,7 @@
stride = op.attrs["stride"]
if len(stride) == 2:
op.attrs["strides"] = (1, stride[0], stride[1], 1)
+ del op.attrs["stride"]
else:
# TODO CONV3D more to be done....
print("Unsupported kernel dimensions: ", len(stride))
@@ -167,6 +179,11 @@
# TODO CONV3D more to be done....
print("Unsupported kernel dimensions: ", len(kernel))
assert False
+ if op.type.is_depthwise_conv2d_op():
+ op.attrs["depth_multiplier"] = op.weights.shape[3]
+
+ elif op.type == Op.Transpose:
+ op.attrs["perms"] = perms.values
if quant_serializer is not None:
quant_info = quant_serializer.deserialize(op_data)