MLBEDSW-3424: Expose API through separate file

All external APIs are now exposed by api.py.

Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
Change-Id: I33f480e424692ac30e9c7d791f583199f31164a7
diff --git a/ethosu/vela/api.py b/ethosu/vela/api.py
index 0799ab1..f64a38f 100644
--- a/ethosu/vela/api.py
+++ b/ethosu/vela/api.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 #
 # Description:
-# Contains data types used in the external API for code generation
+# Contains external APIs
 from enum import auto
 from enum import Enum
 from typing import List
@@ -23,11 +23,26 @@
 from typing import Optional
 from typing import Tuple
 
+import numpy
+
 API_version_major = 1
 API_version_minor = 0
 api_version = f"{API_version_major}.{API_version_minor}"
 
 
+class NpuAccelerator(Enum):
+    """
+    Supported accelerators
+    """
+
+    Ethos_U55_32 = auto()
+    Ethos_U55_64 = auto()
+    Ethos_U55_128 = auto()
+    Ethos_U55_256 = auto()
+    Ethos_U65_256 = auto()
+    Ethos_U65_512 = auto()
+
+
 class NpuElementWiseOp(Enum):
     """
     Elementwise operation
@@ -381,3 +396,60 @@
     """
     version = (API_version_major << 16) | (API_version_minor & 0xFFFF)
     return version
+
+
+def npu_encode_weights(
+    accelerator: NpuAccelerator,
+    weights_volume: numpy.ndarray,
+    dilation_xy: Tuple[int, int],
+    ifm_bitdepth: int,
+    ofm_block_depth: int,
+    is_depthwise: bool,
+    block_traversal: NpuBlockTraversal,
+):
+    """
+    Public facing API to use the Ethos-U weight encoding.
+
+    :param accelerator: NpuAccelerator enum to pick the correct accelerator
+    :param weights_volume: numpy.ndarray in OHWI layout with a shape of four
+    :param dilation_xy: a two element tuple of dilation attributes in x,y dimension
+    :param ifm_bitdepth: the bitdepth of input feature map
+    :param ofm_block_depth: the depth of blocks for processing
+    :param is_depthwise: a boolean indicating these weights are used for a depthwise traversal
+    :param block_traversal: indicates how these weights are traversed on sub-kernel basis
+    :return: a bytearray of compressed weights
+    """
+    from .architecture_features import Accelerator
+    from . import weight_compressor
+
+    acc = Accelerator.from_npu_accelerator(accelerator)
+    return weight_compressor.encode_weights(
+        acc, weights_volume, dilation_xy, ifm_bitdepth, ofm_block_depth, is_depthwise, block_traversal
+    )
+
+
+def npu_encode_bias(bias: numpy.int64, scale: int, shift: int):
+    """
+    Public facing API to pack bias and scale values as required by the hardware
+    :param bias: 64-bit signed number that includes 40-bit signed bias
+    :param scale: 32-bit scale value
+    :param shift: 6-bit shift value
+    :return: packed 80-bit [0(2-bits),shift(6-bits),scale(32-bits),bias(40-bits)]
+    """
+    from . import weight_compressor
+
+    return weight_compressor.encode_bias(bias, scale, shift)
+
+
+def npu_generate_register_command_stream(npu_op_list: List[NpuOperation], accelerator: NpuAccelerator) -> List[int]:
+    """
+    Public facing API for generating an Ethos-U register command stream.
+    Calculates dependencies between commands and inserts wait operations if needed.
+
+    :param npu_op_list: List[NpuOperation] list of high level NPU operations
+    :param accelerator: NpuAccelerator enum to pick the correct accelerator
+    :return register commands, as a list of 32-bit integers
+    """
+    from . import register_command_stream_generator
+
+    return register_command_stream_generator.generate_register_command_stream(npu_op_list, accelerator)
diff --git a/ethosu/vela/architecture_features.py b/ethosu/vela/architecture_features.py
index 7b6c3be..18846cf 100644
--- a/ethosu/vela/architecture_features.py
+++ b/ethosu/vela/architecture_features.py
@@ -21,6 +21,7 @@
 
 import numpy as np
 
+from .api import NpuAccelerator
 from .errors import CliOptionError
 from .errors import ConfigOptionError
 from .ethos_u55_regs.ethos_u55_regs import resampling_mode
@@ -131,6 +132,20 @@
     def member_list(cls):
         return [e.value for e in cls]
 
+    @classmethod
+    def from_npu_accelerator(cls, npu_accelerator: NpuAccelerator) -> "Accelerator":
+        """Converts the given public API object to Accelerator (used internally)"""
+        accelerator_map = {
+            NpuAccelerator.Ethos_U55_32: cls.Ethos_U55_32,
+            NpuAccelerator.Ethos_U55_64: cls.Ethos_U55_64,
+            NpuAccelerator.Ethos_U55_128: cls.Ethos_U55_128,
+            NpuAccelerator.Ethos_U55_256: cls.Ethos_U55_256,
+            NpuAccelerator.Ethos_U65_256: cls.Ethos_U65_256,
+            NpuAccelerator.Ethos_U65_512: cls.Ethos_U65_512,
+        }
+        assert npu_accelerator in accelerator_map, f"Unsupported accelerator {npu_accelerator}"
+        return accelerator_map[npu_accelerator]
+
 
 @enum.unique
 class MemPort(enum.Enum):
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index e612c30..04f7072 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -28,6 +28,7 @@
 
 from . import numeric_util
 from . import scaling
+from .api import NpuAccelerator
 from .api import NpuActivation
 from .api import NpuActivationOp
 from .api import NpuAddressRange
@@ -1270,15 +1271,16 @@
         print("command stream length in words", len(sg.register_command_stream))
 
 
-def generate_register_command_stream(npu_op_list: List[NpuOperation], accelerator: Accelerator) -> List[int]:
+def generate_register_command_stream(npu_op_list: List[NpuOperation], npu_accelerator: NpuAccelerator) -> List[int]:
     """
-    Public facing API for generating an Ethos-U register command stream.
+    Internal implementation of the public facing API for generating an Ethos-U register command stream.
     Calculates dependencies between commands and inserts wait operations if needed.
 
     :param npu_op_list: List[NpuOperation] list of high level NPU operations
     :param accelerator: architecture_features.Accelerator enum to pick the correct Ethos-U accelerator
     :return Ethos-U instructions, as a list of 32-bit integers
     """
+    accelerator = Accelerator.from_npu_accelerator(npu_accelerator)
     emit = CommandStreamEmitter()
     arch = ArchitectureFeatures(
         vela_config_files=None,
diff --git a/ethosu/vela/test/extapi/test_extapi_encode_bias.py b/ethosu/vela/test/extapi/test_extapi_encode_bias.py
index ffdd3b0..c0a4a9a 100644
--- a/ethosu/vela/test/extapi/test_extapi_encode_bias.py
+++ b/ethosu/vela/test/extapi/test_extapi_encode_bias.py
@@ -14,12 +14,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # Description:
-# Contains unit tests for encode_biases API for an external consumer
+# Contains unit tests for npu_encode_bias API for an external consumer
 import random
 
 import numpy as np
 
-from ethosu.vela.weight_compressor import encode_bias
+from ethosu.vela.api import npu_encode_bias
 
 
 def test_encode_bias():
@@ -34,6 +34,6 @@
         bias = np.int64(random.randint(bias_lower_limit, bias_upper_limit))
         scale = int(random.randint(scale_lower_limit, scale_upper_limit))
         shift = int(random.randint(shift_lower_limit, shift_upper_limit))
-        biases_enc = encode_bias(bias, scale, shift)
+        biases_enc = npu_encode_bias(bias, scale, shift)
         assert isinstance(biases_enc, bytearray)
         assert len(biases_enc) == 10
diff --git a/ethosu/vela/test/extapi/test_extapi_encode_weights.py b/ethosu/vela/test/extapi/test_extapi_encode_weights.py
index 854d14c..6367cb3 100644
--- a/ethosu/vela/test/extapi/test_extapi_encode_weights.py
+++ b/ethosu/vela/test/extapi/test_extapi_encode_weights.py
@@ -14,25 +14,17 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # Description:
-# Contains unit tests for encode_weights API for an external consumer
+# Contains unit tests for npu_encode_weights API for an external consumer
 import numpy as np
 import pytest
 
-from ethosu.vela import weight_compressor
+from ethosu.vela.api import npu_encode_weights
+from ethosu.vela.api import NpuAccelerator
 from ethosu.vela.api import NpuBlockTraversal
-from ethosu.vela.architecture_features import Accelerator
 
 
 @pytest.mark.parametrize(
-    "arch",
-    [
-        Accelerator.Ethos_U55_32,
-        Accelerator.Ethos_U55_64,
-        Accelerator.Ethos_U55_128,
-        Accelerator.Ethos_U55_256,
-        Accelerator.Ethos_U65_256,
-        Accelerator.Ethos_U65_512,
-    ],
+    "arch", list(NpuAccelerator),
 )
 @pytest.mark.parametrize("dilation_x", [1, 2])
 @pytest.mark.parametrize("dilation_y", [1, 2])
@@ -56,7 +48,7 @@
     block_traversal = NpuBlockTraversal.PART_KERNEL_FIRST if depth_control == 3 else NpuBlockTraversal.DEPTH_FIRST
     dilation_xy = (dilation_x, dilation_y)
 
-    encoded_stream = weight_compressor.encode_weights(
+    encoded_stream = npu_encode_weights(
         accelerator=arch,
         weights_volume=weights_ohwi,
         dilation_xy=dilation_xy,
diff --git a/ethosu/vela/test/extapi/test_extapi_generate_commands.py b/ethosu/vela/test/extapi/test_extapi_generate_commands.py
index 49b24b2..86ef804 100644
--- a/ethosu/vela/test/extapi/test_extapi_generate_commands.py
+++ b/ethosu/vela/test/extapi/test_extapi_generate_commands.py
@@ -15,7 +15,9 @@
 # limitations under the License.
 #
 # Description:
-# Contains unit tests for generate_register_command_stream API for an external consumer
+# Contains unit tests for npu_generate_register_command_stream API for an external consumer
+from ethosu.vela.api import npu_generate_register_command_stream
+from ethosu.vela.api import NpuAccelerator
 from ethosu.vela.api import NpuActivation
 from ethosu.vela.api import NpuActivationOp
 from ethosu.vela.api import NpuAddressRange
@@ -35,11 +37,9 @@
 from ethosu.vela.api import NpuQuantization
 from ethosu.vela.api import NpuShape3D
 from ethosu.vela.api import NpuTileBox
-from ethosu.vela.architecture_features import Accelerator
 from ethosu.vela.ethos_u55_regs.ethos_u55_regs import cmd0
 from ethosu.vela.ethos_u55_regs.ethos_u55_regs import cmd1
 from ethosu.vela.register_command_stream_generator import CmdMode
-from ethosu.vela.register_command_stream_generator import generate_register_command_stream
 from ethosu.vela.register_command_stream_generator import get_address_ranges
 
 
@@ -109,7 +109,7 @@
     # In this example we assume that the weights were compressed with ofm depth 16;
     # let vela choose suitable block width and height by setting these to -1
     op.block_config = NpuShape3D(height=-1, width=-1, depth=16)
-    cmds = generate_register_command_stream([op], Accelerator.Ethos_U55_128)
+    cmds = npu_generate_register_command_stream([op], NpuAccelerator.Ethos_U55_128)
     check_cmd0(cmds, cmd0.NPU_SET_IFM_REGION, 1)
     check_cmd1(cmds, cmd1.NPU_SET_IFM_BASE0, 512)
     check_cmd1(cmds, cmd1.NPU_SET_IFM_BASE1, 0)
@@ -203,7 +203,7 @@
 def test_fully_connected():
     """Tests command stream generation for a fully connected operation"""
     op = create_fully_connected_op()
-    cmds = generate_register_command_stream([op], Accelerator.Ethos_U55_128)
+    cmds = npu_generate_register_command_stream([op], NpuAccelerator.Ethos_U55_128)
     check_cmd0(cmds, cmd0.NPU_OP_CONV, 0)
     assert len(cmds) > 20
 
@@ -223,7 +223,7 @@
     op.weights = [weights_dest]
     op.biases = [NpuAddressRange(region=0, address=0, length=80)]
     op.block_config = NpuShape3D(height=-1, width=-1, depth=8)
-    cmds = generate_register_command_stream([dma_op, op], Accelerator.Ethos_U55_128)
+    cmds = npu_generate_register_command_stream([dma_op, op], NpuAccelerator.Ethos_U55_128)
     check_cmd0(cmds, cmd0.NPU_SET_DMA0_SRC_REGION, 0)
     check_cmd1(cmds, cmd1.NPU_SET_DMA0_SRC, 0x40)
     check_cmd0(cmds, cmd0.NPU_SET_DMA0_DST_REGION, 1)
@@ -248,7 +248,7 @@
     op.activation = NpuActivation(NpuActivationOp.NONE_OR_RELU)
     op.activation.min = 0  # RELU
     # Do not set a block config, let vela choose one
-    cmds = generate_register_command_stream([op], Accelerator.Ethos_U55_32)
+    cmds = npu_generate_register_command_stream([op], NpuAccelerator.Ethos_U55_32)
     check_cmd1(cmds, cmd1.NPU_SET_OFM_SCALE, 1073741824, 30)
     check_cmd0(cmds, cmd0.NPU_SET_IFM_REGION, 1)
     check_cmd1(cmds, cmd1.NPU_SET_IFM_BASE0, 32)
@@ -337,7 +337,7 @@
 def test_avg_pool():
     """Tests average pool operation"""
     op = create_avg_pool_op()
-    cmds = generate_register_command_stream([op], Accelerator.Ethos_U55_128)
+    cmds = npu_generate_register_command_stream([op], NpuAccelerator.Ethos_U55_128)
     check_cmd0(cmds, cmd0.NPU_OP_POOL, 1)
     assert len(cmds) > 10
 
@@ -346,7 +346,7 @@
     """Tests code generation with 2 operations"""
     op1 = create_fully_connected_op()
     op2 = create_avg_pool_op()
-    cmds = generate_register_command_stream([op1, op2], Accelerator.Ethos_U55_64)
+    cmds = npu_generate_register_command_stream([op1, op2], NpuAccelerator.Ethos_U55_64)
     check_cmd0(cmds, cmd0.NPU_OP_POOL, 1)
     check_cmd0(cmds, cmd0.NPU_OP_CONV, 0)
     check_cmd0(cmds, cmd0.NPU_SET_BLOCKDEP, 0)
@@ -363,7 +363,7 @@
     assert dest is not None
     src = NpuAddressRange(0, 0x24000, dest.length)
     dma_op = NpuDmaOperation(src, dest)
-    cmds = generate_register_command_stream([dma_op, pool_op], Accelerator.Ethos_U55_64)
+    cmds = npu_generate_register_command_stream([dma_op, pool_op], NpuAccelerator.Ethos_U55_64)
     check_cmd0(cmds, cmd0.NPU_OP_DMA_START, 0)
     # A DMA WAIT should have been inserted
     check_cmd0(cmds, cmd0.NPU_OP_DMA_WAIT, 0)
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index 0eab185..40ebcd0 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -17,6 +17,7 @@
 # Compresses and pads the weigths. It also calculates the scales and packs with the biases.
 import math
 from collections import namedtuple
+from typing import Tuple
 
 import numpy as np
 
@@ -50,14 +51,14 @@
 def encode_weights(
     accelerator: Accelerator,
     weights_volume: np.ndarray,
-    dilation_xy: tuple,
+    dilation_xy: Tuple[int, int],
     ifm_bitdepth: int,
     ofm_block_depth: int,
     is_depthwise: bool,
     block_traversal: NpuBlockTraversal,
 ):
     """
-    Public facing API to use the Ethos-U weight encoding.
+    Internal implementation of the public facing API to use weight encoding.
 
     :param accelerator: architecture_features.Accelerator enum to pick the correct Ethos-U accelerator
     :param weights_volume: numpy.ndarray in OHWI layout with a shape of four
@@ -65,10 +66,10 @@
     :param ifm_bitdepth: the bitdepth of input feature map
     :param ofm_block_depth: the depth of blocks for Ethos-U processing
     :param is_depthwise: a boolean indicating these weights are used for a depthwise traversal
-    :param block_traversal: indicates how these weights are traversed on sub-kernal basis
+    :param block_traversal: indicates how these weights are traversed on sub-kernel basis
+
     :return: a bytearray of compressed weights
     """
-
     # Check arg types
     assert isinstance(accelerator, Accelerator)
     assert isinstance(weights_volume, np.ndarray)
@@ -108,7 +109,7 @@
 
 def encode_bias(bias: np.int64, scale: int, shift: int):
     """
-    Public facing API to pack bias and scale values as required by the Ethos-U
+    Internal implementation of public facing API to pack bias and scale values as required by the Ethos-U
 
     :param bias: 64bit signed number that includes 40bit signed bias
     :param scale: 32bit scale value