MLBEDSW-6435: Implement support for ArgMax along depth dimension
- Add support for ArgMax along depth dimension with a depth limit of 127.
- Only supports 8-bit input and 32-bit output
Signed-off-by: Rickard Bolin <rickard.bolin@arm.com>
Change-Id: I5f6f0503135bebabbb1ca637f9729587b7c60740
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index 9f53a1e..495d71a 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -77,7 +77,9 @@
)
binary_elem_wise_main_ops = binary_elem_wise_min_max_ops | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
- shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV, Op.Mean, Op.ExpandDims, Op.Quantize))
+ shapeless_input_ops = binary_elem_wise_main_ops | set(
+ (Op.Split, Op.SplitV, Op.Mean, Op.ExpandDims, Op.Quantize, Op.ArgMax)
+ )
reshape_ops = set(
(
Op.Reshape,
@@ -187,6 +189,9 @@
self.specific_constraints[Op.Mean].append(TFLiteSemantic.constraint_mean_input_dims)
self.specific_constraints[Op.Mean].append(TFLiteSemantic.constraint_mean_axis)
+ # ArgMax specific checks:
+ self.specific_constraints[Op.ArgMax].append(TFLiteSemantic.constraint_input_8bit)
+
def is_operator_semantic_valid(self, op):
ext_type = optype_to_builtintype(op.type)
@@ -226,6 +231,9 @@
TFLiteSemantic.constraint_tens_no_dynamic,
TFLiteSemantic.constraint_tens_output_scalar,
],
+ Op.ArgMax: [
+ TFLiteSemantic.constraint_tens_quant_none_check,
+ ],
}
return generic_constraints_exclude_list