MLBEDSW-4840 Move setting of input indices to tflite reader

Mapping to internal input indexing has been added to
tflite_reader.py and tosa_reader.py.
And the other way around in tflite_writer.py.

Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I4d8596e747cfa7c4203884c4e785eb1977e2bcc1
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 0558e52..ffa4717 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -100,16 +100,16 @@
 
 TensorIndices = namedtuple("TensorIndices", ["ifms", "weights", "biases"])
 
-NO_INDICES = TensorIndices([], [], [])
-IFM_INDICES = TensorIndices([0], [], [])
-IFM_WEIGHTS_INDICES = TensorIndices([0], [1], [])
-IFM_WEIGHTS_BIAS_INDICES = TensorIndices([0], [1], [2])
-IFM_IFM2_INDICES = TensorIndices([0, 1], [], [])
-CONV2D_BACKPROP_INDICES = TensorIndices([2], [1], [3])
-TRANSPOSE_CONV_INDICES = TensorIndices([0], [1], [3])
-CONCAT_INDICES = TensorIndices([1, 2], [], [])
-SPLIT_IFM_INDICES = TensorIndices([1], [], [])
-BLOCK_LSTM_INDICES = TensorIndices([3], [4], [])
+NNG_NO_INDICES = TensorIndices([], [], [])
+NNG_IFM_INDICES = TensorIndices([0], [], [])
+NNG_IFM_WEIGHTS_INDICES = TensorIndices([0], [1], [])
+NNG_IFM_WEIGHTS_BIAS_INDICES = TensorIndices([0], [1], [2])
+NNG_IFM_IFM2_INDICES = TensorIndices([0, 1], [], [])
+NNG_CONV2D_BACKPROP_INDICES = TensorIndices([2], [1], [3])
+NNG_TRANSPOSE_CONV_INDICES = TensorIndices([0], [1], [3])
+NNG_CONCAT_INDICES = TensorIndices([1, 2], [], [])
+NNG_SPLIT_IFM_INDICES = TensorIndices([1], [], [])
+NNG_BLOCK_LSTM_INDICES = TensorIndices([3], [4], [])
 
 
 # Static information related to operation codes
@@ -117,7 +117,7 @@
     __slots__ = ("id", "block_type", "indices", "is_unary")
     _id = 0
 
-    def __init__(self, block_type=NpuBlockType.Default, indices=NO_INDICES, is_unary=False):
+    def __init__(self, block_type=NpuBlockType.Default, indices=NNG_NO_INDICES, is_unary=False):
         OperatorInfo._id += 1
         self.id = OperatorInfo._id
         self.block_type = block_type
@@ -127,37 +127,38 @@
 
 # Internally used operation codes
 class Op(Enum):
-    Abs = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True)
-    Add = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
+    Abs = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_INDICES, is_unary=True)
+    Add = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
     AddN = OperatorInfo()
     Any = OperatorInfo()
     ArgMax = OperatorInfo()
     ArgMin = OperatorInfo()
-    AvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
+    AvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
     BatchMatMul = OperatorInfo()
     BatchToSpaceND = OperatorInfo()
-    BidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
-    BidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
-    BlockLSTM = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=BLOCK_LSTM_INDICES)
+    BidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
+    BidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
+    BlockLSTM = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_BLOCK_LSTM_INDICES)
 
     CLZ = OperatorInfo(
-        block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True
+        block_type=NpuBlockType.ElementWise, indices=NNG_IFM_INDICES, is_unary=True
     )  # NPU specific operation
     Call = OperatorInfo()
     Cast = OperatorInfo()
     Ceil = OperatorInfo()
+    Clamp = OperatorInfo(indices=NNG_IFM_INDICES)  # TOSA specific
     Clip = OperatorInfo()  # NPU specific fused activation function for clipping between activation.min/max
-    Concat = OperatorInfo(indices=CONCAT_INDICES)
+    Concat = OperatorInfo(indices=NNG_CONCAT_INDICES)
     ConcatEmbeddings = OperatorInfo()
-    ConcatSliceWrite = OperatorInfo(indices=IFM_INDICES)
-    ConcatTFLite = OperatorInfo(indices=CONCAT_INDICES)
+    ConcatSliceWrite = OperatorInfo(indices=NNG_IFM_INDICES)
+    ConcatTFLite = OperatorInfo(indices=NNG_CONCAT_INDICES)
     Const = OperatorInfo()  # Constant tensor, only used in CPU subgraphs
-    Conv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES)
-    Conv2DBackpropInput = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=CONV2D_BACKPROP_INDICES)
+    Conv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=NNG_IFM_WEIGHTS_INDICES)
+    Conv2DBackpropInput = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=NNG_CONV2D_BACKPROP_INDICES)
     Conv2DBackpropInputSwitchedBias = OperatorInfo(
-        block_type=NpuBlockType.ConvolutionMxN, indices=TRANSPOSE_CONV_INDICES
+        block_type=NpuBlockType.ConvolutionMxN, indices=NNG_TRANSPOSE_CONV_INDICES
     )
-    Conv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_BIAS_INDICES)
+    Conv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=NNG_IFM_WEIGHTS_BIAS_INDICES)
     Cos = OperatorInfo()
     Cumsum = OperatorInfo()
     Custom = OperatorInfo()  # Custom 3rd party operator, only used in CPU subgraphs
@@ -165,26 +166,28 @@
     Delegate = OperatorInfo()
     Densify = OperatorInfo()
     DepthToSpace = OperatorInfo()
-    DepthwiseConv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionDepthWise, indices=IFM_WEIGHTS_BIAS_INDICES)
-    Dequantize = OperatorInfo(indices=IFM_INDICES)
+    DepthwiseConv2DBias = OperatorInfo(
+        block_type=NpuBlockType.ConvolutionDepthWise, indices=NNG_IFM_WEIGHTS_BIAS_INDICES
+    )
+    Dequantize = OperatorInfo(indices=NNG_IFM_INDICES)
     Div = OperatorInfo()
     Elu = OperatorInfo()
     EmbeddingLookup = OperatorInfo()
     EmbeddingLookupSparse = OperatorInfo()
     Equal = OperatorInfo()
     Exp = OperatorInfo()
-    ExpandDims = OperatorInfo(indices=IFM_INDICES)
+    ExpandDims = OperatorInfo(indices=NNG_IFM_INDICES)
     FakeQuantWithMinMaxArgs = OperatorInfo()
     Fill = OperatorInfo()
     Floor = OperatorInfo()
     FloorDiv = OperatorInfo()
     FloorMod = OperatorInfo()
-    FullyConnected = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_BIAS_INDICES)
+    FullyConnected = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_BIAS_INDICES)
     GatherNd = OperatorInfo()
     GatherV2 = OperatorInfo()
     Greater = OperatorInfo()
     GreaterEqual = OperatorInfo()
-    HardSwish = OperatorInfo(indices=IFM_INDICES)
+    HardSwish = OperatorInfo(indices=NNG_IFM_INDICES)
     HashtableLookup = OperatorInfo()
     Identity = OperatorInfo()
     If = OperatorInfo()
@@ -192,7 +195,7 @@
     L2Pool2D = OperatorInfo()
     LRN = OperatorInfo()
     LSHProjection = OperatorInfo()
-    LeakyRelu = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True)
+    LeakyRelu = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_INDICES, is_unary=True)
     Less = OperatorInfo()
     LessEqual = OperatorInfo()
     Log = OperatorInfo()
@@ -200,92 +203,92 @@
     LogicalAnd = OperatorInfo()
     LogicalNot = OperatorInfo()
     LogicalOr = OperatorInfo()
-    Lstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+    Lstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
     LUT = OperatorInfo()  # NPU specific, operator has LUT, only used in fused activation functions
-    MatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+    MatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
     MatrixDiag = OperatorInfo()
     MatrixSetDiag = OperatorInfo()
     Max = OperatorInfo()
-    MaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
-    Maximum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
-    Mean = OperatorInfo(indices=IFM_INDICES)
+    MaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
+    Maximum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
+    Mean = OperatorInfo(indices=NNG_IFM_INDICES)
     Min = OperatorInfo()
-    Minimum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
+    Minimum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
     MirrorPad = OperatorInfo()
-    Mul = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
+    Mul = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
     Neg = OperatorInfo()
     NonMaxSuppressionV4 = OperatorInfo()
     NonMaxSuppressionV5 = OperatorInfo()
     NotEqual = OperatorInfo()
     OneHot = OperatorInfo()
-    Pack = OperatorInfo(indices=IFM_INDICES)
-    PackReshaped = OperatorInfo(indices=IFM_INDICES)
-    Pad = OperatorInfo(indices=IFM_INDICES)
+    Pack = OperatorInfo(indices=NNG_IFM_INDICES)
+    PackReshaped = OperatorInfo(indices=NNG_IFM_INDICES)
+    Pad = OperatorInfo(indices=NNG_IFM_INDICES)
     PadV2 = OperatorInfo()
     Placeholder = OperatorInfo()  # Only used in CPU subgraphs
     Pow = OperatorInfo()
     Prelu = OperatorInfo()
     Prod = OperatorInfo()
-    Quantize = OperatorInfo(indices=IFM_INDICES)
-    QuantizedAvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
-    QuantizedConv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES)
-    QuantizedMatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
-    QuantizedMaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
-    QuantizedReshape = OperatorInfo(indices=IFM_INDICES)
+    Quantize = OperatorInfo(indices=NNG_IFM_INDICES)
+    QuantizedAvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
+    QuantizedConv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=NNG_IFM_WEIGHTS_INDICES)
+    QuantizedMatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
+    QuantizedMaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
+    QuantizedReshape = OperatorInfo(indices=NNG_IFM_INDICES)
     Range = OperatorInfo()
     Rank = OperatorInfo()
-    ReduceSum = OperatorInfo(block_type=NpuBlockType.ReduceSum, indices=IFM_INDICES)
-    Relu = OperatorInfo(indices=IFM_INDICES)
-    Relu6 = OperatorInfo(indices=IFM_INDICES)
-    ReluN1To1 = OperatorInfo(indices=IFM_INDICES)
-    ReluN = OperatorInfo(indices=IFM_INDICES)  # TOSA specific
-    Rescale = OperatorInfo(indices=IFM_INDICES)  # TOSA specific
-    RescaleAdd = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
-    Reshape = OperatorInfo(indices=IFM_INDICES)
-    ResizeBilinear = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
+    ReduceSum = OperatorInfo(block_type=NpuBlockType.ReduceSum, indices=NNG_IFM_INDICES)
+    Relu = OperatorInfo(indices=NNG_IFM_INDICES)
+    Relu6 = OperatorInfo(indices=NNG_IFM_INDICES)
+    ReluN1To1 = OperatorInfo(indices=NNG_IFM_INDICES)
+    ReluN = OperatorInfo(indices=NNG_IFM_INDICES)  # TOSA specific
+    Rescale = OperatorInfo(indices=NNG_IFM_INDICES)  # TOSA specific
+    RescaleAdd = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
+    Reshape = OperatorInfo(indices=NNG_IFM_INDICES)
+    ResizeBilinear = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
     ResizeNearestNeighbor = OperatorInfo()
     ReverseSequence = OperatorInfo()
     ReverseV2 = OperatorInfo()
-    Rnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+    Rnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
     Round = OperatorInfo()
     Rsqrt = OperatorInfo()
-    SHL = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)  # NPU specific operation
-    SHR = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)  # NPU specific operation
+    SHL = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)  # NPU specific operation
+    SHR = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)  # NPU specific operation
     ScatterNd = OperatorInfo()
     SegmentSum = OperatorInfo()
     Select = OperatorInfo()
     SelectV2 = OperatorInfo()
     Shape = OperatorInfo()
-    Sigmoid = OperatorInfo(indices=IFM_INDICES)
+    Sigmoid = OperatorInfo(indices=NNG_IFM_INDICES)
     SignBit = OperatorInfo()
     Sin = OperatorInfo()
     SkipGram = OperatorInfo()
-    Slice = OperatorInfo(indices=IFM_INDICES)
-    Softmax = OperatorInfo(indices=IFM_INDICES)
+    Slice = OperatorInfo(indices=NNG_IFM_INDICES)
+    Softmax = OperatorInfo(indices=NNG_IFM_INDICES)
     SpaceToBatchND = OperatorInfo()
     SpaceToDepth = OperatorInfo()
     SparseToDense = OperatorInfo()
-    Split = OperatorInfo(indices=SPLIT_IFM_INDICES)
-    SplitSliceRead = OperatorInfo(indices=IFM_INDICES)
-    SplitV = OperatorInfo(indices=IFM_INDICES)
+    Split = OperatorInfo(indices=NNG_SPLIT_IFM_INDICES)
+    SplitSliceRead = OperatorInfo(indices=NNG_IFM_INDICES)
+    SplitV = OperatorInfo(indices=NNG_IFM_INDICES)
     Sqrt = OperatorInfo()
     Square = OperatorInfo()
     SquaredDifference = OperatorInfo()
-    Squeeze = OperatorInfo(indices=IFM_INDICES)
-    StridedSlice = OperatorInfo(indices=IFM_INDICES)
-    Sub = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
+    Squeeze = OperatorInfo(indices=NNG_IFM_INDICES)
+    StridedSlice = OperatorInfo(indices=NNG_IFM_INDICES)
+    Sub = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
     SubgraphInput = OperatorInfo()  # Only used in CPU subgraphs
     Sum = OperatorInfo()
     Svdf = OperatorInfo()
-    Tanh = OperatorInfo(indices=IFM_INDICES)
+    Tanh = OperatorInfo(indices=NNG_IFM_INDICES)
     Tile = OperatorInfo()
     TopKV2 = OperatorInfo()
     Transpose = OperatorInfo()
-    UnidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
-    UnidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+    UnidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
+    UnidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
     Unique = OperatorInfo()
-    Unpack = OperatorInfo(indices=IFM_INDICES)
-    UnpackReshaped = OperatorInfo(indices=IFM_INDICES)
+    Unpack = OperatorInfo(indices=NNG_IFM_INDICES)
+    UnpackReshaped = OperatorInfo(indices=NNG_IFM_INDICES)
     Where = OperatorInfo()
     While = OperatorInfo()
     ZerosLike = OperatorInfo()
@@ -323,7 +326,7 @@
         return self.info.block_type == NpuBlockType.ElementWise and not self.info.is_unary
 
     def is_relu_op(self):
-        return self in (Op.Relu, Op.Relu6, Op.ReluN1To1, Op.ReluN, Op.Clip)
+        return self in (Op.Relu, Op.Relu6, Op.ReluN1To1, Op.ReluN, Op.Clip, Op.Clamp)
 
     def is_activation_op(self):
         return self.is_relu_op() or self in (Op.Tanh, Op.Sigmoid, Op.Softmax, Op.LUT, Op.HardSwish)
@@ -408,7 +411,7 @@
         act.max = 1.0
     elif op_type == Op.HardSwish:
         act.min = 0.0
-    if op_type == Op.Clip:
+    if op_type == Op.Clamp:
         assert min is not None and max is not None
         act.min = min
         act.max = max
diff --git a/ethosu/vela/reader_util.py b/ethosu/vela/reader_util.py
index 5b454b5..233286c 100644
--- a/ethosu/vela/reader_util.py
+++ b/ethosu/vela/reader_util.py
@@ -58,3 +58,30 @@
         if not tens.ops:
             op = Operation(Op.Const, tens.name)
             op.set_output_tensor(tens)
+
+
+def align_inputs_indices(from_indices, to_indices, inputs):
+    to_list = to_indices.ifms + to_indices.weights + to_indices.biases
+    from_list = from_indices.ifms + from_indices.weights + from_indices.biases
+
+    assert len(to_list) == len(from_list)
+    if to_list != from_list:
+        for idx, t_idx in enumerate(to_list):
+            if t_idx >= len(inputs):
+                # Biases are allowed to be left out
+                assert t_idx in from_indices.biases and t_idx in to_indices.biases
+                continue
+            if to_list[idx] != from_list[idx]:
+                # find t_idx in from list and swap.
+                for jdx in from_list[idx:]:
+                    if from_list[jdx] == t_idx:
+                        inputs[idx], inputs[jdx] = inputs[jdx], inputs[idx]
+                        from_list[idx], from_list[jdx] = from_list[jdx], from_list[idx]
+                        break
+    assert from_list == to_list
+    return inputs
+
+
+def align_tensor_indices_to_nng(op_type, indices, inputs):
+    nng_op = Op(op_type)
+    return align_inputs_indices(indices, nng_op.info.indices, inputs)
diff --git a/ethosu/vela/test/test_tflite_reader.py b/ethosu/vela/test/test_tflite_reader.py
index a69e8d3..664a58c 100644
--- a/ethosu/vela/test/test_tflite_reader.py
+++ b/ethosu/vela/test/test_tflite_reader.py
@@ -23,6 +23,8 @@
 
 from ethosu.vela.operation import Op
 from ethosu.vela.tflite.TensorType import TensorType
+from ethosu.vela.tflite_mapping import TFLITE_CONV2D_BACKPROP_INDICES
+from ethosu.vela.tflite_mapping import TFLITE_IFM_WEIGHTS_BIAS_INDICES
 from ethosu.vela.tflite_reader import TFLiteSubgraph
 
 
@@ -43,23 +45,25 @@
         assert output == expected
 
     parse_op_testdata = [
-        # op_type, opt_serializer, inputs, output, expected
-        (Op.FullyConnected, None, [0, 1, 2], 3, 3),  # FC
-        (Op.FullyConnected, None, [0, 1, -1], 3, 3),  # FC disabled Bias
-        (Op.FullyConnected, None, [0, 1], 3, 3),  # FC no Bias
-        (Op.Conv2D, None, [2, 1, 3], 0, 3),  # Conv2D
-        (Op.Conv2DBackpropInput, None, [0, 1, 2, 3], 4, 4),  # TransposeConv
-        (Op.Conv2DBackpropInput, None, [0, 1, 2], 4, 4),  # TransposeConv no Bias
-        pytest.param(Op.Conv2D, None, [0, -1, 1], 3, 3, marks=pytest.mark.xfail),  # Conv2D no Weights
+        # op_type, opt_serializer, indices, inputs, output, expected
+        (Op.FullyConnected, None, TFLITE_IFM_WEIGHTS_BIAS_INDICES, [0, 1, 2], 3, 3),  # FC
+        (Op.FullyConnected, None, TFLITE_IFM_WEIGHTS_BIAS_INDICES, [0, 1, -1], 3, 3),  # FC disabled Bias
+        (Op.FullyConnected, None, TFLITE_IFM_WEIGHTS_BIAS_INDICES, [0, 1], 3, 3),  # FC no Bias
+        (Op.Conv2DBias, None, TFLITE_IFM_WEIGHTS_BIAS_INDICES, [2, 1, 3], 0, 3),  # Conv2D
+        (Op.Conv2DBackpropInput, None, TFLITE_CONV2D_BACKPROP_INDICES, [0, 1, 2, 3], 4, 4),  # TransposeConv
+        (Op.Conv2DBackpropInput, None, TFLITE_CONV2D_BACKPROP_INDICES, [0, 1, 2], 4, 4),  # TransposeConv no Bias
+        pytest.param(
+            Op.Conv2DBias, None, TFLITE_IFM_WEIGHTS_BIAS_INDICES, [0, -1, 1], 3, 3, marks=pytest.mark.xfail
+        ),  # Conv2D no Weights
     ]
 
-    @pytest.mark.parametrize("op_type, opt_serializer, inputs, output, expected", parse_op_testdata)
-    def test_parse_operator(self, op_type, opt_serializer, inputs, output, expected):
+    @pytest.mark.parametrize("op_type, opt_serializer, indices, inputs, output, expected", parse_op_testdata)
+    def test_parse_operator(self, op_type, opt_serializer, indices, inputs, output, expected):
         with patch.object(TFLiteSubgraph, "__init__", lambda self, graph, subraph: None):
             # Mock a TFLiteSubGraph
             sg = TFLiteSubgraph(None, None)
             sg.graph = MagicMock()
-            sg.graph.operator_codes = [(op_type, opt_serializer, "")]
+            sg.graph.operator_codes = [(op_type, opt_serializer, "", indices)]
 
             # Mock a couple of tensors
             sg.tensors = [MagicMock() for _ in range(5)]
diff --git a/ethosu/vela/tflite_mapping.py b/ethosu/vela/tflite_mapping.py
index b526ec5..23a1a2b 100644
--- a/ethosu/vela/tflite_mapping.py
+++ b/ethosu/vela/tflite_mapping.py
@@ -25,6 +25,7 @@
 from .operation import CustomType
 from .operation import Op
 from .operation import Padding as opPad
+from .operation import TensorIndices
 from .tflite import AbsOptions
 from .tflite import AddNOptions
 from .tflite import AddOptions
@@ -489,50 +490,89 @@
 
 is_int_vec = True
 
+TFLITE_NO_INDICES = TensorIndices([], [], [])
+TFLITE_IFM_INDICES = TensorIndices([0], [], [])
+TFLITE_IFM_WEIGHTS_INDICES = TensorIndices([0], [1], [])
+TFLITE_IFM_WEIGHTS_BIAS_INDICES = TensorIndices([0], [1], [2])
+TFLITE_IFM_IFM2_INDICES = TensorIndices([0, 1], [], [])
+TFLITE_CONV2D_BACKPROP_INDICES = TensorIndices([2], [1], [3])
+TFLITE_TRANSPOSE_CONV_INDICES = TensorIndices([0], [1], [3])
+TFLITE_CONCAT_INDICES = TensorIndices([1, 2], [], [])
+TFLITE_SPLIT_IFM_INDICES = TensorIndices([1], [], [])
+TFLITE_BLOCK_LSTM_INDICES = TensorIndices([3], [4], [])
+
 builtin_operator_map = {
-    BuiltinOperator.ADD: (Op.Add, OptionsSerializer("AddOptions", (fused_act, "pot_scale_int16"))),
-    BuiltinOperator.AVERAGE_POOL_2D: (Op.AvgPool, pool2d_opts),
-    BuiltinOperator.CONCATENATION: (Op.ConcatTFLite, OptionsSerializer("ConcatenationOptions", ("axis", fused_act))),
-    BuiltinOperator.CONV_2D: (Op.Conv2DBias, conv2d_opts),
-    BuiltinOperator.DEPTHWISE_CONV_2D: (Op.DepthwiseConv2DBias, depthwise_opts),
-    BuiltinOperator.DEPTH_TO_SPACE: (Op.DepthToSpace, OptionsSerializer("DepthToSpaceOptions", ("block_size",))),
-    BuiltinOperator.DEQUANTIZE: (Op.Dequantize, OptionsSerializer("DequantizeOptions")),
-    BuiltinOperator.EMBEDDING_LOOKUP: (Op.EmbeddingLookup, None),
-    BuiltinOperator.FLOOR: (Op.Floor, None),
+    BuiltinOperator.ADD: (
+        Op.Add,
+        OptionsSerializer("AddOptions", (fused_act, "pot_scale_int16")),
+        TFLITE_IFM_IFM2_INDICES,
+    ),
+    BuiltinOperator.AVERAGE_POOL_2D: (Op.AvgPool, pool2d_opts, TFLITE_IFM_INDICES),
+    BuiltinOperator.CONCATENATION: (
+        Op.ConcatTFLite,
+        OptionsSerializer("ConcatenationOptions", ("axis", fused_act)),
+        TFLITE_CONCAT_INDICES,
+    ),
+    BuiltinOperator.CONV_2D: (Op.Conv2DBias, conv2d_opts, TFLITE_IFM_WEIGHTS_BIAS_INDICES),
+    BuiltinOperator.DEPTHWISE_CONV_2D: (Op.DepthwiseConv2DBias, depthwise_opts, TFLITE_IFM_WEIGHTS_BIAS_INDICES),
+    BuiltinOperator.DEPTH_TO_SPACE: (
+        Op.DepthToSpace,
+        OptionsSerializer("DepthToSpaceOptions", ("block_size",)),
+        TFLITE_NO_INDICES,
+    ),
+    BuiltinOperator.DEQUANTIZE: (Op.Dequantize, OptionsSerializer("DequantizeOptions"), TFLITE_IFM_INDICES),
+    BuiltinOperator.EMBEDDING_LOOKUP: (Op.EmbeddingLookup, None, TFLITE_NO_INDICES),
+    BuiltinOperator.FLOOR: (Op.Floor, None, TFLITE_NO_INDICES),
     BuiltinOperator.FULLY_CONNECTED: (
         Op.FullyConnected,
         OptionsSerializer(
             "FullyConnectedOptions", (fused_act, "weights_format", "asymmetric_quantize_inputs", "keep_num_dims")
         ),
+        TFLITE_IFM_WEIGHTS_BIAS_INDICES,
     ),
-    BuiltinOperator.HASHTABLE_LOOKUP: (Op.HashtableLookup, None),
-    BuiltinOperator.L2_NORMALIZATION: (Op.L2Norm, OptionsSerializer("L2NormOptions", (fused_act,))),
-    BuiltinOperator.L2_POOL_2D: (Op.L2Pool2D, pool2d_opts),
+    BuiltinOperator.HASHTABLE_LOOKUP: (Op.HashtableLookup, None, TFLITE_NO_INDICES),
+    BuiltinOperator.L2_NORMALIZATION: (Op.L2Norm, OptionsSerializer("L2NormOptions", (fused_act,)), TFLITE_NO_INDICES),
+    BuiltinOperator.L2_POOL_2D: (Op.L2Pool2D, pool2d_opts, TFLITE_NO_INDICES),
     BuiltinOperator.LOCAL_RESPONSE_NORMALIZATION: (
         Op.LRN,
         OptionsSerializer("LocalResponseNormalizationOptions", ("radius", "bias", "alpha", "beta")),
+        TFLITE_NO_INDICES,
     ),
-    BuiltinOperator.LOGISTIC: (Op.Sigmoid, None),
-    BuiltinOperator.LSH_PROJECTION: (Op.LSHProjection, OptionsSerializer("LSHProjectionOptions", ("type",))),
-    BuiltinOperator.LSTM: (Op.Lstm, lstm_opts),
-    BuiltinOperator.MAX_POOL_2D: (Op.MaxPool, pool2d_opts),
-    BuiltinOperator.MUL: (Op.Mul, OptionsSerializer("MulOptions", (fused_act,))),
-    BuiltinOperator.RELU: (Op.Relu, None),
-    BuiltinOperator.RELU_N1_TO_1: (Op.ReluN1To1, None),
-    BuiltinOperator.RELU6: (Op.Relu6, None),
-    BuiltinOperator.RESHAPE: (Op.Reshape, OptionsSerializer("ReshapeOptions", (("new_shape", is_int_vec),))),
+    BuiltinOperator.LOGISTIC: (Op.Sigmoid, None, TFLITE_IFM_INDICES),
+    BuiltinOperator.LSH_PROJECTION: (
+        Op.LSHProjection,
+        OptionsSerializer("LSHProjectionOptions", ("type",)),
+        TFLITE_NO_INDICES,
+    ),
+    BuiltinOperator.LSTM: (Op.Lstm, lstm_opts, TFLITE_IFM_WEIGHTS_INDICES),
+    BuiltinOperator.MAX_POOL_2D: (Op.MaxPool, pool2d_opts, TFLITE_IFM_INDICES),
+    BuiltinOperator.MUL: (Op.Mul, OptionsSerializer("MulOptions", (fused_act,)), TFLITE_IFM_IFM2_INDICES),
+    BuiltinOperator.RELU: (Op.Relu, None, TFLITE_IFM_INDICES),
+    BuiltinOperator.RELU_N1_TO_1: (Op.ReluN1To1, None, TFLITE_IFM_INDICES),
+    BuiltinOperator.RELU6: (Op.Relu6, None, TFLITE_IFM_INDICES),
+    BuiltinOperator.RESHAPE: (
+        Op.Reshape,
+        OptionsSerializer("ReshapeOptions", (("new_shape", is_int_vec),)),
+        TFLITE_IFM_INDICES,
+    ),
     BuiltinOperator.RESIZE_BILINEAR: (
         Op.ResizeBilinear,
         OptionsSerializer("ResizeBilinearOptions", ("align_corners", "half_pixel_centers")),
+        TFLITE_IFM_INDICES,
     ),
-    BuiltinOperator.RNN: (Op.Rnn, rnn_opts),
-    BuiltinOperator.SOFTMAX: (Op.Softmax, OptionsSerializer("SoftmaxOptions", ("beta",))),
-    BuiltinOperator.SPACE_TO_DEPTH: (Op.SpaceToDepth, OptionsSerializer("SpaceToDepthOptions", ("block_size",))),
+    BuiltinOperator.RNN: (Op.Rnn, rnn_opts, TFLITE_IFM_WEIGHTS_INDICES),
+    BuiltinOperator.SOFTMAX: (Op.Softmax, OptionsSerializer("SoftmaxOptions", ("beta",)), TFLITE_IFM_INDICES),
+    BuiltinOperator.SPACE_TO_DEPTH: (
+        Op.SpaceToDepth,
+        OptionsSerializer("SpaceToDepthOptions", ("block_size",)),
+        TFLITE_NO_INDICES,
+    ),
     BuiltinOperator.SVDF: (
         Op.Svdf,
         OptionsSerializer("SVDFOptions", ("rank", fused_act, "asymmetric_quantize_inputs")),
+        TFLITE_NO_INDICES,
     ),
-    BuiltinOperator.TANH: (Op.Tanh, None),
+    BuiltinOperator.TANH: (Op.Tanh, None, TFLITE_IFM_INDICES),
     BuiltinOperator.CONCAT_EMBEDDINGS: (
         Op.ConcatEmbeddings,
         OptionsSerializer(
@@ -547,40 +587,76 @@
                 "embedding_dim_per_channel_as_length",
             ),
         ),
+        TFLITE_NO_INDICES,
     ),
     BuiltinOperator.SKIP_GRAM: (
         Op.SkipGram,
         OptionsSerializer("SkipGramOptions", ("ngram_size", "max_skip_size", "include_all_ngrams")),
+        TFLITE_NO_INDICES,
     ),
-    BuiltinOperator.CALL: (Op.Call, OptionsSerializer("CallOptions", ("subgraph",))),
+    BuiltinOperator.CALL: (Op.Call, OptionsSerializer("CallOptions", ("subgraph",)), TFLITE_NO_INDICES),
     BuiltinOperator.EMBEDDING_LOOKUP_SPARSE: (
         Op.EmbeddingLookupSparse,
         OptionsSerializer("EmbeddingLookupSparseOptions", ("combiner",)),
+        TFLITE_NO_INDICES,
     ),
-    BuiltinOperator.PAD: (Op.Pad, OptionsSerializer("PadOptions")),
-    BuiltinOperator.UNIDIRECTIONAL_SEQUENCE_RNN: (Op.UnidirectionalSequenceRnn, seq_rnn_opts),
-    BuiltinOperator.GATHER: (Op.GatherV2, OptionsSerializer("GatherOptions", ("axis",))),
-    BuiltinOperator.BATCH_TO_SPACE_ND: (Op.BatchToSpaceND, OptionsSerializer("BatchToSpaceNDOptions")),
-    BuiltinOperator.SPACE_TO_BATCH_ND: (Op.SpaceToBatchND, OptionsSerializer("SpaceToBatchNDOptions")),
-    BuiltinOperator.TRANSPOSE: (Op.Transpose, OptionsSerializer("TransposeOptions")),
-    BuiltinOperator.MEAN: (Op.Mean, reducer_opts),
-    BuiltinOperator.SUB: (Op.Sub, OptionsSerializer("SubOptions", (fused_act, "pot_scale_int16",))),
-    BuiltinOperator.DIV: (Op.Div, OptionsSerializer("DivOptions", (fused_act,))),
-    BuiltinOperator.SQUEEZE: (Op.Squeeze, OptionsSerializer("SqueezeOptions", (("squeeze_dims", is_int_vec),))),
-    BuiltinOperator.UNIDIRECTIONAL_SEQUENCE_LSTM: (Op.UnidirectionalSequenceLstm, unidir_seq_lstm_opts),
+    BuiltinOperator.PAD: (Op.Pad, OptionsSerializer("PadOptions"), TFLITE_IFM_INDICES),
+    BuiltinOperator.UNIDIRECTIONAL_SEQUENCE_RNN: (
+        Op.UnidirectionalSequenceRnn,
+        seq_rnn_opts,
+        TFLITE_IFM_WEIGHTS_INDICES,
+    ),
+    BuiltinOperator.GATHER: (Op.GatherV2, OptionsSerializer("GatherOptions", ("axis",)), TFLITE_NO_INDICES),
+    BuiltinOperator.BATCH_TO_SPACE_ND: (
+        Op.BatchToSpaceND,
+        OptionsSerializer("BatchToSpaceNDOptions"),
+        TFLITE_NO_INDICES,
+    ),
+    BuiltinOperator.SPACE_TO_BATCH_ND: (
+        Op.SpaceToBatchND,
+        OptionsSerializer("SpaceToBatchNDOptions"),
+        TFLITE_NO_INDICES,
+    ),
+    BuiltinOperator.TRANSPOSE: (Op.Transpose, OptionsSerializer("TransposeOptions"), TFLITE_NO_INDICES),
+    BuiltinOperator.MEAN: (Op.Mean, reducer_opts, TFLITE_IFM_INDICES),
+    BuiltinOperator.SUB: (
+        Op.Sub,
+        OptionsSerializer("SubOptions", (fused_act, "pot_scale_int16",)),
+        TFLITE_IFM_IFM2_INDICES,
+    ),
+    BuiltinOperator.DIV: (Op.Div, OptionsSerializer("DivOptions", (fused_act,)), TFLITE_NO_INDICES),
+    BuiltinOperator.SQUEEZE: (
+        Op.Squeeze,
+        OptionsSerializer("SqueezeOptions", (("squeeze_dims", is_int_vec),)),
+        TFLITE_IFM_INDICES,
+    ),
+    BuiltinOperator.UNIDIRECTIONAL_SEQUENCE_LSTM: (
+        Op.UnidirectionalSequenceLstm,
+        unidir_seq_lstm_opts,
+        TFLITE_IFM_WEIGHTS_INDICES,
+    ),
     BuiltinOperator.STRIDED_SLICE: (
         Op.StridedSlice,
         OptionsSerializer(
-            "StridedSliceOptions", ("begin_mask", "end_mask", "ellipsis_mask", "new_axis_mask", "shrink_axis_mask")
+            "StridedSliceOptions", ("begin_mask", "end_mask", "ellipsis_mask", "new_axis_mask", "shrink_axis_mask"),
         ),
+        TFLITE_IFM_INDICES,
     ),
-    BuiltinOperator.BIDIRECTIONAL_SEQUENCE_RNN: (Op.BidirectionalSequenceRnn, bidir_seq_rnn_opts),
-    BuiltinOperator.EXP: (Op.Exp, OptionsSerializer("ExpOptions")),
-    BuiltinOperator.TOPK_V2: (Op.TopKV2, OptionsSerializer("TopKV2Options")),
-    BuiltinOperator.SPLIT: (Op.Split, OptionsSerializer("SplitOptions", ("num_splits",))),
-    BuiltinOperator.LOG_SOFTMAX: (Op.LogSoftmax, OptionsSerializer("LogSoftmaxOptions")),
-    BuiltinOperator.DELEGATE: (Op.Delegate, None),
-    BuiltinOperator.BIDIRECTIONAL_SEQUENCE_LSTM: (Op.BidirectionalSequenceLstm, bidir_seq_lstm_opts),
+    BuiltinOperator.BIDIRECTIONAL_SEQUENCE_RNN: (
+        Op.BidirectionalSequenceRnn,
+        bidir_seq_rnn_opts,
+        TFLITE_IFM_WEIGHTS_INDICES,
+    ),
+    BuiltinOperator.EXP: (Op.Exp, OptionsSerializer("ExpOptions"), TFLITE_IFM_INDICES),
+    BuiltinOperator.TOPK_V2: (Op.TopKV2, OptionsSerializer("TopKV2Options"), TFLITE_NO_INDICES),
+    BuiltinOperator.SPLIT: (Op.Split, OptionsSerializer("SplitOptions", ("num_splits",)), TFLITE_SPLIT_IFM_INDICES),
+    BuiltinOperator.LOG_SOFTMAX: (Op.LogSoftmax, OptionsSerializer("LogSoftmaxOptions"), TFLITE_NO_INDICES),
+    BuiltinOperator.DELEGATE: (Op.Delegate, None, TFLITE_NO_INDICES),
+    BuiltinOperator.BIDIRECTIONAL_SEQUENCE_LSTM: (
+        Op.BidirectionalSequenceLstm,
+        bidir_seq_lstm_opts,
+        TFLITE_IFM_WEIGHTS_INDICES,
+    ),
     BuiltinOperator.CAST: (
         Op.Cast,
         OptionsSerializer(
@@ -590,117 +666,152 @@
                 ("out_data_type", datatype_deserialize, datatype_serialize),
             ),
         ),
+        TFLITE_NO_INDICES,
     ),
-    BuiltinOperator.PRELU: (Op.Prelu, None),
-    BuiltinOperator.MAXIMUM: (Op.Maximum, OptionsSerializer("MaximumMinimumOptions")),
+    BuiltinOperator.PRELU: (Op.Prelu, None, TFLITE_NO_INDICES),
+    BuiltinOperator.MAXIMUM: (Op.Maximum, OptionsSerializer("MaximumMinimumOptions"), TFLITE_IFM_IFM2_INDICES),
     BuiltinOperator.ARG_MAX: (
         Op.ArgMax,
         OptionsSerializer("ArgMaxOptions", (("output_type", datatype_deserialize, datatype_serialize),)),
+        TFLITE_NO_INDICES,
     ),
-    BuiltinOperator.MINIMUM: (Op.Minimum, OptionsSerializer("MaximumMinimumOptions")),
-    BuiltinOperator.LESS: (Op.Less, OptionsSerializer("LessOptions")),
-    BuiltinOperator.NEG: (Op.Neg, OptionsSerializer("NegOptions")),
-    BuiltinOperator.PADV2: (Op.PadV2, OptionsSerializer("PadV2Options")),
-    BuiltinOperator.GREATER: (Op.Greater, OptionsSerializer("GreaterOptions")),
-    BuiltinOperator.GREATER_EQUAL: (Op.GreaterEqual, OptionsSerializer("GreaterEqualOptions")),
-    BuiltinOperator.LESS_EQUAL: (Op.LessEqual, OptionsSerializer("LessEqualOptions")),
-    BuiltinOperator.SELECT: (Op.Select, OptionsSerializer("SelectOptions")),
-    BuiltinOperator.SLICE: (Op.Slice, OptionsSerializer("SliceOptions")),
-    BuiltinOperator.SIN: (Op.Sin, None),
+    BuiltinOperator.MINIMUM: (Op.Minimum, OptionsSerializer("MaximumMinimumOptions"), TFLITE_IFM_IFM2_INDICES),
+    BuiltinOperator.LESS: (Op.Less, OptionsSerializer("LessOptions"), TFLITE_NO_INDICES),
+    BuiltinOperator.NEG: (Op.Neg, OptionsSerializer("NegOptions"), TFLITE_NO_INDICES),
+    BuiltinOperator.PADV2: (Op.PadV2, OptionsSerializer("PadV2Options"), TFLITE_NO_INDICES),
+    BuiltinOperator.GREATER: (Op.Greater, OptionsSerializer("GreaterOptions"), TFLITE_NO_INDICES),
+    BuiltinOperator.GREATER_EQUAL: (Op.GreaterEqual, OptionsSerializer("GreaterEqualOptions"), TFLITE_NO_INDICES),
+    BuiltinOperator.LESS_EQUAL: (Op.LessEqual, OptionsSerializer("LessEqualOptions"), TFLITE_NO_INDICES),
+    BuiltinOperator.SELECT: (Op.Select, OptionsSerializer("SelectOptions"), TFLITE_NO_INDICES),
+    BuiltinOperator.SLICE: (Op.Slice, OptionsSerializer("SliceOptions"), TFLITE_IFM_INDICES),
+    BuiltinOperator.SIN: (Op.Sin, None, TFLITE_NO_INDICES),
     BuiltinOperator.TRANSPOSE_CONV: (
         Op.Conv2DBackpropInput,
         OptionsSerializer("TransposeConvOptions", (padding, "stride_w", "stride_h")),
+        TFLITE_CONV2D_BACKPROP_INDICES,
     ),
     BuiltinOperator.SPARSE_TO_DENSE: (
         Op.SparseToDense,
         OptionsSerializer("SparseToDenseOptions", ("validate_indices",)),
+        TFLITE_NO_INDICES,
     ),
-    BuiltinOperator.TILE: (Op.Tile, OptionsSerializer("TileOptions")),
-    BuiltinOperator.EXPAND_DIMS: (Op.ExpandDims, OptionsSerializer("ExpandDimsOptions")),
-    BuiltinOperator.EQUAL: (Op.Equal, OptionsSerializer("EqualOptions")),
-    BuiltinOperator.NOT_EQUAL: (Op.NotEqual, OptionsSerializer("NotEqualOptions")),
-    BuiltinOperator.LOG: (Op.Log, None),
-    BuiltinOperator.SUM: (Op.Sum, reducer_opts),
-    BuiltinOperator.SQRT: (Op.Sqrt, None),
-    BuiltinOperator.RSQRT: (Op.Rsqrt, None),
+    BuiltinOperator.TILE: (Op.Tile, OptionsSerializer("TileOptions"), TFLITE_NO_INDICES),
+    BuiltinOperator.EXPAND_DIMS: (Op.ExpandDims, OptionsSerializer("ExpandDimsOptions"), TFLITE_IFM_INDICES),
+    BuiltinOperator.EQUAL: (Op.Equal, OptionsSerializer("EqualOptions"), TFLITE_NO_INDICES),
+    BuiltinOperator.NOT_EQUAL: (Op.NotEqual, OptionsSerializer("NotEqualOptions"), TFLITE_NO_INDICES),
+    BuiltinOperator.LOG: (Op.Log, None, TFLITE_NO_INDICES),
+    BuiltinOperator.SUM: (Op.Sum, reducer_opts, TFLITE_NO_INDICES),
+    BuiltinOperator.SQRT: (Op.Sqrt, None, TFLITE_NO_INDICES),
+    BuiltinOperator.RSQRT: (Op.Rsqrt, None, TFLITE_NO_INDICES),
     BuiltinOperator.SHAPE: (
         Op.Shape,
         OptionsSerializer("ShapeOptions", (("out_type", datatype_deserialize, datatype_serialize),)),
+        TFLITE_NO_INDICES,
     ),
-    BuiltinOperator.POW: (Op.Pow, OptionsSerializer("PowOptions")),
+    BuiltinOperator.POW: (Op.Pow, OptionsSerializer("PowOptions"), TFLITE_NO_INDICES),
     BuiltinOperator.ARG_MIN: (
         Op.ArgMin,
         OptionsSerializer("ArgMinOptions", (("output_type", datatype_deserialize, datatype_serialize),)),
+        TFLITE_NO_INDICES,
     ),
     BuiltinOperator.FAKE_QUANT: (
         Op.FakeQuantWithMinMaxArgs,
         OptionsSerializer("FakeQuantOptions", ("min", "max", "num_bits", "narrow_range")),
+        TFLITE_NO_INDICES,
     ),
-    BuiltinOperator.REDUCE_PROD: (Op.Prod, reducer_opts),
-    BuiltinOperator.REDUCE_MAX: (Op.Max, reducer_opts),
-    BuiltinOperator.PACK: (Op.Pack, OptionsSerializer("PackOptions", ("values_count", "axis"))),
-    BuiltinOperator.LOGICAL_OR: (Op.LogicalOr, OptionsSerializer("LogicalOrOptions")),
-    BuiltinOperator.ONE_HOT: (Op.OneHot, OptionsSerializer("OneHotOptions", ("axis",))),
-    BuiltinOperator.LOGICAL_AND: (Op.LogicalAnd, OptionsSerializer("LogicalAndOptions")),
-    BuiltinOperator.LOGICAL_NOT: (Op.LogicalNot, OptionsSerializer("LogicalNotOptions")),
-    BuiltinOperator.UNPACK: (Op.Unpack, OptionsSerializer("UnpackOptions", ("num", "axis"))),
-    BuiltinOperator.REDUCE_MIN: (Op.Min, reducer_opts),
-    BuiltinOperator.FLOOR_DIV: (Op.FloorDiv, OptionsSerializer("FloorDivOptions")),
-    BuiltinOperator.REDUCE_ANY: (Op.Any, reducer_opts),
-    BuiltinOperator.SQUARE: (Op.Square, OptionsSerializer("SquareOptions")),
-    BuiltinOperator.ZEROS_LIKE: (Op.ZerosLike, OptionsSerializer("ZerosLikeOptions")),
-    BuiltinOperator.FILL: (Op.Fill, OptionsSerializer("FillOptions")),
-    BuiltinOperator.FLOOR_MOD: (Op.FloorMod, OptionsSerializer("FloorModOptions")),
-    BuiltinOperator.RANGE: (Op.Range, OptionsSerializer("RangeOptions")),
+    BuiltinOperator.REDUCE_PROD: (Op.Prod, reducer_opts, TFLITE_NO_INDICES),
+    BuiltinOperator.REDUCE_MAX: (Op.Max, reducer_opts, TFLITE_NO_INDICES),
+    BuiltinOperator.PACK: (Op.Pack, OptionsSerializer("PackOptions", ("values_count", "axis")), TFLITE_IFM_INDICES),
+    BuiltinOperator.LOGICAL_OR: (Op.LogicalOr, OptionsSerializer("LogicalOrOptions"), TFLITE_NO_INDICES),
+    BuiltinOperator.ONE_HOT: (Op.OneHot, OptionsSerializer("OneHotOptions", ("axis",)), TFLITE_NO_INDICES),
+    BuiltinOperator.LOGICAL_AND: (Op.LogicalAnd, OptionsSerializer("LogicalAndOptions"), TFLITE_NO_INDICES),
+    BuiltinOperator.LOGICAL_NOT: (Op.LogicalNot, OptionsSerializer("LogicalNotOptions"), TFLITE_NO_INDICES),
+    BuiltinOperator.UNPACK: (Op.Unpack, OptionsSerializer("UnpackOptions", ("num", "axis")), TFLITE_IFM_INDICES),
+    BuiltinOperator.REDUCE_MIN: (Op.Min, reducer_opts, TFLITE_NO_INDICES),
+    BuiltinOperator.FLOOR_DIV: (Op.FloorDiv, OptionsSerializer("FloorDivOptions"), TFLITE_NO_INDICES),
+    BuiltinOperator.REDUCE_ANY: (Op.Any, reducer_opts, TFLITE_NO_INDICES),
+    BuiltinOperator.SQUARE: (Op.Square, OptionsSerializer("SquareOptions"), TFLITE_NO_INDICES),
+    BuiltinOperator.ZEROS_LIKE: (Op.ZerosLike, OptionsSerializer("ZerosLikeOptions"), TFLITE_NO_INDICES),
+    BuiltinOperator.FILL: (Op.Fill, OptionsSerializer("FillOptions"), TFLITE_NO_INDICES),
+    BuiltinOperator.FLOOR_MOD: (Op.FloorMod, OptionsSerializer("FloorModOptions"), TFLITE_NO_INDICES),
+    BuiltinOperator.RANGE: (Op.Range, OptionsSerializer("RangeOptions"), TFLITE_NO_INDICES),
     BuiltinOperator.RESIZE_NEAREST_NEIGHBOR: (
         Op.ResizeNearestNeighbor,
         OptionsSerializer("ResizeNearestNeighborOptions", ("align_corners", "half_pixel_centers")),
+        TFLITE_NO_INDICES,
     ),
-    BuiltinOperator.LEAKY_RELU: (Op.LeakyRelu, OptionsSerializer("LeakyReluOptions", ("alpha",))),
-    BuiltinOperator.SQUARED_DIFFERENCE: (Op.SquaredDifference, OptionsSerializer("SquaredDifferenceOptions")),
-    BuiltinOperator.MIRROR_PAD: (Op.MirrorPad, OptionsSerializer("MirrorPadOptions", ("mode",))),
-    BuiltinOperator.ABS: (Op.Abs, OptionsSerializer("AbsOptions")),
-    BuiltinOperator.SPLIT_V: (Op.SplitV, OptionsSerializer("SplitVOptions", ("num_splits",))),
+    BuiltinOperator.LEAKY_RELU: (Op.LeakyRelu, OptionsSerializer("LeakyReluOptions", ("alpha",)), TFLITE_IFM_INDICES),
+    BuiltinOperator.SQUARED_DIFFERENCE: (
+        Op.SquaredDifference,
+        OptionsSerializer("SquaredDifferenceOptions"),
+        TFLITE_NO_INDICES,
+    ),
+    BuiltinOperator.MIRROR_PAD: (Op.MirrorPad, OptionsSerializer("MirrorPadOptions", ("mode",)), TFLITE_NO_INDICES),
+    BuiltinOperator.ABS: (Op.Abs, OptionsSerializer("AbsOptions"), TFLITE_IFM_INDICES),
+    BuiltinOperator.SPLIT_V: (Op.SplitV, OptionsSerializer("SplitVOptions", ("num_splits",)), TFLITE_IFM_INDICES),
     BuiltinOperator.UNIQUE: (
         Op.Unique,
         OptionsSerializer("UniqueOptions", (("idx_out_type", datatype_deserialize, datatype_serialize),)),
+        TFLITE_NO_INDICES,
     ),
-    BuiltinOperator.CEIL: (Op.Ceil, None),
-    BuiltinOperator.REVERSE_V2: (Op.ReverseV2, OptionsSerializer("ReverseV2Options")),
-    BuiltinOperator.ADD_N: (Op.AddN, OptionsSerializer("AddNOptions")),
-    BuiltinOperator.GATHER_ND: (Op.GatherNd, OptionsSerializer("GatherNdOptions")),
-    BuiltinOperator.COS: (Op.Cos, OptionsSerializer("CosOptions")),
-    BuiltinOperator.WHERE: (Op.Where, OptionsSerializer("WhereOptions")),
-    BuiltinOperator.RANK: (Op.Rank, OptionsSerializer("RankOptions")),
-    BuiltinOperator.ELU: (Op.Elu, None),
+    BuiltinOperator.CEIL: (Op.Ceil, None, TFLITE_NO_INDICES),
+    BuiltinOperator.REVERSE_V2: (Op.ReverseV2, OptionsSerializer("ReverseV2Options"), TFLITE_NO_INDICES),
+    BuiltinOperator.ADD_N: (Op.AddN, OptionsSerializer("AddNOptions"), TFLITE_NO_INDICES),
+    BuiltinOperator.GATHER_ND: (Op.GatherNd, OptionsSerializer("GatherNdOptions"), TFLITE_NO_INDICES),
+    BuiltinOperator.COS: (Op.Cos, OptionsSerializer("CosOptions"), TFLITE_NO_INDICES),
+    BuiltinOperator.WHERE: (Op.Where, OptionsSerializer("WhereOptions"), TFLITE_NO_INDICES),
+    BuiltinOperator.RANK: (Op.Rank, OptionsSerializer("RankOptions"), TFLITE_NO_INDICES),
+    BuiltinOperator.ELU: (Op.Elu, None, TFLITE_NO_INDICES),
     BuiltinOperator.REVERSE_SEQUENCE: (
         Op.ReverseSequence,
         OptionsSerializer("ReverseSequenceOptions", ("seq_dim", "batch_dim")),
+        TFLITE_NO_INDICES,
     ),
-    BuiltinOperator.MATRIX_DIAG: (Op.MatrixDiag, OptionsSerializer("MatrixDiagOptions")),
-    BuiltinOperator.QUANTIZE: (Op.Quantize, OptionsSerializer("QuantizeOptions")),
-    BuiltinOperator.MATRIX_SET_DIAG: (Op.MatrixSetDiag, OptionsSerializer("MatrixSetDiagOptions")),
-    BuiltinOperator.ROUND: (Op.Round, None),
-    BuiltinOperator.HARD_SWISH: (Op.HardSwish, OptionsSerializer("HardSwishOptions")),
-    BuiltinOperator.IF: (Op.If, OptionsSerializer("IfOptions", ("then_subgraph_index", "else_subgraph_index"))),
+    BuiltinOperator.MATRIX_DIAG: (Op.MatrixDiag, OptionsSerializer("MatrixDiagOptions"), TFLITE_NO_INDICES),
+    BuiltinOperator.QUANTIZE: (Op.Quantize, OptionsSerializer("QuantizeOptions"), TFLITE_IFM_INDICES),
+    BuiltinOperator.MATRIX_SET_DIAG: (Op.MatrixSetDiag, OptionsSerializer("MatrixSetDiagOptions"), TFLITE_NO_INDICES),
+    BuiltinOperator.ROUND: (Op.Round, None, TFLITE_NO_INDICES),
+    BuiltinOperator.HARD_SWISH: (Op.HardSwish, OptionsSerializer("HardSwishOptions"), TFLITE_IFM_INDICES),
+    BuiltinOperator.IF: (
+        Op.If,
+        OptionsSerializer("IfOptions", ("then_subgraph_index", "else_subgraph_index")),
+        TFLITE_NO_INDICES,
+    ),
     BuiltinOperator.WHILE: (
         Op.While,
         OptionsSerializer("WhileOptions", ("cond_subgraph_index", "body_subgraph_index")),
+        TFLITE_NO_INDICES,
     ),
-    BuiltinOperator.NON_MAX_SUPPRESSION_V4: (Op.NonMaxSuppressionV4, OptionsSerializer("NonMaxSuppressionV4Options")),
-    BuiltinOperator.NON_MAX_SUPPRESSION_V5: (Op.NonMaxSuppressionV5, OptionsSerializer("NonMaxSuppressionV5Options")),
-    BuiltinOperator.SCATTER_ND: (Op.ScatterNd, OptionsSerializer("ScatterNdOptions")),
-    BuiltinOperator.SELECT_V2: (Op.SelectV2, OptionsSerializer("SelectV2Options")),
-    BuiltinOperator.DENSIFY: (Op.Densify, OptionsSerializer("DensifyOptions")),
-    BuiltinOperator.SEGMENT_SUM: (Op.SegmentSum, OptionsSerializer("SegmentSumOptions")),
-    BuiltinOperator.BATCH_MATMUL: (Op.BatchMatMul, OptionsSerializer("BatchMatMulOptions", ("adj_x", "adj_y"))),
-    BuiltinOperator.CUMSUM: (Op.Cumsum, OptionsSerializer("CumsumOptions", ("exclusive", "reverse"))),
-    BuiltinOperator.CUSTOM: (Op.Custom, CustomOptionsSerializer()),
+    BuiltinOperator.NON_MAX_SUPPRESSION_V4: (
+        Op.NonMaxSuppressionV4,
+        OptionsSerializer("NonMaxSuppressionV4Options"),
+        TFLITE_NO_INDICES,
+    ),
+    BuiltinOperator.NON_MAX_SUPPRESSION_V5: (
+        Op.NonMaxSuppressionV5,
+        OptionsSerializer("NonMaxSuppressionV5Options"),
+        TFLITE_NO_INDICES,
+    ),
+    BuiltinOperator.SCATTER_ND: (Op.ScatterNd, OptionsSerializer("ScatterNdOptions"), TFLITE_NO_INDICES),
+    BuiltinOperator.SELECT_V2: (Op.SelectV2, OptionsSerializer("SelectV2Options"), TFLITE_NO_INDICES),
+    BuiltinOperator.DENSIFY: (Op.Densify, OptionsSerializer("DensifyOptions"), TFLITE_NO_INDICES),
+    BuiltinOperator.SEGMENT_SUM: (Op.SegmentSum, OptionsSerializer("SegmentSumOptions"), TFLITE_NO_INDICES),
+    BuiltinOperator.BATCH_MATMUL: (
+        Op.BatchMatMul,
+        OptionsSerializer("BatchMatMulOptions", ("adj_x", "adj_y")),
+        TFLITE_NO_INDICES,
+    ),
+    BuiltinOperator.CUMSUM: (
+        Op.Cumsum,
+        OptionsSerializer("CumsumOptions", ("exclusive", "reverse")),
+        TFLITE_NO_INDICES,
+    ),
+    BuiltinOperator.CUSTOM: (Op.Custom, CustomOptionsSerializer(), TFLITE_NO_INDICES),
 }
 
-builtin_operator_inv_map = {v[0]: (k, v[1]) for k, v in builtin_operator_map.items()}
+builtin_operator_inv_map = {v[0]: (k, v[1], v[2]) for k, v in builtin_operator_map.items()}
 
-builtin_operator_inv_map[Op.CustomNpuOp] = (BuiltinOperator.CUSTOM, CustomOptionsSerializer())
+builtin_operator_inv_map[Op.CustomNpuOp] = (BuiltinOperator.CUSTOM, CustomOptionsSerializer(), TFLITE_NO_INDICES)
 
 BUILTIN_OPERATOR_UNKNOWN = "UNKNOWN"
 
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index 1a45a5e..30bf32a 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -27,6 +27,7 @@
 from .operation import create_activation_function
 from .operation import Op
 from .operation import Operation
+from .reader_util import align_tensor_indices_to_nng
 from .reader_util import clone_and_reshape_tensor
 from .reader_util import decode_str
 from .reader_util import fixup_tensors
@@ -112,7 +113,7 @@
         return tens
 
     def parse_operator(self, op_index, op_data):
-        op_type, opt_serializer, custom_code = self.graph.operator_codes[op_data.OpcodeIndex()]
+        op_type, opt_serializer, custom_code, indices = self.graph.operator_codes[op_data.OpcodeIndex()]
         inputs = [self.tensors[idx] if idx != -1 else None for idx in op_data.InputsAsNumpy()]
         outputs = [self.tensors[idx] if idx != -1 else None for idx in op_data.OutputsAsNumpy()]
         intermediates = []
@@ -122,6 +123,7 @@
         name = "unknown_op_name"
         if len(outputs):
             name = outputs[0].name
+        inputs = align_tensor_indices_to_nng(op_type, indices, inputs)
         op = Operation(op_type, name)
         op.op_index = op_index
         op.inputs = inputs
@@ -263,11 +265,11 @@
             raise InputFileError(
                 self.name, f"The input file contains operator code '{c}' which is currently not supported"
             )
-        op_type, ser = builtin_operator_map[c]
+        op_type, ser, indices = builtin_operator_map[c]
         custom_code = None
         if c == BuiltinOperator.CUSTOM:
             custom_code = decode_str(code.CustomCode())
-        return op_type, ser, custom_code
+        return op_type, ser, custom_code, indices
 
 
 def read_tflite(filename, batch_size, feed_dict, output_node_names, initialisation_nodes):
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py
index 8cabb0a..3701893 100644
--- a/ethosu/vela/tflite_writer.py
+++ b/ethosu/vela/tflite_writer.py
@@ -24,6 +24,7 @@
 from .errors import VelaError
 from .nn_graph import PassPlacement
 from .operation import Op
+from .reader_util import align_inputs_indices
 from .tensor import MemType
 from .tensor import TensorPurpose
 from .tflite import Buffer
@@ -38,7 +39,6 @@
 from .tflite_mapping import BuiltinOperator
 from .tflite_mapping import datatype_inv_map
 
-
 # ugh, the python flatbuffer interface is missing a method to add in file identifier. patching it in here:
 
 tflite_version = 3
@@ -90,6 +90,8 @@
             for ps in sg.passes:
                 for op in ps.ops:
                     if op.type not in self.ops_to_ignore:
+                        # swap from nng input indexing to TensorFlow Lite input indexing
+                        self.align_nng_inputs_to_tflite(op)
                         all_ops.append(op)
                     if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
                         # If values are None op has non-constant weights
@@ -104,6 +106,11 @@
         self.operator_codes = sorted(set((op.type, op.attrs.get("custom_code", "")) for op in all_ops))
         self.operator_code_map = {}
 
+    def align_nng_inputs_to_tflite(self, op):
+        from_indices = op.type.info.indices
+        _, _, to_indices = builtin_operator_inv_map[op.type]
+        op.inputs = align_inputs_indices(from_indices, to_indices, op.inputs)
+
     def write_byte_vector(self, v, alignment=1):
         builder = self.builder
         builder.StartVector(1, len(v), alignment)
@@ -170,13 +177,13 @@
         builder = self.builder
         custom_code_offset = None
         if op_type == Op.Custom:
-            tf_code, opt_serializer = builtin_operator_inv_map[op_type]
+            tf_code, opt_serializer, _ = builtin_operator_inv_map[op_type]
             custom_code_offset = builder.CreateString(custom_code)
         else:
             assert (
                 op_type in builtin_operator_inv_map
             ), "Vela does not contain a mapping to serialise {} operator to a TensorFlow Lite operator".format(op_type)
-            tf_code, opt_serializer = builtin_operator_inv_map[op_type]
+            tf_code, opt_serializer, _ = builtin_operator_inv_map[op_type]
 
             if op_type == Op.CustomNpuOp:
                 assert (
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index 94e6f99..fe18ce3 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -60,7 +60,7 @@
 
 
 def rewrite_activation(op, arch, nng):
-    if not op.type.is_relu_op():
+    if op.type not in (Op.ReluN, Op.Clamp):
         return op
 
     ifm = op.ifm
@@ -82,7 +82,7 @@
     if op.ofm.quantization.zero_point is None:
         op.ofm.quantization.zero_point = zp
 
-    if op.type == Op.Clip:
+    if op.type == Op.Clamp:
         op.attrs["min"] = op.attrs["min_int"] - zp
         op.attrs["max"] = op.attrs["max_int"] - zp
     elif op.type == Op.ReluN:
diff --git a/ethosu/vela/tosa_mapping.py b/ethosu/vela/tosa_mapping.py
index 82f61f7..312ac92 100644
--- a/ethosu/vela/tosa_mapping.py
+++ b/ethosu/vela/tosa_mapping.py
@@ -249,7 +249,7 @@
     # TODO TosaOp.MATMUL:
     TosaOp.MAX_POOL2D: (Op.MaxPool, pool2d_attrs, None, TOSA_IFM_INDICES),
     # TODO TosaOp.TRANSPOSE_CONV2D: (Op.Conv2DBackpropInput, transpose_conv2d_attrs, conv_quant_info)
-    TosaOp.CLAMP: (Op.Clip, clamp_attrs, None, TOSA_IFM_INDICES),
+    TosaOp.CLAMP: (Op.Clamp, clamp_attrs, None, TOSA_IFM_INDICES),
     TosaOp.RELUN: (Op.ReluN, relun_attrs, None, TOSA_IFM_INDICES),
     # TODO TosaOp.SIGMOID
     # TODO TosaOp.TANH
diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py
index ac0b396..e51ead1 100644
--- a/ethosu/vela/tosa_reader.py
+++ b/ethosu/vela/tosa_reader.py
@@ -25,6 +25,7 @@
 from .nn_graph import Subgraph
 from .operation import Op
 from .operation import Operation
+from .reader_util import align_tensor_indices_to_nng
 from .reader_util import clone_and_reshape_tensor
 from .reader_util import decode_str
 from .reader_util import fixup_tensors
@@ -104,8 +105,8 @@
         name = "unknown_op_name"
         if len(outputs):
             name = outputs[0].name
+        inputs = align_tensor_indices_to_nng(op_type, indices, inputs)
         op = Operation(op_type, name)
-        op.type.info.indices = indices
         op.op_index = op_index
         op.inputs = inputs
         op.outputs = outputs
diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py
index c87d653..51f80eb 100644
--- a/ethosu/vela/tosa_supported_operators.py
+++ b/ethosu/vela/tosa_supported_operators.py
@@ -32,7 +32,7 @@
     mac_main_ops = convolution_like_ops
 
     type_conversion_ops = set((Op.Rescale,))
-    relu_ops = set((Op.Clip, Op.ReluN,))
+    relu_ops = set((Op.Clamp, Op.ReluN,))
     activation_ops = relu_ops
 
     npu_post_ops = activation_ops