MLBEDSW-7196 Add LSTM support
Added int8 and int16 UNIDIRECTIONAL_SEQUENCE_LSTM support.
The implementation does not include support for:
* CIFG
* Peephole
* Projection
* Normalisation
This change also:
* Removed unused Op.BlockLSTM operation type.
* Removed the only one consumer limitation on putting the SplitSliceRead
on the tensor consumer(s), if all consumers fullfills the requirements
* Added Op.VariableTensorWrite as a Operation.memory_function to make
sure writes to variable tensors:
* Always use linear mode
* Are not moved to fast scratch
* Are not fused with other elementwise operation tensor ranges
Change-Id: Ief831738924ac3d1f2ba6d41f10bd6dc969911f3
Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 25f19b7..457c35e 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -69,8 +69,8 @@
)
)
mac_main_ops = (
- # RNN/LSTM/GRU
- set((Op.BlockLSTM,))
+ # LSTM
+ set((Op.UnidirectionalSequenceLstm,))
# conv/depthwiseconv/transposeconv
| convolution_like_ops
# pooling
@@ -320,6 +320,14 @@
self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_axis)
self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_depth)
+ # UnidirectionalSequenceLstm specific checks:
+ op_type = Op.UnidirectionalSequenceLstm
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_cifg)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_peep_hole)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_projection)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_normalisation)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_weights)
+
def is_operator_supported(self, op):
ext_type = optype_to_builtintype(op.type)
if op.type not in TFLiteSupportedOperators.supported_operators:
@@ -888,3 +896,35 @@
"IFM depth must be no greater than 127"
ifm_depth = op.inputs[0].shape[-1]
return ifm_depth <= 127, f"IFM depth is {ifm_depth}"
+
+ @staticmethod
+ def constraint_lstm_no_cifg(op):
+ "Must not use CIFG"
+ cifg = None not in op.inputs[2:5] + op.inputs[6:9]
+ cifg = cifg and op.inputs[1] is None
+ cifg = cifg and op.inputs[5] is None
+ return not cifg, "Op uses CIFG"
+
+ @staticmethod
+ def constraint_lstm_no_peep_hole(op):
+ "Must not use Peephole"
+ valid = all([tens is None for tens in op.inputs[9:12]])
+ return valid, "Op uses peephole"
+
+ @staticmethod
+ def constraint_lstm_no_projection(op):
+ "Must not use Projection"
+ valid = all([tens is None for tens in op.inputs[16:18]])
+ return valid, "Op uses projection"
+
+ @staticmethod
+ def constraint_lstm_no_normalisation(op):
+ "Must not use Normalisation"
+ valid = all([tens is None for tens in op.inputs[20:24]])
+ return valid, "Op uses normalisation"
+
+ @staticmethod
+ def constraint_lstm_weights(op):
+ "All input and recurrent weights must be available"
+ valid = None not in op.inputs[1:9]
+ return valid, "Op has missing weights"