Add BF16 support to reference model

* Upgrade Eigen to 3.4.0 (for bfloat16 support) and add work-
  arounds for reduce.any() and reduce.all() bugs (introduced
  between 3.3.7 and 3.4.0)
* Truncation to bfloat16 now performed in eval() methods

Signed-off-by: James Ward <james.ward@arm.com>
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: If5f5c988d76d3d30790acf3b97081726b89205fe
diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py
index 8ae3218..b7a76b6 100644
--- a/verif/checker/tosa_result_checker.py
+++ b/verif/checker/tosa_result_checker.py
@@ -9,6 +9,7 @@
 from pathlib import Path
 
 import numpy as np
+from generator.tosa_utils import float32_is_valid_bfloat16
 
 ##################################
 color_printing = True
@@ -63,7 +64,12 @@
 
 
 def test_check(
-    reference, result, test_name="test", quantize_tolerance=0, float_tolerance=1e-3
+    reference,
+    result,
+    test_name="test",
+    quantize_tolerance=0,
+    float_tolerance=1e-3,
+    misc_checks=[],
 ):
     """Check if the result is the same as the expected reference."""
     if not os.path.isfile(reference):
@@ -111,6 +117,20 @@
         )
         return (TestResult.MISMATCH, 0.0, msg)
 
+    # Perform miscellaneous checks
+    if "bf16" in misc_checks:
+        # Ensure floats are valid bfloat16 values
+        test_res_is_bf16 = all([float32_is_valid_bfloat16(f) for f in test_result.flat])
+        ref_res_is_bf16 = all(
+            [float32_is_valid_bfloat16(f) for f in reference_result.flat]
+        )
+        if not (test_res_is_bf16 and ref_res_is_bf16):
+            msg = (
+                "All output values must be valid bfloat16. "
+                "reference_result: {ref_res_is_bf16}; test_result: {test_res_is_bf16}"
+            )
+            return (TestResult.INCORRECT_FORMAT, 0.0, msg)
+
     # for quantized test, allow +-(quantize_tolerance) error
     if reference_result.dtype == np.int32 or reference_result.dtype == np.int64:
 
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 0203513..932ad55 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -776,7 +776,7 @@
             ), "Op.MUL must have 2 placeholders, 0 consts"
 
             tens = []
-            if dtypeList[0] in (DType.FP16, DType.FP32):
+            if dtypeList[0] in (DType.FP16, DType.BF16, DType.FP32):
                 tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
             else:
                 placeholders = []
@@ -1130,6 +1130,8 @@
             accum_dtypes = [DType.INT48]
         elif dtype == DType.FP16:
             accum_dtypes = [DType.FP16, DType.FP32]
+        elif dtype == DType.BF16:
+            accum_dtypes = [DType.FP32]
         elif dtype == DType.FP32:
             accum_dtypes = [DType.FP32]
         elif error_name is None:
@@ -1304,7 +1306,7 @@
             accum_dtypes = [DType.INT32]
         elif dtype == DType.FP16:
             accum_dtypes = [DType.FP16, DType.FP32]
-        elif dtype == DType.FP32:
+        elif dtype == DType.BF16 or dtype == DType.FP32:
             accum_dtypes = [DType.FP32]
         elif error_name is None:
             assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
@@ -1417,6 +1419,8 @@
             dtypeList = [DType.INT8, DType.INT16, DType.INT32]
         elif inDtype == DType.FP16:
             dtypeList = [DType.INT8, DType.INT16, DType.INT32]
+        elif inDtype == DType.BF16:
+            dtypeList = [DType.INT8, DType.INT16, DType.INT32]
         elif inDtype == DType.FP32:
             dtypeList = [DType.INT8, DType.INT16, DType.INT32]
         elif error_name == ErrorIf.WrongInputType:
@@ -1826,6 +1830,8 @@
                 outputDTypeList = [DType.INT48]
             elif dtype == DType.FP16:
                 outputDTypeList = [DType.FP16]
+            elif dtype == DType.BF16:
+                outputDTypeList = [DType.BF16]
             elif dtype == DType.FP32:
                 outputDTypeList = [DType.FP32]
             elif error_name == ErrorIf.WrongInputType:
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index abe1a97..a850699 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -158,6 +158,15 @@
                     DType.INT48,
                     DType.FP32,
                 )
+            elif dtype == DType.BF16:
+                incorrect_types = (
+                    DType.INT4,
+                    DType.INT8,
+                    DType.INT16,
+                    DType.INT32,
+                    DType.INT48,
+                    DType.FP32,
+                )
             elif dtype == DType.FP32:
                 incorrect_types = (
                     DType.INT4,
@@ -299,8 +308,8 @@
 
     @staticmethod
     def eiCastErrorIf(testGen, input_dtype):
-        if input_dtype in [DType.BOOL, DType.FP16, DType.FP32]:
-            outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.FP32]
+        if input_dtype in [DType.BOOL, DType.FP16, DType.BF16, DType.FP32]:
+            outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.BF16, DType.FP32]
         elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
             outputDType = [DType.INT48]
         else:
@@ -425,6 +434,7 @@
                         and output_dtype != DType.INT48
                     )
                     or (input_dtype == DType.FP16 and output_dtype != DType.FP16)
+                    or (input_dtype == DType.BF16 and output_dtype != DType.BF16)
                     or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
                 ):
                     error_result = True
@@ -442,25 +452,29 @@
                         input_dtype == DType.FP16
                         and output_dtype not in (DType.FP16, DType.FP32)
                     )
+                    or (input_dtype == DType.BF16 and output_dtype != DType.FP32)
                     or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
                 ):
                     error_result = True
 
             elif op["op"] == Op.ARGMAX:
                 if (
-                    input_dtype in [DType.INT8, DType.INT16, DType.FP16, DType.FP32]
+                    input_dtype
+                    in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
                     and output_dtype != DType.INT32
                 ):
                     error_result = True
 
             elif op["op"] == Op.MUL:
                 if (
-                    input_dtype not in (DType.FP16, DType.FP32)
+                    input_dtype not in (DType.FP16, DType.BF16, DType.FP32)
                     and output_dtype != DType.INT32
                 ):
                     error_result = True
                 elif input_dtype == DType.FP16 and output_dtype != DType.FP16:
                     error_result = True
+                elif input_dtype == DType.BF16 and output_dtype != DType.BF16:
+                    error_result = True
                 elif input_dtype == DType.FP32 and output_dtype != DType.FP32:
                     error_result = True
 
@@ -489,6 +503,7 @@
                             DType.INT32,
                             DType.FP32,
                             DType.FP16,
+                            DType.BF16,
                         ]
                     )
                     or (
@@ -500,6 +515,7 @@
                             DType.INT32,
                             DType.FP32,
                             DType.FP16,
+                            DType.BF16,
                         ]
                     )
                     or (
@@ -511,6 +527,7 @@
                             DType.INT16,
                             DType.FP32,
                             DType.FP16,
+                            DType.BF16,
                         ]
                     )
                     or (
@@ -518,6 +535,10 @@
                         and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
                     )
                     or (
+                        input_dtype == DType.BF16
+                        and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
+                    )
+                    or (
                         input_dtype == DType.FP32
                         and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
                     )
@@ -537,6 +558,8 @@
                     and output_dtype != DType.INT48
                     or input_dtype == DType.FP16
                     and output_dtype not in (DType.FP16, DType.FP32)
+                    or input_dtype == DType.BF16
+                    and output_dtype != DType.FP32
                     or input_dtype == DType.FP32
                     and output_dtype != DType.FP32
                 ):
@@ -2316,12 +2339,14 @@
                 not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
                 and not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
                 and not (input_dtype == DType.FP16 and output_dtype == DType.FP16)
+                and not (input_dtype == DType.BF16 and output_dtype == DType.BF16)
                 and not (input_dtype == DType.FP32 and output_dtype == DType.FP32)
             )
         elif mode == ResizeMode.NEAREST:
             # Invalid output data type / Invalid input datatype
             return (input_dtype != output_dtype) or (
-                input_dtype not in [DType.INT8, DType.INT16, DType.FP16, DType.FP32]
+                input_dtype
+                not in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
             )
         else:
             # Invalid resize mode
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 78d86cd..95e06ed 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -16,6 +16,7 @@
 from generator.tosa_utils import DTYPE_ATTRIBUTES
 from generator.tosa_utils import MAX_RESIZE_DIMENSION
 from generator.tosa_utils import usableDTypes
+from generator.tosa_utils import vect_f32_to_bf16
 from tosa.DType import DType
 from tosa.Op import Op
 
@@ -84,6 +85,10 @@
             )
         elif dtype == DType.FP16:
             return np.float16(self.rng.random(size=shape))
+        elif dtype == DType.BF16:
+            f32_tensor = np.float32(self.rng.random(size=shape))
+            # Floor the last 16 bits of each f32 value
+            return np.float32(vect_f32_to_bf16(f32_tensor))
         elif dtype == DType.FP32:
             return np.float32(self.rng.random(size=shape))
         else:
@@ -134,6 +139,9 @@
         elif dtype == DType.FP16:
             rand_f32 = self.rng.random()
             return np.float16(rand_f32)
+        elif dtype == DType.BF16:
+            rand_f32 = self.rng.random()
+            return vect_f32_to_bf16(rand_f32)
         elif dtype == DType.BOOL:
             return self.rng.choice([False, True])
         # TOSA specific INT4 weight range from -7 to 7
@@ -324,7 +332,7 @@
 
         # Special for multiply:
         # Force the result to INT32 for INT types
-        if a.dtype not in (DType.FP16, DType.FP32):
+        if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
             result_tens.setDtype(DType.INT32)
         if error_name == ErrorIf.WrongOutputType:
             all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
@@ -1043,7 +1051,7 @@
             return None
 
         attr = ts.TosaSerializerAttribute()
-        if a.dtype in (DType.FP16, DType.FP32):
+        if a.dtype in (DType.FP16, DType.BF16, DType.FP32):
             attr.ClampAttribute(0, 0, min_val, max_val)
         else:
             attr.ClampAttribute(min_val, max_val, 0, 0)
@@ -1859,7 +1867,7 @@
             op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
         )
 
-        if a.dtype in (DType.FP32, DType.FP16, DType.INT32):
+        if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
             then_op, else_op = Op.ADD, Op.SUB
         elif a.dtype in (DType.INT8, DType.INT16):
             then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
@@ -2398,7 +2406,7 @@
     #    if not specified, defaults to (1, 4)
     #  'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
     #  'types': array of datatypes to be tested
-    TYPE_FP = [DType.FP32, DType.FP16]
+    TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
 
     TYPE_INT = [DType.INT8, DType.INT16, DType.INT32]  # Excludes INT4
     TYPE_INT_FP = [
@@ -2406,13 +2414,20 @@
         DType.INT16,
         DType.INT32,
         DType.FP16,
+        DType.BF16,
         DType.FP32,
     ]  # Excludes INT4
 
     TYPE_BOOL = [DType.BOOL]
-    TYPE_FI32 = [DType.FP32, DType.FP16, DType.INT32]  # floating-types and INT32
+    TYPE_FI32 = [
+        DType.FP32,
+        DType.FP16,
+        DType.BF16,
+        DType.INT32,
+    ]  # floating-types and INT32
     TYPE_FIB = [
         DType.FP16,
+        DType.BF16,
         DType.FP32,
         DType.INT8,
         DType.INT16,
@@ -2421,7 +2436,7 @@
     ]
     TYPE_FI16 = [DType.FP32, DType.INT16]
 
-    TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.FP32]
+    TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
 
     # List of [Input Type 1, Input Type 2, Accumulator Type]
     TYPE_CONV = [
@@ -2430,6 +2445,7 @@
         [DType.INT16, DType.INT8, DType.INT48],
         [DType.FP16, DType.FP16, DType.FP16],
         [DType.FP16, DType.FP16, DType.FP32],
+        [DType.BF16, DType.BF16, DType.FP32],
         [DType.FP32, DType.FP32, DType.FP32],
     ]
 
@@ -3448,7 +3464,7 @@
                 TosaTensorValuesGen.tvgReduceSum,
                 TosaArgGen.agAxis,
             ),
-            "types": (DType.FP16, DType.FP32, DType.INT32),
+            "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
             "error_if_validators": (
                 TosaErrorValidator.evAxisLargerRank,
                 TosaErrorValidator.evAxisSmallerZero,
@@ -3635,7 +3651,14 @@
                 TosaTensorValuesGen.tvgDefault,
                 None,
             ),
-            "types": (DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.FP32),
+            "types": (
+                DType.INT8,
+                DType.INT16,
+                DType.INT32,
+                DType.FP16,
+                DType.BF16,
+                DType.FP32,
+            ),
             "error_if_validators": (
                 TosaErrorValidator.evWrongInputType,
                 TosaErrorValidator.evWrongOutputType,
@@ -3676,7 +3699,7 @@
                 TosaTensorValuesGen.tvgDefault,
                 TosaArgGen.agResize,
             ),
-            "types": (DType.INT8, DType.INT16, DType.FP16, DType.FP32),
+            "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
             "invalid_test_validators": (
                 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
             ),
@@ -3712,6 +3735,7 @@
             ),
             "types": (
                 DType.FP16,
+                DType.BF16,
                 DType.FP32,
                 DType.INT8,
                 DType.INT16,
@@ -3842,6 +3866,8 @@
                 DType.INT16,
                 DType.INT32,
                 DType.INT48,
+                DType.FP16,
+                DType.BF16,
                 DType.FP32,
             ]
             wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
@@ -3872,6 +3898,8 @@
                 DType.INT32,
                 DType.INT48,
                 DType.FP32,
+                DType.FP16,
+                DType.BF16,
             ]
             wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
             outputDType = rng.choice(wrong_dtypes)
@@ -3900,6 +3928,8 @@
                 DType.INT32,
                 DType.INT48,
                 DType.FP32,
+                DType.FP16,
+                DType.BF16,
             ]
             wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
             outputDType = rng.choice(wrong_dtypes)
@@ -3929,6 +3959,8 @@
                 DType.INT32,
                 DType.INT48,
                 DType.FP32,
+                DType.FP16,
+                DType.BF16,
             ]
             outputDType = rng.choice(wrong_dtypes)
         else:
@@ -3955,6 +3987,8 @@
                 DType.INT32,
                 DType.INT48,
                 DType.FP32,
+                DType.FP16,
+                DType.BF16,
             ]
             wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
             outputDType = rng.choice(wrong_dtypes)
@@ -3987,6 +4021,8 @@
                 DType.INT32,
                 DType.INT48,
                 DType.FP32,
+                DType.FP16,
+                DType.BF16,
             ]
             wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
             outputDType = rng.choice(wrong_dtypes)
@@ -4189,6 +4225,7 @@
                 DType.INT48,
                 DType.FP32,
                 DType.FP16,
+                DType.BF16,
             ]
             wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
             outputDType = rng.choice(wrong_dtypes)
@@ -4226,6 +4263,8 @@
                     DType.INT16,
                     DType.INT48,
                     DType.FP32,
+                    DType.FP16,
+                    DType.BF16,
                 )
             elif a.dtype == DType.INT16:
                 incorrect_types = (
@@ -4234,8 +4273,12 @@
                     DType.INT16,
                     DType.INT32,
                     DType.FP32,
+                    DType.FP16,
+                    DType.BF16,
                 )
-            elif a.dtype == DType.FP32 or a.dtype == DType.FP16:
+            elif (
+                a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
+            ):
                 incorrect_types = (
                     DType.INT4,
                     DType.INT8,
@@ -4278,6 +4321,8 @@
                 DType.INT32,
                 DType.INT48,
                 DType.FP32,
+                DType.FP16,
+                DType.BF16,
             }
             wrong_dtypes = list(all_dtypes - set([input1.dtype]))
             outputDType = rng.choice(wrong_dtypes)
@@ -4306,6 +4351,7 @@
                 DType.INT48,
                 DType.FP32,
                 DType.FP16,
+                DType.BF16,
             ]
             wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
             outputDType = rng.choice(wrong_dtypes)
@@ -4329,6 +4375,8 @@
                 DType.INT32,
                 DType.INT48,
                 DType.FP32,
+                DType.FP16,
+                DType.BF16,
             ]
             wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
             outputDType = rng.choice(wrong_dtypes)
@@ -4347,6 +4395,8 @@
                 DType.INT32,
                 DType.INT48,
                 DType.FP32,
+                DType.FP16,
+                DType.BF16,
             ]
             wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
             outputDType = rng.choice(wrong_dtypes)
@@ -4383,6 +4433,8 @@
                 DType.INT32,
                 DType.INT48,
                 DType.FP32,
+                DType.FP16,
+                DType.BF16,
             ]
             wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
             outputDType = rng.choice(wrong_dtypes)
@@ -4411,6 +4463,8 @@
                 DType.INT32,
                 DType.INT48,
                 DType.FP32,
+                DType.FP16,
+                DType.BF16,
             ]
             wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
             outputDType = rng.choice(wrong_dtypes)
@@ -4435,6 +4489,8 @@
                 DType.INT32,
                 DType.INT48,
                 DType.FP32,
+                DType.FP16,
+                DType.BF16,
             ]
             wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
             outputDType = rng.choice(wrong_dtypes)
@@ -4462,6 +4518,8 @@
                 DType.INT32,
                 DType.INT48,
                 DType.FP32,
+                DType.FP16,
+                DType.BF16,
             ]
             wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
             outputDType = rng.choice(wrong_dtypes)
@@ -4483,6 +4541,8 @@
                 DType.INT32,
                 DType.INT48,
                 DType.FP32,
+                DType.FP16,
+                DType.BF16,
             ]
             wrong_dtypes.remove(output_dtype)
             output_dtype = rng.choice(wrong_dtypes)
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index 104d9bb..d79ab3c 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -1,5 +1,9 @@
 # Copyright (c) 2021-2022, ARM Limited.
 # SPDX-License-Identifier: Apache-2.0
+import struct
+import sys
+
+import numpy as np
 from tosa.DType import DType
 
 # Maximum dimension size for output and inputs for RESIZE
@@ -15,6 +19,7 @@
     DType.INT32: {"str": "i32", "width": 32},
     DType.INT48: {"str": "i48", "width": 48},
     DType.FP16: {"str": "f16", "width": 16},
+    DType.BF16: {"str": "bf16", "width": 16},
     DType.FP32: {"str": "f32", "width": 32},
 }
 
@@ -125,7 +130,11 @@
                 DType.FP32,
                 DType.FP16,
             )
-        elif input_dtype == DType.FP32 or input_dtype == DType.FP16:
+        elif (
+            input_dtype == DType.FP32
+            or input_dtype == DType.FP16
+            or input_dtype == DType.BF16
+        ):
             incorrect_types = (
                 DType.INT4,
                 DType.INT8,
@@ -134,3 +143,37 @@
                 DType.INT48,
             )
     return rng.choice(a=incorrect_types)
+
+
+def float32_is_valid_bfloat16(f):
+    """Return True if float value is valid bfloat16."""
+    f32_bits = get_float32_bitstring(f)
+    return f32_bits[16:] == "0" * 16
+
+
+def get_float32_bitstring(f):
+    """Return a big-endian string of bits representing a 32 bit float."""
+    f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0]
+    return f"{f32_bits_as_int:032b}"
+
+
+def float32_to_bfloat16(f):
+    """Turns fp32 value into bfloat16 by flooring.
+
+    Floors the least significant 16 bits of the input
+    fp32 value and returns this valid bfloat16 representation as fp32.
+    For simplicity during bit-wrangling, ignores underlying system
+    endianness and interprets as big-endian.
+    Returns a bf16-valid float following system's native byte order.
+    """
+    f32_bits = get_float32_bitstring(f)
+    f32_floored_bits = f32_bits[:16] + "0" * 16
+
+    # Assume sys.byteorder matches system's underlying float byteorder
+    fp_bytes = int(f32_floored_bits, 2).to_bytes(4, byteorder=sys.byteorder)
+    return struct.unpack("@f", fp_bytes)[0]  # native byteorder
+
+
+vect_f32_to_bf16 = np.vectorize(
+    float32_to_bfloat16, otypes=(np.float32,)
+)  # NumPy vectorize: applies function to vector faster than looping
diff --git a/verif/generator/tosa_verif_build_tests.py b/verif/generator/tosa_verif_build_tests.py
index 2fafacb..ab78b1a 100644
--- a/verif/generator/tosa_verif_build_tests.py
+++ b/verif/generator/tosa_verif_build_tests.py
@@ -5,6 +5,7 @@
 
 from generator.tosa_test_gen import TosaTestGen
 from serializer.tosa_serializer import dtype_str_to_val
+from serializer.tosa_serializer import DTypeNames
 
 
 # Used for parsing a comma-separated list of integers in a string
@@ -150,13 +151,14 @@
         help="Create tests with a particular input tensor rank",
     )
 
+    # Used for parsing a comma-separated list of integers in a string
     parser.add_argument(
         "--target-dtype",
         dest="target_dtypes",
         action="append",
         default=None,
         type=lambda x: dtype_str_to_val(x),
-        help="Create test with a particular DType (may be repeated)",
+        help=f"Create test with a particular DType: [{', '.join([d.lower() for d in DTypeNames[1:]])}] (may be repeated)",
     )
 
     parser.add_argument(
diff --git a/verif/tests/test_tosa_refmodel.py b/verif/tests/test_tosa_refmodel.py
index b608fd8..50ff1ab 100644
--- a/verif/tests/test_tosa_refmodel.py
+++ b/verif/tests/test_tosa_refmodel.py
@@ -47,6 +47,7 @@
     "int32": "i32",
     "fp32": "f32",
     "fp16": "f16",
+    "bf16": "bf16",
 }
 
 
@@ -127,11 +128,13 @@
     ("abs", "int32", 1),
     ("abs", "fp32", 1),
     ("abs", "fp16", 1),
+    ("abs", "bf16", 1),
     ("negate", "int8", 1),
     ("negate", "int16", 1),
     ("negate", "int32", 1),
     ("negate", "fp32", 1),
     ("negate", "fp16", 1),
+    ("negate", "bf16", 1),
     # One test per axis (shape dimensions)
     ("concat", "bool", SHAPE_DIMS),
     ("concat", "int8", SHAPE_DIMS),
@@ -139,6 +142,7 @@
     ("concat", "int32", SHAPE_DIMS),
     ("concat", "fp32", SHAPE_DIMS),
     ("concat", "fp16", SHAPE_DIMS),
+    ("concat", "bf16", SHAPE_DIMS),
 ]
 
 
@@ -165,6 +169,9 @@
     # Generate TOSA test(s) (mostly should be single test)
     test_dirs = tosaTest.create_test()
 
+    # Indicate miscellaneous checks to run in tosa_check
+    misc_checks = []
+
     for test_dir in test_dirs:
         # Run ref model
         desc_file = test_dir / TEST_DESC_FILENAME
@@ -227,8 +234,15 @@
         np.save(str(result_file), result)
         assert result_file.is_file()
 
+        # Ensure valid bf16
+        if tosaTest.ref_model_type == "bf16":
+            misc_checks.append("bf16")
+
         # Check Numpy result versus refmodel
         check_result, tolerance, msg = tosa_check(
-            str(result_file), str(ofm_file), test_name=test_dir.name
+            str(result_file),
+            str(ofm_file),
+            test_name=test_dir.name,
+            misc_checks=misc_checks,
         )
         assert check_result == TosaResult.PASS