MLBEDSW-7196 Add LSTM support

Added int8 and int16 UNIDIRECTIONAL_SEQUENCE_LSTM support.
The implementation does not include support for:
* CIFG
* Peephole
* Projection
* Normalisation

This change also:
* Removed unused Op.BlockLSTM operation type.
* Removed the only one consumer limitation on putting the SplitSliceRead
  on the tensor consumer(s), if all consumers fullfills the requirements
* Added Op.VariableTensorWrite as a Operation.memory_function to make
  sure writes to variable tensors:
  * Always use linear mode
  * Are not moved to fast scratch
  * Are not fused with other elementwise operation tensor ranges

Change-Id: Ief831738924ac3d1f2ba6d41f10bd6dc969911f3
Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
diff --git a/SUPPORTED_OPS.md b/SUPPORTED_OPS.md
index 08c63e7..f641d3f 100644
--- a/SUPPORTED_OPS.md
+++ b/SUPPORTED_OPS.md
@@ -55,6 +55,7 @@
 | SUB | [Generic](#tflite-generic-constraints), [Specific](#tflite-sub-constraints) |
 | TANH | [Generic](#tflite-generic-constraints) |
 | TRANSPOSE_CONV | [Generic](#tflite-generic-constraints), [Specific](#tflite-transpose_conv-constraints) |
+| UNIDIRECTIONAL_SEQUENCE_LSTM | [Generic](#tflite-generic-constraints), [Specific](#tflite-unidirectional_sequence_lstm-constraints) |
 | UNPACK | [Generic](#tflite-generic-constraints) |
 
 ### TFLite Generic Constraints
@@ -356,3 +357,19 @@
 - SAME padding: OFM dimensions must equal IFM dimensions multiplied by stride
 - VALID padding: OFM dimensions must equal IFM dimensions multiplied by stride,  
         minus difference between kernel size and stride
+
+### TFLite UNIDIRECTIONAL_SEQUENCE_LSTM Constraints
+
+This is a list of constraints that the UNIDIRECTIONAL_SEQUENCE_LSTM operator must satisfy in order to be scheduled on the NPU.
+
+- IFM must be int8 or int16
+- IFM and OFM data types must match
+- IFM and OFM must have 3D shape
+- Must have 24 input tensors
+- Must have 5 intermediate tensors
+- State tensors must be variable
+- Must not use CIFG
+- Must not use Peephole
+- Must not use Projection
+- Must not use Normalisation
+- All input and recurrent weights must be available
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py
index e1341d8..8279036 100644
--- a/ethosu/vela/graph_optimiser_util.py
+++ b/ethosu/vela/graph_optimiser_util.py
@@ -27,6 +27,7 @@
 from .errors import UnsupportedFeatureError
 from .errors import VelaError
 from .operation import Op
+from .operation_util import create_avgpool_nop
 from .shape4d import Shape4D
 from .tensor import create_const_tensor
 from .tensor import QuantizationParameters
@@ -101,6 +102,10 @@
     ):
         return
 
+    # Writing to the buffer of a variable tensor needs to be linear format
+    if tens.ops[0].memory_function == Op.VariableTensorWrite:
+        return
+
     # Check if any of the producers/consumers is run on CPU
     if not all(cons.run_on_npu for cons in tens.consumer_list):
         return
@@ -222,7 +227,8 @@
         cons_op.ifm_shapes[1] = op.ifm_shapes[0]
     op.ofm.consumer_list.remove(cons_op)
     op.ofm.ops = []
-    op.ifm.consumer_list.remove(op)
+    if op in op.ifm.consumer_list:
+        op.ifm.consumer_list.remove(op)
 
 
 def check_memory_only_removed(op, arch):
@@ -357,3 +363,20 @@
     op.set_ifm_ofm_shapes()
     DebugDatabase.add_optimised(op, op)
     return op
+
+
+def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
+    """Creates an average pool for the given concat op/input feature map"""
+    ofm = concat_op.ofm
+    avgpool_op = create_avgpool_nop(name)
+    avgpool_op.inputs = [ifm]
+    avgpool_op.outputs = [ofm]
+
+    avgpool_op.write_offset = write_offset
+    avgpool_op.write_shape = ifm_shape
+    ofm.ops.append(avgpool_op)
+    avgpool_op.ifm_shapes.append(ifm_shape)
+    avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
+    avgpool_op.memory_function = Op.ConcatSliceWrite
+    DebugDatabase.add_optimised(concat_op, avgpool_op)
+    return avgpool_op
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index 995a0cc..3abcfcf 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -166,9 +166,9 @@
 
 def _get_ifm_to_fuse(sched_op, target_mem_area=None, target_mem_type_set=None):
     ifm_tens = None
-    if sched_op.op_type.is_elementwise_op():
+    elem_op = sched_op.parent_op
+    if sched_op.op_type.is_elementwise_op() and elem_op.memory_function is not Op.VariableTensorWrite:
         # Check if possible to merge ifm/ofm live ranges of elementwise op
-        elem_op = sched_op.parent_op
         if not tensor_should_be_ignored(elem_op.ofm, target_mem_area, target_mem_type_set):
             # Check if overwriting the inputs can be allowed
             OpShapeTens = namedtuple("OpShapeTens", ["op_shape", "tens"])
diff --git a/ethosu/vela/lstm.py b/ethosu/vela/lstm.py
new file mode 100644
index 0000000..5a50788
--- /dev/null
+++ b/ethosu/vela/lstm.py
@@ -0,0 +1,447 @@
+# SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Description:
+# Contains implementation of UnidirectionalSequenceLstm graph optimisation.
+from enum import Enum
+from typing import Tuple
+
+import numpy as np
+
+from .data_type import DataType
+from .debug_database import DebugDatabase
+from .graph_optimiser_util import create_avg_pool_for_concat
+from .operation import ActivationFunction
+from .operation import ExplicitScaling
+from .operation import Op
+from .operation import Operation
+from .operation_util import create_add
+from .operation_util import create_fullyconnected
+from .operation_util import create_fused_activation
+from .operation_util import create_mul
+from .scaling import elementwise_mul_scale
+from .shape4d import Shape4D
+from .tensor import QuantizationParameters
+from .tensor import Tensor
+
+Q0_15_SCALE = np.float32(0.00003051757)
+"""Q0.15 scale like the reference defines it"""
+
+
+class Lstm:
+    """Lstm graph optimisation.
+
+    Unrolls a UNIDIRECTIONAL_SEQUENCE_LSTM operation into its basic operations.
+
+    Usage:
+
+    unrolled_op = Lstm(op).get_graph()
+    """
+
+    class State(Enum):
+        """States (variable tensors)"""
+
+        OUTPUT = 18  # Value = tensor index
+        CELL = 19  # Value = tensor index
+
+    def __init__(self, op):
+        self.op = op
+
+    def get_graph(self) -> Operation:
+        """Return the generated graph implementation"""
+        self.op.ofm.ops = []
+        if self.time_major:
+            output_state = self.get_initial_state(Lstm.State.OUTPUT)
+            cell_state = self.get_initial_state(Lstm.State.CELL)
+            for time in range(self.n_time):
+                feature = self.get_feature(time)
+                output_state, cell_state = self.lstm_step(feature, output_state, cell_state, time)
+                op = self.put_ofm(output_state, time)
+        else:
+            for batch in range(self.n_batch):
+                output_state = self.get_initial_state(Lstm.State.OUTPUT, batch)
+                cell_state = self.get_initial_state(Lstm.State.CELL, batch)
+                for time in range(self.n_time):
+                    feature = self.get_feature(time, batch)
+                    output_state, cell_state = self.lstm_step(feature, output_state, cell_state, time, batch)
+                    op = self.put_ofm(output_state, time, batch)
+        return op
+
+    def get_feature(self, time: int, batch: int = 0) -> Tensor:
+        """Get input feature for provided time and batch"""
+        feature = self.op.ifm.clone(f"_feature#{batch}.{time}")
+        feature.set_all_shapes([self.n_batch if self.time_major else 1, self.n_feature])
+        op = Operation(Op.SplitSliceRead, feature.name)
+        op.add_input_tensor(self.op.ifm)
+        op.set_output_tensor(feature)
+        op.set_ifm_ofm_shapes()
+        offset = [time, 0, 0] if self.time_major else [batch, time, 0]
+        op.read_offsets[0] = Shape4D.from_list(offset, 0)
+        op.read_shapes[0] = op.ofm_shapes[0]
+        DebugDatabase.add_optimised(self.op, op)
+        return feature
+
+    def get_initial_state(self, state_type: State, batch: int = 0) -> Tensor:
+        """Get state tensor for provided state type and batch"""
+        state = self.state(state_type)
+        if self.time_major:
+            # For time major just return the 2D state, since all batches
+            # are calculated at the same time
+            return state
+        else:
+            # For non time major return one batch of the 2D state
+            # by setting the read offset to the provided batch
+
+            # The cloned state tensor will share equivalence id and buffer
+            # with the variable state tensor
+            n_state = state.shape[-1]
+            state_ofm = state.clone(f"_state#{batch}")
+            # Set shape to be one batch
+            state_ofm.set_all_shapes([1, n_state])
+            # Create the op for reading one batch of the state
+            # (will be optimised away at a later stage)
+            op = Operation(Op.SplitSliceRead, state_ofm.name)
+            op.add_input_tensor(state)
+            op.set_output_tensor(state_ofm)
+            op.set_ifm_ofm_shapes()
+            # Set the read offset to the provided batch
+            op.read_offsets[0] = Shape4D.from_list([batch, 0], 0)
+            # Set the read shape to one batch, see above
+            op.read_shapes[0] = op.ofm_shapes[0]
+            DebugDatabase.add_optimised(self.op, op)
+            return state_ofm
+
+    def get_state(self, op: Operation, batch: int = 0) -> Operation:
+        """Setup the correct read offset for reading the state from
+        a variable tensor state"""
+        if not self.time_major and self.n_batch > 1:
+            op.read_offsets[0] = Shape4D.from_list([batch, 0], 0)
+            op.read_shapes[0] = Shape4D(op.ifm.shape)
+            op.ifm_shapes[0] = Shape4D([self.n_batch, op.ifm.shape[-1]])
+        return op
+
+    def put_state(self, op: Operation, state_type: State, batch: int = 0) -> Operation:
+        """Save the state for the provided batch by pointing the operations
+        ofm to the variable state tensor"""
+        # The create op functions always return 4D shape, however the state
+        # should have 2D shape for correct operation
+        op.ofm.shape = op.ofm.shape[-2:]
+        # Get state from type
+        state = self.state(state_type)
+        # By using the same equivalence_id the backing buffer for the ofm
+        # tensor will be the state variable tensor buffer
+        op.ofm.equivalence_id = state.equivalence_id
+        # Set memory function which will make the tensor be in linear format
+        # just as the state variable tensor
+        op.memory_function = Op.VariableTensorWrite
+        # Set the batch write offset into the state tensor buffer unless
+        # time_major mode when all batches are written at once
+        if not self.time_major:
+            op.write_offset = Shape4D.from_list([batch, 0], 0)
+            op.write_shape = Shape4D(op.ofm.shape)
+            op.ofm_shapes = [Shape4D(state.shape)]
+        DebugDatabase.add_optimised(self.op, op)
+        return op
+
+    def put_ofm(self, state: Tensor, time: int, batch: int = 0) -> Operation:
+        """Save the output state for the provided batch and time to OFM"""
+        name = f"{self.op.ofm.name}#{batch}.{time}"
+        offset = Shape4D.from_list([time, 0, 0] if self.time_major else [batch, time, 0], 0)
+        op = create_avg_pool_for_concat(self.op, name, state, Shape4D(state.shape), offset)
+        # The provided state tensor use the output state tensors buffer, so unless
+        # time_major mode we need to set the correct batch read offset
+        if not self.time_major:
+            op.read_offsets[0] = Shape4D.from_list([batch, 0], 0)
+            op.read_shapes[0] = Shape4D(state.shape)
+            op.ifm_shapes[0] = Shape4D(self.output_state.shape)
+        return op
+
+    def lstm_step(
+        self, feature: Tensor, output_state: Tensor, cell_state: Tensor, time: int, batch: int = 0
+    ) -> Tuple[Tensor, Tensor]:
+        """Generate one step of the LSTM implementation for the provided feature, batch and time"""
+        input_gate = self.calculate_gate(
+            f"input_gate#{batch}.{time}",
+            feature,
+            output_state,
+            self.input_to_input_weights,
+            self.input_bias,
+            self.recurrent_to_input_weights,
+            None,
+            Op.Sigmoid,
+            batch,
+        )
+        forget_gate = self.calculate_gate(
+            f"forget_gate#{batch}.{time}",
+            feature,
+            output_state,
+            self.input_to_forget_weights,
+            self.forget_bias,
+            self.recurrent_to_forget_weights,
+            None,
+            Op.Sigmoid,
+            batch,
+        )
+        cell_gate = self.calculate_gate(
+            f"cell_gate#{batch}.{time}",
+            feature,
+            output_state,
+            self.input_to_cell_weights,
+            self.cell_bias,
+            self.recurrent_to_cell_weights,
+            None,
+            Op.Tanh,
+            batch,
+        )
+        cell_state = self.calculate_cell_state(cell_state, input_gate, forget_gate, cell_gate, time, batch)
+        output_gate = self.calculate_gate(
+            f"output_gate#{batch}.{time}",
+            feature,
+            output_state,
+            self.input_to_output_weights,
+            self.output_bias,
+            self.recurrent_to_output_weights,
+            None,
+            Op.Sigmoid,
+            batch,
+        )
+        output_state = self.calculate_output_state(output_gate, cell_state, time, batch)
+        return (output_state, cell_state)
+
+    def calculate_gate(
+        self,
+        name: str,
+        input: Tensor,
+        state: Tensor,
+        input_weights: Tensor,
+        input_bias: Tensor,
+        recurrent_weights: Tensor,
+        recurrent_bias: Tensor,
+        activation: Op,
+        batch: int = 0,
+    ):
+        """Generate a gate for the provided input and weights"""
+        # Activation( Add( FC(input), FC(output state) ) )
+        # Setup fullyconnected quantization
+        q_fc = QuantizationParameters()
+        q_fc.scale_f32 = np.float32(2**-12)
+        q_fc.zero_point = 0
+        # Create fullyconnected
+        in_fc = create_fullyconnected(f"{name}:{input.name}_fc", input, input_weights, input_bias, q_fc, False)
+        re_fc = create_fullyconnected(f"{name}:{state.name}_fc", state, recurrent_weights, recurrent_bias, q_fc, False)
+        self.get_state(re_fc, batch)
+        # Change fullyconnected ofm data type
+        in_fc.ofm.dtype = DataType.int16
+        re_fc.ofm.dtype = DataType.int16
+        # Setup add quantization
+        q_add = q_fc.clone()
+        q_add.scale_f32 = np.float32(2**-15)
+        # Create add + activation
+        add = create_add(f"{name}_add", in_fc.ofm, re_fc.ofm, q_add, ActivationFunction(activation))
+        if activation is Op.Sigmoid:
+            # For Sigmoid we need to set the activation min/max values to match the possible range
+            # in the reference. The values below are the quantized min/max values that the reference
+            # can achive for the LUT based Sigmoid/Logistic. (The NPU does however have a larger range
+            # due to intermediate higher precision.)
+            # The quantized min/max values are divided by the effective output scale 0x3000 (3<<12) used for
+            # elementwise operations with fused Tanh/Sigmoid activations (to get correct scaling before the
+            # fused activation function). This will yield the dequantized min/max values which are later
+            # quantized again by the command stream generator.
+            add.activation.max = 32757 / 0x3000
+            add.activation.min = 11 / 0x3000
+        # Add to debug database
+        DebugDatabase.add_optimised(self.op, in_fc)
+        DebugDatabase.add_optimised(self.op, re_fc)
+        DebugDatabase.add_optimised(self.op, add)
+        return add.ofm
+
+    def calculate_cell_state(
+        self, cell_state: Tensor, input_gate: Tensor, forget_gate: Tensor, cell_gate: Tensor, time: int, batch: int = 0
+    ):
+        """Update the cell state from the provided gate output"""
+        # Clip( Add( Mul(cell state, forget gate), Mul(input gate, cell gate) ) )
+        base_name = f"cell_state#{batch}.{time}"
+        # Cell scale
+        cell_scale = cell_state.quantization.scale_f32
+        # Create mul(cell_state, forget_gate)
+        mul_cf = create_mul(f"{base_name}_cf_mul", cell_state, forget_gate, cell_state.quantization)
+        self.get_state(mul_cf, batch)
+        # Calculate explicit scales to match reference
+        multiplier, shift = elementwise_mul_scale(np.double(cell_scale), np.double(Q0_15_SCALE), np.double(cell_scale))
+        mul_cf.explicit_scaling = ExplicitScaling(False, [shift], [multiplier])
+        # Create mul(cell_gate, input_gate)
+        mul_ci = create_mul(f"{base_name}_ci_mul", cell_gate, input_gate, cell_state.quantization)
+        # Calculate explicit scales to match reference
+        multiplier, shift = elementwise_mul_scale(np.double(Q0_15_SCALE), np.double(Q0_15_SCALE), np.double(cell_scale))
+        mul_ci.explicit_scaling = ExplicitScaling(False, [shift], [multiplier])
+        # Setup cell clip
+        activation = None if self.cell_clip == 0 else ActivationFunction(Op.Clip)
+        if activation:
+            activation.max = self.cell_clip
+            activation.min = -self.cell_clip
+        # Create add + activation
+        add = create_add(f"{base_name}_add", mul_cf.ofm, mul_ci.ofm, cell_state.quantization, activation)
+        add.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1])
+        # Save new state
+        self.put_state(add, Lstm.State.CELL, batch)
+        # Add to debug database
+        DebugDatabase.add_optimised(self.op, mul_cf)
+        DebugDatabase.add_optimised(self.op, mul_ci)
+        DebugDatabase.add_optimised(self.op, add)
+        return add.ofm
+
+    def calculate_output_state(self, output_gate: Tensor, cell_state: Tensor, time: int, batch: int):
+        """Generate the output state from the provided gate output"""
+        # Mul( Tanh(cell state), output gate )
+        base_name = f"output_state#{batch}.{time}"
+        # Setup tanh quantization
+        q_out_tanh = QuantizationParameters()
+        q_out_tanh.scale_f32 = np.float32(2**-15)
+        q_out_tanh.zero_point = 0
+        # Create tanh(cell state)
+        tanh = create_fused_activation(Op.Tanh, f"{base_name}_tanh", cell_state, q_out_tanh)
+        self.get_state(tanh, batch)
+        # Create Mul( Tanh(cell state), output gate )
+        q_mul = self.output_state.quantization
+        mul = create_mul(f"{base_name}_mul", tanh.ofm, output_gate, q_mul, dtype=self.op.ifm.dtype)
+        # Use explicit scaling to match reference, the following line would have been the preferred way
+        # mul.forced_output_quantization = self.hidden_quantization
+        out_scale = self.hidden_quantization.scale_f32
+        multiplier, shift = elementwise_mul_scale(np.double(Q0_15_SCALE), np.double(Q0_15_SCALE), np.double(out_scale))
+        mul.explicit_scaling = ExplicitScaling(False, [shift], [multiplier])
+        # Save new state
+        self.put_state(mul, Lstm.State.OUTPUT, batch)
+        # Add to debug database
+        DebugDatabase.add_optimised(self.op, tanh)
+        DebugDatabase.add_optimised(self.op, mul)
+        return mul.ofm
+
+    def state(self, state_type: State) -> Tensor:
+        """Get state tensor from type"""
+        return self.output_state if state_type == Lstm.State.OUTPUT else self.cell_state
+
+    # Dimensions
+    @property
+    def n_feature(self) -> int:
+        return self.op.ifm.shape[-1]
+
+    @property
+    def n_time(self) -> int:
+        return self.op.ifm.shape[0 if self.time_major else 1]
+
+    @property
+    def n_batch(self) -> int:
+        return self.op.ifm.shape[1 if self.time_major else 0]
+
+    # Attributes
+    @property
+    def cell_clip(self) -> int:
+        return self.op.attrs.get("cell_clip", 0)
+
+    @property
+    def projection_clip(self) -> int:
+        return self.op.attrs.get("proj_clip", 0)
+
+    @property
+    def time_major(self) -> bool:
+        return self.op.attrs.get("time_major", False)
+
+    # Hidden (intermediate)
+    @property
+    def hidden_quantization(self) -> QuantizationParameters:
+        return self.op.intermediates[4].quantization
+
+    # Input weights
+    @property
+    def input_to_input_weights(self) -> Tensor:
+        return self.op.inputs[1]
+
+    @property
+    def input_to_forget_weights(self) -> Tensor:
+        return self.op.inputs[2]
+
+    @property
+    def input_to_cell_weights(self) -> Tensor:
+        return self.op.inputs[3]
+
+    @property
+    def input_to_output_weights(self) -> Tensor:
+        return self.op.inputs[4]
+
+    # Recurrent weights
+    @property
+    def recurrent_to_input_weights(self) -> Tensor:
+        return self.op.inputs[5]
+
+    @property
+    def recurrent_to_forget_weights(self) -> Tensor:
+        return self.op.inputs[6]
+
+    @property
+    def recurrent_to_cell_weights(self) -> Tensor:
+        return self.op.inputs[7]
+
+    @property
+    def recurrent_to_output_weights(self) -> Tensor:
+        return self.op.inputs[8]
+
+    # Peephole weights
+    @property
+    def cell_to_input_weights(self) -> Tensor:
+        return self.op.inputs[9]
+
+    @property
+    def cell_to_forget_weights(self) -> Tensor:
+        return self.op.inputs[10]
+
+    @property
+    def cell_to_output_weights(self) -> Tensor:
+        return self.op.inputs[11]
+
+    # Bias tensors
+    @property
+    def input_bias(self) -> Tensor:
+        return self.op.inputs[12]
+
+    @property
+    def forget_bias(self) -> Tensor:
+        return self.op.inputs[13]
+
+    @property
+    def cell_bias(self) -> Tensor:
+        return self.op.inputs[14]
+
+    @property
+    def output_bias(self) -> Tensor:
+        return self.op.inputs[15]
+
+    # Projection tensors
+    @property
+    def projection_weights(self) -> Tensor:
+        return self.op.inputs[16]
+
+    @property
+    def projection_bias(self) -> Tensor:
+        return self.op.inputs[17]
+
+    # State tensors (variable)
+    @property
+    def output_state(self) -> Tensor:
+        return self.op.inputs[Lstm.State.OUTPUT.value]
+
+    @property
+    def cell_state(self) -> Tensor:
+        return self.op.inputs[Lstm.State.CELL.value]
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 6771710..d167053 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -37,6 +37,7 @@
 
 # Import needed for Type annotations. Only import for Type checking to avoid run-time errors due to cyclic import.
 if TYPE_CHECKING:
+    from .tensor import QuantizationParameters
     from .tensor import Tensor
 
 PointXY = namedtuple("PointXY", "x y")
@@ -142,8 +143,6 @@
     BatchToSpaceND = OperatorInfo()
     BidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
     BidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
-    BlockLSTM = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_BLOCK_LSTM_INDICES)
-
     CLZ = OperatorInfo(
         block_type=NpuBlockType.ElementWise, indices=NNG_IFM_INDICES, is_unary=True
     )  # NPU specific operation
@@ -297,6 +296,7 @@
     Unique = OperatorInfo()
     Unpack = OperatorInfo(indices=NNG_IFM_INDICES)
     UnpackReshaped = OperatorInfo(indices=NNG_IFM_INDICES)
+    VariableTensorWrite = OperatorInfo()
     Where = OperatorInfo()
     While = OperatorInfo()
     ZerosLike = OperatorInfo()
@@ -516,8 +516,8 @@
         self.memory_function: Optional[Op] = None
         # If not none: contains QuantizationParameters to be used as output quantization
         # (which overrides the ofm tensor's quantization), used in LUT
-        self.forced_input_quantization = None
-        self.forced_output_quantization = None
+        self.forced_input_quantization: Optional[QuantizationParameters] = None
+        self.forced_output_quantization: Optional[QuantizationParameters] = None
         self.scheduled_pass = None
         self.op_index = None  # input network operator index
         self.activation_lut = None
diff --git a/ethosu/vela/operation_util.py b/ethosu/vela/operation_util.py
index 74836eb..ef4949f 100644
--- a/ethosu/vela/operation_util.py
+++ b/ethosu/vela/operation_util.py
@@ -27,8 +27,10 @@
 from .operation import Op
 from .operation import Operation
 from .operation import Padding
+from .reader_util import clone_and_reshape_tensor
 from .shape4d import Shape4D
 from .tensor import create_const_tensor
+from .tensor import create_equivalence_id
 from .tensor import QuantizationParameters
 from .tensor import Tensor
 
@@ -117,6 +119,55 @@
     return op
 
 
+def create_fused_activation(op_type: Op, name: str, ifm: Tensor, quantization: QuantizationParameters) -> Operation:
+    assert op_type.is_activation_op()
+    op = create_avgpool_nop(name)
+    op.activation = ActivationFunction(op_type)
+    ofm = Tensor(ifm.shape, ifm.dtype, f"{op.name}_tens0")
+    ofm.quantization = quantization
+    op.add_input_tensor(ifm)
+    op.set_output_tensor(ofm)
+    op.set_ifm_ofm_shapes()
+    return op
+
+
+def create_fullyconnected(
+    name: str,
+    ifm: Tensor,
+    weights: Tensor,
+    bias: Optional[Tensor],
+    quantization: QuantizationParameters,
+    vela_weight_order: bool = True,
+) -> Operation:
+    # Reshape weights if needed
+    if not vela_weight_order:
+        weights = clone_and_reshape_tensor(weights, (1, 0), False)
+
+    n_ofm = weights.shape[-1]
+
+    # Setup bias if needed
+    if not bias:
+        bias_values = [0] * n_ofm
+        dtype = DataType.int64 if ifm.dtype == DataType.int16 else DataType.int32
+        bias = create_const_tensor(f"{name}_bias", [n_ofm], dtype, bias_values)
+        # Set equivalence_id based on values to avoid placing duplicate data in flash
+        bias.equivalence_id = create_equivalence_id(tuple(bias_values))
+        bias.value_id = bias.equivalence_id
+
+    # Setup ofm
+    ofm = Tensor([ifm.shape[0], n_ofm], ifm.dtype, f"{name}_tens0")
+    ofm.quantization = quantization
+
+    # Create op and add tensors
+    op = Operation(Op.FullyConnected, name)
+    op.add_input_tensor(ifm)
+    op.add_input_tensor(weights)
+    op.add_input_tensor(bias)
+    op.set_output_tensor(ofm)
+    op.set_ifm_ofm_shapes()
+    return op
+
+
 def create_depthwise_maxpool(
     name: str,
     ifm: Tensor,
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index e43a919..932f701 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -55,8 +55,6 @@
         Op.QuantizedMatMul,
         Op.MatMul,
         Op.FullyConnected,
-        # RNN/LSTM/GRU
-        Op.BlockLSTM,
         # pooling
         Op.QuantizedMaxPool,
         Op.QuantizedAvgPool,
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index 6fcb6c1..cbd7ce4 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -1242,7 +1242,11 @@
             cost = schedule.cost_map[sched_op]
             if cost.cascade == 0 and sched_op.get_dependants():
                 ofm_tens = sched_op.ofm.connection.parent_tens
-                if not any(cons is None for cons in ofm_tens.consumer_list):
+                # Do not move subgraph outputs or Variable Tensor Writes
+                if (
+                    not any(cons is None for cons in ofm_tens.consumer_list)
+                    and sched_op.parent_op.memory_function is not Op.VariableTensorWrite
+                ):
                     if ofm_tens not in self.scratched_fms:
                         # Remember default mem area and mem type, only done once
                         self.scratched_fms[ofm_tens] = (ofm_tens.mem_area, ofm_tens.mem_type)
@@ -1260,6 +1264,7 @@
                 mem_type_set,
                 lr_graph,
             )
+
         max_mem_usage = lr_graph.get_temporal_memory_usage(fast_storage_mem_area)
 
         # If max_mem_usage does not exceed staging limit at any point all lrs fit and can stay in fast storage
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py
index fd23d04..d4c9255 100644
--- a/ethosu/vela/test/test_tflite_model_semantic.py
+++ b/ethosu/vela/test/test_tflite_model_semantic.py
@@ -576,3 +576,37 @@
     dim = create_const_tensor("expand_dims_dim", [], DataType.uint8, 0)
     op = testutil.create_op(Op.ExpandDims, [ifm, dim], ofm, set_ifm_ofm_shapes=False)
     assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_lstm_semantics():
+    # Test valid configurations
+    op = testutil.create_lstm_op(3, 12, 24, 20, DataType.int8)
+    assert semantic_checker.is_operator_semantic_valid(op)
+    assert semantic_checker.is_operator_semantic_valid(testutil.create_lstm_op(3, 12, 24, 20, DataType.int16))
+    # Test invalid datatype
+    assert not semantic_checker.is_operator_semantic_valid(testutil.create_lstm_op(3, 12, 24, 20, DataType.uint8))
+    # Test invalid shape
+    ifm_shape = op.ifm.shape
+    ofm_shape = op.ofm.shape
+    op.ifm.shape = [12, 24]
+    assert not semantic_checker.is_operator_semantic_valid(op)
+    op.ifm.shape = ifm_shape
+    op.ofm.shape = [12, 20]
+    assert not semantic_checker.is_operator_semantic_valid(op)
+    op.ofm.shape = ofm_shape
+    # Test invalid number of intermediates
+    intermediate = op.intermediates.pop()
+    assert not semantic_checker.is_operator_semantic_valid(op)
+    op.intermediates.append(intermediate)
+    op.intermediates.append(intermediate)
+    assert not semantic_checker.is_operator_semantic_valid(op)
+    op.intermediates.pop()
+    # Test invalid number of inputs
+    input = op.inputs.pop()
+    assert not semantic_checker.is_operator_semantic_valid(op)
+    op.inputs.append(input)
+    op.inputs.append(input)
+    assert not semantic_checker.is_operator_semantic_valid(op)
+    op.inputs.pop()
+    # Test restored valid configuration
+    assert semantic_checker.is_operator_semantic_valid(op)
diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py
index 2713adf..04f10e9 100644
--- a/ethosu/vela/test/test_tflite_supported_operators.py
+++ b/ethosu/vela/test/test_tflite_supported_operators.py
@@ -623,3 +623,49 @@
     assert support.is_operator_supported(op)
     op = create_mean([1, 200, 200, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
     assert not support.is_operator_supported(op)
+
+
+def test_lstm_support():
+    # Test valid configuration
+    op = testutil.create_lstm_op(3, 12, 24, 20, DataType.int8)
+    assert support.is_operator_supported(op)
+    # Test CIFG not supported
+    input_to_input_weights, recurrent_to_input_weights = op.inputs[1], op.inputs[5]
+    op.inputs[1] = None
+    assert not support.is_operator_supported(op)
+    op.inputs[1] = input_to_input_weights
+    op.inputs[5] = None
+    assert not support.is_operator_supported(op)
+    op.inputs[5] = recurrent_to_input_weights
+    # Test Peephole not supported
+    op.inputs[9] = input_to_input_weights
+    assert not support.is_operator_supported(op)
+    op.inputs[9] = None
+    op.inputs[10] = input_to_input_weights
+    assert not support.is_operator_supported(op)
+    op.inputs[10] = None
+    op.inputs[11] = input_to_input_weights
+    assert not support.is_operator_supported(op)
+    op.inputs[11] = None
+    # Test Projection not supported
+    op.inputs[16] = input_to_input_weights
+    assert not support.is_operator_supported(op)
+    op.inputs[16] = None
+    op.inputs[17] = input_to_input_weights
+    assert not support.is_operator_supported(op)
+    op.inputs[17] = None
+    # Test Normalisation not supported
+    op.inputs[20] = input_to_input_weights
+    assert not support.is_operator_supported(op)
+    op.inputs[20] = None
+    op.inputs[21] = input_to_input_weights
+    assert not support.is_operator_supported(op)
+    op.inputs[21] = None
+    op.inputs[22] = input_to_input_weights
+    assert not support.is_operator_supported(op)
+    op.inputs[22] = None
+    op.inputs[23] = input_to_input_weights
+    assert not support.is_operator_supported(op)
+    op.inputs[23] = None
+    # Test restored valid configuration
+    assert support.is_operator_supported(op)
diff --git a/ethosu/vela/test/testutil.py b/ethosu/vela/test/testutil.py
index 88fc874..e08bde2 100644
--- a/ethosu/vela/test/testutil.py
+++ b/ethosu/vela/test/testutil.py
@@ -103,7 +103,10 @@
 def create_op(op_type, inputs, output, attrs=None, set_ifm_ofm_shapes=True):
     op = Operation(op_type, output.name + "_op")
     for input in inputs:
-        op.add_input_tensor(input)
+        if input:  # Add regular tensor input
+            op.add_input_tensor(input)
+        else:  # Add optional (None) inputs for operators with sparse input positioning
+            op.inputs.append(input)
     op.set_output_tensor(output)
     if attrs is not None:
         op.attrs = attrs
@@ -112,6 +115,63 @@
     return op
 
 
+def create_lstm_op(batches, times, features, outputs, datatype):
+    input_shape = [batches, times, features]
+    output_shape = [batches, times, outputs]
+    weight_shape = [features, outputs]
+    state_shape = [batches, outputs]
+    bias_shape = [outputs]
+    ifm = Tensor(input_shape, datatype, "in")
+    ifm.quantization = default_quant_params()
+    ofm = Tensor(output_shape, datatype, "out")
+    ofm.quantization = default_quant_params()
+    bias_dtype = DataType.int64 if datatype == DataType.int16 else DataType.int32
+    bias = create_const_tensor("bias", bias_shape, bias_dtype, [0] * outputs)
+    weight_q = default_quant_params()
+    weight = create_const_tensor("weight", weight_shape, DataType.int8, np.ones(weight_shape), quantization=weight_q)
+    output_state = Tensor(state_shape, datatype, "output_state")
+    output_state.quantization = default_quant_params()
+    output_state.is_variable = True
+    cell_state = Tensor(state_shape, DataType.int16, "cell_state")
+    cell_state.quantization = default_quant_params()
+    cell_state.is_variable = True
+    intermediate = Tensor([], DataType.float32, "intermediate")
+    hidden_scale_intermediate = Tensor([], datatype, "effective_hidden_scale_intermediate")
+    hidden_scale_intermediate.quantization = default_quant_params()
+    peephole = None
+    projection = None
+    normalisation = None
+    inputs = [
+        ifm,
+        weight,
+        weight,
+        weight,
+        weight,
+        weight,
+        weight,
+        weight,
+        weight,
+        peephole,
+        peephole,
+        peephole,
+        bias,
+        bias,
+        bias,
+        bias,
+        projection,
+        projection,
+        output_state,
+        cell_state,
+        normalisation,
+        normalisation,
+        normalisation,
+        normalisation,
+    ]
+    op = create_op(Op.UnidirectionalSequenceLstm, inputs, ofm)
+    op.intermediates = [intermediate, intermediate, intermediate, intermediate, hidden_scale_intermediate]
+    return op
+
+
 def create_subgraph(op_list):
     # Creates subgraph using the given list of operations
     sg = Subgraph()
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 077f4af..478d018 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -35,11 +35,13 @@
 from .graph_optimiser_util import calc_explicit_padding
 from .graph_optimiser_util import convert_depthwise_to_conv
 from .graph_optimiser_util import convert_to_lut
+from .graph_optimiser_util import create_avg_pool_for_concat
 from .graph_optimiser_util import memory_only_ops
 from .graph_optimiser_util import move_splitsliceread_to_consumer
 from .graph_optimiser_util import needed_total_padding
 from .graph_optimiser_util import set_ifm_ofm_op_shapes
 from .graph_optimiser_util import set_tensor_equivalence
+from .lstm import Lstm
 from .numeric_util import clamp_sigmoid
 from .numeric_util import full_shape
 from .numeric_util import round_away_zero
@@ -69,23 +71,6 @@
 passthrough_nodes = (Op.Identity,)
 
 
-def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
-    """Creates an average pool for the given concat op/input feature map"""
-    ofm = concat_op.ofm
-    avgpool_op = create_avgpool_nop(name)
-    avgpool_op.inputs = [ifm]
-    avgpool_op.outputs = [ofm]
-
-    avgpool_op.write_offset = write_offset
-    avgpool_op.write_shape = ifm_shape
-    ofm.ops.append(avgpool_op)
-    avgpool_op.ifm_shapes.append(ifm_shape)
-    avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
-    avgpool_op.memory_function = Op.ConcatSliceWrite
-    DebugDatabase.add_optimised(concat_op, avgpool_op)
-    return avgpool_op
-
-
 def remove_passthrough_tensor(tens, arch, nng):
     if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
         assert len(tens.ops[0].inputs) == 1
@@ -196,17 +181,15 @@
 def remove_SplitSliceRead(op, arch):
 
     if op.type == Op.SplitSliceRead:
-        # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
-        if (
-            len(op.ofm.consumer_list) == 1
-            and op.ofm.consumer_list[0] is not None
-            and op.ofm.consumer_list[0].run_on_npu
-            and op.ofm.consumer_list[0].type not in memory_only_ops
-            and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
+        # Check if it is possible to put the SplitSliceRead on the tensor consumer(s),
+        # or if an avgpool need to be inserted
+        if op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape) and all(
+            consumer is not None and consumer.run_on_npu and consumer.type not in memory_only_ops
+            for consumer in op.ofm.consumer_list
         ):
-            # SplitSliceRead can be performed by tensor consumer
-            cons_op = op.ofm.consumer_list[0]
-            move_splitsliceread_to_consumer(op, cons_op)
+            # SplitSliceRead can be performed by tensor consumer(s)
+            for cons_op in list(op.ofm.consumer_list):
+                move_splitsliceread_to_consumer(op, cons_op)
         else:
             avgpool_op = create_avgpool_nop(op.name + "_avgpool")
             avgpool_op.add_input_tensor(op.ifm)
@@ -801,8 +784,9 @@
 
 
 def rewrite_fully_connected_input(op: Operation, arch, nng):
-
-    if op.type == Op.FullyConnected:
+    # If the operation already have a read shape do not modify
+    # the ifm shape, since that will already be correct
+    if op.type == Op.FullyConnected and not op.read_shapes[0]:
         new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2])
         assert new_shape is not None, "Tensor can not be reshaped to 2D"
         op.ifm_shapes[0] = new_shape
@@ -1080,6 +1064,13 @@
     return op
 
 
+def convert_lstm(op, arch, nng):
+    if op.type == Op.UnidirectionalSequenceLstm:
+        lstm = Lstm(op)
+        op = lstm.get_graph()
+    return op
+
+
 def convert_softmax(op, arch, nng):
     if op.type == Op.Softmax and op.run_on_npu:
         softmax = SoftMax(op)
@@ -2144,6 +2135,7 @@
         convert_mean_to_depthwise_conv_or_avgpool,
         convert_depthwise_to_conv,
         convert_conv_to_fc,
+        convert_lstm,
         convert_softmax,
         convert_prelu,
         convert_mul_max_to_abs_or_lrelu,
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index 5661f36..6ba7b83 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -193,6 +193,14 @@
         self.specific_constraints[Op.ArgMax].append(TFLiteSemantic.constraint_input_8bit)
         self.specific_constraints[Op.ArgMax].append(TFLiteSemantic.constraint_argmax_output)
 
+        # UnidirectionalSequenceLstm specific checks:
+        self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_input_signed)
+        self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_matching_in_out_types)
+        self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_lstm_dimensions)
+        self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_lstm_inputs)
+        self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_lstm_intermediates)
+        self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_lstm_variables)
+
     def is_operator_semantic_valid(self, op):
         ext_type = optype_to_builtintype(op.type)
 
@@ -628,6 +636,13 @@
         return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
 
     @staticmethod
+    def constraint_input_signed(op):
+        "IFM must be int8 or int16"
+        ifm_dtype = op.ifm.dtype
+        valid = (ifm_dtype == DataType.int8) or (ifm_dtype == DataType.int16)
+        return valid, f"Op has ifm_dtype={ifm_dtype}"
+
+    @staticmethod
     def constraint_input_8bit(op):
         "IFM must be int8 or uint8"
         ifm_dtype = op.ifm.dtype
@@ -689,6 +704,36 @@
             return False, f"IFM {op.ifm.shape} and OFM {op.ofm.shape} number of elements are not equal."
         return True, "IFM and OFM number of elements are equal."
 
+    @staticmethod
+    def constraint_lstm_dimensions(op):
+        "IFM and OFM must have 3D shape"
+        valid = len(op.ifm.shape) == len(op.ofm.shape) == 3
+        return valid, f"Op has ifm shape {op.ifm.shape} and ofm shape {op.ofm.shape}"
+
+    @staticmethod
+    def constraint_lstm_inputs(op):
+        "Must have 24 input tensors"
+        n_inputs = len(op.inputs)
+        return n_inputs == 24, f"Op has {n_inputs} inputs"
+
+    @staticmethod
+    def constraint_lstm_intermediates(op):
+        "Must have 5 intermediate tensors"
+        n_intermediates = len(op.intermediates)
+        return n_intermediates == 5, f"Op has {n_intermediates} intermediates"
+
+    @staticmethod
+    def constraint_lstm_variables(op):
+        "State tensors must be variable"
+        valid = True
+        extra = []
+        for tens in op.inputs[18:20]:
+            if not tens.is_variable:
+                valid = False
+                extra.append(tens.name)
+        extra = ", ".join(extra)
+        return valid, f"Op has non-variable state tensor(s): {extra}"
+
 
 def tflite_semantic_checker(nng):
     semantic_checker = TFLiteSemantic()
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 25f19b7..457c35e 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -69,8 +69,8 @@
         )
     )
     mac_main_ops = (
-        # RNN/LSTM/GRU
-        set((Op.BlockLSTM,))
+        # LSTM
+        set((Op.UnidirectionalSequenceLstm,))
         # conv/depthwiseconv/transposeconv
         | convolution_like_ops
         # pooling
@@ -320,6 +320,14 @@
         self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_axis)
         self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_depth)
 
+        # UnidirectionalSequenceLstm specific checks:
+        op_type = Op.UnidirectionalSequenceLstm
+        self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_cifg)
+        self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_peep_hole)
+        self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_projection)
+        self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_normalisation)
+        self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_weights)
+
     def is_operator_supported(self, op):
         ext_type = optype_to_builtintype(op.type)
         if op.type not in TFLiteSupportedOperators.supported_operators:
@@ -888,3 +896,35 @@
         "IFM depth must be no greater than 127"
         ifm_depth = op.inputs[0].shape[-1]
         return ifm_depth <= 127, f"IFM depth is {ifm_depth}"
+
+    @staticmethod
+    def constraint_lstm_no_cifg(op):
+        "Must not use CIFG"
+        cifg = None not in op.inputs[2:5] + op.inputs[6:9]
+        cifg = cifg and op.inputs[1] is None
+        cifg = cifg and op.inputs[5] is None
+        return not cifg, "Op uses CIFG"
+
+    @staticmethod
+    def constraint_lstm_no_peep_hole(op):
+        "Must not use Peephole"
+        valid = all([tens is None for tens in op.inputs[9:12]])
+        return valid, "Op uses peephole"
+
+    @staticmethod
+    def constraint_lstm_no_projection(op):
+        "Must not use Projection"
+        valid = all([tens is None for tens in op.inputs[16:18]])
+        return valid, "Op uses projection"
+
+    @staticmethod
+    def constraint_lstm_no_normalisation(op):
+        "Must not use Normalisation"
+        valid = all([tens is None for tens in op.inputs[20:24]])
+        return valid, "Op uses normalisation"
+
+    @staticmethod
+    def constraint_lstm_weights(op):
+        "All input and recurrent weights must be available"
+        valid = None not in op.inputs[1:9]
+        return valid, "Op has missing weights"