COMPMID-3152: Initial Bfloat16 support

Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Change-Id: Ie6959e37e13731c86b2ee29392a99a293450a1b4
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2824
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
diff --git a/src/core/NEON/kernels/NEDepthConvertLayerKernel.cpp b/src/core/NEON/kernels/NEDepthConvertLayerKernel.cpp
index f824f7a..79dc2cb 100644
--- a/src/core/NEON/kernels/NEDepthConvertLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEDepthConvertLayerKernel.cpp
@@ -33,7 +33,7 @@
 #include "arm_compute/core/Validate.h"
 #include "arm_compute/core/utils/misc/SaturateCast.h"
 
-#include <arm_neon.h>
+#include "arm_compute/core/NEON/wrapper/wrapper.h"
 
 using namespace arm_compute;
 
@@ -43,11 +43,16 @@
 {
     ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
     ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(output);
+    ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(input);
+    ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(output);
     ARM_COMPUTE_UNUSED(policy);
     ARM_COMPUTE_RETURN_ERROR_ON(input == output);
-    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8, DataType::S16, DataType::U16, DataType::F16, DataType::F32, DataType::S32);
-    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8, DataType::S16, DataType::U16, DataType::U32, DataType::S32, DataType::F16,
-                                                         DataType::F32);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8,
+                                                         DataType::S16, DataType::U16, DataType::BFLOAT16, DataType::F16,
+                                                         DataType::F32, DataType::S32);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8,
+                                                         DataType::S16, DataType::U16, DataType::BFLOAT16, DataType::F16,
+                                                         DataType::U32, DataType::S32, DataType::F32);
     ARM_COMPUTE_RETURN_ERROR_ON(shift >= 8);
 
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::QASYMM8_SIGNED && (output->data_type() != DataType::S16 && output->data_type() != DataType::S32
@@ -68,15 +73,18 @@
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::S16 && (output->data_type() != DataType::QASYMM8_SIGNED && output->data_type() != DataType::U8 && output->data_type() != DataType::S32),
                                     "Only data_types supported [in] S16 ->  [out] U8, S32");
 
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::BFLOAT16 && output->data_type() != DataType::F32,
+                                    "Only data_types supported [in] BFLOAT16 ->  [out] F32");
+
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::F16 && (output->data_type() != DataType::QASYMM8_SIGNED && output->data_type() != DataType::QASYMM8
                                                                             && output->data_type() != DataType::U8
                                                                             && output->data_type() != DataType::F32 && output->data_type() != DataType::S32),
                                     "Only data_types supported [in] F16 ->  [out] QASYMM8, F32, S32, U8");
 
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::F32 && (output->data_type() != DataType::QASYMM8_SIGNED && output->data_type() != DataType::QASYMM8
-                                                                            && output->data_type() != DataType::F16
+                                                                            && output->data_type() != DataType::F16 && output->data_type() != DataType::BFLOAT16
                                                                             && output->data_type() != DataType::S32 && output->data_type() != DataType::U8),
-                                    "Only data_types supported [in] F32 ->  [out] QASYMM8, F16, S32, U8");
+                                    "Only data_types supported [in] F32 ->  [out] QASYMM8, BFLOAT16, F16, S32, U8");
 
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::S32 && (output->data_type() != DataType::QASYMM8_SIGNED && output->data_type() != DataType::QASYMM8
                                                                             && output->data_type() != DataType::F16
@@ -786,6 +794,52 @@
             }
             break;
         }
+#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
+        case DataType::BFLOAT16:
+            switch(_output->info()->data_type())
+            {
+                case DataType::F32:
+                {
+                    /* Up-conversion BFLOAT16 -> F32 */
+                    execute_window_loop(win, [&](const Coordinates &)
+                    {
+                        const auto input_ptr  = reinterpret_cast<const bfloat16 *>(input.ptr());
+                        const auto output_ptr = reinterpret_cast<float *>(output.ptr());
+
+                        int x = window_start_x;
+                        for(; x <= (window_end_x - window_step_x); x += window_step_x)
+                        {
+                            const uint16x8x2_t texels =
+                            {
+                                {
+                                    vld1q_u16(reinterpret_cast<uint16_t *>(input.ptr())),
+                                    vld1q_u16(reinterpret_cast<uint16_t *>(input.ptr()) + 8)
+                                }
+                            };
+
+                            vst1q_f32(reinterpret_cast<float *>(output.ptr()),
+                                      vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_low_u16(texels.val[0])), 16)));
+                            vst1q_f32(reinterpret_cast<float *>(output.ptr()) + 4,
+                                      vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_high_u16(texels.val[0])), 16)));
+                            vst1q_f32(reinterpret_cast<float *>(output.ptr()) + 8,
+                                      vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_low_u16(texels.val[1])), 16)));
+                            vst1q_f32(reinterpret_cast<float *>(output.ptr()) + 12,
+                                      vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_high_u16(texels.val[1])), 16)));
+                        }
+
+                        for(; x < window_end_x; ++x)
+                        {
+                            *(output_ptr + x) = float(*(input_ptr + x));
+                        }
+                    },
+                    input, output);
+                    break;
+                }
+                default:
+                    ARM_COMPUTE_ERROR("Output data type unsupported");
+            }
+            break;
+#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
         case DataType::F16:
             switch(_output->info()->data_type())
@@ -980,6 +1034,33 @@
                     break;
                 }
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
+                case DataType::BFLOAT16:
+                {
+                    /* Down-conversion F32 -> BFLOAT16 */
+                    execute_window_loop(win, [&](const Coordinates &)
+                    {
+                        const auto input_ptr  = reinterpret_cast<const float *>(input.ptr());
+                        const auto output_ptr = reinterpret_cast<bfloat16 *>(output.ptr());
+
+                        int x = window_start_x;
+                        for(; x <= (window_end_x - window_step_x); x += window_step_x)
+                        {
+                            wrapper::vcvt_bf16_f32(reinterpret_cast<float *>(input.ptr()),
+                                                   reinterpret_cast<uint16_t *>(output.ptr()));
+                            wrapper::vcvt_bf16_f32(reinterpret_cast<float *>(input.ptr()) + 8,
+                                                   reinterpret_cast<uint16_t *>(output.ptr()) + 8);
+                        }
+
+                        for(; x < window_end_x; ++x)
+                        {
+                            *(output_ptr + x) = *(input_ptr + x);
+                        }
+                    },
+                    input, output);
+                    break;
+                }
+#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
                 case DataType::S32:
                 {
                     const float       scale_s = 1.f / (1 << _shift);