MLBEDSW-2688: Improved LUT support

- Support for more than one 256-byte LUT in SHRAM
- No DMA is performed for a LUT that is already located in SHRAM
- Added MemArea.Shram, used for LUT, to avoid false address collision
  asserts during SRAM tensor allocation
- Added read access to LUT in memory access calculation

Change-Id: If4d1eded5ed029d253f4f5efb2d80495fc3eac99
Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
diff --git a/ethosu/mlw_codec/test/test_mlw_codec.py b/ethosu/mlw_codec/test/test_mlw_codec.py
index 31a3bc0..d37462d 100644
--- a/ethosu/mlw_codec/test/test_mlw_codec.py
+++ b/ethosu/mlw_codec/test/test_mlw_codec.py
@@ -16,6 +16,7 @@
 # limitations under the License.
 # Simple example of the usage of mlw_codec.
 import pytest
+
 from ethosu import mlw_codec
 
 
diff --git a/ethosu/vela/architecture_features.py b/ethosu/vela/architecture_features.py
index 021597e..265af42 100644
--- a/ethosu/vela/architecture_features.py
+++ b/ethosu/vela/architecture_features.py
@@ -316,6 +316,9 @@
         self.shram_reserved_unused_banks = 2 if accel_config.shram_banks > 16 else 0
         self.shram_total_banks = accel_config.shram_banks - self.shram_reserved_unused_banks
         self.shram_bank_granules = np.array(accel_config.shram_granules, np.int32)
+        self.shram_lut_size = 2048
+        # SHRAM base address of the activation lookup table
+        self.shram_lut_address = self.shram_bank_size * self.available_shram_banks(True)
 
         # Build a map of acceptable IFM/OFM block configurations up to the maximum
         # IFM/OFM block size.
@@ -326,6 +329,14 @@
         # Setup supported operators and restriction checkers class
         self.supported_operators = SupportedOperators(softmax_support)
 
+    # Returns available number of SHRAM banks depending on activation lookup table
+    # being used or not
+    def available_shram_banks(self, uses_activation_lut):
+        banks = self.shram_total_banks
+        if uses_activation_lut and self.shram_reserved_unused_banks == 0:
+            banks -= 2
+        return banks
+
     # Calculate block configuration for ALL known IFM operations and
     # accumulator sizes. Consumers will need to select their preferred
     # operation and bit-width at read-time.
diff --git a/ethosu/vela/compiler_driver.py b/ethosu/vela/compiler_driver.py
index f407fdc..5e9e38f 100644
--- a/ethosu/vela/compiler_driver.py
+++ b/ethosu/vela/compiler_driver.py
@@ -22,6 +22,7 @@
 from . import high_level_command_stream_generator
 from . import insert_dma
 from . import live_range
+from . import lut
 from . import mark_tensors
 from . import npu_performance
 from . import npu_serialisation
@@ -198,6 +199,7 @@
         high_level_command_stream_generator.generate_high_level_command_stream(
             nng, sg, arch, options.verbose_high_level_command_stream
         )
+        lut.optimize_high_level_cmd_stream(sg, arch)
         register_command_stream_generator.generate_register_command_stream(
             nng, sg, arch, options.verbose_register_command_stream
         )
diff --git a/ethosu/vela/greedy_allocation.py b/ethosu/vela/greedy_allocation.py
index e017687..1cbfce3 100644
--- a/ethosu/vela/greedy_allocation.py
+++ b/ethosu/vela/greedy_allocation.py
@@ -77,9 +77,7 @@
             for m in lrs:
                 if n != m and n.overlaps_ranges(m):
                     overlap, tens_n, tens_m = n.overlaps_address(m)
-                    if overlap and not (
-                        tens_n.equivalence_id == tens_m.equivalence_id and tens_n.address == tens_m.address
-                    ):
+                    if overlap and not (tens_n.equivalent(tens_m) and tens_n.address == tens_m.address):
                         print("Solution failed, overlapping buffer!")
                         print(tens_n.address, tens_n.address + n.size, n.name)
                         print(tens_m.address, tens_m.address + m.size, m.name)
diff --git a/ethosu/vela/high_level_command_stream.py b/ethosu/vela/high_level_command_stream.py
index c669829..95af1cc 100644
--- a/ethosu/vela/high_level_command_stream.py
+++ b/ethosu/vela/high_level_command_stream.py
@@ -23,6 +23,9 @@
 from .operation import NpuBlockType
 from .range_set import AccessDirection
 from .range_set import MemoryAccessSet
+from .range_set import MemoryRangeSet
+from .tensor import MemArea
+from .tensor import TensorPurpose
 
 
 class Box:
@@ -233,6 +236,13 @@
                 ),
                 AccessDirection.Read,
             )
+        # Add read access to SHRAM by any LUT-s
+        for tens in self.ps.intermediates:
+            if tens.purpose == TensorPurpose.LUT and tens.mem_area == MemArea.Shram:
+                res.add(
+                    MemoryRangeSet(tens.mem_area, tens.address, tens.address + tens.storage_size()),
+                    AccessDirection.Read,
+                )
         return res
 
     def is_npu_pass_command(self):
@@ -359,8 +369,9 @@
 
 
 class DMA(Command):
-    def __init__(self, in_tensor, out_tensor, box):
+    def __init__(self, ps, in_tensor, out_tensor, box):
         self.cmdtype = CommandType.DMA
+        self.ps = ps
         self.in_tensor = in_tensor
         self.out_tensor = out_tensor
         self.box = box
diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py
index d34fb75..d5a6341 100644
--- a/ethosu/vela/high_level_command_stream_generator.py
+++ b/ethosu/vela/high_level_command_stream_generator.py
@@ -32,7 +32,7 @@
     if tensor.needs_dma():
         dma_op = tensor.ops[0]
         in_tensor = dma_op.inputs[0]
-        yield DMA(in_tensor, tensor, box)
+        yield DMA(ps, in_tensor, tensor, box)
 
 
 def match_tensor(source, derived):
diff --git a/ethosu/vela/insert_dma.py b/ethosu/vela/insert_dma.py
index 6c5c803..6cd2202 100644
--- a/ethosu/vela/insert_dma.py
+++ b/ethosu/vela/insert_dma.py
@@ -61,13 +61,7 @@
                         dma_cmd.attrs["destination"] = new_tens.mem_area
                         dma_cmd.run_on_npu = True
                         if tens.purpose == TensorPurpose.LUT:
-                            # TODO: Add support more than one LUT at a time
-                            # Reserve last 2 blocks for LUT
-                            if arch.shram_reserved_unused_banks == 0:
-                                arch.shram_reserved_unused_banks = 2
-                                arch.shram_total_banks -= arch.shram_reserved_unused_banks
-                            # Place the LUT in the last 2 blocks of SHRAM
-                            new_tens.address = arch.shram_bank_size * arch.shram_total_banks
+                            new_tens.mem_area = MemArea.Shram
                         op.inputs[idx] = new_tens
     return op
 
diff --git a/ethosu/vela/lut.py b/ethosu/vela/lut.py
new file mode 100644
index 0000000..39101fa
--- /dev/null
+++ b/ethosu/vela/lut.py
@@ -0,0 +1,120 @@
+# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Description:
+# Functionality for lookup table support.
+import uuid
+from functools import lru_cache
+
+from . import numeric_util
+from .high_level_command_stream import CommandType
+from .tensor import TensorPurpose
+
+
+@lru_cache(maxsize=None)
+def create_equivalence_id(key):
+    # Generates equivalence_id based on key.
+    # The DMA optimization of LUT-s assumes that 2 LUT tensors are identical
+    # if they have the same equivalence_id.
+    # So for example all created 256-byte tanh LUT tensors should have
+    # the same equivalence id.
+    return uuid.uuid4()
+
+
+class LUTState:
+    # Tracks which LUT-s are located in SHRAM.
+    def __init__(self):
+        self.tensors = []
+
+    def get_equivalent(self, lut_tens):
+        # Returns existing lut with same equivalence id, None if not found
+        for t in self.tensors:
+            if t.equivalent(lut_tens):
+                return t
+        return None
+
+    def put(self, lut_tens):
+        # Returns new LUT state containing given tensor + all tensors in this state
+        # that do not overlap with the given tensor
+        new_state = LUTState()
+        new_state.tensors.append(lut_tens)
+        start = lut_tens.address
+        end = start + lut_tens.storage_size()
+        for tens in self.tensors:
+            start2 = tens.address
+            end2 = start2 + tens.storage_size()
+            if not numeric_util.overlaps(start, end, start2, end2):
+                new_state.tensors.append(tens)
+        return new_state
+
+    def find_best_address(self, start, stop, step):
+        # Finds the address in the given range that overlaps with the minimum number of
+        # currently present LUT-s.
+        # An improvement would be to also take future LUT usage into account
+        best_addr = start
+        best_nr_overlaps = stop
+        for addr in range(start, stop, step):
+            nr_overlaps = 0
+            for tens in self.tensors:
+                start2 = tens.address
+                end2 = start2 + tens.storage_size()
+                if numeric_util.overlaps(addr, addr + step, start2, end2):
+                    nr_overlaps += 1
+            if nr_overlaps < best_nr_overlaps:
+                best_nr_overlaps = nr_overlaps
+                best_addr = addr
+        return best_addr
+
+
+def get_lut_index(arch, lut_tensor):
+    # Returns the index in SHRAM where the given LUT is stored, a value between 0 and 8
+    slot = (lut_tensor.address - arch.shram_lut_address) // lut_tensor.storage_size()
+    assert 0 <= slot < 8
+    return slot
+
+
+def optimize_high_level_cmd_stream(sg, arch):
+    # - Allocates SHRAM address/lut index to LUT tensors
+    # - Removes unnecessary DMA operations of LUT-s that are already present in SHRAM from sg's command stream
+    cmd_stream = []  # will contain existing command stream minus unneeded DMA operations
+    lut_state = LUTState()
+    slot_size = 256
+    lut_start = arch.shram_lut_address
+    lut_end = lut_start + arch.shram_lut_size
+    for cmd in sg.high_level_command_stream:
+        if cmd.cmdtype == CommandType.NpuStripe and cmd.ps.lut_tensor is None and arch.shram_reserved_unused_banks == 0:
+            # The command overwrites the last 2 banks containing the LUT; next LUT operation will require DMA
+            # TODO: check the command's SHRAM usage in more detail to determine if the LUT is overwritten or not
+            lut_state = LUTState()
+        if cmd.cmdtype != CommandType.DMA or cmd.out_tensor.purpose != TensorPurpose.LUT:
+            # Non-LUT operation; leave untouched
+            cmd_stream.append(cmd)
+            continue
+        # LUT DMA operation
+        lut_tens = cmd.out_tensor
+        existing_tens = lut_state.get_equivalent(lut_tens)
+        if existing_tens is not None:
+            # LUT is already in SHRAM, no need to perform DMA
+            lut_tens.address = existing_tens.address
+            cmd.ps.primary_op.attrs["lut_index"] = get_lut_index(arch, existing_tens)
+            continue
+        # Place the LUT in the last 2 blocks of SHRAM
+        # Alignment is always on the size of the LUT, 256 for 256-byte LUT, 1K for 1K LUT, etc
+        address = lut_state.find_best_address(lut_start, lut_end, lut_tens.storage_size())
+        lut_tens.address = address
+        cmd.ps.primary_op.attrs["lut_index"] = (address - lut_start) // slot_size
+        lut_state = lut_state.put(lut_tens)
+        cmd_stream.append(cmd)
+    sg.high_level_command_stream = cmd_stream
diff --git a/ethosu/vela/numeric_util.py b/ethosu/vela/numeric_util.py
index 70209fb..4ebef8e 100644
--- a/ethosu/vela/numeric_util.py
+++ b/ethosu/vela/numeric_util.py
@@ -89,3 +89,7 @@
 
 def full_shape(dim, shape, fill):
     return ([fill] * (dim - len(shape))) + shape
+
+
+def overlaps(start1, end1, start2, end2):
+    return start1 < end2 and start2 < end1
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index 8e108db..7b69e35 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -381,12 +381,18 @@
                         input_set.add(input_tens)
 
         ordered_input_list = []
+        # Keep LUT-s in a separate list and add as inputs at the end
+        # to avoid that they would accidentally be assigned as ifm or ifm2
+        lut_list = []
         input_refcounts = collections.defaultdict(int)
         for op in ops_list:
             for inp in op.inputs:
                 if inp in input_set:
                     if input_refcounts[inp] == 0:
-                        ordered_input_list.append(inp)
+                        if inp.purpose == TensorPurpose.LUT:
+                            lut_list.append(inp)
+                        else:
+                            ordered_input_list.append(inp)
                     input_refcounts[inp] += 1
 
         name = ops_list[0].name
@@ -416,6 +422,7 @@
         ps.weight_tensor = ps.get_primary_op_ifm_weights()[1]
         ps.scale_tensor = ps.get_primary_op_ifm_weights_biases_ofm()[2]
         ps.lut_tensor = ps.get_primary_op_lut()
+        ps.inputs.extend(lut_list)
 
         for op in ps.ops:
             op.scheduled_pass = ps
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index 0934881..4a9b071 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -277,10 +277,10 @@
     if prev_cmd is None:
         return False
     if (prev_cmd.cmdtype == cmd.cmdtype == CommandType.NpuStripe) and (prev_cmd.ps != cmd.ps):
-        if prev_cmd.ofm_tensor.equivalence_id == cmd.ifm_tensor.equivalence_id:
+        if prev_cmd.ofm_tensor.equivalent(cmd.ifm_tensor):
             return True
         elif cmd.ifm2_tensor is not None:
-            return prev_cmd.ofm_tensor.equivalence_id == cmd.ifm2_tensor.equivalence_id
+            return prev_cmd.ofm_tensor.equivalent(cmd.ifm2_tensor)
     return False
 
 
@@ -560,12 +560,13 @@
                 else:
                     emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, 1, 0)
 
-                # For elementwise set the required SHRAM to be equal to the total size of SHRAM
-                shram_required = arch.shram_total_banks
+                # For elementwise set the required SHRAM to be equal to the total size of available SHRAM
+                uses_lut = primary_op.activation_lut is not None
+                shram_required = arch.available_shram_banks(uses_lut)
                 emit.cmd0_with_param(cmd0.NPU_SET_IFM_IB_END, shram_required)
 
                 # Acc buffers not needed so set AB_START to size of SHRAM
-                emit.cmd0_with_param(cmd0.NPU_SET_AB_START, arch.shram_total_banks)
+                emit.cmd0_with_param(cmd0.NPU_SET_AB_START, shram_required)
 
                 # Is not a unary operator
                 if cmd.ifm2_tensor is not None:
@@ -852,8 +853,8 @@
                     faf_min = quantise_float32(clamp_sigmoid(ifm_min), ofm_quant.scale_f32, ofm_quant.zero_point)
                     faf_max = quantise_float32(clamp_sigmoid(ifm_max), ofm_quant.scale_f32, ofm_quant.zero_point)
             elif faf == "LUT":
-                lut_index = int(activation.LUT_START.value) + primary_op.attrs.get("lut_index", 0)
-                assert lut_index <= activation.LUT_END.value, "LUT index out of range."
+                lut_index = int(activation.LUT_START.value) + primary_op.attrs.get("lut_index", -1)
+                assert activation.LUT_START.value <= lut_index <= activation.LUT_END.value, "LUT index out of range."
                 emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, lut_index)
                 faf_min = ofm_quant_qmin
                 faf_max = ofm_quant_qmax
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index ecca0e0..312e8f3 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -54,16 +54,17 @@
     Dram = 2
     OnChipFlash = 3
     OffChipFlash = 4
-    Size = OffChipFlash + 1
+    Shram = 5  # for LUT
+    Size = Shram + 1
 
     def display_name(self):
-        return ("Unknown", "SRAM", "DRAM", "On-chip Flash", "Off-chip Flash", "Size")[self.value]
+        return ("Unknown", "SRAM", "DRAM", "On-chip Flash", "Off-chip Flash", "SHRAM", "Size")[self.value]
 
     def identifier_name(self):
-        return ("unknown", "sram", "dram", "on_chip_flash", "off_chip_flash", "size")[self.value]
+        return ("unknown", "sram", "dram", "on_chip_flash", "off_chip_flash", "shram", "size")[self.value]
 
     def all():
-        return (MemArea.Sram, MemArea.Dram, MemArea.OnChipFlash, MemArea.OffChipFlash)
+        return (MemArea.Sram, MemArea.Dram, MemArea.OnChipFlash, MemArea.OffChipFlash, MemArea.Shram)
 
     def __str__(self):
         return self.name
@@ -728,6 +729,9 @@
             return True
         return False
 
+    def equivalent(self, tens):
+        return self.equivalence_id == tens.equivalence_id
+
     def set_all_shapes(self, shape):
         self.shape = shape
         self.storage_shape = shape
diff --git a/ethosu/vela/tensor_allocation.py b/ethosu/vela/tensor_allocation.py
index f29296d..bb91145 100644
--- a/ethosu/vela/tensor_allocation.py
+++ b/ethosu/vela/tensor_allocation.py
@@ -26,6 +26,7 @@
 from .nn_graph import TensorAllocator
 from .tensor import MemArea
 from .tensor import MemType
+from .tensor import TensorPurpose
 
 
 def linear_allocate_live_ranges(live_ranges, alloc_granularity=16):
@@ -44,6 +45,11 @@
                 if allocated_tens.weight_compression_config == tens.weight_compression_config:
                     address = allocated_tens.address
                     break
+        if tens.purpose == TensorPurpose.LUT:
+            for allocated_tens in allocated_tensors:
+                if allocated_tens.equivalent(tens):
+                    address = allocated_tens.address
+                    break
         lr.set_address(address)
         allocated_tensors += lr.tensors
         if address == total_sz:
diff --git a/ethosu/vela/test/test_live_range.py b/ethosu/vela/test/test_live_range.py
index 395d0f3..d087dd9 100644
--- a/ethosu/vela/test/test_live_range.py
+++ b/ethosu/vela/test/test_live_range.py
@@ -18,6 +18,7 @@
 from unittest.mock import MagicMock
 
 import pytest
+
 from ethosu.vela.live_range import LiveRange
 
 
diff --git a/ethosu/vela/test/test_lut.py b/ethosu/vela/test/test_lut.py
new file mode 100644
index 0000000..3b7f57b
--- /dev/null
+++ b/ethosu/vela/test/test_lut.py
@@ -0,0 +1,180 @@
+# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Description:
+# Unit tests for LUT support
+import numpy as np
+
+from ethosu.vela import insert_dma
+from ethosu.vela import lut
+from ethosu.vela import mark_tensors
+from ethosu.vela import pass_packing
+from ethosu.vela.data_type import DataType
+from ethosu.vela.high_level_command_stream import DMA
+from ethosu.vela.nn_graph import Graph
+from ethosu.vela.rewrite_graph import verify_graph_health
+from ethosu.vela.tensor import create_const_tensor
+from ethosu.vela.tensor import TensorPurpose
+from ethosu.vela.test import testutil
+
+
+def set_256_lut(op, key):
+    values = list(range(256))
+    lut_tensor = create_const_tensor(
+        op.name + "_lut", [1, 1, 1, 256], DataType.int8, values, np.uint8, TensorPurpose.LUT
+    )
+    lut_tensor.equivalence_id = lut.create_equivalence_id(key)
+    op.set_activation_lut(lut_tensor)
+
+
+def set_1K_lut(op, key):
+    values = list(range(256))
+    lut_tensor = create_const_tensor(
+        op.name + "_lut", [1, 1, 1, 256], DataType.int32, values, np.uint32, TensorPurpose.LUT
+    )
+    lut_tensor.equivalence_id = lut.create_equivalence_id(key)
+    op.set_activation_lut(lut_tensor)
+
+
+def set_2K_lut(op, key):
+    values = list(range(512))
+    lut_tensor = create_const_tensor(
+        op.name + "_lut", [1, 1, 1, 512], DataType.int32, values, np.uint32, TensorPurpose.LUT
+    )
+    lut_tensor.equivalence_id = lut.create_equivalence_id(key)
+    op.set_activation_lut(lut_tensor)
+
+
+def process(arch, op_list):
+    # Returns subgraph with given operations
+    nng = Graph()
+    sg = testutil.create_subgraph(op_list)
+    nng.subgraphs.append(sg)
+    assert verify_graph_health(nng)
+    nng = mark_tensors.mark_tensor_purpose(nng, arch, False)
+    assert verify_graph_health(nng)
+    nng = insert_dma.insert_dma_commands(nng, arch, False)
+    assert verify_graph_health(nng)
+    pass_packing.pack_into_passes(nng, arch, False)
+    assert verify_graph_health(nng)
+    # Create a DMA instruction for every op
+    cmd_list = []
+    for ps in sg.passes:
+        for intermediate in ps.intermediates:
+            if intermediate.needs_dma():
+                cmd_list.append(DMA(ps, intermediate.get_dma_src_tensor(), intermediate, None))
+    sg.high_level_command_stream = cmd_list
+    return sg
+
+
+def test_optimize_high_level_cmd_stream_2K():
+    # Tests lut.optimize_high_level_cmd_stream, blending 256 byte and 2K luts
+    arch = testutil.create_arch()
+    shape = [1, 1, 1, 1]
+    # u8 LUT op, should lead to DMA
+    op0 = testutil.create_elemwise_op("AddAct", "op0", shape, shape, shape)
+    set_256_lut(op0, "lut0")
+    # u8 LUT op, should lead to DMA
+    op1 = testutil.create_elemwise_op("AddAct", "op1", shape, shape, shape)
+    set_256_lut(op1, "lut1")
+    # u8 LUT op with different LUT, should lead to DMA
+    op2 = testutil.create_elemwise_op("AddAct", "op2", shape, shape, shape)
+    set_256_lut(op2, "lut2")
+    # u8 LUT op with same LUT as in op1, should not lead to DMA
+    op3 = testutil.create_elemwise_op("AddAct", "op3", shape, shape, shape)
+    set_256_lut(op3, "lut1")
+    # u8 LUT op with same LUT as in op2, should not lead to DMA
+    op4 = testutil.create_elemwise_op("AddAct", "op4", shape, shape, shape)
+    set_256_lut(op4, "lut2")
+    # 2K LUT op, should lead to DMA, and will overwrite all previous LUTs in SHRAM
+    op5_2K = testutil.create_elemwise_op("AddAct", "op5", shape, shape, shape)
+    set_2K_lut(op5_2K, "lut5")
+    # Another 2K LUT op, should lead to DMA, and will overwrite the previous LUT in SHRAM
+    op6_2K = testutil.create_elemwise_op("AddAct", "op6", shape, shape, shape)
+    set_2K_lut(op6_2K, "lut6")
+    # u8 LUT op with same LUT as in op1, should lead to DMA
+    op7 = testutil.create_elemwise_op("AddAct", "op7", shape, shape, shape)
+    set_256_lut(op7, "lut1")
+
+    op_list = [op0, op1, op2, op3, op4, op5_2K, op6_2K, op7]
+    sg = process(arch, op_list)
+    orig_cmd_list = sg.high_level_command_stream
+    sg.high_level_command_stream = orig_cmd_list
+    lut.optimize_high_level_cmd_stream(sg, arch)
+    cmd_list = sg.high_level_command_stream
+    # Check that only the needed DMA commands are left
+    expected_dma_ops = [op0, op1, op2, op5_2K, op6_2K, op7]
+    for (cmd, op) in zip(cmd_list, expected_dma_ops):
+        assert cmd.in_tensor == op.activation_lut
+    # Check that lut0, lut1 and lut2 in op0, op1, op2 are stored on different addresses
+    assert orig_cmd_list[0].out_tensor.address != orig_cmd_list[1].out_tensor.address
+    assert orig_cmd_list[0].out_tensor.address != orig_cmd_list[2].out_tensor.address
+    assert orig_cmd_list[1].out_tensor.address != orig_cmd_list[2].out_tensor.address
+    # Check that lut1 in op1 and op3 have same address
+    assert orig_cmd_list[1].out_tensor.address == orig_cmd_list[3].out_tensor.address
+    # Check that lut2 in op2 and op4 have same address
+    assert orig_cmd_list[2].out_tensor.address == orig_cmd_list[4].out_tensor.address
+    # Check that lut-s for 16 bit (op5 and op6) are stored on same address
+    assert orig_cmd_list[5].out_tensor.address == orig_cmd_list[6].out_tensor.address
+
+
+def test_optimize_high_level_cmd_stream_1K():
+    # Tests lut.optimize_high_level_cmd_stream, blending 256 and 1K luts
+    arch = testutil.create_arch()
+    shape = [1, 1, 1, 1]
+    # u8 LUT op, should lead to DMA
+    op0 = testutil.create_elemwise_op("AddAct", "op0", shape, shape, shape)
+    set_256_lut(op0, "lut0")
+    # u8 LUT op, should lead to DMA
+    op1 = testutil.create_elemwise_op("AddAct", "op1", shape, shape, shape)
+    set_256_lut(op1, "lut1")
+    # 1K LUT op with different LUT, should lead to DMA
+    op2_1K = testutil.create_elemwise_op("AddAct", "op2", shape, shape, shape)
+    set_1K_lut(op2_1K, "lut2")
+    # u8 LUT op with same LUT as in op1, should not lead to DMA
+    op3 = testutil.create_elemwise_op("AddAct", "op3", shape, shape, shape)
+    set_256_lut(op3, "lut1")
+    # 1K LUT op with same LUT as in op2, should not lead to DMA
+    op4_1K = testutil.create_elemwise_op("AddAct", "op4", shape, shape, shape)
+    set_1K_lut(op4_1K, "lut2")
+    # 1K LUT op, should lead to DMA, and will overwrite lut2
+    op5_2K = testutil.create_elemwise_op("AddAct", "op5", shape, shape, shape)
+    set_1K_lut(op5_2K, "lut5")
+    # u8 LUT op, lut0 should still be present, should not lead to DMA
+    op6 = testutil.create_elemwise_op("AddAct", "op6", shape, shape, shape)
+    set_256_lut(op6, "lut0")
+    # 1K LUT op with same LUT as in op2, should lead to DMA
+    op7 = testutil.create_elemwise_op("AddAct", "op7", shape, shape, shape)
+    set_1K_lut(op7, "lut2")
+
+    op_list = [op0, op1, op2_1K, op3, op4_1K, op5_2K, op6, op7]
+    sg = process(arch, op_list)
+    orig_cmd_list = sg.high_level_command_stream
+    sg.high_level_command_stream = orig_cmd_list
+    lut.optimize_high_level_cmd_stream(sg, arch)
+    cmd_list = sg.high_level_command_stream
+    # Check that only the needed DMA commands are left
+    expected_dma_ops = [op0, op1, op2_1K, op5_2K, op7]
+    for (cmd, op) in zip(cmd_list, expected_dma_ops):
+        assert cmd.in_tensor == op.activation_lut
+    # Check that lut0, lut1 and lut2 in op0, op1, op2 are stored on different addresses
+    assert orig_cmd_list[0].out_tensor.address != orig_cmd_list[1].out_tensor.address
+    assert orig_cmd_list[0].out_tensor.address != orig_cmd_list[2].out_tensor.address
+    assert orig_cmd_list[1].out_tensor.address != orig_cmd_list[2].out_tensor.address
+    # Check that lut1 in op1 and op3 have same address
+    assert orig_cmd_list[1].out_tensor.address == orig_cmd_list[3].out_tensor.address
+    # Check that lut2 in op2 and op4 and op7 have same address
+    assert orig_cmd_list[2].out_tensor.address == orig_cmd_list[4].out_tensor.address
+    assert orig_cmd_list[2].out_tensor.address == orig_cmd_list[7].out_tensor.address
diff --git a/ethosu/vela/test/test_model_reader.py b/ethosu/vela/test/test_model_reader.py
index 23e7e90..bd7ca37 100644
--- a/ethosu/vela/test/test_model_reader.py
+++ b/ethosu/vela/test/test_model_reader.py
@@ -16,6 +16,7 @@
 # Description:
 # Unit tests for model_reader.
 import pytest
+
 from ethosu.vela import model_reader
 from ethosu.vela.errors import InputFileError
 
diff --git a/ethosu/vela/test/test_tflite_reader.py b/ethosu/vela/test/test_tflite_reader.py
index 898e384..1ba0742 100644
--- a/ethosu/vela/test/test_tflite_reader.py
+++ b/ethosu/vela/test/test_tflite_reader.py
@@ -16,6 +16,7 @@
 # Description:
 # Contains unit tests for tflite_reader
 import pytest
+
 from ethosu.vela.tflite_reader import TFLiteSubgraph
 
 
diff --git a/ethosu/vela/test/testutil.py b/ethosu/vela/test/testutil.py
new file mode 100644
index 0000000..116afa4
--- /dev/null
+++ b/ethosu/vela/test/testutil.py
@@ -0,0 +1,70 @@
+# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Description:
+# Utilities used in vela unit tests
+import numpy as np
+
+from ethosu.vela import architecture_features
+from ethosu.vela.data_type import DataType
+from ethosu.vela.nn_graph import Subgraph
+from ethosu.vela.operation import NpuBlockType
+from ethosu.vela.operation import Operation
+from ethosu.vela.tensor import create_const_tensor
+from ethosu.vela.tensor import MemArea
+from ethosu.vela.tensor import Tensor
+
+
+def create_arch():
+    return architecture_features.ArchitectureFeatures(
+        vela_config=None,
+        system_config=None,
+        accelerator_config=architecture_features.Accelerator.Ethos_U55_128.value,
+        permanent_storage=MemArea.OnChipFlash,
+        override_block_config=None,
+        block_config_limit=None,
+        global_memory_clock_scale=1.0,
+        max_blockdep=0,
+        softmax_support=True,
+    )
+
+
+def create_elemwise_op(type, name, ifm_shape, ifm2_shape, ofm_shape, datatype=DataType.uint8):
+    # Creates elementwise operation with constant IFM/IFM2
+    if datatype.size_in_bytes() == 1:
+        np_type = np.uint8
+    elif datatype.size_in_bytes() == 2:
+        np_type = np.int16
+    else:
+        np_type = np.int32
+    op = Operation(type, name)
+    op.add_input_tensor(create_const_tensor(name + "_ifm", ifm_shape, datatype, np.zeros(ifm_shape), np_type))
+    op.add_input_tensor(create_const_tensor(name + "_ifm2", ifm2_shape, datatype, np.zeros(ifm2_shape), np_type))
+    ofm = Tensor(ofm_shape, datatype, name + "_ofm")
+    op.set_output_tensor(ofm)
+    op.attrs["npu_block_type"] = NpuBlockType.ElementWise
+    return op
+
+
+def create_subgraph(op_list):
+    # Creates subgraph using the given list of operations
+    sg = Subgraph()
+    all_inputs = set(tens for op in op_list for tens in op.inputs)
+    # Reversing, so that the resulting subgraph has same order as op_list
+    for op in op_list[::-1]:
+        for tens in op.outputs:
+            if tens not in all_inputs and tens not in sg.output_tensors:
+                sg.output_tensors.append(tens)
+    return sg