Add BF16 support to reference model

* Upgrade Eigen to 3.4.0 (for bfloat16 support) and add work-
  arounds for reduce.any() and reduce.all() bugs (introduced
  between 3.3.7 and 3.4.0)
* Truncation to bfloat16 now performed in eval() methods

Signed-off-by: James Ward <james.ward@arm.com>
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: If5f5c988d76d3d30790acf3b97081726b89205fe
diff --git a/reference_model/src/arith_util.h b/reference_model/src/arith_util.h
index 554a7a2..33bdeed 100644
--- a/reference_model/src/arith_util.h
+++ b/reference_model/src/arith_util.h
@@ -31,13 +31,18 @@
 #include <math.h>
 #define __STDC_LIMIT_MACROS    //enable min/max of plain data type
 #include "func_debug.h"
+#include "func_config.h"
 #include "inttypes.h"
+#include "tosa_generated.h"
 #include <cassert>
 #include <iostream>
 #include <limits>
 #include <stdint.h>
 #include <typeinfo>
+#include <Eigen/Core>
+#include <bitset>
 
+using namespace tosa;
 using namespace std;
 
 inline size_t _count_one(uint64_t val)
@@ -191,4 +196,88 @@
     // clang-format on
 }
 
+inline void float_trunc_bytes(float* src)
+{
+    /* Set the least significant two bytes to zero for the input float value.*/
+    char src_as_bytes[sizeof(float)];
+    memcpy(src_as_bytes, src, sizeof(float));
+
+    if (g_func_config.float_is_big_endian)
+    {
+        src_as_bytes[2] = '\000';
+        src_as_bytes[3] = '\000';
+    }
+    else
+    {
+        src_as_bytes[0] = '\000';
+        src_as_bytes[1] = '\000';
+    }
+
+    memcpy(src, &src_as_bytes, sizeof(float));
+}
+
+inline void truncateFloatToBFloat(float* src, int64_t size) {
+    /* Set the least significant two bytes to zero for each float
+    value in the input src buffer. */
+    ASSERT_MEM(src);
+    ASSERT_MSG(size > 0, "Size of src (representing number of values in src) must be a positive integer.");
+    for (; size != 0; src++, size--)
+    {
+        float_trunc_bytes(src);
+    }
+}
+
+inline bool checkValidBFloat(float src)
+{
+    /* Checks if the least significant two bytes are zero. */
+    ASSERT_MEM(src);
+    char src_as_bytes[sizeof(float)];
+    memcpy(src_as_bytes, &src, sizeof(float));
+
+    if (g_func_config.float_is_big_endian)
+    {
+        return (src_as_bytes[2] == '\000' && src_as_bytes[3] == '\000');
+    }
+    else
+    {
+        return (src_as_bytes[0] == '\000' && src_as_bytes[1] == '\000');
+    }
+}
+
+inline bool float_is_big_endian()
+{
+    /* Compares float values 1.0 and -1.0 by checking whether the
+    negation causes the first or the last byte to change.
+    First byte changing would indicate the float representation
+    is big-endian.*/
+    float f = 1.0;
+    char f_as_bytes[sizeof(float)];
+    memcpy(f_as_bytes, &f, sizeof(float));
+    f = -f;
+    char f_neg_as_bytes[sizeof(float)];
+    memcpy(f_neg_as_bytes, &f, sizeof(float));
+    return f_as_bytes[0] != f_neg_as_bytes[0];
+}
+
+template <DType Dtype>
+float fpTrunc(float f_in)
+{
+    /* Truncates a float value based on the DType it represents.*/
+    switch (Dtype)
+    {
+        case DType_BF16:
+            truncateFloatToBFloat(&f_in, 1);
+            break;
+        case DType_FP16:
+            // TODO(jw): implement FP16 truncate function (no-op placeholder for now)
+            break;
+        case DType_FP32:
+            // No-op for fp32
+            break;
+        default:
+            ASSERT_MSG(false, "DType %s should not be float-truncated.", EnumNameDType(Dtype));
+    }
+    return f_in;
+}
+
 #endif /* _ARITH_UTIL_H */
diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp
index 776fbf3..5c2735d 100644
--- a/reference_model/src/main.cpp
+++ b/reference_model/src/main.cpp
@@ -20,6 +20,7 @@
 #include "ops/op_factory.h"
 #include "subgraph_traverser.h"
 #include "tosa_serialization_handler.h"
+#include "arith_util.h"
 
 #include <fstream>
 #include <iostream>
@@ -67,6 +68,8 @@
             return TOSA_VERSION_MISMATCH;
     }
 
+    g_func_config.float_is_big_endian = float_is_big_endian();
+
     json test_desc;
 
     // Initialize test descriptor
diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc
index 61f7df6..46234e2 100644
--- a/reference_model/src/ops/activation_funcs.cc
+++ b/reference_model/src/ops/activation_funcs.cc
@@ -16,6 +16,7 @@
 #include "activation_funcs.h"
 #include "quant_util.h"
 #include "template_types.h"
+#include "arith_util.h"
 #include <cmath>
 
 using namespace TosaReference;
@@ -28,13 +29,14 @@
     switch (Dtype)
     {
         case DType_FP16:
+        case DType_BF16:
         case DType_FP32:
         {
             InEigenType min = (InEigenType)attribute->min_fp();
             InEigenType max = (InEigenType)attribute->max_fp();
             ERROR_IF(max < min, "OpClamp: max smaller than min");
 
-            this->fcn = [min, max](InEigenType a) -> OutEigenType { return a <= min ? min : a >= max ? max : a; };
+            this->fcn = [min, max](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(a <= min ? min : a >= max ? max : a); };
         }
         break;
         case DType_INT8:
@@ -59,8 +61,9 @@
     switch (Dtype)
     {
         case DType_FP16:
+        case DType_BF16:
         case DType_FP32:
-            this->fcn = [](InEigenType a) -> OutEigenType { return (1.0 / (1.0 + (expf(-1.0 * a)))); };
+            this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(1.0 / (1.0 + (expf(-1.0 * a)))); };
             break;
         default:
             ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
@@ -75,8 +78,9 @@
     switch (Dtype)
     {
         case DType_FP16:
+        case DType_BF16:
         case DType_FP32:
-            this->fcn = [](InEigenType a) -> OutEigenType { return tanhf(a); };
+            this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(tanhf(a)); };
             break;
         default:
             ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
@@ -87,12 +91,15 @@
 
 // template explicit instantiation
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, BF16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP32);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16);
 
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, BF16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP32);
 
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, BF16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP32);
diff --git a/reference_model/src/ops/comparison.cc b/reference_model/src/ops/comparison.cc
index f240aa5..5b78a4f 100644
--- a/reference_model/src/ops/comparison.cc
+++ b/reference_model/src/ops/comparison.cc
@@ -28,6 +28,7 @@
     switch (Dtype)
     {
         case DType_FP16:
+        case DType_BF16:
         case DType_FP32:
         case DType_INT32:
             this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a == b; };
@@ -45,6 +46,7 @@
     switch (Dtype)
     {
         case DType_FP16:
+        case DType_BF16:
         case DType_FP32:
         case DType_INT32:
             this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b; };
@@ -62,6 +64,7 @@
     switch (Dtype)
     {
         case DType_FP16:
+        case DType_BF16:
         case DType_FP32:
         case DType_INT32:
             this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a >= b; };
@@ -75,13 +78,16 @@
 
 // template explicit instantiation
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, BF16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP32);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32);
 
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, BF16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP32);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32);
 
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, BF16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP32);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32);
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index 69b6a65..bffd659 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -639,6 +639,7 @@
 
 // template explicit instantiation
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP16)
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BF16)
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP32)
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8)
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16)
@@ -646,6 +647,7 @@
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL)
 
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BF16);
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP32);
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8);
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16);
@@ -653,6 +655,7 @@
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
 
 DEF_INSTANTIATE_RESHAPE(OpReshape, FP16);
+DEF_INSTANTIATE_RESHAPE(OpReshape, BF16);
 DEF_INSTANTIATE_RESHAPE(OpReshape, FP32);
 DEF_INSTANTIATE_RESHAPE(OpReshape, INT8);
 DEF_INSTANTIATE_RESHAPE(OpReshape, INT16);
@@ -660,6 +663,7 @@
 DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL);
 
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BF16);
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP32);
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8);
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
@@ -667,6 +671,7 @@
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
 
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BF16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP32);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
@@ -674,6 +679,7 @@
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
 
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BF16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP32);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
@@ -681,6 +687,7 @@
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
 
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
diff --git a/reference_model/src/ops/data_nodes.cc b/reference_model/src/ops/data_nodes.cc
index 5709a92..f5304a5 100644
--- a/reference_model/src/ops/data_nodes.cc
+++ b/reference_model/src/ops/data_nodes.cc
@@ -90,6 +90,7 @@
 // note OpConst is not templated
 
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BF16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP32);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16);
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index 098b0ea..e4c0ee0 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -143,8 +143,9 @@
             };
             break;
         case DType_FP16:
+        case DType_BF16:
         case DType_FP32:
-            this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; };
+            this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(a + b); };
             break;
         default:
             ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
@@ -371,6 +372,7 @@
     switch (Dtype)
     {
         case DType_FP16:
+        case DType_BF16:
         case DType_FP32:
         case DType_INT32:
             this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; };
@@ -388,6 +390,7 @@
     switch (Dtype)
     {
         case DType_FP16:
+        case DType_BF16:
         case DType_FP32:
         case DType_INT32:
             this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; };
@@ -407,8 +410,9 @@
     switch (InDtype)
     {
         case DType_FP16:
+        case DType_BF16:
         case DType_FP32:
-            this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return a * b; };
+            this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(a * b); };
             break;
         case DType_INT32:
             this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType {
@@ -457,8 +461,9 @@
     switch (Dtype)
     {
         case DType_FP16:
+        case DType_BF16:
         case DType_FP32:
-            this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return powf(a, b); };
+            this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(powf(a, b)); };
             break;
         default:
             ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
@@ -482,8 +487,9 @@
             };
             break;
         case DType_FP16:
+        case DType_BF16:
         case DType_FP32:
-            this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; };
+            this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(a - b); };
             break;
         default:
             ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
@@ -581,6 +587,7 @@
 
 // template explicit instantiation
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, BF16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP32);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
 
@@ -617,23 +624,28 @@
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL);
 
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, BF16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP32);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
 
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, BF16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP32);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
 
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, BF16, BF16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP32, FP32);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
 
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, BF16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP32);
 
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, BF16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP32);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
 
@@ -643,5 +655,6 @@
 // Instantiation of nodes for comparison operators opEqual, opGreater
 // and opGreaterEqual
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP16, BOOL);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, BF16, BOOL);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP32, BOOL);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL);
diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc
index d85da1a..677a4e2 100644
--- a/reference_model/src/ops/ewise_ternary.cc
+++ b/reference_model/src/ops/ewise_ternary.cc
@@ -108,6 +108,7 @@
 
 // template explicit instantiation
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BF16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP32);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16);
diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc
index 00897cc..5347b8c 100644
--- a/reference_model/src/ops/ewise_unary.cc
+++ b/reference_model/src/ops/ewise_unary.cc
@@ -78,11 +78,14 @@
 {
     switch (Dtype)
     {
-        case DType_FP32:
-        case DType_FP16:
+        case DType_FP32: // No fpTrunc for FP32 as it is a no-op
         case DType_INT32:
             this->fcn = [](InEigenType a) -> OutEigenType { return a > (InEigenType)0 ? a : (-a); };
             break;
+        case DType_FP16:
+        case DType_BF16:
+            this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(a > (InEigenType)0 ? a : (-a)); };
+            break;
         default:
             ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
     }
@@ -113,8 +116,9 @@
     switch (Dtype)
     {
         case DType_FP16:
+        case DType_BF16:
         case DType_FP32:
-            this->fcn = [](InEigenType a) -> OutEigenType { return ceilf(a); };
+            this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(ceilf(a)); };
             break;
         default:
             ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
@@ -161,8 +165,9 @@
     switch (Dtype)
     {
         case DType_FP16:
+        case DType_BF16:
         case DType_FP32:
-            this->fcn = [](InEigenType a) -> OutEigenType { return expf(a); };
+            this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(expf(a)); };
             break;
         default:
             ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
@@ -177,8 +182,9 @@
     switch (Dtype)
     {
         case DType_FP16:
+        case DType_BF16:
         case DType_FP32:
-            this->fcn = [](InEigenType a) -> OutEigenType { return floorf(a); };
+            this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(floorf(a)); };
             break;
         default:
             ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
@@ -193,8 +199,9 @@
     switch (Dtype)
     {
         case DType_FP16:
+        case DType_BF16:
         case DType_FP32:
-            this->fcn = [](InEigenType a) -> OutEigenType { return logf(a); };
+            this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(logf(a)); };
             break;
         default:
             ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
@@ -245,10 +252,11 @@
     switch (Dtype)
     {
         case DType_FP16:
+        case DType_BF16:
         case DType_FP32:
             this->fcn = [](InEigenType a) -> OutEigenType {
                 InEigenType result = -(a);
-                return result;
+                return fpTrunc<Dtype>(result);
             };
             break;
         case DType_INT16:
@@ -297,8 +305,9 @@
     switch (Dtype)
     {
         case DType_FP16:
+        case DType_BF16:
         case DType_FP32:
-            this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / a; };
+            this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(1.0 / a); };
             break;
         default:
             ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
@@ -313,8 +322,9 @@
     switch (Dtype)
     {
         case DType_FP16:
+        case DType_BF16:
         case DType_FP32:
-            this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / sqrtf(a); };
+            this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(1.0 / sqrtf(a)); };
             break;
         default:
             ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
@@ -325,6 +335,7 @@
 
 // template explicit instantiation
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, BF16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP32);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32);
 
@@ -333,29 +344,36 @@
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32);
 
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, BF16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP32);
 
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32);
 
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32);
 
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, BF16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP32);
 
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, BF16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP32);
 
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL);
 
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, BF16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP32);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32);
 
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, BF16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32);
 
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, BF16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32);
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc
index cf1d9f7..66efee0 100644
--- a/reference_model/src/ops/image.cc
+++ b/reference_model/src/ops/image.cc
@@ -63,7 +63,7 @@
 
     if (this->mode == ResizeMode_BILINEAR)
     {
-        if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FP32 && OutDtype != DType_FP16)
+        if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FP32 && OutDtype != DType_FP16 && OutDtype != DType_BF16)
         {
             printNodeValidationError("OpResize: invalid data type for BILINEAR");
             return 1;
@@ -71,7 +71,7 @@
     }
     else
     {
-        if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FP32 && OutDtype != DType_FP16)
+        if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FP32 && OutDtype != DType_FP16 && OutDtype != DType_BF16)
         {
             printNodeValidationError("OpResize: invalid data type for NEAREST");
             return 1;
@@ -159,15 +159,15 @@
 
                     resize_t dy;
                     resize_t dx;
-                    if (std::is_floating_point<resize_t>::value)
+                    if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)))
                     {
-                        dy = fy - iy;
-                        dx = fx - ix;
+                        dy = (resize_t)(fy - iy);
+                        dx = (resize_t)(fx - ix);
                     }
                     else
                     {
-                        dy = y - (iy * scale_y_n);
-                        dx = x - (ix * scale_x_n);
+                        dy = (resize_t)(y - (iy * scale_y_n));
+                        dx = (resize_t)(x - (ix * scale_x_n));
                     }
 
                     int32_t iy0 = MAX(iy, 0);
@@ -190,6 +190,15 @@
                             acc += (OutEigenType)v10 * dy * (1.0 - dx);
                             acc += (OutEigenType)v11 * dy * dx;
                         }
+                        else if ((typeid(resize_t) == typeid(Eigen::bfloat16)))
+                        {
+                            Eigen::bfloat16 bf16_acc;
+                            bf16_acc = (Eigen::bfloat16)v00 * (Eigen::bfloat16)(1.0 - dy) * (Eigen::bfloat16)(1.0 - dx);
+                            bf16_acc += (Eigen::bfloat16)v01 * (Eigen::bfloat16)(1.0 - dy) * (Eigen::bfloat16)dx;
+                            bf16_acc += (Eigen::bfloat16)v10 * (Eigen::bfloat16)dy * (Eigen::bfloat16)(1.0 - dx);
+                            bf16_acc += (Eigen::bfloat16)v11 * (Eigen::bfloat16)dy * (Eigen::bfloat16)dx;
+                            acc = (float)bf16_acc;
+                        }
                         else
                         {
                             acc = (OutEigenType)v00 * (scale_y_n - dy) * (scale_x_n - dx);
@@ -201,7 +210,7 @@
                     else
                     {
                         ASSERT_MSG(mode == ResizeMode_NEAREST, "OpResize: invalid mode");
-                        if (std::is_floating_point<resize_t>::value)
+                        if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)))
                         {
                             iy = (dy >= 0.5) ? iy1 : iy0;
                             ix = (dx >= 0.5) ? ix1 : ix0;
@@ -213,6 +222,9 @@
                         }
                         acc = in->getTensor()(b, iy, ix, c);
                     }
+                    if ((typeid(resize_t) == typeid(Eigen::bfloat16))) {
+                        ASSERT_MSG(checkValidBFloat(acc), "Resize accumulator float value is not a valid bfloat16 value.");
+                    }
                     out->getTensor()(b, oy, ox, c) = acc;
                 }
 
@@ -225,4 +237,5 @@
 DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT48, int16_t);
 DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT16, int16_t);
 DEF_INSTANTIATE_THREE_TYPE(OpResize, FP16, FP16, float);
+DEF_INSTANTIATE_THREE_TYPE(OpResize, BF16, BF16, Eigen::bfloat16);
 DEF_INSTANTIATE_THREE_TYPE(OpResize, FP32, FP32, float);
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index 1ff8229..0121ccf 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -49,6 +49,7 @@
         // tensor_ops
         case Op_ARGMAX:
             DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16);
+            DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, BF16);
             DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32);
             DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
             DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
@@ -56,6 +57,7 @@
         case Op_AVG_POOL2D:
             DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP16);
             DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP32);
+            DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, BF16, FP32);
             DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP32, FP32);
             DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT8, INT32);
             DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT16, INT32);
@@ -63,6 +65,7 @@
         case Op_CONV2D:
             DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP16);
             DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP32);
+            DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, BF16, BF16, FP32);
             DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP32, FP32, FP32);
             DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT4, INT32);
             DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT8, INT32);
@@ -71,6 +74,7 @@
         case Op_CONV3D:
             DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP16);
             DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP32);
+            DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, BF16, BF16, FP32);
             DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP32, FP32, FP32);
             DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT4, INT32);
             DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT8, INT32);
@@ -79,6 +83,7 @@
         case Op_DEPTHWISE_CONV2D:
             DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP16);
             DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP32);
+            DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, BF16, BF16, FP32);
             DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP32, FP32, FP32);
             DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT4, INT32);
             DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT8, INT32);
@@ -87,6 +92,7 @@
         case Op_FULLY_CONNECTED:
             DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FP16);
             DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FP32);
+            DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, BF16, BF16, FP32);
             DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP32, FP32, FP32);
             DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT4, INT32);
             DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT8, INT32);
@@ -95,12 +101,14 @@
         case Op_MATMUL:
             DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FP16);
             DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FP32);
+            DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, BF16, FP32);
             DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP32, FP32);
             DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT8, INT32);
             DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT16, INT48);
             break;
         case Op_MAX_POOL2D:
             DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP16);
+            DEF_FACTORY_ONE_TYPE(OpMaxPool2d, BF16);
             DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP32);
             DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT8);
             DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16);
@@ -108,6 +116,7 @@
         case Op_TRANSPOSE_CONV2D:
             DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP16);
             DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP32);
+            DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, BF16, BF16, FP32);
             DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP32, FP32, FP32);
             DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT4, INT32);
             DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT8, INT32);
@@ -117,22 +126,26 @@
         // activation_funcs
         case Op_CLAMP:
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP16);
+            DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, BF16);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP32);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16);
             break;
         case Op_SIGMOID:
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16);
+            DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, BF16);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP32);
             break;
         case Op_TANH:
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16);
+            DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, BF16);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP32);
             break;
 
         // ewise_binary
         case Op_ADD:
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16);
+            DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, BF16);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP32);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
             break;
@@ -180,16 +193,19 @@
             break;
         case Op_MAXIMUM:
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16);
+            DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, BF16);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP32);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
             break;
         case Op_MINIMUM:
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16);
+            DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, BF16);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP32);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
             break;
         case Op_MUL:
             DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16);
+            DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, BF16, BF16);
             DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP32, FP32);
             DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
             DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
@@ -197,10 +213,12 @@
             break;
         case Op_POW:
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16);
+            DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, BF16);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP32);
             break;
         case Op_SUB:
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16);
+            DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, BF16);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP32);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
             break;
@@ -212,6 +230,7 @@
         // ewise_unary
         case Op_ABS:
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP16);
+            DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, BF16);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP32);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32);
             break;
@@ -222,6 +241,7 @@
             break;
         case Op_CEIL:
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16);
+            DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, BF16);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP32);
             break;
         case Op_CLZ:
@@ -229,14 +249,17 @@
             break;
         case Op_EXP:
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16);
+            DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32);
             break;
         case Op_FLOOR:
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16);
+            DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, BF16);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP32);
             break;
         case Op_LOG:
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16);
+            DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, BF16);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP32);
             break;
         case Op_LOGICAL_NOT:
@@ -244,6 +267,7 @@
             break;
         case Op_NEGATE:
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP16);
+            DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, BF16);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP32);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16);
@@ -251,16 +275,19 @@
             break;
         case Op_RECIPROCAL:
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16);
+            DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, BF16);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32);
             break;
         case Op_RSQRT:
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16);
+            DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, BF16);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32);
             break;
 
         // ewise_ternary
         case Op_SELECT:
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP16);
+            DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BF16);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP32);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16);
@@ -271,16 +298,19 @@
         // comparison
         case Op_EQUAL:
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP16);
+            DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, BF16);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP32);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32);
             break;
         case Op_GREATER:
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16);
+            DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, BF16);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP32);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32);
             break;
         case Op_GREATER_EQUAL:
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16);
+            DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, BF16);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP32);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32);
             break;
@@ -294,6 +324,7 @@
             break;
         case Op_REDUCE_MAX:
             DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP16);
+            DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, BF16);
             DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP32);
             DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8);
             DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16);
@@ -301,6 +332,7 @@
             break;
         case Op_REDUCE_MIN:
             DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16);
+            DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, BF16);
             DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP32);
             DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8);
             DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16);
@@ -308,10 +340,12 @@
             break;
         case Op_REDUCE_PRODUCT:
             DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16);
+            DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, BF16);
             DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32);
             break;
         case Op_REDUCE_SUM:
             DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16);
+            DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, BF16);
             DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32);
             DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32);
             break;
@@ -319,6 +353,7 @@
         // data layout
         case Op_CONCAT:
             DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP16);
+            DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BF16);
             DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP32);
             DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8);
             DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16);
@@ -327,6 +362,7 @@
             break;
         case Op_PAD:
             DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16);
+            DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BF16);
             DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP32);
             DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32);
             DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8);
@@ -335,6 +371,7 @@
             break;
         case Op_RESHAPE:
             DEF_FACTORY_RESHAPE(OpReshape, FP16);
+            DEF_FACTORY_RESHAPE(OpReshape, BF16);
             DEF_FACTORY_RESHAPE(OpReshape, FP32);
             DEF_FACTORY_RESHAPE(OpReshape, INT8);
             DEF_FACTORY_RESHAPE(OpReshape, INT16);
@@ -343,6 +380,7 @@
             break;
         case Op_REVERSE:
             DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16);
+            DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BF16);
             DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP32);
             DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8);
             DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
@@ -351,6 +389,7 @@
             break;
         case Op_SLICE:
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
+            DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BF16);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP32);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
@@ -359,6 +398,7 @@
             break;
         case Op_TILE:
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
+            DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BF16);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP32);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
@@ -368,6 +408,7 @@
         case Op_TRANSPOSE:
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16);
+            DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
@@ -380,6 +421,7 @@
             DEF_FACTORY_ONE_TYPE(OpGather, INT16);
             DEF_FACTORY_ONE_TYPE(OpGather, INT32);
             DEF_FACTORY_ONE_TYPE(OpGather, FP16);
+            DEF_FACTORY_ONE_TYPE(OpGather, BF16);
             DEF_FACTORY_ONE_TYPE(OpGather, FP32);
             break;
         case Op_SCATTER:
@@ -387,6 +429,7 @@
             DEF_FACTORY_ONE_TYPE(OpScatter, INT16);
             DEF_FACTORY_ONE_TYPE(OpScatter, INT32);
             DEF_FACTORY_ONE_TYPE(OpScatter, FP16);
+            DEF_FACTORY_ONE_TYPE(OpScatter, BF16);
             DEF_FACTORY_ONE_TYPE(OpScatter, FP32);
             break;
 
@@ -397,6 +440,7 @@
             DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OpResize, INT16, INT48);
             DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OpResize, INT16, INT16);
             DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OpResize, FP16, FP16);
+            DEF_FACTORY_TWO_TYPE_RESIZE_BF16(OpResize, BF16, BF16);
             DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OpResize, FP32, FP32);
             break;
 
@@ -405,6 +449,7 @@
             return new OpConst(sgt, id);
         case Op_IDENTITY:
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP16);
+            DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BF16);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP32);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32);
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8);
@@ -435,6 +480,9 @@
             DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8);
             DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16);
             DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32);
+            DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT8);
+            DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT16);
+            DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT32);
             DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT8);
             DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT16);
             DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32);
diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h
index b525e69..f399bd1 100644
--- a/reference_model/src/ops/op_factory.h
+++ b/reference_model/src/ops/op_factory.h
@@ -111,6 +111,12 @@
         return new OP<DType_##DTYPE1, DType_##DTYPE2, float>(sgt, attribute, id);                                      \
     }
 
+#define DEF_FACTORY_TWO_TYPE_RESIZE_BF16(OP, DTYPE1, DTYPE2)                                                           \
+    if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2)                                                 \
+    {                                                                                                                  \
+        return new OP<DType_##DTYPE1, DType_##DTYPE2, Eigen::bfloat16>(sgt, attribute, id);                            \
+    }
+
 #define DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OP, DTYPE1, DTYPE2)                                                          \
     if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2)                                                 \
     {                                                                                                                  \
diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc
index eccba09..cd9d55f 100644
--- a/reference_model/src/ops/reduction.cc
+++ b/reference_model/src/ops/reduction.cc
@@ -80,10 +80,30 @@
     return 0;
 }
 
+// These 2 reducers are to overcome a bug introduced in Eigen between 3.3.7 and 3.4.0
+// The in-built .any and .all operations now fail on an assert in TensorMorphing.h:150
+// which seems to be due to incorrect data being passed internally as m_impl
+struct AllReducer {
+    static const bool PacketAccess = false;
+    void reduce(const bool val, bool* accum) {
+        *accum = *accum && val;
+    }
+    bool initialize() const { return true; }
+    bool finalize(const bool accum) const { return accum; }
+};
+struct AnyReducer {
+    static const bool PacketAccess = false;
+    void reduce(const bool val, bool* accum) {
+        *accum = *accum || val;
+    }
+    bool initialize() const { return false; }
+    bool finalize(const bool accum) const { return accum; }
+};
+
 template <int Rank, DType Dtype>
 int OpReduceAll<Rank, Dtype>::eval()
 {
-    this->out->getTensor() = this->in->getTensor().all(this->dims).reshape(this->out->getTensor().dimensions());
+    this->out->getTensor() = this->in->getTensor().reduce(this->dims, AllReducer()).reshape(this->out->getTensor().dimensions());
 
     return GraphNode::eval();
 }
@@ -91,7 +111,7 @@
 template <int Rank, DType Dtype>
 int OpReduceAny<Rank, Dtype>::eval()
 {
-    this->out->getTensor() = this->in->getTensor().any(this->dims).reshape(this->out->getTensor().dimensions());
+    this->out->getTensor() = this->in->getTensor().reduce(this->dims, AnyReducer()).reshape(this->out->getTensor().dimensions());
 
     return GraphNode::eval();
 }
@@ -115,7 +135,16 @@
 template <int Rank, DType Dtype>
 int OpReduceProduct<Rank, Dtype>::eval()
 {
-    this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions());
+    switch(Dtype)
+    {
+        case DType_FP16:
+        case DType_BF16:
+            this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc<Dtype>(f);});
+            break;
+        default:
+            this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions());
+            break;
+    }
 
     return GraphNode::eval();
 }
@@ -123,7 +152,16 @@
 template <int Rank, DType Dtype>
 int OpReduceSum<Rank, Dtype>::eval()
 {
-    this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions());
+    switch(Dtype)
+    {
+        case DType_FP16:
+        case DType_BF16:
+            this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc<Dtype>(f);});
+            break;
+        default:
+            this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions());
+            break;
+    }
 
     return GraphNode::eval();
 }
@@ -159,20 +197,24 @@
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL);
 
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, BF16);
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP32);
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8);
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16);
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32);
 
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, BF16);
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP32);
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8);
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16);
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32);
 
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, BF16);
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32);
 
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, BF16);
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32);
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32);
diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc
index b6c4043..bcd8ce5 100644
--- a/reference_model/src/ops/scatter_gather.cc
+++ b/reference_model/src/ops/scatter_gather.cc
@@ -227,10 +227,12 @@
 DEF_INSTANTIATE_ONE_TYPE(OpGather, INT16);
 DEF_INSTANTIATE_ONE_TYPE(OpGather, INT32);
 DEF_INSTANTIATE_ONE_TYPE(OpGather, FP16);
+DEF_INSTANTIATE_ONE_TYPE(OpGather, BF16);
 DEF_INSTANTIATE_ONE_TYPE(OpGather, FP32);
 
 DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT8);
 DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT16);
 DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT32);
 DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP16);
+DEF_INSTANTIATE_ONE_TYPE(OpScatter, BF16);
 DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP32);
diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h
index 3de4899..647ca84 100644
--- a/reference_model/src/ops/template_types.h
+++ b/reference_model/src/ops/template_types.h
@@ -19,6 +19,8 @@
 #include "tosa_generated.h"
 #include <Eigen/CXX11/Tensor>
 #include "half.hpp"
+#include <Eigen/Core>
+#include "arith_util.h"
 
 using namespace tosa;
 
@@ -76,6 +78,12 @@
     using type = float;
 };
 template <>
+struct GetEigenType<DType_BF16>
+{
+    // NOTE: full precision used
+    using type = float;
+};
+template <>
 struct GetEigenType<DType_INT32>
 {
     using type = int32_t;
@@ -132,12 +140,6 @@
     using type = typename GetEigenType<Dtype>::type;
 };
 
-template <DType Dtype>
-struct GetHalfEigenType
-{
-    using type = half_float::half;
-};
-
 // Meta function to get number of bits
 template <DType T>
 struct GetNumBits
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index 7db5182..b9ac94a 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -507,12 +507,13 @@
     Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims = { Eigen::IndexPair<Eigen::Index>(1, 0) };
     Eigen::array<Eigen::Index, 4> bcast{ out_batch, 1, 1, out_channels };
 
+    ETensor2<int32_t> dm2_w = div_map_w.reshape(Eigen::array<Eigen::Index, 2>{ 1, out_width });
+    ETensor2<int32_t> dm2_h = div_map_h.reshape(Eigen::array<Eigen::Index, 2>{ out_height, 1 });
     ETensor4<int32_t> div_map =
-        div_map_h.reshape(Eigen::array<Eigen::Index, 2>{ out_height, 1 })
-            .contract(div_map_w.reshape(Eigen::array<Eigen::Index, 2>{ 1, out_width }), contract_dims)
+        dm2_h.contract(dm2_w, contract_dims)
             .reshape(Eigen::array<Eigen::Index, 4>{ 1, out_height, out_width, 1 })
             .broadcast(bcast);
-    if (Dtype != DType_FP32 && Dtype != DType_FP16)
+    if (Dtype != DType_FP32 && Dtype != DType_FP16 && Dtype != DType_BF16)
     {
         try
         {
@@ -533,7 +534,7 @@
     }
     else
     {
-        // Case for float-type resizes
+        // Case for float-types
         this->out->getTensor() = (sum / div_map.template cast<AccEigenType>()).template cast<OutEigenType>();
     }
 
@@ -1679,12 +1680,14 @@
 
 // template explicit instantiation
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, BF16);
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32);
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
 DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
 
 DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FP16);
 DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FP32);
+DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, BF16, FP32);
 DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP32, FP32);
 DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT8, INT32);
 DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT16, INT32);
@@ -1692,6 +1695,7 @@
                                           // [in_t, weight_t, acc_t]
 DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FP16);
 DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FP32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, BF16, BF16, FP32);
 DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP32, FP32, FP32);
 DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT8, INT4, INT32);
 DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT8, INT8, INT32);
@@ -1699,6 +1703,7 @@
 
 DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FP16);
 DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FP32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, BF16, BF16, FP32);
 DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP32, FP32, FP32);
 DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT8, INT4, INT32);
 DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT8, INT8, INT32);
@@ -1706,6 +1711,7 @@
 
 DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FP16);
 DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FP32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, BF16, BF16, FP32);
 DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP32, FP32, FP32);
 DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT8, INT4, INT32);
 DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT8, INT8, INT32);
@@ -1713,6 +1719,7 @@
 
 DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FP16);
 DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FP32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, BF16, BF16, FP32);
 DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP32, FP32, FP32);
 DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT8, INT4, INT32);
 DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT8, INT8, INT32);
@@ -1722,15 +1729,18 @@
 DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, INT16, INT48);
 DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FP16);
 DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FP32);
+DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, BF16, FP32);
 DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP32, FP32);
 
 DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16);
+DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, BF16);
 DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP32);
 DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8);
 DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16);
 
 DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP16);
 DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, BF16, BF16, FP32);
 DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP32, FP32, FP32);
 DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT8, INT4, INT32);
 DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT8, INT8, INT32);
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index f51c38c..e30c7bd 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -353,6 +353,9 @@
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT32);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT8);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32);
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index ae216d8..112e641 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -15,6 +15,7 @@
 
 #include "subgraph_traverser.h"
 #include "tosa_model_types.h"
+#include "arith_util.h"
 
 #ifndef SUBGRAPH_ERROR_IF
 #define SUBGRAPH_ERROR_IF(COND, fmt, ...)                                                                              \
@@ -403,6 +404,16 @@
                     tensor->setTensorValueFloat(f16_data.size(), f16_data.data());
                 }
                 break;
+                case DType_BF16:
+                {
+                    std::vector<float> fp32_data;
+                    TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
+                    // Ensure valid bfloat16 stored in each float
+                    for (auto f : fp32_data)
+                        ASSERT_MSG(checkValidBFloat(f), "Float value %f not valid bfloat16", f);
+                    tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
+                }
+                break;
                 case DType_FP32:
                 {
                     std::vector<float> fp32_data;
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc
index 8d192ca..4eaf21d 100644
--- a/reference_model/src/tensor.cc
+++ b/reference_model/src/tensor.cc
@@ -90,10 +90,12 @@
     int64_t* i64databuf = nullptr;
     bool* bdatabuf      = nullptr;
     NumpyUtilities::NPError nperror;
+    DType dtype = getDtype();
 
-    switch (getDtype())
+    switch (dtype)
     {
         case DType_FP32:
+        case DType_BF16:
             fdatabuf = (float*)calloc(sizeof(float), elements);
             ASSERT_MEM(fdatabuf);
 
@@ -154,19 +156,38 @@
             FATAL_ERROR("Unknown error parsing Numpy file: %s", filename);
     }
 
-    switch (getDtype())
+    switch (dtype)
     {
         case DType_FP16:
             // Convert from fp16 to fp32
+            //TODO(jw): remove this once we cast to fp16 in register_fcn/eval
             for (uint32_t i=0; i < elements; i++) {
                 fdatabuf[i] = half_float::half_cast<float, half_float::half>(f16databuf[i]);
             }
-            // Fall through to DType_FP32 case
+            if (setTensorValueFloat(elements, fdatabuf))
+            {
+                free(f16databuf);
+                free(fdatabuf);
+                return 1;
+            }
+            break;
+        case DType_BF16:
+            for (uint32_t i=0; i < elements; i++)
+            {
+                ASSERT_MSG(
+                    checkValidBFloat(fdatabuf[i]),
+                    "Input float value not a valid bfloat16 value."
+                );
+            }
+            if (setTensorValueFloat(elements, fdatabuf))
+            {
+                free(fdatabuf);
+                return 1;
+            }
+            break;
         case DType_FP32:
             if (setTensorValueFloat(elements, fdatabuf))
             {
-                if (f16databuf)
-                    free(f16databuf);
                 free(fdatabuf);
                 return 1;
             }
@@ -226,10 +247,12 @@
     bool* bdatabuf      = nullptr;
     NumpyUtilities::NPError nperror;
     uint32_t elements = getElementCount();
+    DType dtype = getDtype();
 
-    switch (getDtype())
+    switch (dtype)
     {
         case DType_FP32:
+        case DType_BF16:
             fdatabuf = (float*)calloc(sizeof(float), elements);
             ASSERT_MEM(fdatabuf);
 
@@ -238,7 +261,6 @@
                 free(fdatabuf);
                 return 1;
             }
-
             nperror = NumpyUtilities::writeToNpyFile(filename, shape, fdatabuf);
 
             free(fdatabuf);
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
index 4efbf84..a3ce4bb 100644
--- a/reference_model/src/tensor.h
+++ b/reference_model/src/tensor.h
@@ -646,6 +646,7 @@
         {
             case DType_FP32:
             case DType_FP16:
+            case DType_BF16:
                 switch (rank)
                 {
                     case 0: