MLBEDSW-4892: Fix crash affecting biases without quantization.

Remove quant_values attribute from Tensor class.
It only needs a single values attribute, holding either
quantized or unquantized values as appropriate.

Change-Id: Ie96f80ac58061b6077e0f7048dc60209fdfbcafa
Signed-off-by: James Peet <james.peet@arm.com>
diff --git a/ethosu/vela/data_type.py b/ethosu/vela/data_type.py
index 07086d6..470504d 100644
--- a/ethosu/vela/data_type.py
+++ b/ethosu/vela/data_type.py
@@ -18,6 +18,8 @@
 import enum
 from typing import Any
 
+import numpy as np
+
 from .numeric_util import round_up_divide
 
 
@@ -99,6 +101,16 @@
 
     __repr__ = __str__
 
+    def as_numpy_type(self):
+        numpy_dtype_code = {
+            BaseType.UnsignedInt: "u",
+            BaseType.SignedInt: "i",
+            BaseType.Float: "f",
+            BaseType.Complex: "c",
+        }
+        assert self.type in numpy_dtype_code, f"Failed to interpret {self} as a numpy dtype"
+        return np.dtype(numpy_dtype_code[self.type] + str(self.size_in_bytes()))
+
     stem_name = {
         BaseType.UnsignedInt: ("uint%s", True),
         BaseType.SignedInt: ("int%s", True),
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index 80d0e47..9b76ec1 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -437,8 +437,7 @@
         npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
         if cmd.ifm2_tensor.shape == []:
             # scalar
-            assert cmd.ifm2_tensor.quant_values.size == 1
-            npu_op.ifm2_scalar = cmd.ifm2_tensor.values.item(0)
+            npu_op.ifm2_scalar = cmd.ifm2_tensor.get_scalar()
             npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0)
         else:
             ifm2_blk = cmd.ifm2_box.get_block()
diff --git a/ethosu/vela/npu_serialisation.py b/ethosu/vela/npu_serialisation.py
index 06ea61d..ea35ac6 100644
--- a/ethosu/vela/npu_serialisation.py
+++ b/ethosu/vela/npu_serialisation.py
@@ -48,7 +48,7 @@
 
 def copy_ifm_values_to_memory_tensor(memory_tensor, src_tensor):
     start_addr = src_tensor.address
-    values = src_tensor.quant_values.flatten() if src_tensor.quant_values is not None else src_tensor.values.flatten()
+    values = src_tensor.values.flatten()
     if src_tensor.dtype.size_in_bytes() > 1:
         values = np.frombuffer(values.tobytes(), dtype=np.uint8)
     end_addr = start_addr + values.size
diff --git a/ethosu/vela/reader_util.py b/ethosu/vela/reader_util.py
index 233286c..476b70a 100644
--- a/ethosu/vela/reader_util.py
+++ b/ethosu/vela/reader_util.py
@@ -34,9 +34,6 @@
     if tens.values is not None:
         tens.values = tens.values.transpose(reorder)
 
-    if tens.quant_values is not None:
-        tens.quant_values = tens.quant_values.transpose(reorder)
-
     op = Operation(Op.Const, tens.name)
     op.set_output_tensor(tens)
     return tens
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index c993da1..663c78f 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -532,7 +532,7 @@
     def constraint_weights_limit(cls, op):
         "The sum of the weights cannot exceed {}"
         weights = op.weights
-        values = weights.quant_values.astype(np.int64) - weights.quantization.zero_point
+        values = weights.values.astype(np.int64) - weights.quantization.zero_point
         limit = np.amax(np.sum(np.absolute(values), axis=(0, 1, 2)))
         valid = limit <= cls.weights_limit
         return valid, f"Tensor '{weights.name}' has the sum of weights: {limit}"
@@ -551,8 +551,8 @@
     def constraint_bias_40bit(op):
         "Optional Bias tensor values must fit within 40-bits"
         bias = op.bias
-        if bias and bias.dtype == DataType.int64 and bias.quant_values is not None:
-            valid = all(len(bin(quant_value)[2:]) <= 40 for quant_value in bias.quant_values)
+        if bias and bias.dtype == DataType.int64 and bias.values is not None:
+            valid = all(len(bin(quant_value)[2:]) <= 40 for quant_value in bias.values)
             return valid, f"Tensor '{bias.name}' has values larger than 40-bits"
         return True, "Op has no bias tensor, or it fits in 40-bit"
 
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 7dbdcdd..677757c 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -254,20 +254,8 @@
         res.quant_max = self.quant_max
         return res
 
-    def dequantize(self, values):
-        if self.zero_point.size == 1 and self.scale_f32.size == 1:
-            # same scale is used for all values
-            res = (values.astype(np.float64) - self.zero_point) * self.scale_f32
-        else:
-            # a different scale is used for different sets of values
-            values_as_float = values.astype(np.float64)
-
-            # this is not compatible with the format of depthwise weights,
-            # where input is at index 3 (Output, Kh, Kw, Input)
-            # return the quantized values
-            return np.ndarray((values_as_float.shape))
-
-        return res
+    def dequantize(self, values) -> np.ndarray:
+        return np.subtract(values, self.zero_point) * self.scale_f32
 
     def is_scaling_equal(self, other: Optional["QuantizationParameters"]) -> bool:
         # quantisation parameter scaling is not equal if 'other' is None because
@@ -300,16 +288,12 @@
     value_dtype: np.dtype = None,
     purpose: TensorPurpose = TensorPurpose.Unknown,
     quantization: QuantizationParameters = None,
-    quant_value_dtype: np.dtype = None,
 ):
     # Tensor
     const_tensor = Tensor(shape, dtype, name + "_0")
     const_tensor.purpose = purpose
     const_tensor.quantization = quantization
     const_tensor.values = np.array(values, dtype=value_dtype)
-    const_tensor.quant_values = np.frombuffer(
-        const_tensor.values.tobytes(), dtype=np.uint8 if not quant_value_dtype else quant_value_dtype
-    )
     # Operator
     const_op = Operation(Op.Const, name)
     const_op.set_output_tensor(const_tensor)
@@ -349,7 +333,6 @@
         "ops",
         "consumer_list",
         "values",
-        "quant_values",
         "compressed_values",
         "compressed_values_substream_offsets",
         "mem_area",
@@ -391,8 +374,7 @@
         self.ops: List[Operation] = []
         self.consumer_list: List[Operation] = []
 
-        self.values: Optional[np.ndarray] = None
-        self.quant_values: Optional[np.ndarray] = None
+        self.values: Optional[np.ndarray] = None  # elements are of type self.dtype
         self.compressed_values: Optional[np.ndarray] = None
         self.compressed_values_substream_offsets: Optional[List] = None
         self.mem_area: MemArea = MemArea.Unknown
@@ -816,6 +798,17 @@
 
         return (self.dtype.type & BaseType.Int) != 0 and self.quantization.is_valid()
 
+    def get_scalar(self):
+        """
+        return: Unquantized or dequantized scalar value
+        rtype: self.dtype (if unquantized) or float (if dequantized)
+        """
+        assert self.values.size == 1, "get_scalar called on non-scalar tensor"
+        if self.is_quantized():
+            return self.quantization.dequantize(self.values).item(0)
+        else:
+            return self.values.item(0)
+
     def __lt__(self, other: "Tensor") -> bool:
         return self.equivalence_id < other.equivalence_id
 
diff --git a/ethosu/vela/test/test_graph_optimiser.py b/ethosu/vela/test/test_graph_optimiser.py
index b37bac8..e0eedd6 100644
--- a/ethosu/vela/test/test_graph_optimiser.py
+++ b/ethosu/vela/test/test_graph_optimiser.py
@@ -139,8 +139,7 @@
     conv_out_tens = Tensor(in_shape, in_dtype, "output")
     conv_out_tens.quantization = qp.clone()
     weight_tens = Tensor([kernel_size, kernel_size, in_shape[-1], out_shape[-1]], in_dtype, "weights")
-    weight_tens.values = np.zeros(weight_tens.shape)
-    weight_tens.quant_values = np.zeros(weight_tens.shape, np.int8)
+    weight_tens.values = np.zeros(weight_tens.shape, in_dtype.as_numpy_type())
     weight_tens.quantization = qp.clone()
     bias_tens = Tensor(out_shape, pad_dtype, "biases")
     attrs = {"padding": pad_setting, "stride_w": 2, "stride_h": 2, "dilation_w_factor": 1, "dilation_h_factor": 1}
@@ -349,8 +348,7 @@
         conv_ofm = Tensor([1, 8, 8, 16], DataType.uint8, "output")
         conv_ofm.quantization = quant.clone()
         weight_tens = Tensor([1, 1, 16, 16], DataType.uint8, "weights")
-        weight_tens.values = np.zeros(weight_tens.shape)
-        weight_tens.quant_values = np.zeros(weight_tens.shape, np.uint8)
+        weight_tens.values = np.zeros(weight_tens.shape, np.uint8)
         weight_tens.quantization = quant.clone()
         bias_tens = Tensor([16], DataType.int32, "biases")
 
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index 666a5ec..3830815 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -246,7 +246,7 @@
     op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 1, 1, 1], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1])
     op.attrs = {"stride_w": 1, "stride_h": 1}
     bias = Tensor([1, 1, 1, 1], DataType.int64, "bias")
-    bias.quant_values = np.array([0x01FF_FFFF_FFFF])
+    bias.values = np.array([0x01FF_FFFF_FFFF])
     op.add_input_tensor(bias)
     assert not support.is_operator_supported(op)
 
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 3d9eeb8..9fdff8f 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -354,8 +354,7 @@
     # Create an input tensor filled with zeros
     shape = op.ofm_shapes[0].as_list()
     tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
-    tens.values = np.zeros(shape)
-    tens.quant_values = np.zeros(shape, np.uint8)
+    tens.values = np.zeros(shape, tens.dtype.as_numpy_type())
     tens.quantization = QuantizationParameters(0.0, 255.0)
     tens.quantization.scale_f32 = 1.0
     tens.quantization.zero_point = 0
@@ -470,8 +469,8 @@
 
             # Reshape Weights to be 4D. IO becomes HWIO
             weight_tensor = op.inputs[1]
-            weight_tensor.quant_values = np.expand_dims(np.expand_dims(weight_tensor.quant_values, axis=0), axis=0)
-            weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
+            weight_tensor.values = np.expand_dims(np.expand_dims(weight_tensor.values, axis=0), axis=0)
+            weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
 
             n = op.ofm_shapes[0].batch
             h, w = batching_split.get(n, (1, n))
@@ -608,8 +607,8 @@
             del op.attrs["channel_multiplier"]
             del op.attrs["depth_multiplier"]
 
-            weight_tensor.quant_values = np.transpose(weight_tensor.quant_values, (0, 1, 3, 2))
-            weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
+            weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
+            weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
         else:
             raise UnsupportedFeatureError(
                 f"Unsupported 'DEPTHWISE_CONV_2D' with depth_multiplier = {op.attrs['depth_multiplier']},",
@@ -622,8 +621,8 @@
 def reorder_depthwise_weights(op, arch, nng):
     if op.type.is_depthwise_conv2d_op():
         weight_tensor = op.inputs[1]
-        weight_tensor.quant_values = np.transpose(weight_tensor.quant_values, (0, 1, 3, 2))
-        weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
+        weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
+        weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
         weight_tensor.weight_transpose_depthwise = True
 
     return op
@@ -654,14 +653,14 @@
             for i in range(weight_shape[0]):
                 padded_array[i] = np.vstack(
                     [
-                        weight_tensor.quant_values[i],
+                        weight_tensor.values[i],
                         np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
                     ]
                 )
-            weight_tensor.quant_values = padded_array
+            weight_tensor.values = padded_array
         weight_shape[1] //= 2
         weight_shape[2] *= 2
-        weight_tensor.quant_values = np.reshape(weight_tensor.quant_values, weight_shape)
+        weight_tensor.values = np.reshape(weight_tensor.values, weight_shape)
         weight_tensor.set_all_shapes(weight_shape)
         # If multiple copies of the weights are used, we could avoid
         # them having the same address by changing the value_id
@@ -692,8 +691,8 @@
             }
             # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
             weight_tensor = op.inputs[1]
-            weight_tensor.quant_values = weight_tensor.quant_values.squeeze(axis=(0, 1))
-            weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
+            weight_tensor.values = weight_tensor.values.squeeze(axis=(0, 1))
+            weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
 
             DebugDatabase.add_optimised(op, op)
     return op
@@ -729,11 +728,11 @@
                 ifm2_tensor.shape = full_shape(len(ifm_tensor.shape), ifm2_tensor.shape, 1)
             elif diff < 0:
                 ifm_tensor.shape = full_shape(len(ifm2_tensor.shape), ifm_tensor.shape, 1)
-        elif ifm_tensor.shape == [] and ifm_tensor.quant_values is None:
+        elif ifm_tensor.shape == [] and ifm_tensor.values is None:
             # IFM is marked as a scalar, but is a result of an operation; change it to a shape of size 1
             ifm_tensor.shape = len(ifm2_tensor.shape) * [1]
             ifm_tensor.storage_shape = ifm_tensor.shape
-        elif ifm2_tensor.shape == [] and ifm2_tensor.quant_values is None:
+        elif ifm2_tensor.shape == [] and ifm2_tensor.values is None:
             # IFM2 is marked as a scalar, but is a result of an operation; change it to a shape of size 1
             ifm2_tensor.shape = len(ifm_tensor.shape) * [1]
             ifm2_tensor.storage_shape = ifm2_tensor.shape
@@ -811,7 +810,7 @@
             # to produce bit exact results, the alpha is not enough;
             # save additional scaling info in attr "alpha_scale", to be used as input
             # to the LUT construction
-            alpha_scalar = const_tens.quant_values - const_tens.quantization.zero_point
+            alpha_scalar = const_tens.values - const_tens.quantization.zero_point
             mul_ifm_scale = np.double(ifm.quantization.scale_f32)
             mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
             mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
@@ -912,7 +911,7 @@
     alpha_tens = create_const_tensor(
         op.name + "_alpha_scalar", [], ifm.dtype, [scalar], np.float32, quantization=quantization
     )
-    alpha_tens.quant_values = np.array([1])
+    alpha_tens.values = np.array([1])
     mul_alpha.add_input_tensor(alpha_tens)
     fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
     mul_alpha.set_output_tensor(fm_alpha)
@@ -1209,7 +1208,7 @@
                 purpose=TensorPurpose.Weights,
                 quantization=quantization,
             )
-            weight_tens.quant_values = weights
+            weight_tens.values = weights
             op.type = Op.DepthwiseConv2DBias
             op.inputs = []
             op.add_input_tensor(ifm)
@@ -1331,7 +1330,6 @@
         nr_biases = op.inputs[1].shape[-1]
         bias_values = [0] * nr_biases
         bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], DataType.int32, bias_values)
-        bias_tensor.quant_values = bias_tensor.values
         op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
 
     return op
@@ -1409,13 +1407,7 @@
                     quant = QuantizationParameters()
                     quant.zero_point = 0
                     bias_term_tens = create_const_tensor(
-                        op.name + "_bias",
-                        [1, 1, 1, 1],
-                        DataType.int16,
-                        [bias_term],
-                        np.int16,
-                        quantization=quant,
-                        quant_value_dtype=np.int16,
+                        op.name + "_bias", [1, 1, 1, 1], DataType.int16, [bias_term], np.int16, quantization=quant,
                     )
                     add_op.add_input_tensor(bias_term_tens)
                     add_op.set_output_tensor(op.ofm)
@@ -1514,7 +1506,7 @@
             ),
             1,
         )
-        op.weights.quant_values = np.reshape(op.inputs[1].quant_values, weight_shape)
+        op.weights.values = np.reshape(op.inputs[1].values, weight_shape)
 
         # Add None bias tensor
         op.inputs.append(None)
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index 30bf32a..fbee793 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -107,9 +107,6 @@
                 tens.values = np.array(buf.view(np_dtype))
             else:
                 tens.values = np.array(buf.view(np_dtype).reshape(shape))
-                if tens.quantization is not None:
-                    tens.quant_values = tens.values
-                    tens.values = tens.quantization.dequantize(tens.quant_values)
         return tens
 
     def parse_operator(self, op_index, op_data):
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py
index fd3bf42..e6dd85b 100644
--- a/ethosu/vela/tflite_writer.py
+++ b/ethosu/vela/tflite_writer.py
@@ -243,9 +243,7 @@
     def serialise_tensor(self, tens):
         builder = self.builder
         tens_shape = tens.shape
-        values = tens.quant_values
-        if values is None:
-            values = tens.values
+        values = tens.values
 
         if values is None:
             values = np.empty(shape=(0), dtype=np.uint8)
diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py
index e51ead1..364d9a6 100644
--- a/ethosu/vela/tosa_reader.py
+++ b/ethosu/vela/tosa_reader.py
@@ -192,7 +192,6 @@
                 fname = decode_str(tens_data.NpyFilename())
                 tens.values = np.load(os.path.join(file_path, fname))
                 assert list(tens.values.shape) == tens.shape
-                tens.quant_values = tens.values
             except (struct.error, TypeError, RuntimeError) as e:
                 print(f'Error: Invalid npy file. Got "{e}" ')
                 sys.exit(1)
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index 7e33e93..6536143 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -100,7 +100,7 @@
 def create_weight_compression_config(weight_tens, npu_block_type, ofm_block_depth, ofm_depth_step, dilation):
     # Note: for an ofm block only its depth is used in weight compression.
     # And block depth > ofm depth gives same result as block depth == ofm depth
-    block_depth = min(ofm_block_depth, weight_tens.quant_values.shape[-1])
+    block_depth = min(ofm_block_depth, weight_tens.values.shape[-1])
     return WeightCompressionConfig(npu_block_type, block_depth, ofm_depth_step, dilation, weight_tens.value_id)
 
 
@@ -214,7 +214,7 @@
 
     # the operator should only have a single output
     assert len(tens.consumer_list[0].outputs) == 1
-    biases = tens.quant_values
+    biases = tens.values
 
     first_consumer_op = tens.consumer_list[0]
     ifm_dtype = first_consumer_op.inputs[0].dtype
@@ -318,7 +318,7 @@
         assert weight_tens.quantization.zero_point is not None
 
         # Early zero-point correction
-        quant_buf = weight_tens.quant_values.astype(np.int16)
+        quant_buf = weight_tens.values.astype(np.int16)
         # the zero point can be either a native or numpy type
         if isinstance(weight_tens.quantization.zero_point, (int, float)):
             zero_point = np.int16(weight_tens.quantization.zero_point)
@@ -363,7 +363,7 @@
         scale_tens.element_size_bytes = 10
 
     # Slice the weight stream up depth-ways into bricks and compress
-    full_ofm_depth = weight_tens.quant_values.shape[-1]
+    full_ofm_depth = weight_tens.values.shape[-1]
     ofm_block_depth = block_config.ofm_block.depth
 
     weight_range_index = 0