MLBEDSW-3550 Only use simple scaling when bitexact with TFLite

For 8 bit arithmetic we cannot guarantee reproducibility in the general
case since precision differs, affecting rounding near half integers.
It should be safe when the ratio between output and input scales has
its 12 LSBs all set to 0, however.
For 16 bit arithmetic it should be sufficient to adjust the input and
output scalings with a factor of 2 to get the same rounding.

Signed-off-by: Henrik G Olsson <henrik.olsson@arm.com>
Change-Id: I809c0042615d16c5488d61f0c7d88e1a1315e6eb
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index fb705b9..3b552e0 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -718,19 +718,39 @@
                 ofm_scale, shift = scaling.elementwise_mul_scale(input_scale, input2_scale, output_scale)
             emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, ofm_scale, shift)
         else:  # Add/Sub
+            bitdepth = npu_op.ifm.data_type.size_in_bits()
+            use_advanced_scaling = False
             if None in (input_scale, input2_scale, output_scale):
                 opa_scale = opb_scale = ofm_scale = 1
                 opa_shift = shift = 0
                 if npu_op.rescale is not None:
                     ofm_scale, shift = npu_op.rescale
+            elif input_scale == input2_scale and bitdepth == 16:
+                opa_scale, opb_scale, ofm_scale, shift = scaling.simplified_elementwise_add_sub_scale(
+                    input_scale, input2_scale, output_scale
+                )
+                # align the double rounding with that of advanced scaling
+                opa_scale /= 2
+                opb_scale /= 2
+                shift -= 1
+                opa_shift = 0  # Unused for this case
             elif input_scale == input2_scale:
                 opa_scale, opb_scale, ofm_scale, shift = scaling.simplified_elementwise_add_sub_scale(
                     input_scale, input2_scale, output_scale
                 )
                 opa_shift = 0  # Unused for this case
+                # For 8 bit we can't guarantee double rounding with simplified scaling will always be
+                # the same as with advanced scaling due to different shifts. When the ofm scale fulfils
+                # the following we know that double rounding will have no effect for advanced scaling
+                # no matter the input, so we can safely use simplified scaling with double rounding disabled.
+                use_advanced_scaling = int(ofm_scale) & 0xFFF != 0
+                if not use_advanced_scaling:
+                    npu_op.rounding_mode = NpuRoundingMode.NATURAL
             else:
-                # Use advanced implementation only when input scales differ
-                bitdepth = npu_op.ifm.data_type.size_in_bits()
+                use_advanced_scaling = True
+            if use_advanced_scaling:
+                # Use advanced implementation only when input/output scales differ,
+                # or when we can't guarantee the absence of rounding errors
                 (opa_scale, opa_shift, ofm_scale, shift, op_to_scale,) = scaling.advanced_elementwise_add_sub_scale(
                     input_scale, input2_scale, output_scale, bitdepth
                 )