MLBEDSW-2067: added custom exceptions

Added custom exceptions to handle different types of input errors.

Also performed minor formatting changes using flake8/black.

Change-Id: Ie5b05361507d5e569aff045757aec0a4a755ae98
Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
diff --git a/ethosu/vela/architecture_features.py b/ethosu/vela/architecture_features.py
index c712588..1bf9d95 100644
--- a/ethosu/vela/architecture_features.py
+++ b/ethosu/vela/architecture_features.py
@@ -21,6 +21,7 @@
 
 import numpy as np
 
+from .errors import OptionError
 from .numeric_util import round_up
 from .numeric_util import round_up_divide
 from .operation import NpuBlockType
@@ -158,7 +159,7 @@
         self.vela_config = vela_config
         self.accelerator_config = accelerator_config
         if self.accelerator_config not in ArchitectureFeatures.accelerator_configs:
-            raise Exception("Unknown accelerator configuration " + self.accelerator_config)
+            raise OptionError("--accelerator-config", self.accelerator_config, "Unknown accelerator configuration")
         accel_config = ArchitectureFeatures.accelerator_configs[self.accelerator_config]
         self.config = accel_config
 
@@ -564,7 +565,7 @@
         else:
             section_key = "SysConfig." + self.system_config
             if section_key not in self.vela_config:
-                raise Exception("Unknown system configuration " + self.system_config)
+                raise OptionError("--system-config", self.system_config, "Unknown system configuration")
 
         try:
             self.npu_clock = float(self.__sys_config("npu_freq", "500e6"))
diff --git a/ethosu/vela/errors.py b/ethosu/vela/errors.py
new file mode 100644
index 0000000..efe64d5
--- /dev/null
+++ b/ethosu/vela/errors.py
@@ -0,0 +1,48 @@
+# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Description:
+# Defines custom exceptions.
+
+
+class VelaError(Exception):
+    """Base class for vela exceptions"""
+
+    def __init__(self, data):
+        self.data = data
+
+    def __str__(self):
+        return repr(self.data)
+
+
+class InputFileError(VelaError):
+    """Raised when reading the input file results in errors"""
+
+    def __init__(self, file_name, msg):
+        self.data = "Error reading {}: {}".format(file_name, msg)
+
+
+class UnsupportedFeatureError(VelaError):
+    """Raised when the input file uses non-supported features that cannot be handled"""
+
+    def __init__(self, data):
+        self.data = "The input file uses a feature that is currently not supported: {}".format(data)
+
+
+class OptionError(VelaError):
+    """Raised when an incorrect command line option is used"""
+
+    def __init__(self, option, option_value, msg):
+        self.data = "Incorrect argument: {} {}: {}".format(option, option_value, msg)
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 351716e..72bb486 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -22,6 +22,7 @@
 
 from . import rewrite_graph
 from .data_type import DataType
+from .errors import UnsupportedFeatureError
 from .operation import NpuBlockType
 from .operation import Operation
 from .tensor import Tensor
@@ -124,7 +125,7 @@
         top_pad = 0
         bottom_pad = 0
     else:
-        assert 0, "Unknown padding"
+        raise UnsupportedFeatureError("Unknown padding {}".format(str(padding_type)))
     padding = (top_pad, left_pad, bottom_pad, right_pad)
     skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
     return padding, skirt
@@ -214,7 +215,7 @@
         if op.type == "StridedSlice":
             new_axis_mask = op.attrs["new_axis_mask"]
             shrink_axis_mask = op.attrs["shrink_axis_mask"]
-            ellipsis_mask =  op.attrs["ellipsis_mask"]
+            ellipsis_mask = op.attrs["ellipsis_mask"]
 
             if (new_axis_mask != 0 and shrink_axis_mask != 0) or ellipsis_mask != 0:
                 # Not supported, will be put on CPU
@@ -243,7 +244,7 @@
                     n += 1
                     new_axis_mask &= new_axis_mask - 1
                     axis = int(math.log2(prev_mask - new_axis_mask))
-                    reshape_input_shape = reshape_input_shape[:axis] + reshape_input_shape[(axis + 1):]
+                    reshape_input_shape = reshape_input_shape[:axis] + reshape_input_shape[(axis + 1) :]
                     new_axis_mask >>= 1
 
                 assert len(tens.shape) == (len(op.inputs[0].shape) + n)
@@ -288,7 +289,7 @@
             kernel_size = op.attrs["ksizes"][1:3]
             input_shape = op.inputs[0].shape
         else:
-            assert 0, "Unknown operation that uses padding"
+            raise UnsupportedFeatureError("Unknown operation that uses padding: {}".format(op.type))
 
         padding, skirt = calc_padding_and_skirt(op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape)
         op.attrs["explicit_padding"] = padding
@@ -312,7 +313,9 @@
     )
 )
 depthwise_op = set(("DepthwiseConv2dNative", "DepthwiseConv2dBiasAct",))
-pool_op = set(("AvgPool", "MaxPool", "QuantizedAvgPool", "QuantizedMaxPool", "AvgPoolAct", "MaxPoolAct", "ResizeBilinear",))
+pool_op = set(
+    ("AvgPool", "MaxPool", "QuantizedAvgPool", "QuantizedMaxPool", "AvgPoolAct", "MaxPoolAct", "ResizeBilinear",)
+)
 elementwise_op = set(("AddAct", "MulAct", "SubAct", "Maximum", "Minimum", "LeakyRelu", "Abs"))
 binary_elementwise_op = set(("AddAct", "MulAct", "SubAct", "Maximum", "Minimum"))
 activation_ops = set(("Relu", "Relu6", "ReluN1To1", "Sigmoid", "Tanh"))
@@ -373,13 +376,11 @@
                 weight_tensor.quant_values.shape
             )
         else:
-            print(
-                "Error: Unsupported DepthwiseConv2d with depth_multiplier = {0}, "
-                "ifm channels = {1}, ofm channels = {2}".format(
+            raise UnsupportedFeatureError(
+                "Unsupported DepthwiseConv2d with depth_multiplier = {}, ifm channels = {}, ofm channels = {}".format(
                     op.attrs["depth_multiplier"], ifm_tensor.shape[3], ofm_tensor.shape[3]
                 )
             )
-            assert False
     return op
 
 
diff --git a/ethosu/vela/model_reader.py b/ethosu/vela/model_reader.py
index d1cdc9b..6deb253 100644
--- a/ethosu/vela/model_reader.py
+++ b/ethosu/vela/model_reader.py
@@ -15,6 +15,9 @@
 # limitations under the License.
 # Description:
 # Dispatcher for reading a neural network model.
+from . import tflite_reader
+from .errors import InputFileError
+from .errors import VelaError
 
 
 class ModelReaderOptions:
@@ -29,15 +32,17 @@
 
 def read_model(fname, options, feed_dict={}, output_node_names=[], initialisation_nodes=[]):
     if fname.endswith(".tflite"):
-        from . import tflite_reader
-
-        nng = tflite_reader.read_tflite(
-            fname,
-            options.batch_size,
-            feed_dict=feed_dict,
-            output_node_names=output_node_names,
-            initialisation_nodes=initialisation_nodes,
-        )
+        try:
+            return tflite_reader.read_tflite(
+                fname,
+                options.batch_size,
+                feed_dict=feed_dict,
+                output_node_names=output_node_names,
+                initialisation_nodes=initialisation_nodes,
+            )
+        except VelaError as e:
+            raise e
+        except Exception as e:
+            raise InputFileError(fname, str(e))
     else:
-        assert 0, "Unknown model format"
-    return nng
+        raise InputFileError(fname, "Unknown input file format. Only .tflite files are supported")
diff --git a/ethosu/vela/shared_buffer_allocation.py b/ethosu/vela/shared_buffer_allocation.py
index 335b863..2bfe594 100644
--- a/ethosu/vela/shared_buffer_allocation.py
+++ b/ethosu/vela/shared_buffer_allocation.py
@@ -22,6 +22,7 @@
 from .architecture_features import Kernel
 from .architecture_features import SharedBufferArea
 from .architecture_features import SHRAMElements
+from .errors import OptionError
 from .operation import NpuBlockType
 
 
@@ -163,8 +164,11 @@
 
     if arch.override_block_config:
         config = alloc.try_block(arch.override_block_config)
-        assert config, "Block config override cannot be used"
-        return [config]
+        raise OptionError(
+            "--force-block-config",
+            str(arch.override_block_config),
+            "This forced block config value cannot be used; it is not compatible",
+        )
 
     # Constrain the search space if the OFM is smaller than the max block size
     # - Add other block search constraints here if required
diff --git a/ethosu/vela/test/test_model_reader.py b/ethosu/vela/test/test_model_reader.py
new file mode 100644
index 0000000..ee9a51e
--- /dev/null
+++ b/ethosu/vela/test/test_model_reader.py
@@ -0,0 +1,40 @@
+# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Description:
+# Unit tests for model_reader.
+import pytest
+from ethosu.vela import model_reader
+from ethosu.vela.errors import InputFileError
+
+
+def test_read_model_incorrect_extension(tmpdir):
+    # Tests read_model with a file name that does not end with .tflite
+    with pytest.raises(InputFileError):
+        model_reader.read_model("no_tflite_file.txt", model_reader.ModelReaderOptions())
+
+
+def test_read_model_corrupt_contents(tmpdir):
+    # Tests read_model with a corrupt .tflite file
+    fname = tmpdir.join("corrupt.tflite")
+    fname.write("abcde1234")
+    with pytest.raises(InputFileError):
+        model_reader.read_model(fname.strpath, model_reader.ModelReaderOptions())
+
+
+def test_read_model_file_not_found(tmpdir):
+    # Tests read_model with a .tflite file that does not exist
+    with pytest.raises(InputFileError):
+        model_reader.read_model("non_existing.tflite", model_reader.ModelReaderOptions())
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index 7e158aa..850690f 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -19,6 +19,7 @@
 
 import numpy as np
 
+from .errors import UnsupportedFeatureError
 from .nn_graph import Graph
 from .nn_graph import Subgraph
 from .operation import Operation
@@ -147,18 +148,17 @@
             if op_type.startswith("ResizeBilinear"):
                 upscaled_shape = [op.inputs[0].shape[1] * 2, op.inputs[0].shape[2] * 2]
                 out_shape = op.outputs[0].shape[1:3]
-                if not op.attrs['align_corners'] and out_shape == upscaled_shape:
+                if not op.attrs["align_corners"] and out_shape == upscaled_shape:
                     # this means the output is supposed to be a x2 upscale,
                     # so we need to do SAME padding
-                    op.attrs.update({'padding': b'SAME'})
-                elif (op.attrs['align_corners']
-                    and out_shape == [upscaled_shape[0] - 1, upscaled_shape[1] - 1]):
+                    op.attrs.update({"padding": b"SAME"})
+                elif op.attrs["align_corners"] and out_shape == [upscaled_shape[0] - 1, upscaled_shape[1] - 1]:
                     # here we can just run the avg pool without padding and
                     # produce a (M * 2 - 1, N * 2 - 1) sized output
-                    op.attrs.update({'padding': b'VALID'})
+                    op.attrs.update({"padding": b"VALID"})
                 else:
-                    assert False, "Only 2x upscaling is supported"
-                op.attrs.update({'filter_width': 2, 'filter_height': 2, 'stride_w': 1, 'stride_h': 1,})
+                    raise UnsupportedFeatureError("ResizeBilinear: Only 2x upscaling is supported")
+                op.attrs.update({"filter_width": 2, "filter_height": 2, "stride_w": 1, "stride_h": 1})
 
             if "stride_w" in op.attrs:
                 op.attrs["strides"] = (1, op.attrs["stride_h"], op.attrs["stride_w"], 1)
diff --git a/ethosu/vela/vela.py b/ethosu/vela/vela.py
index 49f8c26..bd5409c 100644
--- a/ethosu/vela/vela.py
+++ b/ethosu/vela/vela.py
@@ -31,6 +31,7 @@
 from . import stats_writer
 from . import tflite_writer
 from ._version import __version__
+from .errors import InputFileError
 from .nn_graph import PassPlacement
 from .nn_graph import TensorAllocator
 from .scheduler import ParetoMetric
@@ -44,8 +45,7 @@
     nng = model_reader.read_model(fname, model_reader_options)
 
     if not nng:
-        print("reading of", fname, "failed")
-        assert False
+        raise InputFileError(fname, "input file could not be read")
 
     if compiler_options.verbose_operators:
         nng.print_operators()
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index 04d684e..a81b1fb 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -23,6 +23,7 @@
 
 from .architecture_features import Block
 from .data_type import DataType
+from .errors import UnsupportedFeatureError
 from .nn_graph import SchedulingStrategy
 from .numeric_util import round_up
 from .operation import NpuBlockType
@@ -292,14 +293,18 @@
                 for weight_scale in weight_scales
             ]
         else:
-            assert False, str(ifm_dtype) + " not implemented"
+            raise UnsupportedFeatureError(
+                "Compression of {} is not implemented; tensor: {}".format(ifm_dtype, tens.name)
+            )
     else:
         if ifm_dtype == DataType.uint8:
             scales = [np.double(ifm_scale * weight_scale * 0x3000) for weight_scale in weight_scales]
         elif ifm_dtype == DataType.int8 or ifm_dtype == DataType.int16:
             scales = [(np.double(ifm_scale * 0x3000) * np.double(weight_scale)) for weight_scale in weight_scales]
         else:
-            assert False, str(ifm_dtype) + " not implemented"
+            raise UnsupportedFeatureError(
+                "Compression of {} is not implemented; tensor: {}".format(ifm_dtype, tens.name)
+            )
 
     # quantise all of the weight scales into (scale_factor, shift)
     if ifm_dtype == DataType.int16: