TOSA: Support for AVGPOOL, MAXPOOL and CONV2D

Added support for
-AVGPOOL and CONV2D with TFLite correspondence
-MAXPOOL
-additional support for replacing RESCALE ops with avgpool.

No support for breaking down tensors over the
size supported by NPU.

Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I1d2aa50ac30a26283b3e6f1fe88cba1544b7c189
diff --git a/ethosu/vela/api.py b/ethosu/vela/api.py
index 69c6040..d516b8d 100644
--- a/ethosu/vela/api.py
+++ b/ethosu/vela/api.py
@@ -25,6 +25,7 @@
 
 import numpy
 
+
 API_VERSION_MAJOR = 1
 API_VERSION_MINOR = 1
 API_VERSION = f"{API_VERSION_MAJOR}.{API_VERSION_MINOR}"
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py
index 5e676f1..570c724 100644
--- a/ethosu/vela/graph_optimiser_util.py
+++ b/ethosu/vela/graph_optimiser_util.py
@@ -15,6 +15,8 @@
 # limitations under the License.
 # Description:
 # Common functions and definitions used during the graph optimization.
+from typing import Tuple
+
 from .data_type import DataType
 from .debug_database import DebugDatabase
 from .errors import VelaError
@@ -132,6 +134,21 @@
     tens.needs_linear_format = False
 
 
+def calc_explicit_padding(input_size, stride, filter_size, pad_before, pad_after) -> Tuple[int, int]:
+    """
+    Based on explicit padding provided in a PAD operation, returns the corresponding hardware padding
+    that provides equivalent results.
+    """
+    total_padding = needed_total_padding(input_size, stride, filter_size)
+
+    # The bottom/right padding might need downward adjustment depending on stride/input size
+    total_minus_before = total_padding - pad_before
+    output_pad_after = pad_after
+    while output_pad_after > 0 and output_pad_after % stride != total_minus_before % stride:
+        output_pad_after -= 1
+    return pad_before, output_pad_after
+
+
 def needed_total_padding(input_size, stride, filter_size):
     out_size = (input_size + stride - 1) // stride
     needed_input = (out_size - 1) * stride + filter_size
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index f8c9de3..c5d0646 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -204,6 +204,8 @@
         return True
     if ps.primary_op.type not in (Op.AvgPool, Op.ResizeBilinear, Op.CLZ, Op.SHL):
         return False
+    if ps.primary_op.type == Op.AvgPool and ps.primary_op.explicit_scaling:
+        return False
     fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
     forced_ofm_quantization = ps.primary_op.forced_output_quantization
     use_0 = (
@@ -413,6 +415,10 @@
     set_common_op_fields(npu_op, cmd, arch)
     # Pooling specific info
     npu_op.rescale = op.rescale
+    if op.explicit_scaling:
+        # Note: reuse of rescale for explicit scaling to not expose this in the external API
+        assert npu_op.rescale is None
+        npu_op.rescale = op.explicit_scaling
     return npu_op
 
 
diff --git a/ethosu/vela/operation_util.py b/ethosu/vela/operation_util.py
index c51a6b5..4a4fd33 100644
--- a/ethosu/vela/operation_util.py
+++ b/ethosu/vela/operation_util.py
@@ -39,7 +39,7 @@
     op.attrs["strides"] = [1, 1, 1, 1]
     op.attrs["ksize"] = [1, 1, 1, 1]
     op.attrs["skirt"] = [0, 0, 0, 0]
-    op.attrs["explicit_padding"] = [0, 0, 0, 0]
+    op.attrs["explicit_padding"] = [0, 0, 0, 0]  # [top, left, bottom, right]
     op.run_on_npu = True
     return op
 
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index d61e571..6ee0005 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -71,6 +71,7 @@
 from .numeric_util import quantise_float32
 from .numeric_util import round_away_zero
 from .numeric_util import round_up_to_int
+from .operation import ExplicitScaling
 from .operation import NpuBlockType
 from .range_set import MemoryAccessSet
 from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
@@ -676,11 +677,18 @@
         ofm_scale_f64 = np.double(ofm_quant.scale_f32)
         scale, shift = scaling.quantise_scale(ifm_scale_f64 / ofm_scale_f64)
     elif pool_op.rescale is not None:
-        # for ResizeBilinear operations with rescale
-        rescale = pool_op.rescale
-        rescale_bits = len(bin(round_up_to_int(rescale))) - 2 + 1
-        scale, shift = scaling.quantise_pooling_scale(kernel.height * kernel.width, rescale_bits)
-        scale = int(round_away_zero(scale * rescale))
+        if type(pool_op.rescale) == ExplicitScaling:
+            # Note: reuse of rescale for explicit scaling to not expose this in the external API
+            explicit_scaling = pool_op.rescale
+            assert explicit_scaling.per_channel is False
+            scale = explicit_scaling.multiplier[0]
+            shift = explicit_scaling.shift[0]
+        else:
+            # for ResizeBilinear operations with rescale
+            rescale = pool_op.rescale
+            rescale_bits = len(bin(round_up_to_int(rescale))) - 2 + 1
+            scale, shift = scaling.quantise_pooling_scale(kernel.height * kernel.width, rescale_bits)
+            scale = int(round_away_zero(scale * rescale))
     else:
         # In case avg pool fused with concat or other memory operation, rescaling might be needed.
         # kernel height == kernel width == 1 is always true in this case
@@ -896,6 +904,9 @@
     use_global_scale = (
         npu_op.sub_op_type in (NpuPoolingOp.AVERAGE, NpuPoolingOp.REDUCE_SUM) and sum(npu_op.padding) == 0
     )
+    # Note: reuse of rescale for explicit scaling to not expose this in the external API
+    if npu_op.rescale is not None and type(npu_op.rescale) == ExplicitScaling:
+        use_global_scale = not npu_op.rescale.per_channel
     generate_common(emit, npu_op, NpuBlockTraversal.DEPTH_FIRST, arch, use_global_scale=use_global_scale)
     # Pooling op specific
     if use_global_scale:
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index ff2f5a0..3f743e4 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -18,7 +18,6 @@
 # to do the traversal of the graph.
 import math
 import uuid
-from typing import Tuple
 
 import numpy as np
 
@@ -31,6 +30,7 @@
 from .debug_database import DebugDatabase
 from .errors import UnsupportedFeatureError
 from .ethos_u55_regs.ethos_u55_regs import resampling_mode
+from .graph_optimiser_util import calc_explicit_padding
 from .graph_optimiser_util import needed_total_padding
 from .graph_optimiser_util import set_ifm_ofm_op_shapes
 from .graph_optimiser_util import set_tensor_equivalence
@@ -270,21 +270,6 @@
     return op
 
 
-def calc_explicit_padding(input_size, stride, filter_size, pad_before, pad_after) -> Tuple[int, int]:
-    """
-    Based on explicit padding provided in a PAD operation, returns the corresponding hardware padding
-    that provides equivalent results.
-    """
-    total_padding = needed_total_padding(input_size, stride, filter_size)
-    # The top/left padding can be taken as is from the PAD
-    output_pad_before = pad_before
-    # The bottom/right padding might need downward adjustment depending on stride/input size
-    output_pad_after = pad_after
-    while output_pad_after > 0 and output_pad_after % stride != (total_padding - pad_before) % stride:
-        output_pad_after -= 1
-    return output_pad_before, output_pad_after
-
-
 def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
     k_w, k_h = kernel.dilated_wh()
     s_x, s_y = kernel.stride
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index fe18ce3..44e0f8e 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -19,21 +19,38 @@
 from .api import NpuRoundingMode
 from .data_type import DataType
 from .debug_database import DebugDatabase
+from .graph_optimiser_util import calc_explicit_padding
 from .graph_optimiser_util import needed_total_padding
 from .graph_optimiser_util import set_ifm_ofm_op_shapes
 from .graph_optimiser_util import set_tensor_equivalence
 from .operation import ExplicitScaling
 from .operation import NpuBlockType
 from .operation import Op
-from .operation import Padding
+from .operation_util import create_avgpool_nop
 
 
-def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
+def replace_rescale_with_avg_pool(rescale_op):
+    assert rescale_op.type == Op.Rescale
+
+    avgpool_op = create_avgpool_nop(rescale_op.name + "_avgpool")
+    rescale_op_clone = rescale_op.clone()
+    op = rescale_op
+    op.attrs = avgpool_op.attrs.copy()
+    op.type = Op.AvgPool
+    DebugDatabase.add_optimised(rescale_op_clone, op)
+
+    return op
+
+
+def calc_skirt(kernel, input_shape, explicit_padding):
     k_w, k_h = kernel.dilated_wh()
     s_x, s_y = kernel.stride
     ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
     xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
-    left_pad, right_pad, top_pad, bottom_pad = explicit_padding
+
+    top, left, bottom, right = explicit_padding
+    top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
+    left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
 
     padding = (top_pad, left_pad, bottom_pad, right_pad)
     skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
@@ -42,16 +59,14 @@
 
 def add_padding_fields(op, arch, nng):
     if op.run_on_npu:
-        if "padding" in op.attrs:
+        if "explicit_padding" in op.attrs:
             input_shape = op.ifm_shapes[0]
 
             if op.type == Op.Conv2DBackpropInputSwitchedBias:
                 # TODO not yet supported, but there will be need for separate handling
                 assert False
             else:
-                padding, skirt = calc_padding_and_skirt(
-                    Padding.EXPLICIT, op.kernel, input_shape, op.attrs.get("padding"),
-                )
+                padding, skirt = calc_skirt(op.kernel, input_shape, op.attrs.get("explicit_padding"))
 
             op.attrs["explicit_padding"] = padding
             op.attrs["skirt"] = skirt
@@ -104,7 +119,6 @@
         prev_op = ifm.ops[0]
 
         # TODO currently not supported
-        assert prev_op.type not in (Op.Placeholder, Op.SubgraphInput, Op.Const)
         assert len(ifm.consumer_list) == 1
 
         input_zp = op.attrs["input_zp"]
@@ -126,27 +140,26 @@
             print("Error (fuse_rescale): zp of tensors producer/consumer differs unexpectedidly ")
             assert False
         ifm.quantization.zero_point = input_zp
+        ofm.quantization.zero_point = output_zp
+        for s, m in zip(shift, multiplier):
+            # TODO these are the TOSA limitations
+            assert m >= 0
+            assert 2 <= s <= 62
+            # TODO these are the HW limitations
+            assert 0 <= s < (1 << 6)
+        explicit_scaling = ExplicitScaling(per_channel, shift, multiplier)
 
-        if not scale32:
-            double_round = False
+        if double_round and scale32:
+            rounding_mode = NpuRoundingMode.TFL
+        else:
+            rounding_mode = NpuRoundingMode.NATURAL
 
         if prev_op.type.is_depthwise_conv2d_op() or prev_op.type.is_conv2d_op() or prev_op.type == Op.FullyConnected:
             assert len(multiplier) == len(shift) == len(prev_op.bias.values)
 
             if ifm.dtype == DataType.int32 and per_channel:
-                for s, m in zip(shift, multiplier):
-                    # TODO these are the TOSA limitations
-                    assert m >= 0
-                    assert 2 <= s <= 62
-                    # TODO these are the HW limitations
-                    assert 0 <= s < (1 << 6)
-                prev_op.explicit_scaling = ExplicitScaling(per_channel, shift, multiplier)
-                ofm.quantization.zero_point = output_zp
-
-                if double_round:
-                    prev_op.rounding_mode = NpuRoundingMode.TFL
-                else:
-                    prev_op.rounding_mode = NpuRoundingMode.NATURAL
+                prev_op.explicit_scaling = explicit_scaling
+                prev_op.rounding_mode = rounding_mode
 
                 # Bypass op
                 prev_op.set_output_tensor(ofm)
@@ -155,13 +168,42 @@
             else:
                 print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
                 assert False
+        # TODO which are the cases we need to and can do standalone Rescale?
+        # TODO should we try to identify a conversion uint8<->int8 accomplished by 2 RESCALE ops?
+        # origin might be TFLite op QUANTIZE, should we look to see if they can be translated to QUANTIZE?
+        # limited to these at the moment:
+        elif (
+            (ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8)
+            or (ifm.dtype == DataType.uint8 and ofm.dtype == DataType.int8)
+            or (ifm.dtype == DataType.int8 and ofm.dtype == DataType.uint8)
+        ):
+            # Create  NOP performing the RESCALE
+            avgpool_op = replace_rescale_with_avg_pool(op)
+            avgpool_op.rounding_mode = rounding_mode
 
+            if per_channel:
+                # TODO
+                avgpool_op.explicit_scaling = explicit_scaling
+                print("Warning, unsupported TOSA Rescale")
+                assert False
+            else:
+                avgpool_op.explicit_scaling = explicit_scaling
         else:
             print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
             assert False
     return op
 
 
+def fixup_quantization(op, arch, nng):
+    if op.ifm and op.ifm.quantization.zero_point is None:
+        op.ifm.quantization.zero_point = 0
+    if op.ifm2 and op.ifm2.quantization.zero_point is None:
+        op.ifm.quantization.zero_point = 0
+    if op.ofm and op.ofm.quantization.zero_point is None:
+        op.ofm.quantization.zero_point = 0
+    return op
+
+
 def supported_operator_check(op, arch, nng):
     op.run_on_npu = arch.tosa_supported_operators.is_operator_supported(op)
     return op
@@ -187,10 +229,14 @@
             nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False,
         )
 
-    # Post-processing step
+    # Post-processing step 1
     for idx, sg in enumerate(nng.subgraphs):
         nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
             nng, sg, arch, [], [rewrite_activation, add_padding_fields],
         )
 
+    # Post-processing step 2
+    for idx, sg in enumerate(nng.subgraphs):
+        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(nng, sg, arch, [], [fixup_quantization],)
+
     return nng
diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py
index dfed035..eb31716 100644
--- a/ethosu/vela/tosa_reader.py
+++ b/ethosu/vela/tosa_reader.py
@@ -30,6 +30,7 @@
 from .reader_util import decode_str
 from .reader_util import fixup_tensors
 from .tensor import QuantizationParameters
+from .tensor import shape_num_elements
 from .tensor import Tensor
 from .tflite_mapping import DataType
 from .tosa.TosaGraph import TosaGraph as TG
@@ -135,6 +136,22 @@
         if attr_serializer is not None:
             op.attrs = attr_serializer.deserialize(op_data)
 
+            if "padding" in op.attrs:
+                padding = op.attrs["padding"]  # [top, bottom, left, right]
+                op.attrs["explicit_padding"] = (
+                    padding[0],
+                    padding[2],
+                    padding[1],
+                    padding[3],
+                )  # [top, left, bottom, right]
+            if "stride" in op.attrs:
+                stride = op.attrs["stride"]
+                if len(stride) == 2:
+                    op.attrs["strides"] = (1, stride[0], stride[1], 1)
+                else:
+                    # TODO CONV3D more to be done....
+                    print("Unsupported kernel dimensions: ", len(stride))
+                    assert False
             if "dilation" in op.attrs:
                 dilation = op.attrs["dilation"]
                 if len(dilation) == 2:
@@ -160,7 +177,7 @@
                 self.set_tensor_zp(op.ifm, quant_info["input_zp"])
             if "weight_zp" in quant_info:
                 self.set_tensor_zp(op.weights, quant_info["weight_zp"])
-            if "ouput_zp" in quant_info:
+            if "output_zp" in quant_info:
                 self.set_tensor_zp(op.ofm, quant_info["output_zp"])
             if "a_zp" in quant_info:
                 self.set_tensor_zp(op.ifm, quant_info["a_zp"])
@@ -194,7 +211,12 @@
             data_as_numpy = tens_data.DataAsNumpy()
             if tens_dtype in datatype_map_numpy:
                 np_dtype = datatype_map_numpy[tens_dtype]
-                tens.values = np.array(data_as_numpy.view(np_dtype).reshape(shape))
+
+                # TOSA pads the tensor data
+                shape_elements = shape_num_elements(shape)
+                values = np.array(data_as_numpy.view(np_dtype))
+                values = values[0:shape_elements]
+                tens.values = values.reshape(shape)
             else:
                 # int48 is only expected as an accumulated data/output format, int4 not supported
                 print(f"Error: unsupported/unexpected Tensor type {dtype}, with data")
diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py
index 51f80eb..3b0e6b3 100644
--- a/ethosu/vela/tosa_supported_operators.py
+++ b/ethosu/vela/tosa_supported_operators.py
@@ -29,7 +29,11 @@
     # Categorised lists of supported operators
     convolution_ops = set((Op.Conv2DBias,))
     convolution_like_ops = convolution_ops
-    mac_main_ops = convolution_like_ops
+    max_pooling_ops = Op.op_set(Op.is_maxpool_op)
+    avg_pooling_ops = Op.op_set(Op.is_avgpool_op)
+    pooling_ops = set((Op.ReduceSum,)) | max_pooling_ops | avg_pooling_ops
+
+    mac_main_ops = convolution_like_ops | pooling_ops
 
     type_conversion_ops = set((Op.Rescale,))
     relu_ops = set((Op.Clamp, Op.ReluN,))
diff --git a/ethosu/vela/vela.py b/ethosu/vela/vela.py
index 7400b8e..9448749 100644
--- a/ethosu/vela/vela.py
+++ b/ethosu/vela/vela.py
@@ -87,7 +87,7 @@
     output_tfl_filename = output_basename + "_vela.tflite"
     if input_name.endswith(".tflite"):
         tflite_writer.write_tflite(nng, output_tfl_filename)
-    elif input_name.endswith(".tosa"):
+    if input_name.endswith(".tosa"):
         rawdata_writer.write_rawdata_output(nng, arch, output_basename)
 
     if enable_debug_db: