TOSA: Added support for ADD, SUB and MUL

Added support for ADD, SUB and MUL

Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I52acdc126b16e2cf4096bcf7a77023ea7d204998
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index c5d0646..67d1cd9 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -86,6 +86,7 @@
 # Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
 elementwise_op_map = {
     Op.Mul: NpuElementWiseOp.MUL,
+    Op.RescaleMul: NpuElementWiseOp.MUL,
     Op.Add: NpuElementWiseOp.ADD,
     Op.RescaleAdd: NpuElementWiseOp.ADD,
     Op.Sub: NpuElementWiseOp.SUB,
@@ -460,7 +461,7 @@
         output_scale = npu_op.ifm.quantization.scale_f32 / npu_op.ofm.quantization.scale_f32
     if op.type == Op.LeakyRelu:
         output_scale = op.attrs["alpha"]
-    if op.type == Op.RescaleAdd:
+    if op.type in (Op.RescaleAdd, Op.RescaleMul):
         assert op.rescale is not None, f"{op.type} must have rescale"
         npu_op.rescale = op.rescale
     if op.type in (Op.Add, Op.Mul, Op.Sub):
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 80be228..681f498 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -244,6 +244,7 @@
     ReluN = OperatorInfo(indices=NNG_IFM_INDICES)  # TOSA specific
     Rescale = OperatorInfo(indices=NNG_IFM_INDICES)  # TOSA specific
     RescaleAdd = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
+    RescaleMul = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
     Reshape = OperatorInfo(indices=NNG_IFM_INDICES)
     ResizeBilinear = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
     ResizeNearestNeighbor = OperatorInfo()
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index 6ee0005..d74f4d2 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -729,7 +729,9 @@
             output_scale = 1 / 0x3000
 
         if npu_op.sub_op_type == NpuElementWiseOp.MUL:
-            if None in (input_scale, input2_scale, output_scale):
+            if npu_op.rescale:
+                ofm_scale, shift = npu_op.rescale
+            elif None in (input_scale, input2_scale, output_scale):
                 ofm_scale = 1
                 shift = 0
             else:
diff --git a/ethosu/vela/tosa_mapping.py b/ethosu/vela/tosa_mapping.py
index 5d0dd33..377f455 100644
--- a/ethosu/vela/tosa_mapping.py
+++ b/ethosu/vela/tosa_mapping.py
@@ -166,7 +166,7 @@
     "RescaleAttribute",
     ("input_zp", "output_zp", ("multiplier", is_vec), ("shift", is_vec), "scale32", "double_round", "per_channel"),
 )
-mul_attrs = AttrSerializer("MulAttribute", ("shift"))
+mul_attrs = AttrSerializer("MulAttribute", ("shift",))
 ars_attrs = AttrSerializer("ArithmeticRightShiftAttribute", ("round",))
 condif_attrs = AttrSerializer("CondIfAttribute", (("then_branch"), ("else_branch")))  # TODO these are references
 while_attrs = AttrSerializer("WhileLoopAttribute", (("cond_branch"), ("body_branch")))  # TODO these are references
@@ -195,7 +195,6 @@
     TosaOp.LOGICAL_XOR,
     TosaOp.MAXIMUM,
     TosaOp.MINIMUM,
-    TosaOp.MUL,
     TosaOp.POW,
     TosaOp.TABLE,
     TosaOp.ABS,
@@ -275,7 +274,7 @@
     # TODO TosaOp.LOGICAL_XOR
     # TODO TosaOp.MAXIMUM
     # TODO TosaOp.MINIMUM
-    # TODO TosaOp.MUL
+    TosaOp.MUL: (Op.Mul, mul_attrs, None, TOSA_IFM_IFM2_INDICES),
     # TODO TosaOp.POW
     TosaOp.SUB: (Op.Sub, None, None, TOSA_IFM_IFM2_INDICES),
     # TODO TosaOp.TABLE
diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py
index 268d43c..2925ab4 100644
--- a/ethosu/vela/tosa_reader.py
+++ b/ethosu/vela/tosa_reader.py
@@ -179,6 +179,11 @@
                     # TODO CONV3D more to be done....
                     print("Unsupported kernel dimensions: ", len(kernel))
                     assert False
+            if "shift" in op.attrs and op.type == Op.Mul:
+                shift = op.attrs["shift"]
+                if shift != 0:
+                    op.type = Op.RescaleMul
+                    op.rescale = [1, shift]
             if op.type.is_depthwise_conv2d_op():
                 op.attrs["depth_multiplier"] = op.weights.shape[3]
 
@@ -213,7 +218,6 @@
         # Initialize quantization parameters
         tens.quantization = QuantizationParameters()
 
-        tens.quantization.scale_f32 = 1.0
         if dtype == DataType.uint8:
             tens.quantization.quant_min = 0
             tens.quantization.quant_max = (1 << dtype.bits) - 1
diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py
index 90d5468..d7a1ebc 100644
--- a/ethosu/vela/tosa_supported_operators.py
+++ b/ethosu/vela/tosa_supported_operators.py
@@ -40,12 +40,16 @@
     mac_main_ops = convolution_like_ops | pooling_ops | fc_vector_products
     memory_only_ops = set((Op.Reshape, Op.Transpose,))
 
+    binary_elem_wise_add_mul_sub = set((Op.Add, Op.Mul, Op.RescaleMul, Op.Sub,))
+
     type_conversion_ops = set((Op.Rescale,))
     relu_ops = set((Op.Clamp, Op.ReluN,))
     activation_ops = relu_ops
 
     npu_post_ops = activation_ops
-    supported_operators = mac_main_ops | type_conversion_ops | npu_post_ops | memory_only_ops
+    supported_operators = (
+        mac_main_ops | type_conversion_ops | npu_post_ops | memory_only_ops | binary_elem_wise_add_mul_sub
+    )
 
     # Supported data types
     # TODO will differ compared to TensorFlow Lite, currently set to the same
diff --git a/ethosu/vela/vela.py b/ethosu/vela/vela.py
index 6c9fbce..9448749 100644
--- a/ethosu/vela/vela.py
+++ b/ethosu/vela/vela.py
@@ -85,7 +85,7 @@
     )
 
     output_tfl_filename = output_basename + "_vela.tflite"
-    if input_name.endswith(".tflite") or input_name.endswith(".tosa"):
+    if input_name.endswith(".tflite"):
         tflite_writer.write_tflite(nng, output_tfl_filename)
     if input_name.endswith(".tosa"):
         rawdata_writer.write_rawdata_output(nng, arch, output_basename)
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index 10a1a6d..6881703 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -314,7 +314,7 @@
     # No cache hit, need to perform the encoding
     if do_weights:
         assert weight_tens.quantization is not None
-        assert weight_tens.quantization.scale_f32 is not None
+        assert weight_tens.quantization.scale_f32 is not None or op.explicit_scaling
         assert weight_tens.quantization.zero_point is not None
 
         # Early zero-point correction