IVGCVSW-4511 Add BFloat16 to RefLayerSupport and unit tests

Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: Ifaae4d5aac468ba927b2c6a4bf31b8c8522aeb2e
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index cb94955..9dc576c 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -79,6 +79,7 @@
 
     // Define supported types.
     std::array<DataType,6> supportedTypes = {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16,
         DataType::QAsymmS8,
@@ -145,6 +146,7 @@
     bool supported = true;
 
     std::array<DataType,6> supportedTypes = {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16,
         DataType::QAsymmS8,
@@ -179,8 +181,9 @@
 {
     IgnoreUnused(descriptor);
 
-    std::array<DataType, 4> supportedTypes =
+    std::array<DataType, 5> supportedTypes =
     {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::QAsymmU8,
         DataType::QSymmS16,
@@ -208,8 +211,9 @@
 {
     IgnoreUnused(descriptor);
 
-    std::array<DataType, 4> supportedTypes =
+    std::array<DataType, 5> supportedTypes =
     {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16,
         DataType::QAsymmU8,
@@ -256,12 +260,13 @@
     std::string outputTensorStr = "output";
 
     // Define supported types.
-    std::array<DataType,4> supportedTypes =
+    std::array<DataType,5> supportedTypes =
     {
-            DataType::Float32,
-            DataType::Float16,
-            DataType::QAsymmU8,
-            DataType::QSymmS16
+        DataType::BFloat16,
+        DataType::Float32,
+        DataType::Float16,
+        DataType::QAsymmU8,
+        DataType::QSymmS16
     };
 
     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
@@ -298,8 +303,9 @@
 {
     IgnoreUnused(descriptor);
 
-    std::array<DataType, 4> supportedInputTypes =
+    std::array<DataType, 5> supportedInputTypes =
     {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16,
         DataType::QAsymmU8,
@@ -327,13 +333,14 @@
     IgnoreUnused(descriptor);
 
     bool supported = true;
-    std::array<DataType,5> supportedTypes =
+    std::array<DataType,6> supportedTypes =
     {
-            DataType::Float32,
-            DataType::Float16,
-            DataType::QAsymmU8,
-            DataType::QAsymmS8,
-            DataType::QSymmS16
+        DataType::BFloat16,
+        DataType::Float32,
+        DataType::Float16,
+        DataType::QAsymmU8,
+        DataType::QAsymmS8,
+        DataType::QSymmS16
     };
 
     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
@@ -354,8 +361,9 @@
 bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
                                           Optional<std::string&> reasonIfUnsupported) const
 {
-    std::array<DataType,6> supportedTypes =
+    std::array<DataType,7> supportedTypes =
     {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Signed32,
         DataType::QAsymmU8,
@@ -418,8 +426,9 @@
     bool supported = true;
 
     // Define supported types.
-    std::array<DataType,6> supportedTypes =
+    std::array<DataType,7> supportedTypes =
     {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16,
         DataType::QAsymmU8,
@@ -464,8 +473,9 @@
 
     if (biases.has_value())
     {
-        std::array<DataType,3> biasesSupportedTypes =
+        std::array<DataType,4> biasesSupportedTypes =
         {
+            DataType::BFloat16,
             DataType::Float32,
             DataType::Float16,
             DataType::Signed32
@@ -516,8 +526,9 @@
     IgnoreUnused(descriptor);
     bool supported = true;
 
-    std::array<DataType,4> supportedTypes =
+    std::array<DataType,5> supportedTypes =
     {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16,
         DataType::QAsymmU8,
@@ -546,8 +557,9 @@
     bool supported = true;
 
     // Define supported types.
-    std::array<DataType,6> supportedTypes =
+    std::array<DataType,7> supportedTypes =
     {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16,
         DataType::QSymmS8,
@@ -592,8 +604,9 @@
 
     if (biases.has_value())
     {
-        std::array<DataType,3> biasesSupportedTypes =
+        std::array<DataType,4> biasesSupportedTypes =
         {
+            DataType::BFloat16,
             DataType::Float32,
             DataType::Float16,
             DataType::Signed32
@@ -629,7 +642,8 @@
     supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
                                   "Reference dequantize: per-axis quantized input not support .");
 
-    std::array<DataType,2> supportedOutputTypes = {
+    std::array<DataType,3> supportedOutputTypes = {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16
     };
@@ -658,8 +672,9 @@
 
     bool supported = true;
 
-    std::array<DataType,3> supportedInputTypes =
+    std::array<DataType,4> supportedInputTypes =
     {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::QAsymmU8,
         DataType::QSymmS16
@@ -691,7 +706,8 @@
 {
     bool supported = true;
 
-    std::array<DataType,4> supportedTypes = {
+    std::array<DataType,5> supportedTypes = {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16,
         DataType::QAsymmU8,
@@ -726,8 +742,9 @@
 {
     IgnoreUnused(descriptor);
 
-    std::array<DataType, 4> supportedTypes =
+    std::array<DataType, 5> supportedTypes =
     {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16,
         DataType::QAsymmU8,
@@ -789,8 +806,9 @@
     IgnoreUnused(output);
     bool supported = true;
 
-    std::array<DataType,3> supportedTypes =
+    std::array<DataType,4> supportedTypes =
     {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16,
         DataType::QSymmS16
@@ -815,13 +833,14 @@
     bool supported = true;
 
     // Define supported types.
-    std::array<DataType,5> supportedTypes =
+    std::array<DataType,6> supportedTypes =
     {
-            DataType::Float32,
-            DataType::Float16,
-            DataType::QAsymmU8,
-            DataType::QAsymmS8,
-            DataType::QSymmS16
+        DataType::BFloat16,
+        DataType::Float32,
+        DataType::Float16,
+        DataType::QAsymmU8,
+        DataType::QAsymmS8,
+        DataType::QSymmS16
     };
 
     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
@@ -863,9 +882,10 @@
     if (descriptor.m_BiasEnabled)
     {
         // Defined supported types for bias
-        std::array<DataType, 3>
+        std::array<DataType, 4>
         supportedBiasTypes =
         {
+            DataType::BFloat16,
             DataType::Float32,
             DataType::Float16,
             DataType::Signed32
@@ -891,8 +911,9 @@
                                         armnn::Optional<std::string&> reasonIfUnsupported) const
 {
     bool supported = true;
-    std::array<DataType,4> supportedTypes =
+    std::array<DataType,5> supportedTypes =
     {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16,
         DataType::QAsymmU8,
@@ -939,8 +960,9 @@
 {
     IgnoreUnused(descriptor);
     // Define supported types
-    std::array<DataType, 4> supportedTypes =
+    std::array<DataType, 3> supportedTypes =
         {
+            DataType::BFloat16,
             DataType::Float32,
             DataType::Float16
         };
@@ -970,8 +992,9 @@
 {
     IgnoreUnused(descriptor);
     // Define supported types
-    std::array<DataType, 4> supportedTypes =
+    std::array<DataType, 5> supportedTypes =
     {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16,
         DataType::QAsymmU8,
@@ -1003,10 +1026,11 @@
 {
     IgnoreUnused(descriptor);
 
-    std::array<DataType, 2> supportedTypes =
+    std::array<DataType, 3> supportedTypes =
     {
-            DataType::Float32,
-            DataType::Float16
+        DataType::BFloat16,
+        DataType::Float32,
+        DataType::Float16
     };
 
     bool supported = true;
@@ -1038,7 +1062,8 @@
 
     bool supported = true;
 
-    std::array<DataType,2> supportedTypes = {
+    std::array<DataType,3> supportedTypes = {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::QSymmS16
     };
@@ -1139,7 +1164,8 @@
 {
     bool supported = true;
 
-    std::array<DataType,5> supportedTypes = {
+    std::array<DataType,6> supportedTypes = {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16,
         DataType::QAsymmS8,
@@ -1177,8 +1203,9 @@
     std::string meanLayerStr = "Mean";
     std::string outputTensorStr = "output";
 
-    std::array<DataType,4> supportedTypes =
+    std::array<DataType,5> supportedTypes =
     {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16,
         DataType::QAsymmU8,
@@ -1243,8 +1270,9 @@
 {
     bool supported = true;
 
-    std::array<DataType,5> supportedTypes =
+    std::array<DataType,6> supportedTypes =
     {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16,
         DataType::QAsymmU8,
@@ -1271,7 +1299,8 @@
 {
     bool supported = true;
 
-    std::array<DataType,4> supportedTypes = {
+    std::array<DataType,5> supportedTypes = {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16,
         DataType::QAsymmU8,
@@ -1307,6 +1336,7 @@
     bool supported = true;
 
     std::array<DataType,6> supportedTypes = {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16,
         DataType::QAsymmU8,
@@ -1343,8 +1373,9 @@
     IgnoreUnused(descriptor);
 
     // Define supported types
-    std::array<DataType, 4> supportedTypes =
+    std::array<DataType, 5> supportedTypes =
     {
+        DataType::BFloat16,
         DataType::Float16,
         DataType::Float32,
         DataType::QAsymmU8,
@@ -1381,8 +1412,9 @@
     bool supported = true;
 
     // Define supported output and inputs types.
-    std::array<DataType,4> supportedTypes =
+    std::array<DataType,5> supportedTypes =
     {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16,
         DataType::QAsymmU8,
@@ -1410,8 +1442,9 @@
     bool supported = true;
 
     // Define supported output and inputs types.
-    std::array<DataType, 4> supportedTypes =
+    std::array<DataType, 5> supportedTypes =
     {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16,
         DataType::QAsymmU8,
@@ -1439,8 +1472,9 @@
     bool supported = true;
 
     // Define supported output and inputs types.
-    std::array<DataType,5> supportedTypes =
+    std::array<DataType,6> supportedTypes =
     {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16,
         DataType::QAsymmS8,
@@ -1467,7 +1501,8 @@
    bool supported = true;
 
     // Define supported input types.
-    std::array<DataType,6> supportedInputTypes = {
+    std::array<DataType,7> supportedInputTypes = {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16,
         DataType::QAsymmS8,
@@ -1505,6 +1540,7 @@
     // Define supported output types.
     std::array<DataType,7> supportedOutputTypes =
     {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16,
         DataType::Signed32,
@@ -1522,8 +1558,9 @@
                                                 Optional<std::string&> reasonIfUnsupported) const
 {
     bool supported = true;
-    std::array<DataType,4> supportedTypes =
+    std::array<DataType,5> supportedTypes =
     {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16,
         DataType::QAsymmU8,
@@ -1549,8 +1586,9 @@
 {
     IgnoreUnused(descriptor);
     bool supported = true;
-    std::array<DataType,5> supportedTypes =
+    std::array<DataType,6> supportedTypes =
     {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16,
         DataType::QAsymmU8,
@@ -1588,8 +1626,9 @@
     IgnoreUnused(descriptor);
     bool supported = true;
 
-    std::array<DataType, 3> supportedTypes =
+    std::array<DataType, 4> supportedTypes =
     {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::QAsymmU8,
         DataType::QSymmS16
@@ -1614,14 +1653,15 @@
 {
     IgnoreUnused(descriptor);
     bool supported = true;
-    std::array<DataType,6> supportedTypes =
+    std::array<DataType,7> supportedTypes =
     {
-            DataType::Float32,
-            DataType::Float16,
-            DataType::QSymmS8,
-            DataType::QAsymmS8,
-            DataType::QAsymmU8,
-            DataType::QSymmS16
+        DataType::BFloat16,
+        DataType::Float32,
+        DataType::Float16,
+        DataType::QSymmS8,
+        DataType::QAsymmS8,
+        DataType::QAsymmU8,
+        DataType::QSymmS16
     };
 
     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
@@ -1643,12 +1683,13 @@
 {
     IgnoreUnused(descriptor);
     bool supported = true;
-    std::array<DataType,4> supportedTypes =
+    std::array<DataType,5> supportedTypes =
     {
-            DataType::Float32,
-            DataType::Float16,
-            DataType::QAsymmU8,
-            DataType::QSymmS16
+        DataType::BFloat16,
+        DataType::Float32,
+        DataType::Float16,
+        DataType::QAsymmU8,
+        DataType::QSymmS16
     };
 
     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
@@ -1672,8 +1713,9 @@
     IgnoreUnused(descriptor);
     bool supported = true;
 
-    std::array<DataType,4> supportedTypes =
+    std::array<DataType,5> supportedTypes =
     {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16,
         DataType::QAsymmU8,
@@ -1698,8 +1740,9 @@
 {
     IgnoreUnused(descriptor);
     bool supported = true;
-    std::array<DataType,4> supportedTypes =
+    std::array<DataType,5> supportedTypes =
     {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16,
         DataType::QAsymmU8,
@@ -1719,8 +1762,9 @@
 {
     IgnoreUnused(descriptor);
     bool supported = true;
-    std::array<DataType,4> supportedTypes =
+    std::array<DataType,5> supportedTypes =
     {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16,
         DataType::QAsymmU8,
@@ -1749,8 +1793,9 @@
     IgnoreUnused(descriptor);
 
     bool supported = true;
-    std::array<DataType,4> supportedTypes =
+    std::array<DataType,5> supportedTypes =
     {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16,
         DataType::QAsymmU8,
@@ -1780,8 +1825,9 @@
     IgnoreUnused(descriptor);
     bool supported = true;
 
-    std::array<DataType,3> supportedTypes =
+    std::array<DataType,4> supportedTypes =
     {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::QAsymmU8,
         DataType::QSymmS16
@@ -1806,7 +1852,8 @@
 {
     bool supported = true;
 
-    std::array<DataType,4> supportedTypes = {
+    std::array<DataType,5> supportedTypes = {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16,
         DataType::QAsymmU8,
@@ -1841,8 +1888,9 @@
 {
     bool supported = true;
 
-    std::array<DataType, 4> supportedTypes
+    std::array<DataType, 5> supportedTypes
     {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16,
         DataType::QAsymmU8,
@@ -1877,12 +1925,13 @@
     IgnoreUnused(descriptor);
     bool supported = true;
 
-    std::array<DataType,4> supportedTypes =
+    std::array<DataType,5> supportedTypes =
     {
-            DataType::Float32,
-            DataType::Float16,
-            DataType::QAsymmU8,
-            DataType::QSymmS16
+        DataType::BFloat16,
+        DataType::Float32,
+        DataType::Float16,
+        DataType::QAsymmU8,
+        DataType::QSymmS16
     };
 
     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
@@ -1922,11 +1971,12 @@
 
     if (biases.has_value())
     {
-        std::array<DataType,3> biasesSupportedTypes =
+        std::array<DataType,4> biasesSupportedTypes =
         {
-                DataType::Float32,
-                DataType::Float16,
-                DataType::Signed32
+            DataType::BFloat16,
+            DataType::Float32,
+            DataType::Float16,
+            DataType::Signed32
         };
         supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
                                       "Reference TransposeConvolution2d: biases is not a supported type.");
@@ -1944,8 +1994,9 @@
     bool supported = true;
 
     // Define supported output and inputs types.
-    std::array<DataType, 4> supportedTypes =
+    std::array<DataType, 5> supportedTypes =
     {
+        DataType::BFloat16,
         DataType::Float32,
         DataType::Float16,
         DataType::QAsymmU8,
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 52d71df..1d82421 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -50,6 +50,11 @@
     return IsDataType<DataType::Signed32>(info);
 }
 
+bool IsBFloat16(const WorkloadInfo& info)
+{
+    return IsDataType<DataType::BFloat16>(info);
+}
+
 bool IsFloat16(const WorkloadInfo& info)
 {
     return IsDataType<DataType::Float16>(info);
@@ -441,6 +446,10 @@
     {
         return std::make_unique<RefPadFloat16Workload>(descriptor, info);
     }
+    else if (IsBFloat16(info))
+    {
+        return std::make_unique<RefPadBFloat16Workload>(descriptor, info);
+    }
     return MakeWorkload<RefPadFloat32Workload, RefPadQAsymm8Workload>(descriptor, info);
 }
 
@@ -451,6 +460,10 @@
     {
         return std::make_unique<RefPermuteQSymm16Workload>(descriptor, info);
     }
+    else if (IsBFloat16(info))
+    {
+        return std::make_unique<RefPermuteBFloat16Workload>(descriptor, info);
+    }
     return MakeWorkloadHelper<RefPermuteFloat16Workload, RefPermuteFloat32Workload, RefPermuteQAsymm8Workload,
         NullWorkload, NullWorkload, NullWorkload>(descriptor, info);
 }
@@ -568,6 +581,10 @@
     {
         return std::make_unique<RefTransposeQSymm16Workload>(descriptor, info);
     }
+    else if (IsBFloat16(info))
+    {
+        return std::make_unique<RefTransposeBFloat16Workload>(descriptor, info);
+    }
     return MakeWorkloadHelper<RefTransposeFloat16Workload, RefTransposeFloat32Workload, RefTransposeQAsymm8Workload,
             NullWorkload, NullWorkload, NullWorkload>(descriptor, info);
 }
diff --git a/src/backends/reference/test/RefLayerSupportTests.cpp b/src/backends/reference/test/RefLayerSupportTests.cpp
index f0c49ac..ab749c1 100644
--- a/src/backends/reference/test/RefLayerSupportTests.cpp
+++ b/src/backends/reference/test/RefLayerSupportTests.cpp
@@ -48,6 +48,12 @@
     BOOST_CHECK(supportChecker.IsAdditionSupported(in0, in1, out, reasonNotSupported));
 }
 
+BOOST_AUTO_TEST_CASE(IsLayerSupportedBFloat16Reference)
+{
+    armnn::RefWorkloadFactory factory;
+    IsLayerSupportedTests<armnn::RefWorkloadFactory, armnn::DataType::BFloat16>(&factory);
+}
+
 BOOST_AUTO_TEST_CASE(IsLayerSupportedFloat16Reference)
 {
     armnn::RefWorkloadFactory factory;
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index 40bf600..a6bfe35 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -70,6 +70,14 @@
 
 ARMNN_AUTO_TEST_CASE(SimpleConvolution2dSquareNhwc, SimpleConvolution2d3x3NhwcTest, false)
 
+ARMNN_AUTO_TEST_CASE(Convolution2d3x3Dilation3x3BFloat16,
+                     Convolution2d3x3Dilation3x3Test<DataType::BFloat16, DataType::BFloat16>,
+                     false,
+                     DataLayout::NCHW)
+ARMNN_AUTO_TEST_CASE(Convolution2d3x3Dilation3x3NhwcBFloat16,
+                     Convolution2d3x3Dilation3x3Test<DataType::BFloat16, DataType::BFloat16>,
+                     false,
+                     DataLayout::NHWC)
 ARMNN_AUTO_TEST_CASE(Convolution2d3x3Dilation3x3,
                      Convolution2d3x3Dilation3x3Test<DataType::Float32, DataType::Float32>,
                      false,
@@ -95,6 +103,14 @@
                      false,
                      DataLayout::NHWC)
 
+ARMNN_AUTO_TEST_CASE(Convolution2d2x3x3Dilation3x3BFloat16,
+                     Convolution2d2x3x3Dilation3x3Test<DataType::BFloat16, DataType::BFloat16>,
+                     false,
+                     DataLayout::NCHW)
+ARMNN_AUTO_TEST_CASE(Convolution2d2x3x3Dilation3x3NhwcBFloat16,
+                     Convolution2d2x3x3Dilation3x3Test<DataType::BFloat16, DataType::BFloat16>,
+                     false,
+                     DataLayout::NHWC)
 ARMNN_AUTO_TEST_CASE(Convolution2d2x3x3Dilation3x3,
                      Convolution2d2x3x3Dilation3x3Test<DataType::Float32, DataType::Float32>,
                      false,
@@ -120,6 +136,14 @@
                      false,
                      DataLayout::NHWC)
 
+ARMNN_AUTO_TEST_CASE(Convolution2d2x2Dilation2x2Padding2x2Stride3x3BFloat16,
+                     Convolution2d2x2Dilation2x2Padding2x2Stride3x3Test<DataType::BFloat16, DataType::BFloat16>,
+                     false,
+                     DataLayout::NCHW)
+ARMNN_AUTO_TEST_CASE(Convolution2d2x2Dilation2x2Padding2x2Stride3x3NhwcBFloat16,
+                     Convolution2d2x2Dilation2x2Padding2x2Stride3x3Test<DataType::BFloat16, DataType::BFloat16>,
+                     false,
+                     DataLayout::NHWC)
 ARMNN_AUTO_TEST_CASE(Convolution2d2x2Dilation2x2Padding2x2Stride3x3,
                      Convolution2d2x2Dilation2x2Padding2x2Stride3x3Test<DataType::Float32, DataType::Float32>,
                      false,
@@ -179,6 +203,14 @@
                      DepthwiseConvolution2d3x3Dilation3x3Test<DataType::Float32, DataType::Float32>,
                      false,
                      DataLayout::NHWC)
+ARMNN_AUTO_TEST_CASE(DepthwiseConvolution2d3x3Dilation3x3BFloat16,
+                     DepthwiseConvolution2d3x3Dilation3x3Test<DataType::BFloat16, DataType::BFloat16>,
+                     false,
+                     DataLayout::NCHW)
+ARMNN_AUTO_TEST_CASE(DepthwiseConvolution2d3x3Dilation3x3NhwcBFloat16,
+                     DepthwiseConvolution2d3x3Dilation3x3Test<DataType::BFloat16, DataType::BFloat16>,
+                     false,
+                     DataLayout::NHWC)
 ARMNN_AUTO_TEST_CASE(DepthwiseConvolution2d3x3Dilation3x3Uint8,
                      DepthwiseConvolution2d3x3Dilation3x3Test<DataType::QAsymmU8, DataType::Signed32>,
                      false,
@@ -204,6 +236,14 @@
                      DepthwiseConvolution2d2x3x3Dilation3x3Test<DataType::Float32, DataType::Float32>,
                      false,
                      DataLayout::NHWC)
+ARMNN_AUTO_TEST_CASE(DepthwiseConvolution2d2x3x3Dilation3x3BFloat16,
+                     DepthwiseConvolution2d2x3x3Dilation3x3Test<DataType::BFloat16, DataType::BFloat16>,
+                     false,
+                     DataLayout::NCHW)
+ARMNN_AUTO_TEST_CASE(DepthwiseConvolution2d2x3x3Dilation3x3NhwcBFloat16,
+                     DepthwiseConvolution2d2x3x3Dilation3x3Test<DataType::BFloat16, DataType::BFloat16>,
+                     false,
+                     DataLayout::NHWC)
 ARMNN_AUTO_TEST_CASE(DepthwiseConvolution2d2x3x3Dilation3x3Uint8,
                      DepthwiseConvolution2d2x3x3Dilation3x3Test<DataType::QAsymmU8, DataType::Signed32>,
                      false,
@@ -228,6 +268,14 @@
                      DepthwiseConvolution2dMult2Test<armnn::DataType::Float32, armnn::DataType::Float32>,
                      false,
                      armnn::DataLayout::NCHW)
+ARMNN_AUTO_TEST_CASE(DepthwiseConvolution2dMult4BFloat16,
+                     DepthwiseConvolution2dMult4Test<armnn::DataType::BFloat16, armnn::DataType::BFloat16>,
+                     false,
+                     armnn::DataLayout::NCHW)
+ARMNN_AUTO_TEST_CASE(DepthwiseConvolution2dMult2BFloat16,
+                     DepthwiseConvolution2dMult2Test<armnn::DataType::BFloat16, armnn::DataType::BFloat16>,
+                     false,
+                     armnn::DataLayout::NCHW)
 
 ARMNN_AUTO_TEST_CASE(DepthwiseConvolution2dDepthMul1,
                      DepthwiseConvolution2dDepthMul1Test, true, DataLayout::NCHW)
@@ -496,6 +544,7 @@
 
 // Concat
 ARMNN_AUTO_TEST_CASE(SimpleConcat, ConcatTest)
+ARMNN_AUTO_TEST_CASE(ConcatBFloat16, ConcatBFloat16Test)
 ARMNN_AUTO_TEST_CASE(ConcatFloat16, ConcatFloat16Test)
 ARMNN_AUTO_TEST_CASE(ConcatUint8, ConcatUint8Test)
 ARMNN_AUTO_TEST_CASE(ConcatUint8DifferentQParams, ConcatUint8DifferentQParamsTest)
@@ -950,6 +999,11 @@
 ARMNN_AUTO_TEST_CASE(LogSoftmaxFloat16_4, LogSoftmaxTest4<DataType::Float16>)
 
 // Pad
+ARMNN_AUTO_TEST_CASE(PadBFloat162d, PadBFloat162dTest)
+ARMNN_AUTO_TEST_CASE(PadBFloat162dCustomPadding, PadBFloat162dCustomPaddingTest)
+ARMNN_AUTO_TEST_CASE(PadBFloat163d, PadBFloat163dTest)
+ARMNN_AUTO_TEST_CASE(PadBFloat164d, PadBFloat164dTest)
+
 ARMNN_AUTO_TEST_CASE(PadFloat322d, PadFloat322dTest)
 ARMNN_AUTO_TEST_CASE(PadFloat322dCustomPadding, PadFloat322dCustomPaddingTest)
 ARMNN_AUTO_TEST_CASE(PadFloat323d, PadFloat323dTest)
@@ -1040,6 +1094,10 @@
 ARMNN_AUTO_TEST_CASE(Rsqrt3dQuantisedSymm16, Rsqrt3dTest<DataType::QSymmS16>)
 
 // Permute
+ARMNN_AUTO_TEST_CASE(SimplePermuteBFloat16, SimplePermuteTest<DataType::BFloat16>)
+ARMNN_AUTO_TEST_CASE(PermuteBFloat16ValueSet1Test, PermuteValueSet1Test<DataType::BFloat16>)
+ARMNN_AUTO_TEST_CASE(PermuteBFloat16ValueSet2Test, PermuteValueSet2Test<DataType::BFloat16>)
+ARMNN_AUTO_TEST_CASE(PermuteBFloat16ValueSet3Test, PermuteValueSet3Test<DataType::BFloat16>)
 ARMNN_AUTO_TEST_CASE(SimplePermuteFloat32, SimplePermuteTest<DataType::Float32>)
 ARMNN_AUTO_TEST_CASE(PermuteFloat32ValueSet1Test, PermuteValueSet1Test<DataType::Float32>)
 ARMNN_AUTO_TEST_CASE(PermuteFloat32ValueSet2Test, PermuteValueSet2Test<DataType::Float32>)
@@ -1465,6 +1523,10 @@
 ARMNN_AUTO_TEST_CASE(Slice1dInt16, Slice1dInt16Test)
 
 // Transpose
+ARMNN_AUTO_TEST_CASE(SimpleTransposeBFloat16, SimpleTransposeTest<DataType::BFloat16>)
+ARMNN_AUTO_TEST_CASE(TransposeBFloat16ValueSet1Test, TransposeValueSet1Test<DataType::BFloat16>)
+ARMNN_AUTO_TEST_CASE(TransposeBFloat16ValueSet2Test, TransposeValueSet2Test<DataType::BFloat16>)
+ARMNN_AUTO_TEST_CASE(TransposeBFloat16ValueSet3Test, TransposeValueSet3Test<DataType::BFloat16>)
 ARMNN_AUTO_TEST_CASE(SimpleTransposeFloat32, SimpleTransposeTest<DataType::Float32>)
 ARMNN_AUTO_TEST_CASE(TransposeFloat32ValueSet1Test, TransposeValueSet1Test<DataType::Float32>)
 ARMNN_AUTO_TEST_CASE(TransposeFloat32ValueSet2Test, TransposeValueSet2Test<DataType::Float32>)
diff --git a/src/backends/reference/workloads/Pad.cpp b/src/backends/reference/workloads/Pad.cpp
index 9fedb44..ffdd469 100644
--- a/src/backends/reference/workloads/Pad.cpp
+++ b/src/backends/reference/workloads/Pad.cpp
@@ -152,6 +152,13 @@
     }
 }
 
+template void Pad<BFloat16>(const TensorInfo& inputInfo,
+                            const TensorInfo& outputInfo,
+                            std::vector<std::pair<unsigned int, unsigned int>> m_PadList,
+                            const BFloat16* inputData,
+                            BFloat16* outData,
+                            const float padValue);
+
 template void Pad<float>(const TensorInfo& inputInfo,
                          const TensorInfo& outputInfo,
                          std::vector<std::pair<unsigned int, unsigned int>> m_PadList,
diff --git a/src/backends/reference/workloads/RefPadWorkload.cpp b/src/backends/reference/workloads/RefPadWorkload.cpp
index 356f6b1..777682d 100644
--- a/src/backends/reference/workloads/RefPadWorkload.cpp
+++ b/src/backends/reference/workloads/RefPadWorkload.cpp
@@ -33,6 +33,7 @@
     Pad(inputInfo, outputInfo, m_Data.m_Parameters.m_PadList, inputData, outputData, m_Data.m_Parameters.m_PadValue);
 }
 
+template class RefPadWorkload<DataType::BFloat16>;
 template class RefPadWorkload<DataType::Float32>;
 template class RefPadWorkload<DataType::Float16>;
 template class RefPadWorkload<DataType::QAsymmU8>;
diff --git a/src/backends/reference/workloads/RefPadWorkload.hpp b/src/backends/reference/workloads/RefPadWorkload.hpp
index 28fb553..5134ac8 100644
--- a/src/backends/reference/workloads/RefPadWorkload.hpp
+++ b/src/backends/reference/workloads/RefPadWorkload.hpp
@@ -30,6 +30,7 @@
     void Execute() const override;
 };
 
+using RefPadBFloat16Workload = RefPadWorkload<DataType::BFloat16>;
 using RefPadFloat32Workload = RefPadWorkload<DataType::Float32>;
 using RefPadFloat16Workload = RefPadWorkload<DataType::Float16>;
 using RefPadQAsymm8Workload = RefPadWorkload<DataType::QAsymmU8>;
diff --git a/src/backends/reference/workloads/RefPermuteWorkload.cpp b/src/backends/reference/workloads/RefPermuteWorkload.cpp
index d0e1431..5751ed8 100644
--- a/src/backends/reference/workloads/RefPermuteWorkload.cpp
+++ b/src/backends/reference/workloads/RefPermuteWorkload.cpp
@@ -28,6 +28,7 @@
                         src->Map(), dst->Map(), sizeof(T));
 }
 
+template class RefPermuteWorkload<DataType::BFloat16>;
 template class RefPermuteWorkload<DataType::Float16>;
 template class RefPermuteWorkload<DataType::Float32>;
 template class RefPermuteWorkload<DataType::QAsymmU8>;
diff --git a/src/backends/reference/workloads/RefPermuteWorkload.hpp b/src/backends/reference/workloads/RefPermuteWorkload.hpp
index 00a3385..a8d308e 100644
--- a/src/backends/reference/workloads/RefPermuteWorkload.hpp
+++ b/src/backends/reference/workloads/RefPermuteWorkload.hpp
@@ -27,6 +27,7 @@
     void Execute() const override;
 };
 
+using RefPermuteBFloat16Workload = RefPermuteWorkload<DataType::BFloat16>;
 using RefPermuteFloat16Workload = RefPermuteWorkload<DataType::Float16>;
 using RefPermuteFloat32Workload = RefPermuteWorkload<DataType::Float32>;
 using RefPermuteQAsymm8Workload = RefPermuteWorkload<DataType::QAsymmU8>;
diff --git a/src/backends/reference/workloads/RefTransposeWorkload.cpp b/src/backends/reference/workloads/RefTransposeWorkload.cpp
index 6bdfb21..242668b 100644
--- a/src/backends/reference/workloads/RefTransposeWorkload.cpp
+++ b/src/backends/reference/workloads/RefTransposeWorkload.cpp
@@ -27,6 +27,7 @@
     armnnUtils::Transpose(GetTensorInfo(src).GetShape(), mappings, src->Map(), dst->Map(), sizeof(T));
 }
 
+template class RefTransposeWorkload<DataType::BFloat16>;
 template class RefTransposeWorkload<DataType::Float16>;
 template class RefTransposeWorkload<DataType::Float32>;
 template class RefTransposeWorkload<DataType::QAsymmU8>;
diff --git a/src/backends/reference/workloads/RefTransposeWorkload.hpp b/src/backends/reference/workloads/RefTransposeWorkload.hpp
index 4b1c3d3..dcfe618 100644
--- a/src/backends/reference/workloads/RefTransposeWorkload.hpp
+++ b/src/backends/reference/workloads/RefTransposeWorkload.hpp
@@ -27,6 +27,7 @@
     void Execute() const override;
 };
 
+using RefTransposeBFloat16Workload = RefTransposeWorkload<DataType::BFloat16>;
 using RefTransposeFloat16Workload = RefTransposeWorkload<DataType::Float16>;
 using RefTransposeFloat32Workload = RefTransposeWorkload<DataType::Float32>;
 using RefTransposeQAsymm8Workload = RefTransposeWorkload<DataType::QAsymmU8>;