MLBEDSW-3019: Add profiling debug database

 - Added mechanism to track input to output graph transforms for
   debugging the resultant command stream.
 - Provides base implementation for MLBEDSW-2661

Signed-off-by: Tim Hall <tim.hall@arm.com>
Change-Id: I2dfe8a409fbde7ad0282bfab5acb11ba1c8b82d8
diff --git a/OPTIONS.md b/OPTIONS.md
index 9220151..9aaf67b 100644
--- a/OPTIONS.md
+++ b/OPTIONS.md
@@ -204,6 +204,19 @@
 vela network.tflite --recursion-limit 50000
 ```
 
+### Enable Debug DB
+
+The neural network debug database allows tracking of optimisations from the
+input network graph to the output command stream.  Set this option to enable the
+calculation and writing of an XML file that contains the network debug database
+tables to the output directory.  
+**Type: Boolean**  
+**Default: Disabled**  
+
+```bash
+vela network.tflite --enable-debug-db
+```
+
 ### Max Block Dependency
 
 Set the maximum value that can be used for the block dependency delay between
diff --git a/ethosu/vela/compiler_driver.py b/ethosu/vela/compiler_driver.py
index 9263305..e089b70 100644
--- a/ethosu/vela/compiler_driver.py
+++ b/ethosu/vela/compiler_driver.py
@@ -31,10 +31,13 @@
 from . import scheduler
 from . import tensor_allocation
 from . import weight_compressor
+from .debug_database import DebugDatabase
 from .errors import VelaError
 from .nn_graph import PassPlacement
 from .nn_graph import TensorAllocator
+from .operation import Op
 from .rewrite_graph import verify_graph_health
+from .rewrite_graph import visit_graph_post_order
 from .tensor import MemType
 from .tensor import Tensor
 
@@ -127,8 +130,18 @@
     return ((lower + upper) / 2, True)
 
 
+def _record_operator(op, arch):
+    if op.type != Op.Const:
+        DebugDatabase.add_source(op)
+
+
 def compiler_driver(nng, arch, options, scheduler_options):
     assert verify_graph_health(nng)
+
+    # Pre-optimisation operator tracking
+    for sg in nng.subgraphs:
+        visit_graph_post_order(sg.output_tensors, arch, [], [_record_operator])
+
     nng = graph_optimiser.optimise_graph_a(nng, arch, options.verbose_graph)
     assert verify_graph_health(nng)
 
diff --git a/ethosu/vela/debug_database.py b/ethosu/vela/debug_database.py
new file mode 100644
index 0000000..b5852cd
--- /dev/null
+++ b/ethosu/vela/debug_database.py
@@ -0,0 +1,121 @@
+# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import csv
+import io
+
+import lxml.etree as xml
+
+from . import numeric_util
+from .operation import Operation
+
+
+class DebugDatabase:
+    NULLREF = -1
+    show_warnings = False
+
+    SOURCE_TABLE = "source"
+    _sourceUID = {}
+    _sourceHeaders = ["id", "operator", "kernel_w", "kernel_h", "ofm_w", "ofm_h", "ofm_d"]
+    _sourceTable = []
+
+    OPTIMISED_TABLE = "optimised"
+    _optimisedUID = {}
+    _optimisedHeaders = ["id", "source_id", "operator", "kernel_w", "kernel_h", "ofm_w", "ofm_h", "ofm_d"]
+    _optimisedTable = []
+
+    QUEUE_TABLE = "queue"
+    _queueHeaders = ["offset", "cmdstream_id", "optimised_id"]
+    _queueTable = []
+
+    STREAM_TABLE = "cmdstream"
+    _streamUID = {}
+    _streamHeaders = ["id", "file_offset"]
+    _streamTable = []
+
+    @classmethod
+    def add_source(cls, op: Operation):
+        assert isinstance(op, Operation)
+        uid = len(cls._sourceUID)
+        cls._sourceUID[op] = uid
+        ofm_shape = numeric_util.full_shape(3, op.outputs[0].shape, 1)
+        cls._sourceTable.append(
+            [uid, op.type, op.kernel.width, op.kernel.height, ofm_shape[-2], ofm_shape[-3], ofm_shape[-1]]
+        )
+
+    @classmethod
+    def add_optimised(cls, parent: Operation, op: Operation):
+        assert isinstance(parent, Operation) and isinstance(op, Operation)
+        if op not in cls._optimisedUID:
+            if parent not in cls._sourceUID:
+                # The the parent wasn't in the source network try to look it
+                # up in the optimised network and use that op's source parent.
+                if parent in cls._optimisedUID:
+                    src_uid = cls._optimisedUID[parent][1]
+                else:
+                    if DebugDatabase.show_warnings:
+                        print("Debug Database: Associated parent '{0}' not in network".format(parent.type))
+                    src_uid = DebugDatabase.NULLREF
+            else:
+                src_uid = cls._sourceUID[parent]
+            uid = len(cls._optimisedUID)
+            cls._optimisedUID[op] = (uid, src_uid)
+            ofm_shape = numeric_util.full_shape(3, op.outputs[0].shape, 1)
+            cls._optimisedTable.append(
+                [uid, src_uid, op.type, op.kernel.width, op.kernel.height, ofm_shape[-2], ofm_shape[-3], ofm_shape[-1]]
+            )
+
+    @classmethod
+    def add_stream(cls, key):
+        if key not in cls._streamUID:
+            uid = len(cls._streamUID)
+            cls._streamUID[key] = uid
+        return uid
+
+    @classmethod
+    def set_stream_offset(cls, key, file_offset):
+        assert key in cls._streamUID
+        uid = cls._streamUID[key]
+        cls._streamTable.append([uid, file_offset])
+
+    @classmethod
+    def add_command(cls, stream_id, offset, op: Operation):
+        assert stream_id < len(cls._streamUID)
+        assert op in cls._optimisedUID, "Optimised operator must exist before code generation"
+        optimised_id = cls._optimisedUID[op][0]
+        cls._queueTable.append([offset, stream_id, optimised_id])
+
+    @classmethod
+    def _write_table(cls, root, name, headers, table):
+        # Convert table to CSV
+        out = io.StringIO()
+        writer = csv.writer(out, quoting=csv.QUOTE_NONNUMERIC)
+        writer.writerow(headers)
+        writer.writerows(table)
+
+        # Package table into XML output
+        table = xml.SubElement(root, "table", {"name": name})
+        table.text = xml.CDATA(out.getvalue())
+
+    @classmethod
+    def write(cls, file_path, input_file, output_file):
+        root = xml.Element("debug", {"source": input_file, "optimised": output_file})
+
+        cls._write_table(root, cls.SOURCE_TABLE, cls._sourceHeaders, cls._sourceTable)
+        cls._write_table(root, cls.OPTIMISED_TABLE, cls._optimisedHeaders, cls._optimisedTable)
+        cls._write_table(root, cls.QUEUE_TABLE, cls._queueHeaders, cls._queueTable)
+        cls._write_table(root, cls.STREAM_TABLE, cls._streamHeaders, cls._streamTable)
+
+        xml.ElementTree(root).write(file_path, encoding="utf-8", xml_declaration=True, pretty_print=True)
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index e31348b..7304630 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -25,6 +25,7 @@
 from . import rewrite_graph
 from . import scaling
 from .data_type import DataType
+from .debug_database import DebugDatabase
 from .errors import UnsupportedFeatureError
 from .ethos_u55_regs.ethos_u55_regs import resampling_mode
 from .numeric_util import clamp_sigmoid
@@ -77,6 +78,7 @@
             new_op.attrs["concat_end"] = offset
             new_op.run_on_npu = True
             tens.ops.append(new_op)
+            DebugDatabase.add_optimised(concat_op, new_op)
         assert tens.shape[axis] == offset
 
         # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
@@ -128,6 +130,7 @@
         new_op.attrs["split_end"] = offset_end
         new_op.run_on_npu = True
         new_op.set_output_tensor(tens)
+        DebugDatabase.add_optimised(split_op, new_op)
 
     return tens
 
@@ -399,6 +402,7 @@
             reshape_op.attrs["new_shape"] = desired_shape
             reshape_op.inputs = [inp, new_shape_tens]
             reshape_op.set_output_tensor(reshape_out)
+            DebugDatabase.add_optimised(op, reshape_op)
 
             op.inputs[idx] = reshape_out
 
@@ -492,6 +496,7 @@
             reshape_op.attrs["new_shape"] = reshape_input_shape
             reshape_op.inputs = [reshape_in, new_shape_tens]
             reshape_op.set_output_tensor(out_tens)
+            DebugDatabase.add_optimised(op, reshape_op)
 
             op.outputs[idx] = reshape_in
 
@@ -568,6 +573,7 @@
                     op.attrs["depth_multiplier"], ifm_tensor.shape[3], ofm_tensor.shape[3]
                 )
             )
+        DebugDatabase.add_optimised(op, op)
     return op
 
 
@@ -616,6 +622,9 @@
             reshape_op.set_output_tensor(orig_ofm_tensor)
             # Replace this ops OFM to point to the 2D tensor
             op.outputs[0] = fc_ofm_tensor
+            # Record optimisation in debug database
+            DebugDatabase.add_optimised(op, reshape_op)
+            DebugDatabase.add_optimised(op, op)
     return op
 
 
@@ -670,6 +679,10 @@
 
             # Mark the op so that it will be removed as passthrough later on
             op.type = Op.Identity
+
+            # Record optimisation in debug database
+            DebugDatabase.add_optimised(op, act_op)
+            DebugDatabase.add_optimised(op, op)
     return op
 
 
@@ -788,6 +801,10 @@
         op.name = op.name.replace("Maximum", new_op.name)
         op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
         op.inputs = [shared_in]
+
+        # Record optimisation in debug database
+        DebugDatabase.add_optimised(op, op)
+
     return op
 
 
@@ -812,6 +829,7 @@
     mul_alpha.add_input_tensor(alpha_tens)
     fm_alpha = ofm.clone(op.name + "_alpha")
     mul_alpha.set_output_tensor(fm_alpha)
+    DebugDatabase.add_optimised(op, mul_alpha)
 
     if check_quantized_tens_scaling_equal(ifm, ofm):
         # No identity multiplication is needed
@@ -832,6 +850,7 @@
         mul_identity.add_input_tensor(identity_tens)
         fm_id = ofm.clone(op.name + "_id")
         mul_identity.set_output_tensor(fm_id)
+        DebugDatabase.add_optimised(op, mul_alpha)
 
     # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
     op.type = Op.Maximum
@@ -840,6 +859,8 @@
     ifm.consumer_list.remove(op)
     op.add_input_tensor(fm_alpha)
     op.add_input_tensor(fm_id)
+
+    DebugDatabase.add_optimised(op, op)
     return op
 
 
@@ -1012,6 +1033,7 @@
         prev_op.set_activation_lut(op.activation_lut)
     # Bypass op
     prev_op.set_output_tensor(ofm)
+    DebugDatabase.add_optimised(op, prev_op)
     return op
 
 
@@ -1052,6 +1074,11 @@
     return op
 
 
+def _record_optimised(op, arch):
+    if op.type != Op.Const:
+        DebugDatabase.add_optimised(op, op)
+
+
 def optimise_graph_a(nng, arch, verbose_graph=False):
     if verbose_graph:
         nng.print_graph()
@@ -1093,6 +1120,10 @@
             nng, sg, arch, [remove_passthrough_tensor], [fuse_activation_function_with_prev, add_padding_fields]
         )
 
+    # Post-optimisation operator debug tracing
+    for sg in nng.subgraphs:
+        rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [_record_optimised])
+
     if verbose_graph:
         nng.print_graph()
     return nng
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index 5673c2d..59376a8 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -18,6 +18,7 @@
 import collections
 import enum
 
+from .debug_database import DebugDatabase
 from .nn_graph import Pass
 from .nn_graph import PassPlacement
 from .operation import create_avgpool_nop
@@ -430,7 +431,6 @@
             # Configure a 1x1 AvgPool and attach the op onto it
             op = op_list[0]
             inp = op.inputs[0]
-
             avgpool_op = create_avgpool_nop(op.name + "_avgpool")
             avgpool_op.add_input_tensor(inp)
             avgpool_out = inp.clone("_avgpooled")
@@ -440,6 +440,7 @@
             op.inputs[0] = avgpool_out
             op_list.insert(0, avgpool_op)
 
+            DebugDatabase.add_optimised(op, avgpool_op)
             return avgpool_op
 
         return None
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index e5e4fb1..e3fedfc 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -32,6 +32,7 @@
 from .architecture_features import SHRAMElements
 from .data_type import BaseType
 from .data_type import DataType
+from .debug_database import DebugDatabase
 from .ethos_u55_regs.ethos_u55_regs import acc_format
 from .ethos_u55_regs.ethos_u55_regs import activation
 from .ethos_u55_regs.ethos_u55_regs import cmd0
@@ -96,10 +97,13 @@
 
 
 class CommandStreamEmitter:
+    WORD_SIZE = 4
+
     def __init__(self):
         self.cmd_stream = []
         self.reg_machine = [RegisterMachine(), RegisterMachine()]
         self.last_absolute_wait = defaultdict(int)
+        self.offset = 0
 
     def get_reg_machine(self, cmd):
         if "DMA" in cmd.name:
@@ -110,7 +114,7 @@
     def size_in_bytes(self):
         sz = 0
         for cmd in self.cmd_stream:
-            sz += len(cmd) * 4
+            sz += len(cmd) * CommandStreamEmitter.WORD_SIZE
         return sz
 
     def to_list(self):
@@ -154,6 +158,7 @@
 
         # This is not a redundant command, actually write it
         self.cmd_stream.append((command,))
+        self.offset += CommandStreamEmitter.WORD_SIZE
 
     def cmd1_with_offset(self, cmd, offset, param=0x0):
         offset = int(offset) & 0xFFFFFFFFF
@@ -164,17 +169,20 @@
 
         # This is not a redundant command, actually write it
         self.cmd_stream.append((command, offset))
+        self.offset += CommandStreamEmitter.WORD_SIZE * 2
 
     def cmd_wait(self, cmd, channel, outstanding_count):
         param = (16 * channel) + outstanding_count
         command = ((param & 0xFFFF) << 16) | cmd.value
         self.cmd_stream.append((command,))
+        self.offset += CommandStreamEmitter.WORD_SIZE
 
     def cmd_do_operation(self, cmd, param=0):
         param = int(param)
         command = ((param & 0xFFFF) << 16) | cmd.value
 
         self.cmd_stream.append((command,))
+        self.offset += CommandStreamEmitter.WORD_SIZE
         self.get_reg_machine(cmd).switch_bank()
 
 
@@ -378,6 +386,9 @@
 
     dep_watermark = Watermark(0, 0)
 
+    stream_id = DebugDatabase.add_stream(sg)
+    DebugDatabase.set_stream_offset(sg, 0)  # Default to zero, can only set during file writing
+
     for cmd_index, cmd in enumerate(cmd_stream):
         dep_watermark, cmd_waits = get_cmd_wait_dependency(arch, cmd_stream, memory_accesses, cmd_index, dep_watermark)
 
@@ -1077,6 +1088,7 @@
             prev_cmd = cmd
 
             emit_cmd_waits(cmd_waits)
+            DebugDatabase.add_command(stream_id, emit.offset, primary_op)
 
             if npu_block_type == NpuBlockType.ConvolutionMxN:
                 emit.cmd_do_operation(cmd0.NPU_OP_CONV)
diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py
index 12c2016..efd91a3 100644
--- a/ethosu/vela/softmax.py
+++ b/ethosu/vela/softmax.py
@@ -25,6 +25,7 @@
 from . import fp_math
 from . import scaling
 from .data_type import DataType
+from .debug_database import DebugDatabase
 from .operation import Op
 from .operation import Operation
 from .tensor import create_const_tensor
@@ -220,6 +221,9 @@
 
     def get_graph_8bit(self, ifm, ofm):
         exp_lut = self.generate_exp_table(self.op.attrs.get("beta", 1.0), ifm.quantization.scale_f32)
+        ifm = create_reshape_tensor(ifm, ifm.get_full_shape())
+        DebugDatabase.add_optimised(self.op, ifm.ops[0])
+        ofm = create_reshape_tensor(ofm, ofm.get_full_shape(), False)
         no_scale_quant = ifm.quantization.clone()
         no_scale_quant.scale_f32 = None
         no_scale_quant.zero_point = 0
@@ -245,6 +249,7 @@
         ifm_max = Tensor([1, maxpool_h, 1, 1], ifm.dtype, maxpool_op.name + "_0")
         ifm_max.quantization = no_scale_quant
         maxpool_op.set_output_tensor(ifm_max)
+        DebugDatabase.add_optimised(self.op, maxpool_op)
 
         # PASS 1 - Sub+LUT(exp)
         sub_op = Operation(Op.Sub, self.op.name + "_sub1")
@@ -261,6 +266,7 @@
         ifm_exp.quantization.quant_min = -128
         ifm_exp.quantization.quant_max = 127
         sub_op.set_output_tensor(ifm_exp)
+        DebugDatabase.add_optimised(self.op, sub_op)
 
         # PASS 2 - SHR
         shr2_op = Operation(Op.SHR, self.op.name + "_shr2")
@@ -274,6 +280,7 @@
         rescaled_exp = Tensor(ifm.shape, ifm_exp.dtype, shr2_op.name + "_0")
         rescaled_exp.quantization = no_scale_quant
         shr2_op.set_output_tensor(rescaled_exp)
+        DebugDatabase.add_optimised(self.op, shr2_op)
 
         # PASS 3 - Reduce sum
         reduce_sum_op = Operation(Op.ReduceSum, self.op.name + "_reduce_sum3")
@@ -290,6 +297,7 @@
         sum_of_exp = Tensor(reduce_sum_shape, DataType.int32, reduce_sum_op.name + "_0")
         sum_of_exp.quantization = no_scale_quant
         reduce_sum_op.set_output_tensor(sum_of_exp)
+        DebugDatabase.add_optimised(self.op, reduce_sum_op)
 
         # PASS 4 - CLZ
         clz_op = Operation(Op.CLZ, self.op.name + "_clz4")
@@ -297,6 +305,7 @@
         headroom_plus_one = Tensor(reduce_sum_shape, DataType.int32, clz_op.name + "_0")
         headroom_plus_one.quantization = no_scale_quant
         clz_op.set_output_tensor(headroom_plus_one)
+        DebugDatabase.add_optimised(self.op, clz_op)
 
         # PASS 5 - Sub
         sub5_op = Operation(Op.Sub, self.op.name + "_sub5")
@@ -314,6 +323,7 @@
         right_shift = Tensor(reduce_sum_shape, DataType.int32, sub5_op.name + "_0")
         right_shift.quantization = no_scale_quant
         sub5_op.set_output_tensor(right_shift)
+        DebugDatabase.add_optimised(self.op, sub5_op)
 
         # PASS 6 - Sub
         one = create_const_tensor("one_const", [1, 1, 1, 1], DataType.int32, [1], np.int32, quantization=no_scale_quant)
@@ -323,6 +333,7 @@
         headroom = Tensor(reduce_sum_shape, DataType.int32, sub6_op.name + "_0")
         headroom.quantization = no_scale_quant
         sub6_op.set_output_tensor(headroom)
+        DebugDatabase.add_optimised(self.op, sub6_op)
 
         # PASS 7 - SHL
         shl7_op = Operation(Op.SHL, self.op.name + "_shl7")
@@ -331,6 +342,7 @@
         shifted_sum = Tensor(reduce_sum_shape, DataType.int32, shl7_op.name + "_0")
         shifted_sum.quantization = no_scale_quant
         shl7_op.set_output_tensor(shifted_sum)
+        DebugDatabase.add_optimised(self.op, shl7_op)
 
         # PASS 8 - Sub
         sub8_op = Operation(Op.Sub, self.op.name + "_sub8")
@@ -343,6 +355,7 @@
         shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, sub8_op.name + "_0")
         shifted_sum_minus_one.quantization = no_scale_quant
         sub8_op.set_output_tensor(shifted_sum_minus_one)
+        DebugDatabase.add_optimised(self.op, sub8_op)
 
         # PASS 9 - SHL
         shl9_op = Operation(Op.SHL, self.op.name + "_shl9")
@@ -351,6 +364,7 @@
         shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, shl9_op.name + "_0")
         shifted_sum_minus_one.quantization = no_scale_quant
         shl9_op.set_output_tensor(shifted_sum_minus_one)
+        DebugDatabase.add_optimised(self.op, shl9_op)
 
         # PASS 10 - Add
         add10_op = Operation(Op.Add, self.op.name + "_add10")
@@ -364,6 +378,7 @@
         half_denominator = Tensor(reduce_sum_shape, DataType.int32, add10_op.name + "_0")
         half_denominator.quantization = one_scale_quant
         add10_op.set_output_tensor(half_denominator)
+        DebugDatabase.add_optimised(self.op, add10_op)
 
         # PASS 11 - Multiply
         mul11_op = Operation(Op.Mul, self.op.name + "_mul11")
@@ -382,6 +397,7 @@
         rescaled.quantization = one_scale_quant.clone()
         rescaled.quantization.scale_f32 = 2.0
         mul11_op.set_output_tensor(rescaled)
+        DebugDatabase.add_optimised(self.op, mul11_op)
 
         # PASS 12 - Add
         add12_op = Operation(Op.Add, self.op.name + "_add12")
@@ -394,6 +410,7 @@
         rescale_w_offset = Tensor(reduce_sum_shape, DataType.int32, add12_op.name + "_0")
         rescale_w_offset.quantization = one_scale_quant
         add12_op.set_output_tensor(rescale_w_offset)
+        DebugDatabase.add_optimised(self.op, add12_op)
 
         nr_x = rescale_w_offset
         F2_one = create_const_tensor(
@@ -411,6 +428,7 @@
             half_denominator_times_x.quantization = one_scale_quant.clone()
             half_denominator_times_x.quantization.scale_f32 = 2.0
             mul_op.set_output_tensor(half_denominator_times_x)
+            DebugDatabase.add_optimised(self.op, mul_op)
             # PASS 14, 19, 24 - SUB
             sub_op = Operation(Op.Sub, self.op.name + "_sub%d" % (14 + i * 5))
             sub_op.add_input_tensor(F2_one)
@@ -418,6 +436,7 @@
             one_minus_half_denominator_times_x = Tensor(reduce_sum_shape, DataType.int32, sub_op.name + "_0")
             one_minus_half_denominator_times_x.quantization = one_scale_quant
             sub_op.set_output_tensor(one_minus_half_denominator_times_x)
+            DebugDatabase.add_optimised(self.op, sub_op)
             # PASS 15, 20, 25 - MUL
             mul_op = Operation(Op.Mul, self.op.name + "_mul%d" % (15 + i * 5))
             mul_op.add_input_tensor(nr_x)
@@ -426,6 +445,7 @@
             to_rescale.quantization = one_scale_quant.clone()
             to_rescale.quantization.scale_f32 = 2.0
             mul_op.set_output_tensor(to_rescale)
+            DebugDatabase.add_optimised(self.op, mul_op)
             # PASS 16, 21, 26 - MUL
             shl_op = Operation(Op.Mul, self.op.name + "_mul%d" % (16 + i * 5))
             shl_op.add_input_tensor(to_rescale)
@@ -433,6 +453,7 @@
             to_add = Tensor(reduce_sum_shape, DataType.int32, shl_op.name + "_0")
             to_add.quantization = no_scale_quant
             shl_op.set_output_tensor(to_add)
+            DebugDatabase.add_optimised(self.op, shl_op)
             # PASS 17, 22, 27 - ADD
             add_op = Operation(Op.Add, self.op.name + "_add%d" % (17 + i * 5))
             add_op.add_input_tensor(nr_x)
@@ -440,6 +461,7 @@
             nr_x = Tensor(reduce_sum_shape, DataType.int32, add_op.name + "_0")
             nr_x.quantization = one_scale_quant
             add_op.set_output_tensor(nr_x)
+            DebugDatabase.add_optimised(self.op, add_op)
 
         # PASS 28 - Multiply
         mul28_op = Operation(Op.Mul, self.op.name + "_mul28")
@@ -450,6 +472,7 @@
         scale_factor = Tensor(reduce_sum_shape, DataType.int32, mul28_op.name + "_0")
         scale_factor.quantization = one_scale_quant
         mul28_op.set_output_tensor(scale_factor)
+        DebugDatabase.add_optimised(self.op, mul28_op)
 
         # PASS 29 - Multiply
         mul_op = Operation(Op.Mul, self.op.name + "_mul29")
@@ -459,6 +482,7 @@
         scaled_exp.quantization = one_scale_quant.clone()
         scaled_exp.quantization.scale_f32 = 2.0
         mul_op.set_output_tensor(scaled_exp)
+        DebugDatabase.add_optimised(self.op, mul_op)
 
         # PASS 30 - SHR
         shr30_op = Operation(Op.SHR, self.op.name + "_shr30")
@@ -466,6 +490,7 @@
         shr30_op.add_input_tensor(scaled_exp)
         shr30_op.add_input_tensor(right_shift)
         shr30_op.set_output_tensor(ofm)
+        DebugDatabase.add_optimised(self.op, shr30_op)
 
         return shr30_op
 
@@ -476,6 +501,7 @@
         # PASS 0 - Depthwise Maxpool
         maxpool_op = self.op.clone("_maxpool0")
         maxpool_op.type = Op.MaxPool
+        DebugDatabase.add_optimised(self.op, maxpool_op)
         maxpool_h = ifm.shape[1] * ifm.shape[2]
         maxpool_w = ifm.shape[3]
         maxpool_ifm_shape = [1, maxpool_h, maxpool_w, 1]
@@ -490,6 +516,7 @@
         maxpool_ofm = Tensor([1, maxpool_h, 1, 1], ifm.dtype, maxpool_op.name + "_0")
         maxpool_ofm.quantization = no_scale_quant
         maxpool_op.set_output_tensor(maxpool_ofm)
+        DebugDatabase.add_optimised(self.op, maxpool_op)
 
         # PASS 1 - Sub
         sub1_op = Operation(Op.Sub, self.op.name + "_sub1")
@@ -498,6 +525,7 @@
         sub1_ofm = Tensor(ifm.shape, DataType.int32, sub1_op.name + "_0")
         sub1_ofm.quantization = ifm.quantization.clone()
         sub1_op.set_output_tensor(sub1_ofm)
+        DebugDatabase.add_optimised(self.op, sub1_op)
 
         # PASS 2 - Mul
         beta = self.op.attrs.get("beta", 1.0)
@@ -516,6 +544,7 @@
         mul2_ofm.quantization = ofm.quantization.clone()
         mul2_ofm.quantization.scale_f32 = mul2_out_range
         mul2_op.set_output_tensor(mul2_ofm)
+        DebugDatabase.add_optimised(self.op, mul2_op)
 
         # PASS 3 - Add+LUT(exp)
         add_op = Operation(Op.Add, self.op.name + "_add3")
@@ -533,6 +562,7 @@
         exp_ofm = Tensor(mul2_ofm.shape, DataType.int16, add_op.name + "_0")
         exp_ofm.quantization = mul2_ofm.quantization.clone()
         add_op.set_output_tensor(exp_ofm)
+        DebugDatabase.add_optimised(self.op, add_op)
 
         # PASS 4 - Reduce sum
         reduce_sum_op = Operation(Op.ReduceSum, self.op.name + "_reduce_sum4")
@@ -549,6 +579,7 @@
         sum_of_exp = Tensor(reduce_sum_shape, DataType.int32, reduce_sum_op.name + "_0")
         sum_of_exp.quantization = no_scale_quant
         reduce_sum_op.set_output_tensor(sum_of_exp)
+        DebugDatabase.add_optimised(self.op, reduce_sum_op)
 
         # PASS 5 - CLZ
         clz_op = Operation(Op.CLZ, self.op.name + "_clz5")
@@ -556,6 +587,7 @@
         headroom_plus_one = Tensor(reduce_sum_shape, DataType.int32, clz_op.name + "_0")
         headroom_plus_one.quantization = no_scale_quant
         clz_op.set_output_tensor(headroom_plus_one)
+        DebugDatabase.add_optimised(self.op, clz_op)
 
         # PASS 6 - Sub
         sub6_op = Operation(Op.Sub, self.op.name + "_sub6")
@@ -568,6 +600,7 @@
         reciprocal_right_shift = Tensor(reduce_sum_shape, DataType.int32, sub6_op.name + "_0")
         reciprocal_right_shift.quantization = no_scale_quant
         sub6_op.set_output_tensor(reciprocal_right_shift)
+        DebugDatabase.add_optimised(self.op, sub6_op)
 
         # PASS 7 - SHL
         shl7_op = Operation(Op.SHL, self.op.name + "_shl7")
@@ -580,6 +613,7 @@
         constant_one = Tensor(reduce_sum_shape, DataType.int32, shl7_op.name + "_0")
         constant_one.quantization = no_scale_quant
         shl7_op.set_output_tensor(constant_one)
+        DebugDatabase.add_optimised(self.op, shl7_op)
 
         # PASS 8 - Sub
         sub8_op = Operation(Op.Sub, self.op.name + "_sub8")
@@ -588,6 +622,7 @@
         sum_of_exps_minus_one = Tensor(reduce_sum_shape, DataType.int32, sub8_op.name + "_0")
         sum_of_exps_minus_one.quantization = no_scale_quant
         sub8_op.set_output_tensor(sum_of_exps_minus_one)
+        DebugDatabase.add_optimised(self.op, sub8_op)
 
         # PASS 9 - SHL
         shl9_op = Operation(Op.SHL, self.op.name + "_shl9")
@@ -596,6 +631,7 @@
         shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, shl9_op.name + "_0")
         shifted_sum_minus_one.quantization = no_scale_quant
         shl9_op.set_output_tensor(shifted_sum_minus_one)
+        DebugDatabase.add_optimised(self.op, shl9_op)
 
         # PASS 10 - SHR
         shr10_op = Operation(Op.SHR, self.op.name + "_shr10")
@@ -608,6 +644,7 @@
         shifted_sum_minus_one_16 = Tensor(reduce_sum_shape, DataType.int32, shr10_op.name + "_0")
         shifted_sum_minus_one_16.quantization = shifted_sum_minus_one.quantization.clone()
         shr10_op.set_output_tensor(shifted_sum_minus_one_16)
+        DebugDatabase.add_optimised(self.op, shr10_op)
 
         # PASS 11 - Sub+LUT(one over one plus x)
         sub11_op = Operation(Op.Sub, self.op.name + "_sub11")
@@ -630,6 +667,7 @@
         reciprocal_scale = Tensor(reduce_sum_shape, DataType.int16, sub11_op.name + "_0")
         reciprocal_scale.quantization = no_scale_quant
         sub11_op.set_output_tensor(reciprocal_scale)
+        DebugDatabase.add_optimised(self.op, sub11_op)
 
         # PASS 12 - Multiply
         mul_op = Operation(Op.Mul, self.op.name + "_mul12")
@@ -638,11 +676,13 @@
         mul_ofm = Tensor(exp_ofm.shape, DataType.int32, mul_op.name + "_0")
         mul_ofm.quantization = no_scale_quant
         mul_op.set_output_tensor(mul_ofm)
+        DebugDatabase.add_optimised(self.op, mul_op)
 
         # PASS 13 - SHR
         shr13_op = Operation(Op.SHR, self.op.name + "_shr13")
         shr13_op.add_input_tensor(mul_ofm)
         shr13_op.add_input_tensor(reciprocal_right_shift)
         shr13_op.set_output_tensor(ofm)
+        DebugDatabase.add_optimised(self.op, shr13_op)
 
         return shr13_op
diff --git a/ethosu/vela/vela.py b/ethosu/vela/vela.py
index 4b43751..5df20d2 100644
--- a/ethosu/vela/vela.py
+++ b/ethosu/vela/vela.py
@@ -31,6 +31,7 @@
 from . import stats_writer
 from . import tflite_writer
 from ._version import __version__
+from .debug_database import DebugDatabase
 from .errors import InputFileError
 from .nn_graph import PassPlacement
 from .nn_graph import TensorAllocator
@@ -39,14 +40,18 @@
 from .tensor import Tensor
 
 
-def process(fname, arch, model_reader_options, compiler_options, scheduler_options):
+def process(input_name, enable_debug_db, arch, model_reader_options, compiler_options, scheduler_options):
     if compiler_options.timing:
         start = time.time()
 
-    nng = model_reader.read_model(fname, model_reader_options)
+    os.makedirs(compiler_options.output_dir, exist_ok=True)
+    output_basename = os.path.join(compiler_options.output_dir, os.path.splitext(os.path.basename(input_name))[0])
+    DebugDatabase.show_warnings = enable_debug_db
+
+    nng = model_reader.read_model(input_name, model_reader_options)
 
     if not nng:
-        raise InputFileError(fname, "input file could not be read")
+        raise InputFileError(input_name, "input file could not be read")
 
     if compiler_options.verbose_operators:
         nng.print_operators()
@@ -58,16 +63,21 @@
 
     compiler_driver.compiler_driver(nng, arch, compiler_options, scheduler_options)
 
-    passes_csv_file = "%s/%s_pass-breakdown_%s.csv" % (compiler_options.output_dir, nng.name, arch.system_config)
+    passes_csv_file = "{0}_pass-breakdown_{1}.csv".format(output_basename, arch.system_config)
     stats_writer.write_pass_metrics_csv(nng, passes_csv_file)
 
-    summary_csv_file = "%s/%s_summary_%s.csv" % (compiler_options.output_dir, nng.name, arch.system_config)
+    summary_csv_file = "{0}_summary_{1}.csv".format(output_basename, arch.system_config)
     stats_writer.write_summary_metrics_csv(nng, summary_csv_file, arch)
 
     stats_writer.print_performance_metrics(nng, show_cpu_operations=compiler_options.show_cpu_operations, arch=arch)
 
-    if fname.endswith(".tflite"):
-        tflite_writer.write_tflite(nng, "%s/%s_vela.tflite" % (compiler_options.output_dir, nng.name))
+    output_filename = output_basename + "_vela.tflite"
+    if input_name.endswith(".tflite"):
+        tflite_writer.write_tflite(nng, output_filename)
+
+    if enable_debug_db:
+        debug_filename = output_basename + "_debug.xml"
+        DebugDatabase.write(debug_filename, input_name, output_filename)
 
     if compiler_options.timing:
         stop = time.time()
@@ -123,6 +133,13 @@
     parser.add_argument(
         "--output-dir", type=str, default="output", help="Output directory to write files to (default: %(default)s)"
     )
+    parser.add_argument(
+        "--enable-debug-db",
+        action="store_true",
+        default=None,
+        help="Enables the calculation and writing of a network debug database to output directory",
+    )
+
     parser.add_argument("--config", type=str, help="Location of vela configuration file")
 
     parser.add_argument("--verbose-graph", action="store_true", help="Verbose graph rewriter")
@@ -319,9 +336,7 @@
 
     model_reader_options = model_reader.ModelReaderOptions()
 
-    os.makedirs(args.output_dir, exist_ok=True)
-
-    nng = process(args.network, arch, model_reader_options, compiler_options, scheduler_options)
+    nng = process(args.network, args.enable_debug_db, arch, model_reader_options, compiler_options, scheduler_options)
 
     if args.show_subgraph_io_summary:
         print_subgraph_io_summary(nng)
diff --git a/setup.py b/setup.py
index 07ab2d1..cc30636 100644
--- a/setup.py
+++ b/setup.py
@@ -56,7 +56,7 @@
     keywords=["ethos-u", "vela compiler", "tflite", "npu"],
     packages=find_namespace_packages(include=["ethosu.*"]),
     python_requires="~=3.6",  # We support only 3.6+
-    install_requires=["flatbuffers==1.11.0", "numpy>=1.16.6"],
+    install_requires=["flatbuffers==1.11.0", "numpy>=1.16.6", "lxml>=4.6.1"],
     entry_points={"console_scripts": ["vela = ethosu.vela.vela:main"]},
     ext_modules=[mlw_module],
     setup_requires=["setuptools_scm"],