MLBEDSW-2082: Add Exp support

- Added int8 and int16 Exp support, implemented as LUT.
- Added generic 8bit and 16bit LUT table functions following
the implementation in the latest reference. If new ops are added
by the reference, they can easily be implemented in Vela using
the generic functions.
- Moved convert_to_lut to lut.py to have all LUT related code in
one file.
- Updated SUPPORTED_OPS.md

Change-Id: I388e76ea4b39162313599a5341cfb9bad71a782c
Signed-off-by: Johan Alfven <johan.alfven@arm.com>
diff --git a/SUPPORTED_OPS.md b/SUPPORTED_OPS.md
index ab9b009..80647f8 100644
--- a/SUPPORTED_OPS.md
+++ b/SUPPORTED_OPS.md
@@ -1,7 +1,7 @@
 # Supported Ops
 
 This file was automatically generated by Vela using the `--supported-ops-report` parameter.  
-Vela version: `3.7.1.dev17+g7b3008a.d20230420`
+Vela version: `3.7.1.dev23+g3734897.d20230427`
 
 This file complies with
 [**Gitiles Markdown syntax**](https://github.com/google/gitiles/blob/master/Documentation/markdown.md)
@@ -25,6 +25,7 @@
 | CONCATENATION | [Generic](#tflite-generic-constraints), [Specific](#tflite-concatenation-constraints) |
 | CONV_2D | [Generic](#tflite-generic-constraints), [Specific](#tflite-conv_2d-constraints) |
 | DEPTHWISE_CONV_2D | [Generic](#tflite-generic-constraints), [Specific](#tflite-depthwise_conv_2d-constraints) |
+| EXP | [Generic](#tflite-generic-constraints), [Specific](#tflite-exp-constraints) |
 | EXPAND_DIMS | [Generic](#tflite-generic-constraints), [Specific](#tflite-expand_dims-constraints) |
 | FULLY_CONNECTED | [Generic](#tflite-generic-constraints), [Specific](#tflite-fully_connected-constraints) |
 | HARD_SWISH | [Generic](#tflite-generic-constraints), [Specific](#tflite-hard_swish-constraints) |
@@ -63,6 +64,7 @@
 This is a list of constraints most NPU operators must satisfy in order to be scheduled on the NPU.
 (Operators excluded from certain constraints are shown in brackets [ ] )
 
+- All required operator attributes must be specified
 - Input(s) and Output tensors must not be dynamic - [QUANTIZE]
 - Input(s) and Output tensors must have a defined shape
 - Output tensors cannot be scalar - [QUANTIZE]
@@ -161,6 +163,14 @@
 - Stride values for both width and height must be between 1 and 3
 - For depth multipliers > 1, IFM channels must be 1 and OFM channels must be equal to the depth multiplier
 
+### TFLite EXP Constraints
+
+This is a list of constraints that the EXP operator must satisfy in order to be scheduled on the NPU.
+
+- At least one Input's shape must match the OFM's shape
+- IFM and OFM data types must match
+- IFM must be int8 or int16
+
 ### TFLite EXPAND_DIMS Constraints
 
 This is a list of constraints that the EXPAND_DIMS operator must satisfy in order to be scheduled on the NPU.
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py
index 8279036..da3fe13 100644
--- a/ethosu/vela/graph_optimiser_util.py
+++ b/ethosu/vela/graph_optimiser_util.py
@@ -20,7 +20,6 @@
 
 import numpy as np
 
-from . import lut
 from .architecture_features import Accelerator
 from .data_type import DataType
 from .debug_database import DebugDatabase
@@ -29,8 +28,6 @@
 from .operation import Op
 from .operation_util import create_avgpool_nop
 from .shape4d import Shape4D
-from .tensor import create_const_tensor
-from .tensor import QuantizationParameters
 from .tensor import Tensor
 
 memory_only_ops = (
@@ -329,42 +326,6 @@
     return op
 
 
-def convert_to_lut(op, lut_values, lut_name):
-    # Rewrite the operation by Add with scalar 0 + LUT activation
-    ifm = op.ifm
-    ofm = op.ofm
-    if ifm is None:
-        return op
-    assert ifm.dtype.size_in_bytes() == 1
-    op.type = Op.Add
-    op.name = op.name + "_lut_" + lut_name
-    # Mark as no-op to enable potential fusing optimizations
-    op.attrs["is_nop"] = True
-    # Create an input tensor containing scalar zero
-    quantization = QuantizationParameters(0.0, 255.0)
-    quantization.scale_f32 = ifm.quantization.scale_f32
-    quantization.zero_point = 0
-    tens = create_const_tensor(ifm.name + "_scalar0", [], ifm.dtype, [0], quantization=quantization)
-    op.add_input_tensor(tens)
-    op.ifm_shapes.append(Shape4D(tens.shape))  # TODO no shape?
-
-    # The LUT must be applied without any preceding rescaling (the LUT itself performs the rescale),
-    # so even if the OFM has a different scale than the IFM, the generated OFM scale instructions
-    # should be the same as the IFM
-    op.forced_output_quantization = ifm.quantization
-
-    # the lut tensor datatype needs to match both; the ofm datatype, because these are the values output; and the
-    # datatype used to generate the lut values (which is probably the ifm datatype), because we want to avoid any
-    # potential overflow errors in create_lut_tensor() caused by converting Python int (which could represent a uint)
-    # to NumPy int. this can be guaranteed by checking that the ifm and ofm datatypes are the same
-    assert ifm.dtype == ofm.dtype
-    lut_tensor = lut.create_lut_tensor(op.name + "_values", lut_values, ofm.dtype)
-    op.set_activation_lut(lut_tensor)
-    op.set_ifm_ofm_shapes()
-    DebugDatabase.add_optimised(op, op)
-    return op
-
-
 def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
     """Creates an average pool for the given concat op/input feature map"""
     ofm = concat_op.ofm
diff --git a/ethosu/vela/lut.py b/ethosu/vela/lut.py
index d0ac970..c8fb7bc 100644
--- a/ethosu/vela/lut.py
+++ b/ethosu/vela/lut.py
@@ -21,10 +21,15 @@
 import numpy as np
 
 from . import numeric_util
+from .data_type import DataType
+from .debug_database import DebugDatabase
 from .high_level_command_stream import DMA
 from .high_level_command_stream import NpuStripe
+from .numeric_util import round_away_zero
+from .operation import Op
 from .tensor import create_const_tensor
 from .tensor import create_equivalence_id
+from .tensor import QuantizationParameters
 from .tensor import TensorPurpose
 
 
@@ -88,6 +93,8 @@
     # address in constant memory, and unnecessary DMA operations can be avoided.
     sz = len(values)
     assert sz in (256, 512)
+    # int16 lut uses uint32 lut with base + slope
+    dtype = DataType.uint32 if dtype == DataType.int16 else dtype
     tens = create_const_tensor(name, [1, 1, 1, sz], dtype, values, TensorPurpose.LUT)
     tens.equivalence_id = create_equivalence_id(tuple(values))
     return tens
@@ -128,3 +135,110 @@
         lut_state = lut_state.put(lut_tens)
         cmd_stream.append(cmd)
     sg.high_level_command_stream = cmd_stream
+
+
+def convert_to_lut(op, lut_values, lut_name):
+    # Rewrite the operation by Add with scalar 0 + LUT activation
+    ifm = op.ifm
+    ofm = op.ofm
+    if ifm is None:
+        return op
+    assert ifm.dtype in (DataType.int8, DataType.uint8, DataType.int16)
+    op.type = Op.Add
+    op.name = f"{op.name}_lut_{lut_name}"
+    # Mark as no-op to enable potential fusing optimizations
+    op.attrs["is_nop"] = True
+    # Create an input tensor containing scalar zero
+    _max = 65536.0 if ifm.dtype == DataType.int16 else 255.0
+    quantization = QuantizationParameters(0.0, _max)
+    quantization.scale_f32 = ifm.quantization.scale_f32
+    quantization.zero_point = 0
+    tens = create_const_tensor(ifm.name + "_scalar0", [], ifm.dtype, [0], quantization=quantization)
+    op.add_input_tensor(tens)
+
+    # The LUT must be applied without any preceding rescaling (the LUT itself performs the rescale),
+    # so even if the OFM has a different scale than the IFM, the generated OFM scale instructions
+    # should be the same as the IFM
+    op.forced_output_quantization = ifm.quantization
+
+    # the lut tensor datatype needs to match both; the ofm datatype, because these are the values output; and the
+    # datatype used to generate the lut values (which is probably the ifm datatype), because we want to avoid any
+    # potential overflow errors in create_lut_tensor() caused by converting Python int (which could represent a uint)
+    # to NumPy int. this can be guaranteed by checking that the ifm and ofm datatypes are the same
+    assert ifm.dtype == ofm.dtype
+    lut_tensor = create_lut_tensor(op.name + "_values", lut_values, ofm.dtype)
+    op.set_activation_lut(lut_tensor)
+    op.set_ifm_ofm_shapes()
+    DebugDatabase.add_optimised(op, op)
+    return op
+
+
+def create_lut_8bit_op(op, lut_fn, fn_name):
+    ifm_scale = op.ifm.quantization.scale_f32
+    ofm_scale = op.ofm.quantization.scale_f32
+    zp_in = op.ifm.quantization.zero_point
+    zp_out = op.ofm.quantization.zero_point
+
+    values = []
+    ix = range(256) if op.ifm.dtype == DataType.uint8 else range(-128, 128)
+    quantized_min = min(ix)
+    quantized_max = max(ix)
+    for x in ix:
+        x_real = ifm_scale * (x - zp_in)
+        y_real = lut_fn(x_real)
+        lut_result = round_away_zero(y_real / ofm_scale) + zp_out
+        lut_result = min(quantized_max, max(quantized_min, lut_result))
+        values.append(lut_result)
+
+    return convert_to_lut(op, values, fn_name)
+
+
+def create_lut_int16_op(op, lut_fn, fn_name):
+    ifm_scale = op.ifm.quantization.scale_f32
+    ofm_scale = op.ofm.quantization.scale_f32
+    zp_in = op.ifm.quantization.zero_point
+    zp_out = op.ofm.quantization.zero_point
+
+    input_min = ifm_scale * (np.iinfo(np.int16).min - zp_in)
+    input_max = ifm_scale * (np.iinfo(np.int16).max - zp_in)
+    output_min = ofm_scale * (np.iinfo(np.int16).min - zp_out)
+    output_max = ofm_scale * (np.iinfo(np.int16).max - zp_out)
+
+    # Create 16bit lut following the reference
+    nbr_steps = 512
+    step = (input_max - input_min) / nbr_steps
+    half_step = step / 2
+    output_scaling_inv = (np.iinfo(np.int16).max - np.iinfo(np.int16).min + 1) / (output_max - output_min)
+
+    table_min = np.iinfo(np.int16).min
+    table_max = np.iinfo(np.int16).max
+
+    values = []
+    for i in range(nbr_steps):
+        val = lut_fn(input_min + i * step)
+        val_midpoint = lut_fn(input_min + i * step + half_step)
+        val_next = lut_fn(input_min + (i + 1) * step)
+
+        sample_val = round_away_zero(val * output_scaling_inv)
+        midpoint_interp_val = round_away_zero(
+            (val_next * output_scaling_inv + round_away_zero(val * output_scaling_inv)) / 2
+        )
+        midpoint_val = round_away_zero(val_midpoint * output_scaling_inv)
+        midpoint_err = midpoint_interp_val - midpoint_val
+        bias = round_away_zero(midpoint_err / 2)
+
+        lut_result = min(max(sample_val - bias, table_min), table_max)
+        values.append(lut_result)
+
+    val = round_away_zero(lut_fn(input_max) * output_scaling_inv)
+    lut_result = min(max(val, table_min), table_max)
+    values.append(lut_result)
+
+    # Convert to hardware 16bit lut with base and slope
+    lut = [0] * nbr_steps
+    for i in range(nbr_steps):
+        slope = (int(values[i + 1]) - int(values[i])) << 16
+        base = int(values[i])
+        lut[i] = slope + base
+
+    return convert_to_lut(op, lut, fn_name)
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index eafe3bd..6959652 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -179,7 +179,7 @@
     EmbeddingLookup = OperatorInfo()
     EmbeddingLookupSparse = OperatorInfo()
     Equal = OperatorInfo()
-    Exp = OperatorInfo()
+    Exp = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_INDICES, is_unary=True)
     ExpandDims = OperatorInfo(indices=NNG_IFM_INDICES)
     FakeQuantWithMinMaxArgs = OperatorInfo()
     Fill = OperatorInfo()
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index c79f154..1b70165 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -34,7 +34,6 @@
 from .graph_optimiser_util import bypass_memory_only_ops
 from .graph_optimiser_util import calc_explicit_padding
 from .graph_optimiser_util import convert_depthwise_to_conv
-from .graph_optimiser_util import convert_to_lut
 from .graph_optimiser_util import create_avg_pool_for_concat
 from .graph_optimiser_util import memory_only_ops
 from .graph_optimiser_util import move_splitsliceread_to_consumer
@@ -42,6 +41,9 @@
 from .graph_optimiser_util import set_ifm_ofm_op_shapes
 from .graph_optimiser_util import set_tensor_equivalence
 from .lstm import Lstm
+from .lut import convert_to_lut
+from .lut import create_lut_8bit_op
+from .lut import create_lut_int16_op
 from .numeric_util import clamp_sigmoid
 from .numeric_util import full_shape
 from .numeric_util import round_away_zero
@@ -1935,6 +1937,19 @@
     return op
 
 
+def convert_ops_to_lut(op, arch, nng):
+    if op.type == Op.Exp:
+        if op.ifm.dtype == DataType.int8:
+            return create_lut_8bit_op(op, math.exp, "exp")
+        elif op.ifm.dtype == DataType.int16:
+            return create_lut_int16_op(op, math.exp, "exp")
+        else:
+            # Should already be catched in tflite supported ops
+            assert False, f"Unsupported data type {op.ifm.dtype} for {op.type}"
+
+    return op
+
+
 def optimise_quantize(op: Operation, arch, nng):
 
     if op.type == Op.Quantize and op.run_on_npu:
@@ -2214,6 +2229,7 @@
     # Rewrite of operators
     op_rewrite_list = [
         set_tensor_equivalence,
+        convert_ops_to_lut,
         convert_mean_to_depthwise_conv,
         convert_depthwise_to_conv,
         convert_conv_to_fc,
diff --git a/ethosu/vela/tflite_mapping.py b/ethosu/vela/tflite_mapping.py
index bb45a7f..dda418c 100644
--- a/ethosu/vela/tflite_mapping.py
+++ b/ethosu/vela/tflite_mapping.py
@@ -734,7 +734,7 @@
         ),
         TFLITE_IFM_WEIGHTS_INDICES,
     ),
-    BuiltinOperator.EXP: (Op.Exp, OptionsSerializer("ExpOptions"), TFLITE_NO_INDICES),
+    BuiltinOperator.EXP: (Op.Exp, OptionsSerializer("ExpOptions"), TFLITE_IFM_INDICES),
     BuiltinOperator.TOPK_V2: (Op.TopKV2, OptionsSerializer("TopKV2Options"), TFLITE_NO_INDICES),
     BuiltinOperator.SPLIT: (Op.Split, OptionsSerializer("SplitOptions", ("num_splits",)), TFLITE_SPLIT_IFM_INDICES),
     BuiltinOperator.LOG_SOFTMAX: (Op.LogSoftmax, OptionsSerializer("LogSoftmaxOptions"), TFLITE_NO_INDICES),
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index 7537d7d..24c0794 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -201,6 +201,9 @@
         self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_lstm_intermediates)
         self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_lstm_variables)
 
+        # Exp specific checks
+        self.specific_constraints[Op.Exp].append(TFLiteSemantic.constraint_input_signed)
+
     def is_operator_semantic_valid(self, op):
         ext_type = optype_to_builtintype(op.type)
 
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index 2a599aa..b347414 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -25,11 +25,11 @@
 from .graph_optimiser_util import bypass_memory_only_ops
 from .graph_optimiser_util import calc_explicit_padding
 from .graph_optimiser_util import convert_depthwise_to_conv
-from .graph_optimiser_util import convert_to_lut
 from .graph_optimiser_util import move_splitsliceread_to_consumer
 from .graph_optimiser_util import needed_total_padding
 from .graph_optimiser_util import set_ifm_ofm_op_shapes
 from .graph_optimiser_util import set_tensor_equivalence
+from .lut import convert_to_lut
 from .operation import ExplicitScaling
 from .operation import Op
 from .operation_util import create_add_nop