COMPMID-3152: Initial Bfloat16 support

Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Change-Id: Ie6959e37e13731c86b2ee29392a99a293450a1b4
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2824
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
diff --git a/arm_compute/core/CPP/Validate.h b/arm_compute/core/CPP/Validate.h
index f195a31..dfee9de 100644
--- a/arm_compute/core/CPP/Validate.h
+++ b/arm_compute/core/CPP/Validate.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2019 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -48,6 +48,26 @@
     return Status {};
 }
 
+/** Return an error if the data type of the passed tensor info is BFLOAT16 and BFLOAT16 support is not compiled in.
+ *
+ * @param[in] function    Function in which the error occurred.
+ * @param[in] file        Name of the file where the error occurred.
+ * @param[in] line        Line on which the error occurred.
+ * @param[in] tensor_info Tensor info to validate.
+ *
+ * @return Status
+ */
+inline Status error_on_unsupported_cpu_bf16(const char *function, const char *file, const int line,
+                                            const ITensorInfo *tensor_info)
+{
+    ARM_COMPUTE_RETURN_ERROR_ON_LOC(tensor_info == nullptr, function, file, line);
+#if !(defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16))
+    ARM_COMPUTE_RETURN_ERROR_ON_LOC_MSG(tensor_info->data_type() == DataType::BFLOAT16,
+                                        function, file, line, "This CPU architecture does not support BFloat16 data type, you need v8.6 or above");
+#endif /* !(defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)) */
+    return Status {};
+}
+
 /** Return an error if the data type of the passed tensor is FP16 and FP16 support is not compiled in.
  *
  * @param[in] function Function in which the error occurred.
@@ -65,10 +85,33 @@
     return Status{};
 }
 
+/** Return an error if the data type of the passed tensor is BFLOAT16 and BFLOAT16 support is not compiled in.
+ *
+ * @param[in] function Function in which the error occurred.
+ * @param[in] file     Name of the file where the error occurred.
+ * @param[in] line     Line on which the error occurred.
+ * @param[in] tensor   Tensor to validate.
+ *
+ * @return Status
+ */
+inline Status error_on_unsupported_cpu_bf16(const char *function, const char *file, const int line,
+                                            const ITensor *tensor)
+{
+    ARM_COMPUTE_RETURN_ERROR_ON_LOC(tensor == nullptr, function, file, line);
+    ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_unsupported_cpu_bf16(function, file, line, tensor->info()));
+    return Status{};
+}
+
 #define ARM_COMPUTE_ERROR_ON_CPU_F16_UNSUPPORTED(tensor) \
     ARM_COMPUTE_ERROR_THROW_ON(::arm_compute::error_on_unsupported_cpu_fp16(__func__, __FILE__, __LINE__, tensor))
 
 #define ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(tensor) \
     ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_unsupported_cpu_fp16(__func__, __FILE__, __LINE__, tensor))
+
+#define ARM_COMPUTE_ERROR_ON_CPU_BF16_UNSUPPORTED(tensor) \
+    ARM_COMPUTE_ERROR_THROW_ON(::arm_compute::error_on_unsupported_cpu_bf16(__func__, __FILE__, __LINE__, tensor))
+
+#define ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(tensor) \
+    ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_unsupported_cpu_bf16(__func__, __FILE__, __LINE__, tensor))
 } // namespace arm_compute
 #endif /* ARM_COMPUTE_CPP_VALIDATE_H */
diff --git a/arm_compute/core/NEON/kernels/NEDepthConvertLayerKernel.h b/arm_compute/core/NEON/kernels/NEDepthConvertLayerKernel.h
index df4102c..5cda320 100644
--- a/arm_compute/core/NEON/kernels/NEDepthConvertLayerKernel.h
+++ b/arm_compute/core/NEON/kernels/NEDepthConvertLayerKernel.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2016-2019 ARM Limited.
+ * Copyright (c) 2016-2020 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -55,24 +55,25 @@
      * Valid conversions Input -> Output :
      *
      *   - QASYMM8_SIGNED -> S16, S32, F32, F16
-     *   - QASYMM8 -> U16, S16, S32, F32, F16
-     *   - U8 -> U16, S16, S32, F32, F16
-     *   - U16 -> U8, U32
-     *   - S16 -> QASYMM8_SIGNED, U8, S32
-     *   - F16 -> QASYMM8_SIGNED, QASYMM8, F32, S32, U8
-     *   - S32 -> QASYMM8_SIGNED, QASYMM8, F16, F32, U8
-     *   - F32 -> QASYMM8_SIGNED, QASYMM8, F16, S32, U8
+     *   - QASYMM8        -> U16, S16, S32, F32, F16
+     *   - U8             -> U16, S16, S32, F32, F16
+     *   - U16            -> U8, U32
+     *   - S16            -> QASYMM8_SIGNED, U8, S32
+     *   - BFLOAT16       -> F32
+     *   - F16            -> QASYMM8_SIGNED, QASYMM8, F32, S32, U8
+     *   - S32            -> QASYMM8_SIGNED, QASYMM8, F16, F32, U8
+     *   - F32            -> QASYMM8_SIGNED, QASYMM8, BFLOAT16, F16, S32, U8
      *
-     * @param[in]  input  The input tensor to convert. Data types supported: QASYMM8_SIGNED/QASYMM8/U8/U16/S16/F16/F32.
-     * @param[out] output The output tensor. Data types supported: QASYMM8_SIGNED/QASYMM8/U8/U16/S16/U32/S32/F16/F32.
+     * @param[in]  input  The input tensor to convert. Data types supported: QASYMM8_SIGNED/QASYMM8/U8/U16/S16/BFLOAT16/F16/F32.
+     * @param[out] output The output tensor. Data types supported: QASYMM8_SIGNED/QASYMM8/U8/U16/S16/U32/S32/BFLOAT16/F16/F32.
      * @param[in]  policy Conversion policy.
      * @param[in]  shift  (Optional) Value for down/up conversions. Must be 0 <= shift < 8.
      */
     void configure(const ITensor *input, ITensor *output, ConvertPolicy policy, uint32_t shift = 0);
     /** Static function to check if given info will lead to a valid configuration of @ref NEDepthConvertLayerKernel
      *
-     * @param[in] input  Source tensor info. Data types supported: QASYMM8_SIGNED/QASYMM8/U8/U16/S16/F16/F32.
-     * @param[in] output Destination tensor info. Data type supported: QASYMM8_SIGNED/QASYMM8/U8/U16/S16/U32/S32/F16/F32.
+     * @param[in] input  Source tensor info. Data types supported: QASYMM8_SIGNED/QASYMM8/U8/U16/S16/BFLOAT16/F16/F32.
+     * @param[in] output Destination tensor info. Data type supported: QASYMM8_SIGNED/QASYMM8/U8/U16/S16/U32/S32/BFLOAT16/F16/F32.
      * @param[in] policy Conversion policy
      * @param[in] shift  (Optional) Value for down/up conversions. Must be 0 <= shift < 8.
      *
diff --git a/arm_compute/core/NEON/wrapper/intrinsics/cvt.h b/arm_compute/core/NEON/wrapper/intrinsics/cvt.h
index 1f22e09..5ea9a5d 100644
--- a/arm_compute/core/NEON/wrapper/intrinsics/cvt.h
+++ b/arm_compute/core/NEON/wrapper/intrinsics/cvt.h
@@ -56,6 +56,25 @@
     return vcvtq_s32_f32(a);
 }
 
+#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
+/** Convert 2x128-bit floating point vectors into 1x128-bit bfloat16 vector
+ *
+ * @param[in]     inptr  Pointer to the input memory to load values from
+ * @param[in,out] outptr Pointer to the output memory to store values to
+ */
+inline void vcvt_bf16_f32(const float *inptr, uint16_t *outptr)
+{
+    __asm __volatile(
+        "ldp    q0, q1, [%[inptr]]\n"
+        ".inst  0xea16800\n"  // BFCVTN v0, v0
+        ".inst  0x4ea16820\n" // BFCVTN2 v0, v1
+        "str    q0, [%[outptr]]\n"
+        : [inptr] "+r"(inptr)
+        : [outptr] "r"(outptr)
+        : "v0", "v1", "memory");
+}
+#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
+
 } // namespace wrapper
 } // namespace arm_compute
 #endif /* ARM_COMPUTE_WRAPPER_CVT_H */
diff --git a/arm_compute/core/PixelValue.h b/arm_compute/core/PixelValue.h
index 31bc550..337ccbc 100644
--- a/arm_compute/core/PixelValue.h
+++ b/arm_compute/core/PixelValue.h
@@ -89,6 +89,9 @@
             case DataType::S64:
                 value.s64 = static_cast<int64_t>(v);
                 break;
+            case DataType::BFLOAT16:
+                value.bf16 = static_cast<bfloat16>(v);
+                break;
             case DataType::F16:
                 value.f16 = static_cast<half>(v);
                 break;
@@ -174,6 +177,15 @@
     {
         value.s64 = v;
     }
+    /** Initialize the union with a BFLOAT16 pixel value
+     *
+     * @param[in] v F16 value.
+     */
+    PixelValue(bfloat16 v)
+        : PixelValue()
+    {
+        value.bf16 = v;
+    }
     /** Initialize the union with a F16 pixel value
      *
      * @param[in] v F16 value.
@@ -214,6 +226,7 @@
             double   f64;     /**< Single channel double */
             float    f32;     /**< Single channel float 32 */
             half     f16;     /**< Single channel F16 */
+            bfloat16 bf16;    /**< Single channel brain floating-point number */
             uint8_t  u8;      /**< Single channel U8 */
             int8_t   s8;      /**< Single channel S8 */
             uint16_t u16;     /**< Single channel U16 */
@@ -285,6 +298,14 @@
     {
         v = value.s64;
     }
+    /** Interpret the pixel value as a BFLOAT16
+     *
+     * @param[out] v Returned value
+     */
+    void get(bfloat16 &v) const
+    {
+        v = value.bf16;
+    }
     /** Interpret the pixel value as a F16
      *
      * @param[out] v Returned value
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h
index cf689d7..b640987 100644
--- a/arm_compute/core/Types.h
+++ b/arm_compute/core/Types.h
@@ -30,6 +30,7 @@
 #include "arm_compute/core/Strides.h"
 #include "arm_compute/core/TensorShape.h"
 #include "arm_compute/core/utils/misc/Macros.h"
+#include "support/Bfloat16.h"
 #include "support/Half.h"
 
 #include <cmath>
@@ -58,6 +59,7 @@
     U16,      /**< 1 channel, 1 U16 per channel */
     S32,      /**< 1 channel, 1 S32 per channel */
     U32,      /**< 1 channel, 1 U32 per channel */
+    BFLOAT16, /**< 16-bit brain floating-point number */
     F16,      /**< 1 channel, 1 F16 per channel */
     F32,      /**< 1 channel, 1 F32 per channel */
     UV88,     /**< 2 channel, 1 U8 per channel */
@@ -89,6 +91,7 @@
     S32,                /**< signed 32-bit number */
     U64,                /**< unsigned 64-bit number */
     S64,                /**< signed 64-bit number */
+    BFLOAT16,           /**< 16-bit brain floating-point number */
     F16,                /**< 16-bit floating-point number */
     F32,                /**< 32-bit floating-point number */
     F64,                /**< 64-bit floating-point number */
diff --git a/arm_compute/core/Utils.h b/arm_compute/core/Utils.h
index 4a3b01d..8577046 100644
--- a/arm_compute/core/Utils.h
+++ b/arm_compute/core/Utils.h
@@ -114,6 +114,7 @@
         case DataType::S16:
         case DataType::QSYMM16:
         case DataType::QASYMM16:
+        case DataType::BFLOAT16:
         case DataType::F16:
             return 2;
         case DataType::F32:
@@ -146,6 +147,7 @@
             return 1;
         case Format::U16:
         case Format::S16:
+        case Format::BFLOAT16:
         case Format::F16:
         case Format::UV88:
         case Format::YUYV422:
@@ -191,6 +193,7 @@
         case DataType::S16:
         case DataType::QSYMM16:
         case DataType::QASYMM16:
+        case DataType::BFLOAT16:
         case DataType::F16:
             return 2;
         case DataType::U32:
@@ -228,6 +231,8 @@
             return DataType::U32;
         case Format::S32:
             return DataType::S32;
+        case Format::BFLOAT16:
+            return DataType::BFLOAT16;
         case Format::F16:
             return DataType::F16;
         case Format::F32:
@@ -260,6 +265,7 @@
         case Format::S16:
         case Format::U32:
         case Format::S32:
+        case Format::BFLOAT16:
         case Format::F16:
         case Format::F32:
         case Format::UV88:
@@ -447,6 +453,7 @@
         case Format::U16:
         case Format::S32:
         case Format::U32:
+        case Format::BFLOAT16:
         case Format::F16:
         case Format::F32:
         case Format::RGB888:
@@ -481,6 +488,7 @@
         case Format::S16:
         case Format::U32:
         case Format::S32:
+        case Format::BFLOAT16:
         case Format::F16:
         case Format::F32:
             return 1;
@@ -531,6 +539,7 @@
         case DataType::QSYMM8_PER_CHANNEL:
         case DataType::QSYMM16:
         case DataType::QASYMM16:
+        case DataType::BFLOAT16:
         case DataType::F16:
         case DataType::U32:
         case DataType::S32:
@@ -596,6 +605,12 @@
             max = PixelValue(std::numeric_limits<int32_t>::max());
             break;
         }
+        case DataType::BFLOAT16:
+        {
+            min = PixelValue(bfloat16::lowest());
+            max = PixelValue(bfloat16::max());
+            break;
+        }
         case DataType::F16:
         {
             min = PixelValue(std::numeric_limits<half>::lowest());
@@ -1284,6 +1299,8 @@
             const auto val_s32 = static_cast<int32_t>(val);
             return ((val_s32 == val) && val_s32 >= std::numeric_limits<int32_t>::lowest() && val_s32 <= std::numeric_limits<int32_t>::max());
         }
+        case DataType::BFLOAT16:
+            return (val >= bfloat16::lowest() && val <= bfloat16::max());
         case DataType::F16:
             return (val >= std::numeric_limits<half>::lowest() && val <= std::numeric_limits<half>::max());
         case DataType::F32:
@@ -1323,6 +1340,11 @@
             // We use T instead of print_type here is because the std::is_floating_point<half> returns false and then the print_type becomes int.
             s << std::right << static_cast<T>(ptr[i]) << element_delim;
         }
+        else if(std::is_same<typename std::decay<T>::type, bfloat16>::value)
+        {
+            // We use T instead of print_type here is because the std::is_floating_point<bfloat> returns false and then the print_type becomes int.
+            s << std::right << float(ptr[i]) << element_delim;
+        }
         else
         {
             s << std::right << static_cast<print_type>(ptr[i]) << element_delim;
@@ -1357,6 +1379,11 @@
             // We use T instead of print_type here is because the std::is_floating_point<half> returns false and then the print_type becomes int.
             ss << static_cast<T>(ptr[i]);
         }
+        else if(std::is_same<typename std::decay<T>::type, bfloat16>::value)
+        {
+            // We use T instead of print_type here is because the std::is_floating_point<bfloat> returns false and then the print_type becomes int.
+            ss << float(ptr[i]);
+        }
         else
         {
             ss << static_cast<print_type>(ptr[i]);