MLBEDSW-2804: Added bias data type check

Allows int64 data type to be used as long as all values can be packed
into a int40 value.

Signed-off-by: Jacob Bohlin <jacob.bohlin@arm.com>
Change-Id: I0e25ec482e3ea765a5fd00bcf7e212a9e65a1461
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index f7a9509..8dec379 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -179,6 +179,27 @@
 
         return ifm_tensor, weight_tensor, bias_tensor, ofm_tensor
 
+    def get_ifm_ifm2_weights_biases_ofm(self):
+        ifm_tensor = None
+        ifm2_tensor = None
+        weight_tensor = None
+        bias_tensor = None
+        ofm_tensor = None
+
+        ifm_idx, ifm2_idx, weight_idx, bias_idx, ofm_idx = self.get_ifm_ifm2_weight_bias_ofm_indices()
+        if ifm_idx != -1:
+            ifm_tensor = self.inputs[ifm_idx]
+        if ifm2_idx != -1:
+            ifm2_tensor = self.inputs[ifm2_idx]
+        if weight_idx != -1:
+            weight_tensor = self.inputs[weight_idx]
+        if bias_idx != -1:
+            bias_tensor = self.inputs[bias_idx]
+        if ofm_idx != -1:
+            ofm_tensor = self.outputs[ofm_idx]
+
+        return ifm_tensor, ifm2_tensor, weight_tensor, bias_tensor, ofm_tensor
+
     def is_concat_op(self):
         return self.type in ("Concat", "ConcatV2", "QuantizedConcat", "ConcatTFLite", "PackReshaped")
 
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index f57cbee..8ec7720 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -201,10 +201,13 @@
             return False
 
         # check data type
-        ifm_tensor, _, weight_tensor, _ = op.get_ifm_ifm2_weights_ofm()
+        ifm_tensor, _, weight_tensor, bias_tensor, _ = op.get_ifm_ifm2_weights_biases_ofm()
         if weight_tensor.element_size() > 1:
             return False
 
+        if not self.check_bias_restrictions(bias_tensor):
+            return False
+
         # check kernel size [HWIO]
         dilated_weight_w = weight_tensor.shape[1] + (weight_tensor.shape[1] - 1) * (dilation_w_factor - 1)
         dilated_weight_h = weight_tensor.shape[0] + (weight_tensor.shape[0] - 1) * (dilation_h_factor - 1)
@@ -307,10 +310,13 @@
 
     def check_vector_product_restrictions(self, op):
         # check data type
-        ifm_tensor, _, weight_tensor, _ = op.get_ifm_ifm2_weights_ofm()
+        _, _, weight_tensor, bias_tensor, _ = op.get_ifm_ifm2_weights_biases_ofm()
         if weight_tensor.element_size() > 1:
             return False
 
+        if not self.check_bias_restrictions(bias_tensor):
+            return False
+
         return True
 
     def check_element_wise_restrictions(self, op):
@@ -407,3 +413,16 @@
                 return False
 
         return True
+
+    def check_bias_restrictions(self, bias_tensor):
+        # check data type
+        if bias_tensor.dtype not in (DataType.int32, DataType.int64):
+            return False
+
+        # check if values fits in 40-bit
+        if bias_tensor.dtype == DataType.int64:
+            for value in bias_tensor.values:
+                if not (-(1 << 39) <= value < (1 << 39)):
+                    return False
+
+        return True