[MLBEDSW-2928] Add batching to softmax
Added batching to softmax by reshaping the input.
Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
Change-Id: I0b516f9bf2410fb86372b229beba4a7280c498cc
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index e0ee616..86cc3c0 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -420,8 +420,8 @@
if ifm_tensor.dtype not in (DataType.uint8, DataType.int8, DataType.int16):
return False
- # check batch size
- if len(ifm_tensor.shape) in (2, 4) and ifm_tensor.shape[0] != 1:
+ # check shape
+ if len(ifm_tensor.shape) > 4 or ifm_tensor.shape != ofm_tensor.shape:
return False
return True