[EXTAPI] refactor weight compression to be used by an external consumer

*lint
*added unit tests
*added typecheck
*added docstring for the api
Change-Id: Ibd4bc40d4381ac40ad2ea3d500b26c4ec565ab07
Signed-off-by: Manupa Karunaratne <manupa.karunaratne@arm.com>
diff --git a/ethosu/vela/architecture_features.py b/ethosu/vela/architecture_features.py
index 6460c52..43b3210 100644
--- a/ethosu/vela/architecture_features.py
+++ b/ethosu/vela/architecture_features.py
@@ -120,6 +120,19 @@
     Size = Accumulators + 1
 
 
+class Accelerator(enum.Enum):
+    Ethos_U55_32 = "ethos-u55-32"
+    Ethos_U55_64 = "ethos-u55-64"
+    Ethos_U55_128 = "ethos-u55-128"
+    Ethos_U55_256 = "ethos-u55-256"
+    Yoda_256 = "yoda-256"
+    Yoda_512 = "yoda-512"
+
+    @classmethod
+    def member_list(cls):
+        return [e.value for e in cls]
+
+
 class ArchitectureFeatures:
     """This class is a container for various parameters of the Ethos-U55 core
 and system configuration that can be tuned, either by command line
@@ -136,15 +149,28 @@
         "ArchitectureConfig", "macs cores ofm_ublock ifm_ublock shram_banks shram_granules elem_units"
     )
     accelerator_configs = {
-        "yoda-512": ArchitectureConfig(256, 2, Block(2, 2, 8), Block(2, 2, 8), 48, [8, 8, 8, 8, 8, 16, 20], 8),
-        "yoda-256": ArchitectureConfig(256, 1, Block(2, 2, 8), Block(2, 2, 8), 48, [8, 8, 8, 8, 8, 16, 20], 8),
-        "ethos-u55-256": ArchitectureConfig(256, 1, Block(2, 2, 8), Block(2, 2, 8), 48, [8, 8, 8, 8, 8, 16, 20], 8),
-        "ethos-u55-128": ArchitectureConfig(128, 1, Block(2, 1, 8), Block(2, 2, 8), 24, [4, 4, 4, 4, 4, 8, 12], 4),
-        "ethos-u55-64": ArchitectureConfig(64, 1, Block(1, 1, 8), Block(1, 1, 8), 16, [2, 2, 2, 2, 4, 4, 8], 2),
-        "ethos-u55-32": ArchitectureConfig(32, 1, Block(1, 1, 4), Block(1, 1, 8), 16, [2, 2, 2, 2, 4, 4, 4], 1),
+        Accelerator.Yoda_512: ArchitectureConfig(
+            256, 2, Block(2, 2, 8), Block(2, 2, 8), 48, [8, 8, 8, 8, 8, 16, 20], 8
+        ),
+        Accelerator.Yoda_256: ArchitectureConfig(
+            256, 1, Block(2, 2, 8), Block(2, 2, 8), 48, [8, 8, 8, 8, 8, 16, 20], 8
+        ),
+        Accelerator.Ethos_U55_256: ArchitectureConfig(
+            256, 1, Block(2, 2, 8), Block(2, 2, 8), 48, [8, 8, 8, 8, 8, 16, 20], 8
+        ),
+        Accelerator.Ethos_U55_128: ArchitectureConfig(
+            128, 1, Block(2, 1, 8), Block(2, 2, 8), 24, [4, 4, 4, 4, 4, 8, 12], 4
+        ),
+        Accelerator.Ethos_U55_64: ArchitectureConfig(
+            64, 1, Block(1, 1, 8), Block(1, 1, 8), 16, [2, 2, 2, 2, 4, 4, 8], 2
+        ),
+        Accelerator.Ethos_U55_32: ArchitectureConfig(
+            32, 1, Block(1, 1, 4), Block(1, 1, 8), 16, [2, 2, 2, 2, 4, 4, 4], 1
+        ),
     }
 
     OFMSplitDepth = 16
+    SubKernelMax = Block(8, 8, 65536)
 
     def __init__(
         self,
@@ -159,20 +185,18 @@
     ):
         accelerator_config = accelerator_config.lower()
         self.vela_config = vela_config
-        self.accelerator_config = accelerator_config
-        if self.accelerator_config not in ArchitectureFeatures.accelerator_configs:
+        if accelerator_config not in Accelerator.member_list():
             raise OptionError("--accelerator-config", self.accelerator_config, "Unknown accelerator configuration")
+        self.accelerator_config = Accelerator(accelerator_config)
         accel_config = ArchitectureFeatures.accelerator_configs[self.accelerator_config]
         self.config = accel_config
 
         self.system_config = system_config
-
-        self.is_yoda_system = "yoda-" in self.accelerator_config
+        self.is_yoda_system = self.accelerator_config in (Accelerator.Yoda_256, Accelerator.Yoda_512)
 
         self.ncores = accel_config.cores
         self.ofm_ublock = accel_config.ofm_ublock
         self.ifm_ublock = accel_config.ifm_ublock
-        self.subkernel_max = Block(8, 8, 65536)
         self.ofm_block_max = Block(64, 32, 128)
         self.override_block_config = override_block_config
         self.block_config_limit = block_config_limit
diff --git a/ethosu/vela/errors.py b/ethosu/vela/errors.py
index 2c93fbc..59740aa 100644
--- a/ethosu/vela/errors.py
+++ b/ethosu/vela/errors.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 # Description:
 # Defines custom exceptions.
+import inspect
 import sys
 
 from .operation import Operation
@@ -121,3 +122,20 @@
 
     print("Error: {}".format(data))
     sys.exit(1)
+
+
+def typecheck(func):
+    def wrapper(*args, **kwargs):
+        fsig = inspect.signature(func)
+        args_zipped = zip(kwargs.values(), fsig.parameters.keys())
+        for actual, expected in args_zipped:
+            expected_type = fsig.parameters[expected].annotation
+            actual_type = type(actual)
+            if expected_type is inspect.Parameter.empty:
+                raise TypeError("Please provide type info for {}, hint = {}".format(expected, actual_type))
+            if expected_type is not actual_type:
+                raise TypeError("expected : {}, but got {}".format(expected_type, actual_type))
+        # Actual execution
+        return func(*args, **kwargs)
+
+    return wrapper
diff --git a/ethosu/vela/test/extapi/test_extapi_encode_weights.py b/ethosu/vela/test/extapi/test_extapi_encode_weights.py
new file mode 100644
index 0000000..47ca02b
--- /dev/null
+++ b/ethosu/vela/test/extapi/test_extapi_encode_weights.py
@@ -0,0 +1,73 @@
+# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Description:
+# Contains unit tests for encode_weights API for an external consumer
+import numpy as np
+import pytest
+
+from ethosu.vela import weight_compressor
+from ethosu.vela.architecture_features import Accelerator
+
+
+@pytest.mark.parametrize(
+    "arch",
+    [
+        Accelerator.Ethos_U55_32,
+        Accelerator.Ethos_U55_64,
+        Accelerator.Ethos_U55_128,
+        Accelerator.Ethos_U55_256,
+        Accelerator.Yoda_256,
+        Accelerator.Yoda_512,
+    ],
+)
+@pytest.mark.parametrize("dilation_x", [1, 2])
+@pytest.mark.parametrize("dilation_y", [1, 2])
+@pytest.mark.parametrize("ifm_bitdepth", [8, 16])
+@pytest.mark.parametrize("depth_control", [1, 2, 3])
+@pytest.mark.parametrize("weights_shape_and_block_depth", [((16, 16, 16, 16), 8), ((3, 3, 25, 16), 8)])
+def test_encode_weights(
+    arch, weights_shape_and_block_depth, dilation_x, dilation_y, ifm_bitdepth, depth_control,
+):
+    """
+    This unit test checks the interface of the API function but not the functionality.
+    Functional correctness is tested at a system level.
+    """
+
+    weights_shape = weights_shape_and_block_depth[0]
+    ofm_block_depth = weights_shape_and_block_depth[1]
+    val_max = np.iinfo(np.uint8).max
+    weights_hwio = np.random.randint(val_max, size=weights_shape, dtype=np.uint8)
+    weights_ohwi = np.transpose(weights_hwio, (3, 0, 1, 2))
+    is_depthwise = True if depth_control == 2 else False
+    is_partkernel = True if depth_control == 3 else False
+    dilation_xy = (dilation_x, dilation_y)
+
+    encoded_stream = weight_compressor.encode_weights(
+        accelerator=arch,
+        weights_volume=weights_ohwi,
+        dilation_xy=dilation_xy,
+        ifm_bitdepth=ifm_bitdepth,
+        ofm_block_depth=ofm_block_depth,
+        is_depthwise=is_depthwise,
+        is_partkernel=is_partkernel,
+    )
+    assert type(encoded_stream) == bytearray
+
+
+if __name__ == "__main__":
+    # two test candidates for debugging purposes
+    test_encode_weights(Accelerator.Ethos_U55_256, ((3, 3, 25, 16), 8), 1, 1, 8, 0)
+    test_encode_weights(Accelerator.Ethos_U55_256, ((16, 16, 16, 16), 8), 1, 1, 8, 0)
diff --git a/ethosu/vela/vela.py b/ethosu/vela/vela.py
index 20bc525..1766750 100644
--- a/ethosu/vela/vela.py
+++ b/ethosu/vela/vela.py
@@ -170,7 +170,7 @@
         "--accelerator-config",
         type=str,
         default="ethos-u55-256",
-        choices=list(architecture_features.ArchitectureFeatures.accelerator_configs.keys()),
+        choices=list(architecture_features.Accelerator.member_list()),
         help="Accelerator configuration to use (default: %(default)s)",
     )
     parser.add_argument(
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index 8ebd751..687a080 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -20,7 +20,10 @@
 
 import numpy as np
 
+from .architecture_features import Accelerator
+from .architecture_features import ArchitectureFeatures
 from .data_type import DataType
+from .errors import typecheck
 from .errors import UnsupportedFeatureError
 from .nn_graph import SchedulingStrategy
 from .numeric_util import round_up
@@ -42,6 +45,55 @@
 )
 
 
+@typecheck
+def encode_weights(
+    accelerator: Accelerator,
+    weights_volume: np.ndarray,
+    dilation_xy: tuple,
+    ifm_bitdepth: int,
+    ofm_block_depth: int,
+    is_depthwise: bool,
+    is_partkernel: bool,
+):
+    """
+    Public facing API to use the ethosu weight encoding.
+
+    :param accelerator: architecture_features.Accelerator enum to pick the correct ethosu accelerator
+    :param weights_volume: numpy.ndarray in OHWI layout with a shape of four
+    :param dilation_xy: a two element tuple of dilation attributes in x,y dimension
+    :param ifm_bitdepth: the bitdepth of input feature map
+    :param ofm_block_depth: the depth of blocks for ethosu processing
+    :param is_depthwise: a boolean indicating these weights are used for a depthwise traversal
+    :param is_partkernel: a boolean indicating these weights are traversed on sub-kernal basis
+    :return: a bytearray of compressed weights
+    """
+
+    # Checks for weight layout
+    assert len(weights_volume.shape) == 4, "weights ndarray should have a shape of 4"
+
+    # It cannot be both partkernel and depthwise
+    assert not (is_depthwise and is_partkernel), "encode_weights :: partkernel and depthwise are mutually exclusive"
+
+    # Check valid values for dilation
+    assert dilation_xy[0] in (1, 2), "encode_weights :: dilation x should be 1 or 2 not {}".format(dilation_xy[0])
+    assert dilation_xy[1] in (1, 2), "encode_weights :: dilation y should be 1 or 2 not {}".format(dilation_xy[1])
+
+    ifm_ublock = ArchitectureFeatures.accelerator_configs[accelerator].ifm_ublock
+    ofm_ublock = ArchitectureFeatures.accelerator_configs[accelerator].ofm_ublock
+    raw_stream = generate_brick(
+        ifm_ublock=ifm_ublock,
+        ofm_ublock=ofm_ublock,
+        brick_weights=weights_volume,
+        ofm_block_depth=ofm_block_depth,
+        is_depthwise=is_depthwise,
+        is_partkernel=is_partkernel,
+        ifm_bitdepth=ifm_bitdepth,
+        dilation=dilation_xy,
+    )
+    encoded_stream = encode(raw_stream)
+    return encoded_stream
+
+
 def create_weight_compression_config(tens, npu_block_type, ofm_block_depth, ofm_depth_step, dilation):
     # Note: for an ofm block only its depth is used in weight compression.
     # And block depth > ofm depth gives same result as block depth == ofm depth
@@ -93,13 +145,12 @@
     return compressed
 
 
-def generate_brick(arch, brick_weights, ofm_block_depth, block_traversal, ifm_bitdepth, dilation):
-    is_depthwise = block_traversal == TensorBlockTraversal.DepthWise
-    is_partkernel = block_traversal == TensorBlockTraversal.PartKernelFirst
-    decomp_h = arch.subkernel_max.height // dilation[0]
-    decomp_w = arch.subkernel_max.width // dilation[1]
-    ofm_ublock = arch.ofm_ublock
-    ifm_ublock = arch.ifm_ublock
+def generate_brick(
+    ifm_ublock, ofm_ublock, brick_weights, ofm_block_depth, is_depthwise, is_partkernel, ifm_bitdepth, dilation
+):
+
+    decomp_h = ArchitectureFeatures.SubKernelMax.height // dilation[0]
+    decomp_w = ArchitectureFeatures.SubKernelMax.width // dilation[1]
     # Expect weights formatted OHWI
     ofm_depth = brick_weights.shape[-4]
     ifm_depth = brick_weights.shape[-1]
@@ -245,6 +296,9 @@
         else:
             tens.block_traversal = TensorBlockTraversal.DepthFirst
 
+    is_depthwise = tens.block_traversal == TensorBlockTraversal.DepthWise
+    is_partkernel = tens.block_traversal == TensorBlockTraversal.PartKernelFirst
+
     if tens.consumer_list[0].type == "Conv2DBackpropInputSwitchedBias":
         # Transpose Convoluion, reverse weights in H and W axes
         weights = np.flip(weights, axis=(0, 1))
@@ -262,7 +316,6 @@
 
         substream_offsets = [0]
         encoded_stream = []
-        raw_size = 0
 
         # For each core, deinterleave weights from the larger volume
         # and generate separate compressed streams.
@@ -270,15 +323,17 @@
             core_weights = core_deinterleave(brick_weights, core, arch.ncores)
 
             block_depth = (ofm_block_depth + arch.ncores - 1 - core) // arch.ncores
+            encoded_substream = []
             if block_depth != 0:
-                raw_stream = generate_brick(
-                    arch, core_weights, block_depth, tens.block_traversal, ifm_bitdepth, dilation
+                encoded_substream = encode_weights(
+                    accelerator=arch.accelerator_config,
+                    weights_volume=core_weights,
+                    dilation_xy=dilation,
+                    ifm_bitdepth=ifm_bitdepth,
+                    ofm_block_depth=block_depth,
+                    is_depthwise=is_depthwise,
+                    is_partkernel=is_partkernel,
                 )
-            else:
-                raw_stream = []
-
-            raw_size += len(raw_stream)
-            encoded_substream = encode(raw_stream)
             encoded_stream.extend(encoded_substream)
             substream_offsets.append(len(encoded_stream))