MLBEDSW-2528: MLCE-219: Custom operator pass through

 - Fixed custom operator pass through
 - Added error printing functions for operators and tensor
 - Minor cleanup of custom exception handling

Signed-off-by: Tim Hall <tim.hall@arm.com>
Change-Id: Idf295df1e4c544381dc480244d880c32fb285e38
diff --git a/ethosu/vela/errors.py b/ethosu/vela/errors.py
index efe64d5..2c93fbc 100644
--- a/ethosu/vela/errors.py
+++ b/ethosu/vela/errors.py
@@ -15,6 +15,10 @@
 # limitations under the License.
 # Description:
 # Defines custom exceptions.
+import sys
+
+from .operation import Operation
+from .tensor import Tensor
 
 
 class VelaError(Exception):
@@ -31,7 +35,7 @@
     """Raised when reading the input file results in errors"""
 
     def __init__(self, file_name, msg):
-        self.data = "Error reading {}: {}".format(file_name, msg)
+        self.data = "Error reading input file {}: {}".format(file_name, msg)
 
 
 class UnsupportedFeatureError(VelaError):
@@ -45,4 +49,75 @@
     """Raised when an incorrect command line option is used"""
 
     def __init__(self, option, option_value, msg):
-        self.data = "Incorrect argument: {} {}: {}".format(option, option_value, msg)
+        self.data = "Incorrect argument to CLI option: {} {}: {}".format(option, option_value, msg)
+
+
+def OperatorError(op, msg):
+    """Called when parsing an operator results in errors"""
+
+    assert isinstance(op, Operation)
+
+    if op.op_index is None:
+        data = "Invalid {} (name = {}) operator in the internal representation.".format(op.type, op.name)
+    else:
+        data = "Invalid {} (op_index = {}) operator in the input network.".format(op.type, op.op_index)
+
+    data += " {}\n".format(msg)
+
+    data += "   Input tensors:\n"
+    for idx, tens in enumerate(op.inputs):
+        if isinstance(tens, Tensor):
+            tens_name = tens.name
+        else:
+            tens_name = "Not a Tensor"
+
+        data += "      {} = {}\n".format(idx, tens_name)
+
+    data += "   Output tensors:\n"
+    for idx, tens in enumerate(op.outputs):
+        if isinstance(tens, Tensor):
+            tens_name = tens.name
+        else:
+            tens_name = "Not a Tensor"
+
+        data += "      {} = {}\n".format(idx, tens_name)
+
+    data = data[:-1]  # remove last newline
+
+    print("Error: {}".format(data))
+    sys.exit(1)
+
+
+def TensorError(tens, msg):
+    """Called when parsing a tensor results in errors"""
+
+    assert isinstance(tens, Tensor)
+
+    data = "Invalid {} tensor. {}\n".format(tens.name, msg)
+
+    data += "   Driving operators:\n"
+    for idx, op in enumerate(tens.ops):
+        if isinstance(op, Operation):
+            op_type = op.type
+            op_id = op.op_index
+        else:
+            op_type = "Not an Operation"
+            op_id = ""
+
+        data += "      {} = {} ({})\n".format(idx, op_type, op_id)
+
+    data += "   Consuming operators:\n"
+    for idx, op in enumerate(tens.consumer_list):
+        if isinstance(op, Operation):
+            op_type = op.type
+            op_id = op.op_index
+        else:
+            op_type = "Not an Operation"
+            op_id = ""
+
+        data += "      {} = {} ({})\n".format(idx, op_type, op_id)
+
+    data = data[:-1]  # remove last newline
+
+    print("Error: {}".format(data))
+    sys.exit(1)
diff --git a/ethosu/vela/mark_tensors.py b/ethosu/vela/mark_tensors.py
index 72ab8cf..c4f2bae 100644
--- a/ethosu/vela/mark_tensors.py
+++ b/ethosu/vela/mark_tensors.py
@@ -17,8 +17,10 @@
 # Mark purpose and select formats for Tensors. Also compresses the weights.
 from . import rewrite_graph
 from . import weight_compressor
+from .errors import OperatorError
 from .tensor import TensorFormat
 from .tensor import TensorPurpose
+from .tflite_mapping import custom_prefix
 
 
 def purpose_from_list(lst):
@@ -268,18 +270,33 @@
             if ops is None or op.type in ops:
                 if ops is None:
                     print(
-                        "warning: don't know how to mark up purpose for",
+                        "Warning: Don't know how to mark up purpose for",
                         op.type,
                         op.inputs,
                         "triggering all feature map fallback",
                     )
+
                 for idx, tens in enumerate(op.inputs):
                     purpose = input_purpose(op, idx)
                     mark_tensor_helper(tens, purpose)
+
                 if op.type == "Reshape":
                     # Reshape's input and output point to same data
                     op.outputs[0].mem_area = op.inputs[0].mem_area
+
+                if op.type.startswith(custom_prefix) and op.attrs.get("custom_type", "") == "ExistingNpuOp":
+                    scratch_tensor = None
+
+                    if len(op.inputs) >= 3:
+                        scratch_tensor = op.inputs[2]  # should be existing scratch tensor
+                        if scratch_tensor.name.endswith("_scratch"):
+                            scratch_tensor.purpose = TensorPurpose.Scratch
+
+                    if scratch_tensor is None:
+                        raise OperatorError(op, "Scratch tensor not found.")
+
                 break
+
         return op
 
     for sg in nng.subgraphs:
@@ -316,6 +333,8 @@
             fmt = arch.default_feature_map_format
         elif tens.purpose == TensorPurpose.Weights:
             fmt = arch.default_weight_format
+        elif tens.purpose == TensorPurpose.Scratch:
+            fmt = arch.default_feature_map_format
         elif tens.purpose == TensorPurpose.Unknown:
             fmt = TensorFormat.Unknown
         else:
diff --git a/ethosu/vela/model_reader.py b/ethosu/vela/model_reader.py
index 6deb253..0f79f9b 100644
--- a/ethosu/vela/model_reader.py
+++ b/ethosu/vela/model_reader.py
@@ -17,7 +17,6 @@
 # Dispatcher for reading a neural network model.
 from . import tflite_reader
 from .errors import InputFileError
-from .errors import VelaError
 
 
 class ModelReaderOptions:
@@ -32,17 +31,12 @@
 
 def read_model(fname, options, feed_dict={}, output_node_names=[], initialisation_nodes=[]):
     if fname.endswith(".tflite"):
-        try:
-            return tflite_reader.read_tflite(
-                fname,
-                options.batch_size,
-                feed_dict=feed_dict,
-                output_node_names=output_node_names,
-                initialisation_nodes=initialisation_nodes,
-            )
-        except VelaError as e:
-            raise e
-        except Exception as e:
-            raise InputFileError(fname, str(e))
+        return tflite_reader.read_tflite(
+            fname,
+            options.batch_size,
+            feed_dict=feed_dict,
+            output_node_names=output_node_names,
+            initialisation_nodes=initialisation_nodes,
+        )
     else:
-        raise InputFileError(fname, "Unknown input file format. Only .tflite files are supported")
+        raise InputFileError(fname, "Unsupported file extension. Only .tflite files are supported")
diff --git a/ethosu/vela/npu_serialisation.py b/ethosu/vela/npu_serialisation.py
index 08dc0d3..18d38f3 100644
--- a/ethosu/vela/npu_serialisation.py
+++ b/ethosu/vela/npu_serialisation.py
@@ -141,7 +141,7 @@
             for op in ps.ops:
                 if op.type == "NpuOp":
                     callee = op.attrs["subgraph"]
-                    op.attrs["custom_options"] = {"type": op.type}
+                    op.attrs["custom_type"] = op.type
 
                     sz = 0
                     for tens in [callee.scratch_tensor, callee.flash_tensor, callee.command_stream_tensor]:
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 51311ef..448d838 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -31,7 +31,7 @@
     """Class representing a Neural Network operation. Has a name, a type,
 input and output tensors, as well as an attribute dictionary."""
 
-    __slots__ = "type", "name", "attrs", "inputs", "outputs", "flops", "scheduled_pass", "run_on_npu"
+    __slots__ = "type", "name", "op_index", "attrs", "inputs", "outputs", "flops", "scheduled_pass", "run_on_npu"
 
     def __init__(self, op_type, name):
         self.type = op_type
@@ -42,6 +42,7 @@
         self.flops = 0
         self.run_on_npu = True
         self.scheduled_pass = None
+        self.op_index = None  # input network operator index
 
     def clone(self, suffix="_clone"):
         res = Operation(self.type, self.name + suffix)
@@ -51,6 +52,7 @@
         res.outputs = list(self.outputs)
         res.flops = self.flops
         res.scheduled_pass = self.scheduled_pass
+        res.op_index = None  # not relevant as not part of input network
 
         return res
 
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 426a710..42d9526 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -226,7 +226,6 @@
         "weight_compressed_offsets",
         "element_size_bytes",
         "block_traversal",
-        "offset",
         "cpu_tensor",
         "npu_tensor",
         "equivalence_id",
diff --git a/ethosu/vela/test/test_model_reader.py b/ethosu/vela/test/test_model_reader.py
index ee9a51e..23e7e90 100644
--- a/ethosu/vela/test/test_model_reader.py
+++ b/ethosu/vela/test/test_model_reader.py
@@ -26,15 +26,7 @@
         model_reader.read_model("no_tflite_file.txt", model_reader.ModelReaderOptions())
 
 
-def test_read_model_corrupt_contents(tmpdir):
-    # Tests read_model with a corrupt .tflite file
-    fname = tmpdir.join("corrupt.tflite")
-    fname.write("abcde1234")
-    with pytest.raises(InputFileError):
-        model_reader.read_model(fname.strpath, model_reader.ModelReaderOptions())
-
-
 def test_read_model_file_not_found(tmpdir):
     # Tests read_model with a .tflite file that does not exist
-    with pytest.raises(InputFileError):
+    with pytest.raises(FileNotFoundError):
         model_reader.read_model("non_existing.tflite", model_reader.ModelReaderOptions())
diff --git a/ethosu/vela/tflite_mapping.py b/ethosu/vela/tflite_mapping.py
index d077768..7952168 100644
--- a/ethosu/vela/tflite_mapping.py
+++ b/ethosu/vela/tflite_mapping.py
@@ -328,7 +328,6 @@
         self.module = globals()[self.name]
         self.cls = getattr(self.module, self.name)
         self.builtin_opt_type = builtin_options_inv_map[self.cls]
-        self.custom_opt_format = 0
         self.members = []
         for mem in members:
             deserialize = identity
@@ -347,11 +346,12 @@
             camelcase_mem = underscore_to_camel_case(mem)
             self.members.append((underscore_mem, camelcase_mem, deserialize, serialize, is_vector))
 
-    def deserialize(self, builtin_data, custom_data):
+    def deserialize(self, op_data):
+        builtin_options = op_data.BuiltinOptions()
         attrs = {}
-        if builtin_data:
+        if builtin_options:
             tfattrs = self.cls()
-            tfattrs.Init(builtin_data.Bytes, builtin_data.Pos)
+            tfattrs.Init(builtin_options.Bytes, builtin_options.Pos)
             for underscore_mem, camelcase_mem, deserialize, serialize, is_vector in self.members:
                 fun = camelcase_mem
                 if is_vector:
@@ -376,26 +376,35 @@
 
 
 class CustomOptionsSerializer:
+    CUSTOM_OPTIONS_NPU_OP = [0x01, 0x04, 0x01]  # NpuOp=1, FlexbufferFormat.UINT8=4, byte length=1
+    CUSTOM_OPTIONS_FORMAT_DEFAULT = 0
+
     def __init__(self):
-        self.builtin_opt_type = 0
         self.custom_opt_format = 0
 
-    def deserialize(self, builtin_data, custom_data):
+    def deserialize(self, op_data):
         attrs = {}
-        attrs["custom_options"] = custom_data
+        custom_options = op_data.CustomOptionsAsNumpy()
+        attrs["custom_options"] = custom_options
+        attrs["custom_options_format"] = op_data.CustomOptionsFormat()
+
+        if np.array_equal(custom_options, self.CUSTOM_OPTIONS_NPU_OP):
+            attrs["custom_type"] = "ExistingNpuOp"
+
         return attrs
 
     def serialize(self, builder, attrs):
-
-        custom_opts = attrs.get("custom_options", [])
-        custom_data = []
+        custom_type = attrs.get("custom_type", "")
+        self.custom_opt_format = attrs.get("custom_options_format", self.CUSTOM_OPTIONS_FORMAT_DEFAULT)
 
         # Set NPU op custom options for the TensorFlow Lite custom operator
-        if custom_opts["type"] == "NpuOp":
-            custom_data = [0x01, 0x04, 0x01]  # NpuOp=1, FlexbufferFormat.UINT8=4, byte length=1
+        if custom_type == "NpuOp":
+            custom_options = self.CUSTOM_OPTIONS_NPU_OP
+        else:
+            custom_options = attrs.get("custom_options", [])
 
-        custom_data_bytes = struct.pack("<{0}B".format(len(custom_data)), *custom_data)
-        custom_offset = write_byte_vector(builder, custom_data_bytes)
+        custom_options_bytes = struct.pack("<{0}B".format(len(custom_options)), *custom_options)
+        custom_offset = write_byte_vector(builder, custom_options_bytes)
 
         return None, custom_offset
 
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index 5667aff..9d312e5 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -20,6 +20,7 @@
 import numpy as np
 
 from .errors import InputFileError
+from .errors import TensorError
 from .nn_graph import Graph
 from .nn_graph import Subgraph
 from .operation import Operation
@@ -69,14 +70,16 @@
             self.tensors.append(self.parse_tensor(subgraph.Tensors(idx)))
 
         for idx in range(subgraph.OperatorsLength()):
-            self.parse_operator(subgraph.Operators(idx))
+            self.parse_operator(idx, subgraph.Operators(idx))
 
-        self.outputs = [self.tensors[idx] for idx in subgraph.OutputsAsNumpy()]
-        self.inputs = [self.tensors[idx] for idx in subgraph.InputsAsNumpy()]
+        self.outputs = self.get_tensors_from_indices_remove_duplicates(subgraph.OutputsAsNumpy(), "output")
+        self.inputs = self.get_tensors_from_indices_remove_duplicates(subgraph.InputsAsNumpy(), "input")
 
         # Fix up tensors without operations. Generate either Placeholder or Constant ops
         for tens in self.inputs:
-            assert not tens.ops
+            if tens.ops != []:
+                TensorError(tens, "This subgraph input tensor has unexpected driving operators.")
+
             op = Operation("Placeholder", tens.name)
             op.outputs = [tens]
             tens.ops = [op]
@@ -87,6 +90,21 @@
                 op.outputs = [tens]
                 tens.ops = [op]
 
+    def get_tensors_from_indices_remove_duplicates(self, indices, warning_str):
+        tensors = []
+        for idx in indices:
+            tensor = self.tensors[idx]
+            if tensor not in tensors:
+                tensors.append(tensor)
+            else:
+                print(
+                    "Warning: Subgraph {0} tensor ({1}) with idx = {2} already seen. Removing the duplicate.".format(
+                        warning_str, tensor, idx
+                    )
+                )
+
+        return tensors
+
     def parse_tensor(self, tens_data):
         np_shape = tens_data.ShapeAsNumpy()
         shape = list(np_shape) if type(np_shape) is np.ndarray else []
@@ -121,7 +139,7 @@
                 tens.values = tens.quantization.dequantize(tens.quant_values)
         return tens
 
-    def parse_operator(self, op_data):
+    def parse_operator(self, op_index, op_data):
         op_type, opt_serializer = self.graph.operator_codes[op_data.OpcodeIndex()]
         inputs = [self.tensors[idx] for idx in op_data.InputsAsNumpy()]
         outputs = [self.tensors[idx] for idx in op_data.OutputsAsNumpy()]
@@ -129,6 +147,7 @@
         if len(outputs):
             name = outputs[0].name
         op = Operation(op_type, name)
+        op.op_index = op_index
         op.inputs = inputs
         op.outputs = outputs
         for out in op.outputs:
@@ -143,7 +162,7 @@
             inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 0))
 
         if opt_serializer is not None:
-            op.attrs = opt_serializer.deserialize(op_data.BuiltinOptions(), op_data.CustomOptionsAsNumpy())
+            op.attrs = opt_serializer.deserialize(op_data)
 
             if "stride_w" in op.attrs:
                 op.attrs["strides"] = (1, op.attrs["stride_h"], op.attrs["stride_w"], 1)
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py
index 675b698..8db3e5b 100644
--- a/ethosu/vela/tflite_writer.py
+++ b/ethosu/vela/tflite_writer.py
@@ -133,10 +133,9 @@
             builder.PrependUOffsetTRelative(e)
         return builder.EndVector(len(v))
 
-    def assign_buffers_to_tensors(self, tensors):
-        scratch_tensors = [tens for tens in tensors if tens.purpose == TensorPurpose.Scratch]
-        if len(scratch_tensors) > 0:
-            scratch_tensor_mem_area = scratch_tensors[0].mem_area
+    def assign_buffers_to_tensors(self, tensors, scratch_tensor):
+        if scratch_tensor is not None:
+            scratch_tensor_mem_area = scratch_tensor.mem_area
         else:
             scratch_tensor_mem_area = None  # all tensors are initialised to MemArea.Unknown
 
@@ -150,7 +149,7 @@
                 buffer_map[tens] = buf_idx
                 buf_idx += 1
 
-        # Initialize buffers_to_write to a length equal to numer of buffers so
+        # Initialize buffers_to_write to a length equal to number of buffers so
         # they can be appended at the correct index during tensor serialization
         self.buffers_to_write = [None] * (buf_idx)
 
@@ -176,7 +175,7 @@
                 assert code == "NpuOp"  # Currently only support serialising NPU operators as a custom op
                 custom_code_offset = builder.CreateString("ethos-u")
 
-            self.operator_code_map[code] = (idx, tf_code, opt_serializer)
+        self.operator_code_map[code] = (idx, tf_code, opt_serializer)
 
         OperatorCode.OperatorCodeStart(builder)
         OperatorCode.OperatorCodeAddBuiltinCode(builder, tf_code)
@@ -311,19 +310,29 @@
 
         all_tensors = [tens for nm, idx, tens in sorted((tens.name, idx, tens) for idx, tens in enumerate(tensor_set))]
 
+        scratch_tensors = [tens for tens in all_tensors if tens.purpose == TensorPurpose.Scratch]
+
+        if len(scratch_tensors) == 0:
+            scratch_tensor = None
+        else:
+            assert len(scratch_tensors) == 1, "Multiple scratch tensors"
+            scratch_tensor = scratch_tensors[0]
+
         self.tensor_map = {tens: idx for idx, tens in enumerate(all_tensors)}
-        self.buffer_map = self.assign_buffers_to_tensors(all_tensors)
+        self.buffer_map = self.assign_buffers_to_tensors(all_tensors, scratch_tensor)
 
         tensors_offset = self.write_offset_vector([self.serialise_tensor(tens) for tens in all_tensors])
 
-        # Add the Scratch Tensor as input to the NPU subgraph to get it allocated by TensorFlow Lite Micro
-        scratch_tensor_idx = [v for k, v in self.tensor_map.items() if k.name.endswith("scratch")]
-
         # Make sure the input_tensors haven't been modified
         assert all(inp in sg.original_inputs for inp in sg.input_tensors)
-        inputs_offset = self.write_int_vector(
-            [self.tensor_map[tens] for tens in sg.original_inputs] + scratch_tensor_idx
-        )
+        inputs = [self.tensor_map[tens] for tens in sg.original_inputs]
+
+        # Add the Scratch Tensor as input to the NPU subgraph to get it allocated by TensorFlow Lite Micro
+        scratch_tensor_idx = self.tensor_map.get(scratch_tensor, None)
+        if scratch_tensor_idx is not None and scratch_tensor_idx not in inputs:
+            inputs.append(scratch_tensor_idx)
+
+        inputs_offset = self.write_int_vector(inputs)
         outputs_offset = self.write_int_vector([self.tensor_map[tens] for tens in sg.output_tensors])
 
         operators_offset = self.write_offset_vector([self.serialise_operator(op) for op in all_ops])
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index c5f4ce1..77220a9 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -381,15 +381,6 @@
 
 
 def update_pass_weight_and_scale_tensors(nng, arch):
-    def find_npu_usage_of_tensor(tens):
-        # TODO: This function is identical to the one in mark_tensors.py. A common version should be used.
-        for op in tens.consumers():
-            if op.type == "DMA":
-                return find_npu_usage_of_tensor(op.outputs[0])
-            if "npu_block_type" in op.attrs:
-                return op.attrs["npu_block_type"]
-            return NpuBlockType.Default
-
     for sg in nng.subgraphs:
         for ps in sg.passes:
             tens = ps.weight_tensor