Add BF16 support to reference model
* Upgrade Eigen to 3.4.0 (for bfloat16 support) and add work-
arounds for reduce.any() and reduce.all() bugs (introduced
between 3.3.7 and 3.4.0)
* Truncation to bfloat16 now performed in eval() methods
Signed-off-by: James Ward <james.ward@arm.com>
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: If5f5c988d76d3d30790acf3b97081726b89205fe
diff --git a/reference_model/src/ops/comparison.cc b/reference_model/src/ops/comparison.cc
index f240aa5..5b78a4f 100644
--- a/reference_model/src/ops/comparison.cc
+++ b/reference_model/src/ops/comparison.cc
@@ -28,6 +28,7 @@
switch (Dtype)
{
case DType_FP16:
+ case DType_BF16:
case DType_FP32:
case DType_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a == b; };
@@ -45,6 +46,7 @@
switch (Dtype)
{
case DType_FP16:
+ case DType_BF16:
case DType_FP32:
case DType_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b; };
@@ -62,6 +64,7 @@
switch (Dtype)
{
case DType_FP16:
+ case DType_BF16:
case DType_FP32:
case DType_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a >= b; };
@@ -75,13 +78,16 @@
// template explicit instantiation
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32);