COMPMID-3160: Add Bfloat16 support in NEGEMMConvolutionLayer

Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Change-Id: I0e449306c138a562ffc1455e76ec44b2fd059d85
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2860
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
index 91b91d66..a3080e7 100644
--- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
+++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
@@ -505,17 +505,17 @@
 
 Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const GEMMInfo &gemm_info)
 {
-    ARM_COMPUTE_UNUSED(gemm_info);
-    ARM_COMPUTE_UNUSED(c);
+    ARM_COMPUTE_UNUSED(gemm_info, c);
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(a, b, d);
     ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
+    ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a);
 #ifndef __aarch64__
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->element_size() == 1, "8bit integer types only supported for aarch64");
 #endif /* __aarch64__ */
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S8,
-                                                         DataType::F16, DataType::F32);
+                                                         DataType::BFLOAT16, DataType::F16, DataType::F32);
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(b, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8_PER_CHANNEL, DataType::S8,
-                                                         DataType::F16, DataType::F32);
+                                                         DataType::BFLOAT16, DataType::F16, DataType::F32);
     if(is_data_type_quantized_per_channel(b->data_type()))
     {
         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::QASYMM8_SIGNED, DataType::S8);
@@ -526,6 +526,7 @@
     }
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F32 && d->data_type() != DataType::F32, "Only F32 output supported for F32 input");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F16 && d->data_type() != DataType::F16, "Only F16 output supported for F16 input");
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::BFLOAT16 && d->data_type() != DataType::F32, "Only F32 output supported for BFLOAT16 input");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && d->data_type() != DataType::QASYMM8, "Only QASYMM8 output supported for QASYMM8 input");
@@ -578,6 +579,11 @@
             }
             break;
 #endif /* __aarch64__ */
+#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
+        case DataType::BFLOAT16:
+            create_arm_gemm<bfloat16, float>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);
+            break;
+#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
         case DataType::F16:
             create_arm_gemm<float16_t, float16_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);