Add test for len1_array_to_scalar function
Moved len1_array_to_scalar from a nested function to a staticmethod
of TFLiteSubgraph.
Change-Id: I182f0b70f03070855c1a4478d26644892c1ebb15
Signed-off-by: Diego Russo <diego.russo@arm.com>
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index 4f9bd7d..7e158aa 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -91,28 +91,15 @@
shape = list(np_shape) if type(np_shape) is np.ndarray else []
name = decode_str(tens_data.Name())
dtype = datatype_map[tens_data.Type()]
-
tens = Tensor(shape, dtype, name)
-
quant = tens_data.Quantization()
- def len1_array_to_scalar(arr):
- # The following flatbuffer quantisation fields all return a scalar value of 0 if they are not definied in
- # the input buffer. This is represented in Vela by using None.
- # Otherwise, the fields returned are a single or multi-element array. In which case, single element arrays
- # are converted to scalars
- if isinstance(arr, int) and arr == 0:
- return None
- if len(arr) == 1:
- return arr[0]
- return arr
-
tens.quantization = QuantizationParameters()
if quant is not None:
- tens.quantization.min = len1_array_to_scalar(quant.MinAsNumpy())
- tens.quantization.max = len1_array_to_scalar(quant.MaxAsNumpy())
- tens.quantization.scale_f32 = len1_array_to_scalar(quant.ScaleAsNumpy())
- tens.quantization.zero_point = len1_array_to_scalar(quant.ZeroPointAsNumpy())
+ tens.quantization.min = self.len1_array_to_scalar(quant.MinAsNumpy())
+ tens.quantization.max = self.len1_array_to_scalar(quant.MaxAsNumpy())
+ tens.quantization.scale_f32 = self.len1_array_to_scalar(quant.ScaleAsNumpy())
+ tens.quantization.zero_point = self.len1_array_to_scalar(quant.ZeroPointAsNumpy())
if dtype == DataType.uint8:
tens.quantization.quant_min = 0
@@ -199,6 +186,18 @@
op.outputs[0] = intermediate_tens
act_op.inputs = [intermediate_tens]
+ @staticmethod
+ def len1_array_to_scalar(arr):
+ # The following flatbuffer quantisation fields all return a scalar value of 0 if they are not definied in
+ # the input buffer. This is represented in Vela by using None.
+ # Otherwise, the fields returned are a single or multi-element array. In which case, single element arrays
+ # are converted to scalars
+ if isinstance(arr, int) and arr == 0:
+ return None
+ if len(arr) == 1:
+ return arr[0]
+ return arr
+
class TFLiteGraph:
def __init__(