[MLBEDSW-2730] Implement LUT generation for softmax uint8/int8

Implemented LUT generation for softmax uint8/int8 to match the
reference.

Change-Id: Ib9acaa295ee1066591e800023d75f364520b44c1
Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py
index c67cc37..eb97c79 100644
--- a/ethosu/vela/softmax.py
+++ b/ethosu/vela/softmax.py
@@ -1,22 +1,28 @@
 # Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
 #
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
 # SPDX-License-Identifier: Apache-2.0
 #
-# Licensed under the Apache License, Version 2.0 (the License); you may
-# not use this file except in compliance with the License.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
-# www.apache.org/licenses/LICENSE-2.0
+#     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an AS IS BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+#
 # Description:
 # Contains SoftMax
+import math
+
 import numpy as np
 
+from . import fp_math
 from . import scaling
 from .data_type import DataType
 from .operation import Operation
@@ -30,76 +36,6 @@
     # Turn off black formatting for the LUT tables to keep them compact
     # fmt: off
 
-    EXP_LUT_U8 = [
-        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
-        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
-        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
-        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
-        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
-        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
-        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
-        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
-        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
-        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
-        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
-        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
-        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
-        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
-        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
-        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
-        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
-        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
-        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
-        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
-        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
-        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
-        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
-        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
-        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
-        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
-        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
-        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
-        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
-        0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
-        0x00000291, 0x000006fa, 0x000012f6, 0x0000338b, 0x00008c1b, 0x00017cd8, 0x00040b3d, 0x000afe11,
-        0x001de16c, 0x00513949, 0x00dcca03, 0x02582ac2, 0x065f6c52, 0x1152aaf6, 0x2f16ad4c, 0x7fffffff
-    ]
-
-    EXP_LUT_I8 = [
-        0x000011c9, 0x000012b8, 0x000013b4, 0x000014bd, 0x000015d4, 0x000016fa, 0x0000182f, 0x00001975,
-        0x00001acb, 0x00001c34, 0x00001daf, 0x00001f3f, 0x000020e3, 0x0000229e, 0x00002470, 0x0000265a,
-        0x0000285e, 0x00002a7d, 0x00002cb9, 0x00002f13, 0x0000318c, 0x00003427, 0x000036e5, 0x000039c8,
-        0x00003cd1, 0x00004004, 0x00004361, 0x000046ec, 0x00004aa6, 0x00004e93, 0x000052b4, 0x0000570d,
-        0x00005ba1, 0x00006072, 0x00006583, 0x00006ada, 0x00007077, 0x00007661, 0x00007c9a, 0x00008327,
-        0x00008a0c, 0x0000914d, 0x000098f1, 0x0000a0fb, 0x0000a971, 0x0000b259, 0x0000bbb9, 0x0000c597,
-        0x0000cffa, 0x0000dae9, 0x0000e66b, 0x0000f288, 0x0000ff48, 0x00010cb3, 0x00011ad3, 0x000129b1,
-        0x00013957, 0x000149d0, 0x00015b26, 0x00016d65, 0x0001809b, 0x000194d2, 0x0001aa1a, 0x0001c080,
-        0x0001d814, 0x0001f0e4, 0x00020b03, 0x00022681, 0x00024371, 0x000261e7, 0x000281f7, 0x0002a3b5,
-        0x0002c73b, 0x0002ec9e, 0x000313f8, 0x00033d64, 0x000368fd, 0x000396e0, 0x0003c72e, 0x0003fa05,
-        0x00042f89, 0x000467dd, 0x0004a326, 0x0004e18e, 0x0005233d, 0x00056860, 0x0005b126, 0x0005fdbf,
-        0x00064e5f, 0x0006a33b, 0x0006fc8e, 0x00075a93, 0x0007bd89, 0x000825b3, 0x00089356, 0x000906bd,
-        0x00098034, 0x000a000f, 0x000a86a2, 0x000b1447, 0x000ba95f, 0x000c464d, 0x000ceb7c, 0x000d9959,
-        0x000e505a, 0x000f10f9, 0x000fdbb8, 0x0010b120, 0x001191c0, 0x00127e2f, 0x0013770b, 0x00147cfc,
-        0x001590b2, 0x0016b2e6, 0x0017e45c, 0x001925e1, 0x001a784c, 0x001bdc81, 0x001d536f, 0x001ede14,
-        0x00207d76, 0x002232af, 0x0023fee3, 0x0025e348, 0x0027e125, 0x0029f9ce, 0x002c2ead, 0x002e813e,
-        0x0030f30f, 0x003385c7, 0x00363b1e, 0x003914e9, 0x003c150f, 0x003f3d97, 0x004290a0, 0x00461065,
-        0x0049bf40, 0x004d9fac, 0x0051b444, 0x0055ffc2, 0x005a850e, 0x005f472f, 0x00644959, 0x00698eea,
-        0x006f1b6b, 0x0074f298, 0x007b185e, 0x008190dd, 0x00886073, 0x008f8bad, 0x00971761, 0x009f08a0,
-        0x00a764c0, 0x00b03163, 0x00b9746c, 0x00c3341a, 0x00cd76f8, 0x00d843eb, 0x00e3a23a, 0x00ef9981,
-        0x00fc31d0, 0x0109739d, 0x011767cf, 0x012617cd, 0x01358d6e, 0x0145d319, 0x0156f3be, 0x0168fadc,
-        0x017bf49d, 0x018fedb3, 0x01a4f391, 0x01bb1457, 0x01d25ede, 0x01eae2e1, 0x0204b0c5, 0x021fd9e9,
-        0x023c708e, 0x025a87f5, 0x027a343a, 0x029b8ac1, 0x02bea1ea, 0x02e39148, 0x030a71be, 0x03335d49,
-        0x035e6f88, 0x038bc564, 0x03bb7d53, 0x03edb776, 0x0422956d, 0x045a3add, 0x0494cd23, 0x04d27398,
-        0x051357c1, 0x0557a511, 0x059f8990, 0x05eb3585, 0x063adbc4, 0x068eb1f7, 0x06e6f042, 0x0743d212,
-        0x07a595d0, 0x080c7d1f, 0x0878cd5d, 0x08eacf11, 0x0962cefe, 0x09e11dc0, 0x0a661028, 0x0af1ffdf,
-        0x0b854a8e, 0x0c205363, 0x0cc38284, 0x0d6f4577, 0x0e241032, 0x0ee25ba2, 0x0faaa7e6, 0x107d7b92,
-        0x115b64b1, 0x1244f774, 0x133ad1b8, 0x143d9876, 0x154df988, 0x166cac69, 0x179a70c9, 0x18d81250,
-        0x1a266643, 0x1b864d38, 0x1cf8b430, 0x1e7e9307, 0x2018f0a9, 0x21c8e098, 0x238f850c, 0x256e1033,
-        0x2765c273, 0x2977ef40, 0x2ba5faa9, 0x2df15b73, 0x305b9d6b, 0x32e65e8a, 0x3593552c, 0x38644d67,
-        0x3b5b2b66, 0x3e79ee87, 0x41c2adcb, 0x45379f4e, 0x48db158a, 0x4caf81e6, 0x50b7797f, 0x54f5af16,
-        0x596cfe2f, 0x5e2066d0, 0x631310c8, 0x684852d8, 0x6dc3a909, 0x7388c421, 0x799b84b7, 0x7fffffff,
-    ]
-
     EXP_LUT = [
         0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002,
         0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002, 0x00000002,
@@ -239,8 +175,27 @@
         self.op = op
 
     def generate_exp_table(self, beta, input_scale):
-        # TODO: Generate the exp table using the same math as the reference
-        return self.EXP_LUT_U8 if input_scale == 1.0 else self.EXP_LUT_I8
+        integer_bits = 5
+        total_signed_bits = 31
+        # Calculate scaling
+        real_beta = min(
+            np.double(beta) * np.double(input_scale) * (1 << (31 - integer_bits)), np.double((1 << 31) - 1.0)
+        )
+        scale, shift = scaling.quantise_scale(real_beta)
+        shift = 31 - shift
+        diff_min = -1.0 * math.floor(
+            1.0 * ((1 << integer_bits) - 1) * (1 << (total_signed_bits - integer_bits)) / (1 << shift)
+        )
+        # Generate the exp LUT
+        lut = []
+        for x in range(256):
+            input_diff = x - 255
+            if input_diff >= diff_min:
+                rescale = fp_math.saturating_rounding_mul(input_diff * (1 << shift), scale)
+                lut.append(fp_math.exp_on_negative_values(rescale))
+            else:
+                lut.append(0)
+        return lut
 
     def get_graph(self):
         ifm = self.op.inputs[0]
@@ -339,7 +294,12 @@
         sub5_op = Operation("SubAct", self.op.name + "_sub5")
         sub5_op.add_input_tensor(
             create_const_tensor(
-                "headroom_offset_const", [1, 1, 1, 1], DataType.int32, [12 + 31 - 8], np.int32, quantization=no_scale_quant
+                "headroom_offset_const",
+                [1, 1, 1, 1],
+                DataType.int32,
+                [12 + 31 - 8],
+                np.int32,
+                quantization=no_scale_quant,
             ),
         )
         sub5_op.add_input_tensor(headroom_plus_one)
@@ -348,9 +308,7 @@
         sub5_op.set_output_tensor(right_shift)
 
         # PASS 6 - Sub
-        one = create_const_tensor(
-            "one_const", [1, 1, 1, 1], DataType.int32, [1], np.int32, quantization=no_scale_quant
-        )
+        one = create_const_tensor("one_const", [1, 1, 1, 1], DataType.int32, [1], np.int32, quantization=no_scale_quant)
         sub6_op = Operation("SubAct", self.op.name + "_sub6")
         sub6_op.add_input_tensor(headroom_plus_one)
         sub6_op.add_input_tensor(one)
@@ -404,7 +362,12 @@
         mul11_op.add_input_tensor(half_denominator)
         mul11_op.add_input_tensor(
             create_const_tensor(
-                "neg_32_over_17_const", [1, 1, 1, 1], DataType.int32, [-1010580540], np.int32, quantization=one_scale_quant
+                "neg_32_over_17_const",
+                [1, 1, 1, 1],
+                DataType.int32,
+                [-1010580540],
+                np.int32,
+                quantization=one_scale_quant,
             ),
         )
         rescaled = Tensor(ifm_exp.shape, DataType.int32, mul11_op.name + "_0")
@@ -428,9 +391,7 @@
         F2_one = create_const_tensor(
             "F2_one_const", [1, 1, 1, 1], DataType.int32, [(1 << 29)], np.int32, quantization=no_scale_quant
         )
-        two = create_const_tensor(
-            "two_const", [1, 1, 1, 1], DataType.int32, [2], np.int32, quantization=no_scale_quant
-        )
+        two = create_const_tensor("two_const", [1, 1, 1, 1], DataType.int32, [2], np.int32, quantization=no_scale_quant)
         for i in range(3):
             # PASS 13, 18, 23 - MUL
             mul_op = Operation("MulAct", self.op.name + "_mul%d" % (13 + i * 5))
@@ -448,7 +409,7 @@
             one_minus_half_denominator_times_x.quantization = one_scale_quant
             sub_op.set_output_tensor(one_minus_half_denominator_times_x)
             # PASS 15, 20, 25 - MUL
-            mul_op = Operation("MulAct", self.op.name + "_mul%d" %+ (15 + i * 5))
+            mul_op = Operation("MulAct", self.op.name + "_mul%d" % (15 + i * 5))
             mul_op.add_input_tensor(nr_x)
             mul_op.add_input_tensor(one_minus_half_denominator_times_x)
             to_rescale = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0")