MLBEDSW-7151: MLCE: Difference in model output between x86 & aarch64

 - The issue is due to undefined behaviour when casting a NumPy float
to a NumPy unsigned integer which occurs in create_const_tensor()
 - The fix is to make sure that the values are first cast to a Python
float
 - In addition, the values datatype argument has been removed from
create_const_tensor() to stop the tensor and values datatypes getting
out of sync

Change-Id: I134b9be8c941b361929a5ae7db8cb35f2e9728f2
Signed-off-by: Tim Hall <tim.hall@arm.com>
diff --git a/ethosu/vela/data_type.py b/ethosu/vela/data_type.py
index 829cef3..5d0320b 100644
--- a/ethosu/vela/data_type.py
+++ b/ethosu/vela/data_type.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -110,7 +110,7 @@
             BaseType.Complex: "c",
         }
         assert self.type in numpy_dtype_code, f"Failed to interpret {self} as a numpy dtype"
-        return np.dtype(numpy_dtype_code[self.type] + str(self.size_in_bytes()))
+        return np.dtype(numpy_dtype_code[self.type] + str(self.size_in_bytes())).type
 
     stem_name = {
         BaseType.UnsignedInt: ("uint%s", True),
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py
index d90c06b..2822feb 100644
--- a/ethosu/vela/graph_optimiser_util.py
+++ b/ethosu/vela/graph_optimiser_util.py
@@ -429,7 +429,7 @@
     quantization = QuantizationParameters(0.0, 255.0)
     quantization.scale_f32 = ifm.quantization.scale_f32
     quantization.zero_point = 0
-    tens = create_const_tensor(op.inputs[0].name + "_scalar0", [], ifm.dtype, [0], np.uint8, quantization=quantization)
+    tens = create_const_tensor(op.inputs[0].name + "_scalar0", [], ifm.dtype, [0], quantization=quantization)
     op.add_input_tensor(tens)
     op.ifm_shapes.append(Shape4D(tens.shape))  # TODO no shape?
 
diff --git a/ethosu/vela/lut.py b/ethosu/vela/lut.py
index fdf9d0f..d0ac970 100644
--- a/ethosu/vela/lut.py
+++ b/ethosu/vela/lut.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -88,8 +88,7 @@
     # address in constant memory, and unnecessary DMA operations can be avoided.
     sz = len(values)
     assert sz in (256, 512)
-    ntype = np.uint8 if dtype.size_in_bytes() == 1 else np.uint32
-    tens = create_const_tensor(name, [1, 1, 1, sz], dtype, values, ntype, TensorPurpose.LUT)
+    tens = create_const_tensor(name, [1, 1, 1, sz], dtype, values, TensorPurpose.LUT)
     tens.equivalence_id = create_equivalence_id(tuple(values))
     return tens
 
diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py
index a92d0bb..575e1e6 100644
--- a/ethosu/vela/softmax.py
+++ b/ethosu/vela/softmax.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
 #
 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
 #
@@ -270,9 +270,7 @@
             ifm2_shape=ifm_max_shape,
         )
         sub_op.set_activation_lut(
-            create_const_tensor(
-                f"{sub_op.name}_exp_lut", [1, 1, 1, 256], DataType.int32, exp_lut, np.int32, TensorPurpose.LUT
-            )
+            create_const_tensor(f"{sub_op.name}_exp_lut", [1, 1, 1, 256], DataType.int32, exp_lut, TensorPurpose.LUT)
         )
         ifm_exp = add_op_get_ofm(sub_op)
         # Note: activation.min/max are non-quantized values
@@ -281,9 +279,7 @@
 
         # PASS 2 - SHR
         name = f"{self.op.name}_shr{pass_number}"
-        shift = create_const_tensor(
-            f"{name}_const", [1, 1, 1, 1], DataType.int32, [12], np.int32, quantization=no_scale_quant
-        )
+        shift = create_const_tensor(f"{name}_const", [1, 1, 1, 1], DataType.int32, [12], quantization=no_scale_quant)
         shr_op = create_shr(name, ifm_exp, shift, no_scale_quant, activation)
         shr_op.rounding_mode = NpuRoundingMode.NATURAL
         rescaled_exp = add_op_get_ofm(shr_op)
@@ -304,7 +300,6 @@
             [1, 1, 1, 1],
             DataType.int32,
             [12 + 31 - 8],
-            np.int32,
             quantization=no_scale_quant,
         )
         right_shift = add_op_get_ofm(
@@ -318,7 +313,7 @@
         )
 
         # PASS 6 - Sub
-        one = create_const_tensor("one_const", [1, 1, 1, 1], DataType.int32, [1], np.int32, quantization=no_scale_quant)
+        one = create_const_tensor("one_const", [1, 1, 1, 1], DataType.int32, [1], quantization=no_scale_quant)
         headroom = add_op_get_ofm(
             create_sub(f"{self.op.name}_sub{pass_number}", headroom_plus_one, one, no_scale_quant, activation)
         )
@@ -330,7 +325,7 @@
 
         # PASS 8 - Sub
         shifted_one = create_const_tensor(
-            "shifted_one_const", [1, 1, 1, 1], DataType.int32, [1 << 30], np.int32, quantization=no_scale_quant
+            "shifted_one_const", [1, 1, 1, 1], DataType.int32, [1 << 30], quantization=no_scale_quant
         )
         shifted_sum_minus_one = add_op_get_ofm(
             create_sub(f"{self.op.name}_sub{pass_number}", shifted_sum, shifted_one, no_scale_quant, activation)
@@ -349,7 +344,7 @@
 
         # PASS 10 - Add
         f0_one_const = create_const_tensor(
-            "F0_one_const", [1, 1, 1, 1], DataType.int32, [(1 << 31) - 1], np.int32, quantization=no_scale_quant
+            "F0_one_const", [1, 1, 1, 1], DataType.int32, [(1 << 31) - 1], quantization=no_scale_quant
         )
         add_op = create_add(
             f"{self.op.name}_add{pass_number}",
@@ -363,7 +358,7 @@
 
         # PASS 11 - Multiply
         neg_32_over_17 = create_const_tensor(
-            "neg_32_over_17_const", [1, 1, 1, 1], DataType.int32, [-1010580540], np.int32, quantization=one_scale_quant
+            "neg_32_over_17_const", [1, 1, 1, 1], DataType.int32, [-1010580540], quantization=one_scale_quant
         )
         rescaled = add_op_get_ofm(
             create_mul(
@@ -377,7 +372,7 @@
 
         # PASS 12 - Add
         const_48_over_17 = create_const_tensor(
-            "48_over_17_const", [1, 1, 1, 1], DataType.int32, [1515870810], np.int32, quantization=no_scale_quant
+            "48_over_17_const", [1, 1, 1, 1], DataType.int32, [1515870810], quantization=no_scale_quant
         )
         rescale_w_offset = add_op_get_ofm(
             create_add(
@@ -392,11 +387,9 @@
         # PASS 13 - 27
         nr_x = rescale_w_offset
         F2_one = create_const_tensor(
-            "F2_one_const", [1, 1, 1, 1], DataType.int32, [(1 << 29)], np.int32, quantization=no_scale_quant
+            "F2_one_const", [1, 1, 1, 1], DataType.int32, [(1 << 29)], quantization=no_scale_quant
         )
-        four = create_const_tensor(
-            "four_const", [1, 1, 1, 1], DataType.int32, [4], np.int32, quantization=no_scale_quant
-        )
+        four = create_const_tensor("four_const", [1, 1, 1, 1], DataType.int32, [4], quantization=no_scale_quant)
         for _ in range(3):
             # PASS 13, 18, 23 - MUL
             half_denominator_times_x = add_op_get_ofm(
@@ -438,7 +431,7 @@
             )
 
         # PASS 28 - Multiply
-        two = create_const_tensor("two_const", [1, 1, 1, 1], DataType.int32, [2], np.int32, quantization=no_scale_quant)
+        two = create_const_tensor("two_const", [1, 1, 1, 1], DataType.int32, [2], quantization=no_scale_quant)
         scale_factor = add_op_get_ofm(
             create_mul(f"{self.op.name}_mul{pass_number}", nr_x, two, one_scale_quant, activation)
         )
@@ -502,20 +495,18 @@
         mul2_quant = ofm.quantization.clone()
         mul2_quant.scale_f32 = mul2_out_range
         scale = create_const_tensor(
-            f"{name}_scale_const", [1, 1, 1, 1], DataType.int32, [mul2_scale], np.int32, quantization=scale_quant
+            f"{name}_scale_const", [1, 1, 1, 1], DataType.int32, [mul2_scale], quantization=scale_quant
         )
         mul2_ofm = add_op_get_ofm(create_mul(name, sub1_ofm, scale, mul2_quant))
 
         # PASS 3 - Add+LUT(exp)
         name = f"{self.op.name}_add{pass_number}"
         const_add = create_const_tensor(
-            f"{name}_const", [1, 1, 1, 1], DataType.int32, [32767], np.int32, quantization=no_scale_quant
+            f"{name}_const", [1, 1, 1, 1], DataType.int32, [32767], quantization=no_scale_quant
         )
         add_op = create_add(name, mul2_ofm, const_add, mul2_ofm.quantization.clone(), dtype=DataType.int16)
         add_op.set_activation_lut(
-            create_const_tensor(
-                f"{name}_exp_lut", [1, 1, 1, 512], DataType.int32, self.EXP_LUT, np.int32, TensorPurpose.LUT
-            )
+            create_const_tensor(f"{name}_exp_lut", [1, 1, 1, 512], DataType.int32, self.EXP_LUT, TensorPurpose.LUT)
         )
         ifm_exp = add_op_get_ofm(add_op)
 
@@ -529,13 +520,11 @@
 
         # PASS 6 - Sub
         name = f"{self.op.name}_sub{pass_number}"
-        const_31 = create_const_tensor(
-            f"{name}_const", [1, 1, 1, 1], DataType.int32, [31], np.int32, quantization=no_scale_quant
-        )
+        const_31 = create_const_tensor(f"{name}_const", [1, 1, 1, 1], DataType.int32, [31], quantization=no_scale_quant)
         reciprocal_right_shift = add_op_get_ofm(create_sub(name, const_31, headroom_plus_one, no_scale_quant))
 
         # PASS 7 - SHL
-        one = create_const_tensor("one_const", [1, 1, 1, 1], DataType.int32, [1], np.int32, quantization=no_scale_quant)
+        one = create_const_tensor("one_const", [1, 1, 1, 1], DataType.int32, [1], quantization=no_scale_quant)
         constant_one = add_op_get_ofm(
             create_shl(f"{self.op.name}_shl{pass_number}", one, reciprocal_right_shift, no_scale_quant)
         )
@@ -552,15 +541,13 @@
 
         # PASS 10 - SHR
         name = f"{self.op.name}_shr{pass_number}"
-        shift = create_const_tensor(
-            f"{name}_const", [1, 1, 1, 1], DataType.int32, [15], np.int32, quantization=no_scale_quant
-        )
+        shift = create_const_tensor(f"{name}_const", [1, 1, 1, 1], DataType.int32, [15], quantization=no_scale_quant)
         shifted_sum_minus_one_16 = add_op_get_ofm(create_shr(name, shifted_sum_minus_one, shift, no_scale_quant))
 
         # PASS 11 - Sub+LUT(one over one plus x)
         name = f"{self.op.name}_sub{pass_number}"
         sub11_const = create_const_tensor(
-            f"{name}_const", [1, 1, 1, 1], DataType.int32, [32768], np.int32, quantization=no_scale_quant
+            f"{name}_const", [1, 1, 1, 1], DataType.int32, [32768], quantization=no_scale_quant
         )
         sub11_op = create_sub(name, shifted_sum_minus_one_16, sub11_const, no_scale_quant, dtype=DataType.int16)
         sub11_op.set_activation_lut(
@@ -569,7 +556,6 @@
                 [1, 1, 1, 512],
                 DataType.int32,
                 self.ONE_OVER_ONE_PLUS_X_LUT,
-                np.uint32,
                 TensorPurpose.LUT,
             )
         )
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 899b1be..6a95bad 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -300,17 +300,31 @@
 def create_const_tensor(
     name: str,
     shape: Shape,
-    dtype: DataType,
-    values: np.ndarray,
-    value_dtype: np.dtype = None,
+    dtype: DataType,  # datatype of the tensor
+    values: Optional[Union[np.ndarray, list]],  # list-like data of some type, or scalar (skip mypy), or None
     purpose: TensorPurpose = TensorPurpose.Unknown,
-    quantization: QuantizationParameters = None,
+    quantization: Optional[QuantizationParameters] = None,
 ):
+    assert isinstance(dtype, DataType)
+
     # Tensor
     const_tensor = Tensor(shape, dtype, name + "_0")
     const_tensor.purpose = purpose
     const_tensor.quantization = quantization
-    const_tensor.values = np.array(values, dtype=value_dtype)
+
+    # if the tensor datatype does not match that of the values then np.array() will perform a cast operation. this can
+    # result in undefined behaviour if casting from a numpy float to a numpy unsigned integer. therefore, we need to
+    # avoid this undefined behaviour by converting the numpy floats to python floats as these give the desired behaviour
+    # when casting to unsigned integers
+    if (
+        values is not None
+        and shape != []  # values are not a scalar
+        and isinstance(values[0], np.floating)
+        and dtype.type == BaseType.Unsigned
+    ):
+        values = [float(v) for v in values]
+
+    const_tensor.values = np.array(values, dtype=dtype.as_numpy_type())
     # Operator
     const_op = Operation(Op.Const, name)
     const_op.set_output_tensor(const_tensor)
diff --git a/ethosu/vela/test/test_graph_optimiser.py b/ethosu/vela/test/test_graph_optimiser.py
index 152669f..54dd70f 100644
--- a/ethosu/vela/test/test_graph_optimiser.py
+++ b/ethosu/vela/test/test_graph_optimiser.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -40,9 +40,9 @@
 def test_convert_batched_fc():
     """Tests shape conversion of batched fully connected"""
     ifm_shape = [4, 8]
-    ifm = create_const_tensor("test_in", ifm_shape, np.uint8, np.zeros(ifm_shape))
+    ifm = create_const_tensor("test_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
     w_shape = [8, 4]
-    weights = create_const_tensor("weight_in", w_shape, np.uint8, np.zeros(w_shape))
+    weights = create_const_tensor("weight_in", w_shape, DataType.uint8, np.zeros(w_shape))
     ofm = Tensor(ifm.shape, np.uint8, "test_out")
     op = testutil.create_op(Op.FullyConnected, [ifm, weights], ofm)
 
@@ -132,7 +132,8 @@
     qp = testutil.default_quant_params()
     in0 = Tensor(in_shape, in_dtype, "in")
     in0.quantization = qp
-    pad_tensor = create_const_tensor(name="pad", shape=list(np.shape(padding)), values=padding, dtype=pad_dtype)
+    shape = [] if padding == [] else list(np.shape(padding))
+    pad_tensor = create_const_tensor(name="pad", shape=shape, values=padding, dtype=pad_dtype)
     out = Tensor(out_shape, out_dtype, "out")
     out.quantization = qp.clone()
     op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
@@ -543,9 +544,7 @@
     Tests if the quant value at vela compile time is calculated correctly
     """
 
-    quant_ifm = create_const_tensor(
-        "const_quant_ifm", values=np.array(127), value_dtype=np.int8, shape=[], dtype=DataType.int8
-    )
+    quant_ifm = create_const_tensor("const_quant_ifm", values=np.array(127), shape=[], dtype=DataType.int8)
     quant_ifm.quantization = testutil.default_quant_params()
     quant_ifm.quantization.scale_f32 = 0.15748031
     quant_ifm.quantization.quant_min = -128
@@ -568,9 +567,7 @@
 
     assert op.ofm.values == 127
 
-    quant_ifm = create_const_tensor(
-        "const_quant_ifm", values=np.array(127), value_dtype=np.int8, shape=[], dtype=DataType.int8
-    )
+    quant_ifm = create_const_tensor("const_quant_ifm", values=np.array(127), shape=[], dtype=DataType.int8)
     quant_ifm.quantization = testutil.default_quant_params()
     quant_ifm.quantization.scale_f32 = 0.15748031
     quant_ifm.quantization.quant_min = -128
@@ -600,9 +597,7 @@
     when passing multiple values to quantize node
     """
 
-    quant_ifm = create_const_tensor(
-        "const_quant_ifm", values=np.array([127, 127]), value_dtype=np.int8, shape=[], dtype=DataType.int8
-    )
+    quant_ifm = create_const_tensor("const_quant_ifm", values=np.array([127, 127]), shape=[], dtype=DataType.int8)
     quant_ifm.quantization = testutil.default_quant_params()
     quant_ifm.quantization.scale_f32 = 0.15748031
     quant_ifm.quantization.quant_min = -128
diff --git a/ethosu/vela/test/test_lut.py b/ethosu/vela/test/test_lut.py
index 9073270..712be7a 100644
--- a/ethosu/vela/test/test_lut.py
+++ b/ethosu/vela/test/test_lut.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -18,8 +18,6 @@
 # Unit tests for LUT support
 import random
 
-import numpy as np
-
 from ethosu.vela import lut
 from ethosu.vela import mark_tensors
 from ethosu.vela import pass_packing
@@ -37,9 +35,7 @@
 def set_256_lut(op, key, arch):
     random.seed(key)
     values = random.choices(range(256), k=256)
-    lut_tensor = create_const_tensor(
-        op.name + "_lut", [1, 1, 1, 256], DataType.int8, values, np.uint8, TensorPurpose.LUT
-    )
+    lut_tensor = create_const_tensor(op.name + "_lut", [1, 1, 1, 256], DataType.int8, values, TensorPurpose.LUT)
     scratch_lut_tensor = lut_tensor.clone_into_fast_storage(arch)
     op.set_activation_lut(scratch_lut_tensor)
 
@@ -47,9 +43,7 @@
 def set_1K_lut(op, key, arch):
     random.seed(key)
     values = random.choices(range(256), k=256)
-    lut_tensor = create_const_tensor(
-        op.name + "_lut", [1, 1, 1, 256], DataType.int32, values, np.uint32, TensorPurpose.LUT
-    )
+    lut_tensor = create_const_tensor(op.name + "_lut", [1, 1, 1, 256], DataType.int32, values, TensorPurpose.LUT)
     scratch_lut_tensor = lut_tensor.clone_into_fast_storage(arch)
     op.set_activation_lut(scratch_lut_tensor)
 
@@ -57,9 +51,7 @@
 def set_2K_lut(op, key, arch):
     random.seed(key)
     values = random.choices(range(512), k=512)
-    lut_tensor = create_const_tensor(
-        op.name + "_lut", [1, 1, 1, 512], DataType.int32, values, np.uint32, TensorPurpose.LUT
-    )
+    lut_tensor = create_const_tensor(op.name + "_lut", [1, 1, 1, 512], DataType.int32, values, TensorPurpose.LUT)
     scratch_lut_tensor = lut_tensor.clone_into_fast_storage(arch)
     op.set_activation_lut(scratch_lut_tensor)
 
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py
index c242063..2e0936d 100644
--- a/ethosu/vela/test/test_tflite_model_semantic.py
+++ b/ethosu/vela/test/test_tflite_model_semantic.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -195,11 +195,11 @@
     # SplitV requires a maximum of one inferred shape (-1)
     qp = testutil.default_quant_params()
     op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8])
-    sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, -1, 2, -1]]]], np.int16, quantization=qp)
+    sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, -1, 2, -1]]]], quantization=qp)
     op.add_input_tensor(sizes)
     assert not semantic_checker.is_operator_semantic_valid(op)
     op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8])
-    sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, 1, 2, -1]]]], np.int16, quantization=qp)
+    sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, 1, 2, -1]]]], quantization=qp)
     op.add_input_tensor(sizes)
     assert semantic_checker.is_operator_semantic_valid(op)
 
@@ -278,7 +278,8 @@
     qp = testutil.default_quant_params()
     in0 = Tensor(in_shape, in_dtype, "in")
     in0.quantization = qp
-    pad_tensor = create_const_tensor(name="pad", shape=list(np.shape(padding)), values=padding, dtype=pad_dtype)
+    shape = [] if padding == [] else list(np.shape(padding))
+    pad_tensor = create_const_tensor(name="pad", shape=shape, values=padding, dtype=pad_dtype)
     out = Tensor(out_shape, out_dtype, "out")
     out.quantization = qp.clone()
     op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
@@ -449,9 +450,9 @@
     ofm = Tensor(output_shape, datatype, "out")
     ofm.quantization = testutil.default_quant_params()
     if type(axis) is list:
-        indices = create_const_tensor("indices", [len(axis)], DataType.int32, axis, np.uint8)
+        indices = create_const_tensor("indices", [len(axis)], DataType.int32, axis)
     elif type(axis) is int:
-        indices = create_const_tensor("indices", [], DataType.int32, axis, np.uint8)
+        indices = create_const_tensor("indices", [], DataType.int32, axis)
     op = testutil.create_op(Op.Mean, [ifm, indices], ofm, attrs)
     return op
 
diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py
index d091531..6a0b58e 100644
--- a/ethosu/vela/test/test_tflite_supported_operators.py
+++ b/ethosu/vela/test/test_tflite_supported_operators.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -303,55 +303,55 @@
     for resize_op in Op.op_set(Op.is_resize_op):
         # IFM W and H == 1
         op = testutil.create_op_with_quant_tensors(resize_op, [1, 1, 1, 8], [1, 8, 8, 8])
-        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8], np.int32))
+        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8]))
         assert support.is_operator_supported(op)
 
         # IFM == OFM
         op = testutil.create_op_with_quant_tensors(resize_op, [1, 8, 8, 8], [1, 8, 8, 8])
-        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8], np.int32))
+        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8]))
         assert support.is_operator_supported(op)
 
         # IFM x2 == OFM ; align_corners = False
         op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 8, 8, 8])
-        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8], np.int32))
+        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8]))
         assert support.is_operator_supported(op)
 
         # IFM x4 == OFM ; align_corners = False
         op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 16, 16, 8])
-        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [16, 16], np.int32))
+        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [16, 16]))
         assert support.is_operator_supported(op)
 
         # IFM x8 == OFM ; align_corners = False
         op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 32, 32, 8])
-        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [32, 32], np.int32))
+        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [32, 32]))
         assert support.is_operator_supported(op)
 
         # IFM -1 x2 == OFM -1 ; align_corners = True
         op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 7, 7, 8])
-        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [7, 7], np.int32))
+        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [7, 7]))
         op.attrs["align_corners"] = True
         assert support.is_operator_supported(op)
 
         # IFM -1 x4 == OFM -1 ; align_corners = True
         op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 13, 13, 8])
-        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [13, 13], np.int32))
+        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [13, 13]))
         op.attrs["align_corners"] = True
         assert support.is_operator_supported(op)
 
         # IFM -1 x8 == OFM -1 ; align_corners = True
         op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 25, 25, 8])
-        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [25, 25], np.int32))
+        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [25, 25]))
         op.attrs["align_corners"] = True
         assert support.is_operator_supported(op)
 
         # Invalid case - upscale size
         op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 17, 17, 8])
-        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [17, 17], np.int32))
+        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [17, 17]))
         assert not support.is_operator_supported(op)
 
         # Invalid case - upscale size with align corners
         op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 15, 15, 8])
-        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [15, 15], np.int32))
+        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [15, 15]))
         op.attrs["align_corners"] = True
         assert not support.is_operator_supported(op)
 
@@ -360,7 +360,7 @@
     for resize_op in Op.op_set(Op.is_resize_op):
         # Invalid case - size != ofm size
         op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 8, 8, 8])
-        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [7, 7], np.int32))
+        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [7, 7]))
         assert not support.is_operator_supported(op)
 
 
@@ -368,7 +368,7 @@
     for resize_op in Op.op_set(Op.is_resize_op):
         # Invalid case - both align corners and half-pixel centers
         op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 8, 8, 8])
-        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8], np.int32))
+        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8]))
         op.attrs["align_corners"] = True
         op.attrs["half_pixel_centers"] = True
         assert not support.is_operator_supported(op)
@@ -395,7 +395,8 @@
     qp = testutil.default_quant_params()
     in0 = Tensor(in_shape, in_dtype, "in")
     in0.quantization = qp
-    pad_tensor = create_const_tensor(name="pad", shape=list(np.shape(padding)), values=padding, dtype=pad_dtype)
+    shape = [] if padding == [] else list(np.shape(padding))
+    pad_tensor = create_const_tensor(name="pad", shape=shape, values=padding, dtype=pad_dtype)
     out = Tensor(out_shape, out_dtype, "out")
     out.quantization = qp.clone()
     op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
@@ -587,9 +588,9 @@
     ofm = Tensor(output_shape, datatype, "out")
     ofm.quantization = testutil.default_quant_params()
     if type(axis) is list:
-        indices = create_const_tensor("indices", [len(axis)], DataType.int32, axis, np.uint8)
+        indices = create_const_tensor("indices", [len(axis)], DataType.int32, axis)
     elif type(axis) is int:
-        indices = create_const_tensor("indices", [], DataType.int32, axis, np.uint8)
+        indices = create_const_tensor("indices", [], DataType.int32, axis)
     op = testutil.create_op(Op.Mean, [ifm, indices], ofm, attrs)
     return op
 
diff --git a/ethosu/vela/test/testutil.py b/ethosu/vela/test/testutil.py
index acf35fe..88fc874 100644
--- a/ethosu/vela/test/testutil.py
+++ b/ethosu/vela/test/testutil.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -53,21 +53,13 @@
     ofm_quant=default_quant_params(),
 ):
     # Creates elementwise operation with constant IFM/IFM2
-    if datatype.size_in_bytes() == 1:
-        np_type = np.uint8
-    elif datatype.size_in_bytes() == 2:
-        np_type = np.int16
-    else:
-        np_type = np.int32
     op = Operation(op_type, name)
     op.add_input_tensor(
-        create_const_tensor(name + "_ifm", ifm_shape, datatype, np.zeros(ifm_shape), np_type, quantization=ifm_quant)
+        create_const_tensor(name + "_ifm", ifm_shape, datatype, np.zeros(ifm_shape), quantization=ifm_quant)
     )
     if ifm2_shape is not None:
         op.add_input_tensor(
-            create_const_tensor(
-                name + "_ifm2", ifm2_shape, datatype, np.zeros(ifm2_shape), np_type, quantization=ifm2_quant
-            )
+            create_const_tensor(name + "_ifm2", ifm2_shape, datatype, np.zeros(ifm2_shape), quantization=ifm2_quant)
         )
     ofm = Tensor(ofm_shape, datatype, name + "_ofm")
     ofm.quantization = ofm_quant
@@ -89,25 +81,17 @@
     op.set_output_tensor(ofm)
     # Optional weight tensor
     if weights_shape is not None:
-        if datatype.size_in_bytes() == 1:
-            np_type = np.uint8
-        elif datatype.size_in_bytes() == 2:
-            np_type = np.int16
-        else:
-            np_type = np.int32
         qp = default_quant_params()
         if op.type is not Op.FullyConnected:
             qp.zero_point = np.zeros(weights_shape)
-        weights = create_const_tensor(
-            "weights", weights_shape, datatype, np.zeros(weights_shape), np_type, quantization=qp
-        )
+        weights = create_const_tensor("weights", weights_shape, datatype, np.zeros(weights_shape), quantization=qp)
         op.add_input_tensor(weights)
     # Optional bias tensor
     if bias_shape is not None:
         qp = default_quant_params()
         if op.type is not Op.FullyConnected:
             qp.zero_point = np.zeros(bias_shape)
-        bias = create_const_tensor("bias", bias_shape, DataType.int32, np.zeros(bias_shape), np.int32, quantization=qp)
+        bias = create_const_tensor("bias", bias_shape, DataType.int32, np.zeros(bias_shape), quantization=qp)
         op.add_input_tensor(bias)
 
     if set_ifm_ofm_shapes:
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 242f0ea..ff7b486 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -343,17 +343,10 @@
     weight_quant.zero_point = 0
     weight_quant.quant_dim = 0
     ofm_dtype = ofm.dtype
-    if ofm_dtype == DataType.uint8:
-        weight_value_dtype = np.uint8
+    if ofm_dtype.type == BaseType.UnsignedInt:
         weight_quant.quant_min = 0
         weight_quant.quant_max = (1 << ofm_dtype.bits) - 1
     else:
-        if ofm_dtype == DataType.int8:
-            weight_value_dtype = np.int8
-        else:
-            assert ofm_dtype == DataType.int16
-            weight_value_dtype = np.int16
-
         weight_quant.quant_min = -(1 << (ofm_dtype.bits - 1))
         weight_quant.quant_max = (1 << (ofm_dtype.bits - 1)) - 1
 
@@ -376,9 +369,8 @@
         create_const_tensor(
             "weights",
             weight_shape,
-            ofm.dtype,
+            ofm_dtype,
             np.array(weight_values).reshape(weight_shape),
-            value_dtype=weight_value_dtype,
             quantization=weight_quant,
         ),
         1,  # inputs tensor weight index
@@ -586,7 +578,6 @@
                         shape,
                         intermediate_tens.dtype,
                         np.array(kernel).reshape(shape),
-                        value_dtype=np.int8,
                         quantization=quant,
                     ),
                 )
@@ -1227,9 +1218,7 @@
             scalar, _ = scaling.elementwise_mul_scale(ifm.quantization.scale_f32, alpha, ofm.quantization.scale_f32)
         else:
             scalar = 1
-    alpha_tens = create_const_tensor(
-        op.name + "_alpha_scalar", [1], alpha_dtype, [scalar], alpha_dtype.as_numpy_type(), quantization=quantization
-    )
+    alpha_tens = create_const_tensor(op.name + "_alpha_scalar", [1], alpha_dtype, [scalar], quantization=quantization)
     mul_alpha.add_input_tensor(alpha_tens)
     fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
     mul_alpha.set_output_tensor(fm_alpha)
@@ -1256,9 +1245,7 @@
         quantization.max = quantization.quant_max - quantization.quant_min
         quantization.scale_f32 = np.float32(1)
         quantization.zero_point = 0
-        identity_tens = create_const_tensor(
-            op.name + "_id_scalar", [], ifm.dtype, [1], np.uint8, quantization=quantization
-        )
+        identity_tens = create_const_tensor(op.name + "_id_scalar", [], ifm.dtype, [1], quantization=quantization)
         mul_identity.add_input_tensor(identity_tens)
         # Make sure that fm_id is allocated to a different address than fm_alpha
         fm_id = ofm.clone(op.name + "_id", set_unique=True)
@@ -1470,7 +1457,6 @@
                 shape,
                 op.ifm.dtype,
                 weights,
-                np.uint8,
                 purpose=TensorPurpose.Weights,
                 quantization=quantization,
             )
@@ -1526,7 +1512,7 @@
     if top > 0:
         shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
         zero_tens = create_const_tensor(
-            op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
+            op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
         )
         # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
         zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
@@ -1538,7 +1524,6 @@
             shape.as_list(),
             ofm.dtype,
             shape.elements() * [pad_value],
-            np.uint8,
             quantization=quant,
         )
         zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
@@ -1548,14 +1533,14 @@
     if left > 0:
         shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
         zero_tens = create_const_tensor(
-            op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
+            op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
         )
         zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
         create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
     if right > 0:
         shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
         zero_tens = create_const_tensor(
-            op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
+            op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
         )
         zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
         create_avg_pool_for_concat(
@@ -1715,7 +1700,6 @@
                 weight_shape,
                 inp.dtype,
                 np.ones(weight_shape),
-                value_dtype=np.uint8,
                 quantization=weight_quant,
             ),
             1,
@@ -2008,8 +1992,7 @@
                 ofm_clone = ofm.clone()
                 ofm_clone.values = ofm.values
                 ofm.values = None
-                np_dtype = ofm.dtype.as_numpy_type()
-                zero = create_const_tensor("zero", [1], ofm.dtype, [0], np_dtype, quantization=ofm.quantization)
+                zero = create_const_tensor("zero", [1], ofm.dtype, [0], quantization=ofm.quantization)
                 memcpy = create_add_nop(f"{ofm.name}_copy")
                 memcpy.add_input_tensor(ofm_clone)
                 memcpy.add_input_tensor(zero)
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index 25d3dbc..2a599aa 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -164,7 +164,6 @@
         [1],
         copy_tens.dtype,
         [0],
-        copy_tens.dtype.as_numpy_type(),
         quantization=copy_tens.quantization,
     )
     copy_op = create_add_nop(name)
@@ -190,7 +189,6 @@
         [1],
         copy_tens.dtype,
         [0],
-        copy_tens.dtype.as_numpy_type(),
         quantization=copy_tens.quantization,
     )
     copy_op = create_add_nop(name)
@@ -267,9 +265,7 @@
 def create_add_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
     """Creates an add op for the given concat op/input feature map"""
     ofm = concat_op.ofm
-    ifm2 = create_const_tensor(
-        name + "_zero_scalar", [1], ofm.dtype, [0], ofm.dtype.as_numpy_type(), quantization=ofm.quantization
-    )
+    ifm2 = create_const_tensor(name + "_zero_scalar", [1], ofm.dtype, [0], quantization=ofm.quantization)
     add_op = create_add_nop(name)
 
     add_op.inputs = [ifm, ifm2]
@@ -306,9 +302,7 @@
         else:
             name = op.name + "_add"
             ofm = op.ofm
-            ifm2 = create_const_tensor(
-                name + "_zero_scalar", [1], ofm.dtype, [0], ofm.dtype.as_numpy_type(), quantization=ofm.quantization
-            )
+            ifm2 = create_const_tensor(name + "_zero_scalar", [1], ofm.dtype, [0], quantization=ofm.quantization)
             add_op = create_add_nop(name)
             add_op.inputs = [op.ifm, ifm2]
             add_op.outputs = [ofm]
@@ -476,14 +470,14 @@
     if left > 0:
         shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
         zero_tens = create_const_tensor(
-            op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
+            op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
         )
         zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
         create_add_for_concat(op, op.name + "_left", zero_tens, shape, shp0)
     if right > 0:
         shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
         zero_tens = create_const_tensor(
-            op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
+            op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
         )
         zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
         create_add_for_concat(op, op.name + "_right", zero_tens, shape, shp0.with_width(ofm_shape.width - right))
@@ -816,9 +810,7 @@
                 new_pad_tens = op.inputs[1].clone("_dim_{dim}")
 
                 name = op.inputs[1].name + f"_dim_{dim}"
-                new_pad_tens = create_const_tensor(
-                    name, list(new_pad_input.shape), DataType.int32, new_pad_input, np.int32
-                )
+                new_pad_tens = create_const_tensor(name, list(new_pad_input.shape), DataType.int32, new_pad_input)
                 pad_op.add_input_tensor(new_pad_tens)
 
                 new_ofm_shape = new_ifm_shape.copy()