Add framework unit test generation scripts

And fixes in tosa_verif_run_tests:
* support for no-color printing
* stop double printing of error messages on verbose
* differentiate result code pass from results check

Change-Id: I26e957013a8d18f7d3d3691067dfb778008a1eea
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
diff --git a/README.md b/README.md
index d4aa5f9..97fef58 100644
--- a/README.md
+++ b/README.md
@@ -179,7 +179,7 @@
 
 ## TOSA Unit Test Infrastructure
 
-The TOSA Unit Test infrastruture builds and runs self-contained tests
+The TOSA Unit Test infrastructure builds and runs self-contained tests
 for implementations of the *Tensor Operator Set Architecture (TOSA)
 Specification*.  These tools directly generate TOSA operators for
 verification of the TOSA reference model against existing frameworks
@@ -374,6 +374,45 @@
 `tosa_mock_sut_run.py` file.
 
 
+### TOSA Framework Unit Tests
+
+Included in the TOSA Unit Test infrastructure are scripts to enable the creation
+of TOSA unit tests for example frameworks. Included at the moment is support for
+TensorFlow and TensorFlow Lite.
+
+#### Setup
+
+Installation (via `pip install`) of the following python package is required to
+generate the tests:
+
+* `tensorflow`
+
+A built copy of the tensorflow framework from source is required to compile the
+tests to TOSA - see the online documentation <https://www.tensorflow.org/install/source>
+on how to do this.
+The following tools are used from this build:
+
+* `tensorflow/basel-bin/tensorflow/compiler/mlir/lite/flatbuffer_translate`
+* `tensorflow/basel-bin/tensorflow/compiler/mlir/tf-opt`
+
+#### Usage
+
+The command to generate the unit test framework models:
+
+```bash
+tosa_verif_framework_generator -o tests
+```
+
+Next to convert these models to TOSA and then run them on the reference model:
+
+```bash
+tosa_verif_framework_compiler_runner \
+  --tf-base-dir tensorflow           \
+  --tools-base-dir reference_model   \
+  --recursive                        \
+  --test tests
+```
+
 ## Other tools
 
 Included in this repository are some support utilities used by the test runner:
diff --git a/setup.cfg b/setup.cfg
index c1a0ccb..7a6026c 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -26,6 +26,7 @@
     runner
     generator
     checker
+    frameworks
     xunit
     json2fbbin
     json2numpy
@@ -50,6 +51,8 @@
     json2fbbin = json2fbbin.json2fbbin:main
     tosa_verif_result_check = checker.tosa_result_checker:main
     convert2conformance = convert2conformance.convert2conformance:main
+    tosa_verif_framework_generator = frameworks.tosa_verif_framework_generator:main
+    tosa_verif_framework_compiler_runner = frameworks.tosa_verif_framework_compiler_runner:main
 
 [tool:pytest]
 testpaths=verif/tests
diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py
index 3a15de9..66864c2 100644
--- a/verif/checker/tosa_result_checker.py
+++ b/verif/checker/tosa_result_checker.py
@@ -11,7 +11,7 @@
 import numpy as np
 
 ##################################
-no_color_printing = False
+color_printing = True
 
 
 @unique
@@ -25,9 +25,16 @@
     BOLD_WHITE = "\u001b[1m"
 
 
+def set_print_in_color(enabled):
+    """Set color printing to enabled or disabled."""
+    global color_printing
+    color_printing = enabled
+
+
 def print_color(color, msg):
     """Print color status messages if enabled."""
-    if no_color_printing:
+    global color_printing
+    if not color_printing:
         print(msg)
     else:
         print("{}{}{}".format(color.value, msg, LogColors.NONE.value))
diff --git a/verif/frameworks/__init__.py b/verif/frameworks/__init__.py
new file mode 100644
index 0000000..8792170
--- /dev/null
+++ b/verif/frameworks/__init__.py
@@ -0,0 +1,3 @@
+"""Namespace."""
+# Copyright (c) 2022 Arm Limited.
+# SPDX-License-Identifier: Apache-2.0
diff --git a/verif/frameworks/arg_gen.py b/verif/frameworks/arg_gen.py
new file mode 100644
index 0000000..8feb9b2
--- /dev/null
+++ b/verif/frameworks/arg_gen.py
@@ -0,0 +1,703 @@
+# Copyright (c) 2020-2022, ARM Limited.
+# SPDX-License-Identifier: Apache-2.0
+import numpy as np
+
+
+class ArgGen:
+    """Argument generator functions.  These functions take a shape and dtype to
+    create arguments for an operator. Methods are prefixed with 'ag' to make
+    search easy."""
+
+    def __init__(self):
+        pass
+
+    @staticmethod
+    def agNone(op, shapes, rng):
+        """A trivial argument generator for operators that only take tensor
+        operands"""
+        return [("", [])]
+
+    # Build the axis argument for operators where we want to iterate over N axes
+    # as an argument
+    @staticmethod
+    def agAxes(op, shapes, rng):
+        axes = []
+        for i in range(-len(shapes), len(shapes), 1):
+            if i >= 0:
+                axes.append(["_axis_{}".format(i), [i]])
+            else:
+                axes.append(["_axis_m{}".format(-i), [i]])
+        return axes
+
+    # Build the axis LIST argument for operators that take an axis list.
+    # This builds a list of each axis individually, plus one element
+    # that contains a list of all axes.  Note that we need to pack the list in
+    # an additional list so that it isn't exploded when being passed to the
+    # build_operator function.
+    # tensor_arg_count not used
+    def agAxesList(op, shapes, rng):
+        axes = ArgGen.agAxes(op, shapes, rng)
+        axes_list = []
+        for desc, a in axes:
+            axes_list.append([desc, [a]])
+
+        axes_list.append(["_axisall", [list(range(len(shapes)))]])
+        axes_list.append(["_axisall_none", [None]])
+        return axes_list
+
+    def agAxesListKeepdims(op, shapes, rng):
+        axes = ArgGen.agAxes(op, shapes, rng)
+        axes_list = []
+        for desc, a in axes:
+            axes_list.append([desc + "_keep0", [a, False]])
+            # avoid trying to reduce an axis of shape 1, as the TFL converter
+            # will optimize away the entire reduction
+            if (a[0] >= 0 and shapes[a[0]] != 1) or (
+                a[0] < 0 and shapes[len(shapes) + a[0]] != 1
+            ):
+                axes_list.append([desc + "_keep1", [a, True]])
+
+        axes_list.append(["_axisall_keep0", [list(range(len(shapes))), False]])
+        axes_list.append(["_axisall_keep0_none", [None, False]])
+        # another instance where the reduce gets optimized out.
+        if len(shapes) != 1:
+            axes_list.append(["_axisall_keep1", [list(range(len(shapes))), True]])
+            axes_list.append(["_axisall_keep1_none", [None, True]])
+        # no longer test axis empty, as TFL converter optimizes the reduce out
+        return axes_list
+
+    # conv2d argument generators build the TF constants
+    def agConv2d(op, shapes, rng):
+        arg_list = []
+
+        # Must be rank 4
+        if len(shapes) < 4:
+            return arg_list
+
+        filter_h, filter_w = op["filter"]
+
+        # strides, padding, dilations,
+        for stride_h in [1, 2]:
+            for stride_w in [1, 2]:
+                for padding in ["SAME", "VALID"]:
+                    for dilation_h in [1, 2]:
+                        for dilation_w in [1, 2]:
+
+                            # Disqualify argument combinations that would cause
+                            # an illegal convolution
+
+                            if (padding == "VALID") and (
+                                (shapes[1] - (filter_h - 1) * 2 - dilation_h) <= 0
+                                or (shapes[2] - (filter_w - 1) * 2 - dilation_w) <= 0
+                            ):
+                                continue
+
+                            arg_list.append(
+                                [
+                                    "_st{}{}_pad{}_dilat{}{}".format(
+                                        stride_h,
+                                        stride_w,
+                                        padding,
+                                        dilation_h,
+                                        dilation_w,
+                                    ),
+                                    [
+                                        [stride_h, stride_w],
+                                        padding,
+                                        [dilation_h, dilation_w],
+                                    ],
+                                ]
+                            )
+        return arg_list
+
+    # conv2d argument generators build the TF constants
+    def agDepthwiseConv2d(op, shapes, rng):
+        arg_list = []
+
+        # Must be rank 4
+        if len(shapes) < 4:
+            return arg_list
+
+        filter_h, filter_w = op["filter"]
+
+        # strides, padding, dilations, Depthwise conv2d is the same as conv2d
+        # except that strides in h/w must be the same and the argument must be
+        # formatted as [1, stride_h, stride_w, 1] in TF.
+        for stride in [1, 2]:
+            for padding in ["SAME", "VALID"]:
+                for dilation_h in [1, 2]:
+                    for dilation_w in [1, 2]:
+
+                        # Disqualify argument combinations that would cause an illegal
+                        # convolution
+
+                        if (padding == "VALID") and (
+                            (shapes[1] - (filter_h - 1) * 2 - dilation_h) <= 0
+                            or (shapes[2] - (filter_w - 1) * 2 - dilation_w) <= 0
+                        ):
+                            continue
+
+                        # When dilation is used, stride must be 1x1 (TF rules)
+                        if dilation_h > 1 or dilation_w > 1:
+                            if stride > 1:
+                                continue
+
+                        # Dilation must evenly divide the tensor.  Some of our inputs
+                        # intentionally use odd-sized tensors.
+                        if shapes[1] % dilation_h != 0 or shapes[2] % dilation_w != 0:
+                            continue
+
+                        arg_list.append(
+                            [
+                                "_st{}{}_pad{}_dilat{}{}".format(
+                                    stride, stride, padding, dilation_h, dilation_w
+                                ),
+                                [
+                                    [1, stride, stride, 1],
+                                    padding,
+                                    [dilation_h, dilation_w],
+                                ],
+                            ]
+                        )
+        return arg_list
+
+    # conv2d argument generators build the TF constants
+    def agTransposeConv2d(op, shapes, rng):
+        arg_list = []
+
+        # Must be rank 4
+        if len(shapes) < 4:
+            return arg_list
+
+        filter_h, filter_w = op["filter"]
+
+        # strides, padding, dilations,
+        for stride_h in [1, 2]:
+            for stride_w in [1, 2]:
+                for padding in ["SAME", "VALID"]:
+                    if padding == "SAME":
+                        out_height = (shapes[1]) * stride_h
+                        out_width = (shapes[2]) * stride_w
+                    else:  # padding == 'VALID'
+                        out_height = (shapes[1] - 1) * stride_h + filter_h
+                        out_width = (shapes[2] - 1) * stride_w + filter_w
+
+                    output_shape = [shapes[0], out_height, out_width, shapes[3] * 2]
+                    arg_list.append(
+                        [
+                            "_st{}{}_pad{}".format(stride_h, stride_w, padding),
+                            [output_shape, [stride_h, stride_w], padding],
+                        ]
+                    )
+        return arg_list
+
+    def agPooling(op, shapes, rng):
+        arg_list = []
+
+        # Must be rank 4
+        if len(shapes) < 4:
+            return arg_list
+
+        for stride_h in [1, 2]:
+            for stride_w in [1, 2]:
+                for kernel_h in [1, 2]:
+                    for kernel_w in [1, 2]:
+                        for padding in ["SAME", "VALID"]:
+
+                            if (padding == "VALID") and (
+                                (shapes[1] % (kernel_h * stride_h) > 0)
+                                or (shapes[2] % (kernel_w * stride_w) > 0)
+                                or (shapes[1] <= kernel_h)
+                                or (shapes[2] <= kernel_w)
+                            ):
+                                continue
+
+                            if (padding == "SAME") and (
+                                (shapes[1] < kernel_h) or (shapes[2] < kernel_w)
+                            ):
+                                continue
+
+                            arg_list.append(
+                                [
+                                    "_st{}{}_pad{}_kern{}{}".format(
+                                        stride_h, stride_w, padding, kernel_h, kernel_w
+                                    ),
+                                    [
+                                        [stride_h, stride_w],
+                                        [kernel_h, kernel_w],
+                                        padding,
+                                    ],
+                                ]
+                            )
+        return arg_list
+
+    def getFactors(val, start=1):
+        factors = []
+        for i in range(start, int(np.sqrt(val))):
+            if (val % i) == 0:
+                factors.append(i)
+
+        return factors
+
+    def agReshape(op, shapes, rng):
+        # This is slow code.  Fortunately, the numbers involved are small
+        arg_list = []
+
+        total_elements = 1
+        for s in shapes:
+            total_elements *= s
+
+        # Find integer factors of this shape
+        factors = ArgGen.getFactors(total_elements)
+
+        for rank in range(1, len(shapes) + 1):
+            if len(factors) < rank:
+                break
+
+            new_shape = []
+            remaining_elements = total_elements
+
+            # Randomly shuffle the factors and iteratively pick from the factors
+            # of the remaining elements
+            shuffled_factors = rng.permutation(factors)
+            for i in range(rank):
+                # Pick rank - 1 factors
+                new_shape.append(shuffled_factors[0])
+                remaining_elements = remaining_elements // shuffled_factors[0]
+                shuffled_factors = rng.permutation(
+                    ArgGen.getFactors(remaining_elements)
+                )
+            new_shape.append(remaining_elements)
+
+            # Don't do no-op reshapes because TFLite optimizes out the op
+            if new_shape == list(shapes):
+                continue
+
+            arg_list.append(["_rank{}".format(rank), [new_shape]])
+
+        return arg_list
+
+    def agTranspose(op, shapes, rng):
+        arg_list = []
+
+        # Must have at least two dimensions to transpose
+        if (len(shapes)) < 2:
+            return arg_list
+
+        # Pick a bunch of random permutations
+        range_arr = np.arange(len(shapes))
+        for i in range(len(shapes)):
+            perm = rng.permutation(range_arr).astype(np.int32)
+            # print('\n shape {} permute{} perm: {} arr: {}'.format(shapes, i,
+            # perm, range_arr))
+            if np.allclose(perm, range_arr):
+                print("skipped")
+                continue
+            arg_list.append(["_permute{}".format(i), [perm]])
+
+        return arg_list
+
+    def agSlice(op, shapes, rng):
+        arg_list = []
+
+        rank = len(shapes)
+
+        if rank == 1 and shapes[0] == 1:
+            return arg_list
+
+        for i in range(4):
+            # Pick a few random start points, axes, and strides
+            start = np.empty((rank), dtype=int)
+            size = np.empty((rank), dtype=int)
+            for j in range(rank):
+                if shapes[j] > 2:
+                    start[j] = rng.integers(0, shapes[j] - 2)
+                    # print('j = {}: {} - {} - 1: {}'.format(j, shapes[j],
+                    # start[j], shapes[j] - start[j] - 1))
+                    size[j] = rng.integers(1, shapes[j] - start[j] - 1)
+                else:
+                    start[j] = 0
+                    size[j] = shapes[j]
+
+                arg_list.append(["_perm{}".format(i), [start, size]])
+
+        return arg_list
+
+    def agStridedSlice(op, shapes, rng):
+        arg_list = []
+
+        rank = len(shapes)
+
+        # Reference model is limited to rank=6 internally right now
+        if rank > 3:
+            return arg_list
+
+        if rank == 1 and shapes[0] == 1:
+            return arg_list
+
+        for i in range(4):
+            # Pick a few random begin points, axes, and strides
+            begin = np.empty((rank), dtype=int)
+            end = np.empty((rank), dtype=int)
+            strides = np.empty((rank), dtype=int)
+
+            begin_mask = rng.integers(0, (1 << (rank - 1)))
+            end_mask = rng.integers(0, (1 << (rank - 1)))
+
+            for j in range(rank):
+
+                if begin_mask & (1 << j) or shapes[j] < 2:
+                    begin[j] = 0
+                else:
+                    begin[j] = rng.integers(0, shapes[j] - 1)
+
+                if end_mask & (1 << j) or shapes[j] < 2 or (begin[j] + 2) >= shapes[j]:
+                    end[j] = shapes[j]
+                else:
+                    end[j] = rng.integers(begin[j] + 1, shapes[j] - 1)
+
+                possible_stride = ArgGen.getFactors(end[j] - begin[j], 2)
+
+                if not possible_stride:
+                    strides[j] = 1
+                else:
+                    strides[j] = rng.choice(possible_stride)
+
+            # Randomly set the masks, except ellipsis_mask and new_axis_mask
+            # which must be zero for now For begin/end mask this to work,
+            # strides must be adjusted to still be divsible...
+            ellipsis_mask = 0
+            new_axis_mask = 0
+
+            # if rng.choice([0, 1]) and rank > 1:
+            #    new_axis_mask = 1 << rng.integers(0, rank - 1)
+            # else:
+            #    new_axis_mask = 0
+
+            if rng.choice([0, 1]) and rank > 1:
+                shrink_axis_mask = 1 << rng.integers(0, rank - 1)
+            else:
+                shrink_axis_mask = 0
+
+            # Only one of these bits may be set.  Prefer shrink_axis_mask
+            new_axis_mask = new_axis_mask & ~shrink_axis_mask
+
+            arg_list.append(
+                [
+                    "_perm{}".format(i),
+                    [
+                        begin,
+                        end,
+                        strides,
+                        begin_mask,
+                        end_mask,
+                        ellipsis_mask,
+                        new_axis_mask,
+                        shrink_axis_mask,
+                    ],
+                ]
+            )
+
+            # print('Shape: {} begin={} end={} strides={} begin_mask={:x}
+            # end_mask={:x} new_axis_mask={:x} shrink_mask={:x}'.format(shapes,
+            # begin, end, strides, begin_mask, end_mask, new_axis_mask,
+            # shrink_axis_mask))
+
+        return arg_list
+
+    # tf.stack axis can be [0, rank(input)]
+    def agStack(op, shapes, rng):
+        axes = []
+        for i in range(len(shapes) + 1):
+            axes.append(["_axis{}".format(i), [i]])
+        return axes
+
+    def agPad(op, shapes, rng):
+        arg_list = []
+
+        rank = len(shapes)
+        for left in range(3):
+            for right in range(3):
+                paddings = np.zeros((rank, 2), dtype=np.int32)
+                for d in range(rank):
+                    paddings[d, 0] = left
+                    paddings[d, 1] = right
+
+                    arg_list.append(["_pad{}{}".format(left, right), [paddings]])
+        return arg_list
+
+    def agFill(op, shapes, rng):
+        values = []
+        for i in range(4):
+            value = rng.integers(0, 10, dtype=np.int32)
+            values.append(["_value{}".format(value), [shapes, value]])
+        return values
+
+    def getValuesToSum(total, rng):
+        # Get a list of random integers that sum up to 'total'
+        vals = []
+
+        # np.random.randint() min and max to be different, so if the remainder
+        # is 1, give up
+        while total > 1:
+            vals.append(rng.integers(1, total))
+            total = total - vals[-1]
+
+        if total == 1:
+            vals.append(1)
+
+        return vals
+
+    def agSplit(op, shapes, rng):
+        arg_list = []
+
+        rank = len(shapes)
+
+        # Shuffle the random number generator a few more times to get
+        # a better range of axes across shapes
+        for i in range(rank):
+            for j in range(shapes[i]):
+                rng.integers(shapes[i])
+
+        for i in range(3):
+            # Need to generate tests for both the num_splits and size_vector versions.
+            axis = rng.choice(np.arange(0, rank))
+
+            # For num_splits, get a few divisors of the given axis
+            divs = ArgGen.getFactors(shapes[axis], 2)
+
+            if divs:
+                # Get no more than 2 samples
+                splits = list(rng.choice(divs, size=2))
+
+                for s in splits:
+                    arg_list.append(
+                        ["_split{}_axis{}".format(int(s), axis), [int(s), axis]]
+                    )
+
+            # For vector splits, get a list of integers that sum up to the axis size
+            vals = ArgGen.getValuesToSum(shapes[axis], rng)
+
+            if len(vals) > 1:
+                arg_list.append(["_splitv_axis{}".format(axis), [vals, axis]])
+
+        return arg_list
+
+    def agTile(op, shapes, rng):
+        arg_list = []
+
+        rank = len(shapes)
+
+        # create 1D multiples list
+        multiples = list()
+        for i in range(rank):
+            multiples.append(rng.integers(1, 4))
+
+        multiples_str = "x".join(list(str(i) for i in multiples))
+
+        arg_list.append(["_tile_{}".format(multiples_str), [multiples]])
+
+        return arg_list
+
+    def agGather(op, shapes, rng):
+        args = []
+        for batch_dims in range(len(shapes) - 1):
+            for axis in range(batch_dims, len(shapes)):
+                # indices value must be within [0, shapes[i])
+
+                # Create an arbitrary shape for the indices
+                indices_rank = rng.integers(batch_dims + 1, 4)
+                indices_shape = rng.integers(1, 8, size=indices_rank)
+
+                # Copy in the batch dimensions because they must match
+                for b in range(batch_dims):
+                    indices_shape[b] = shapes[b]
+
+                # Calculate total element count
+                indices_size = 1
+                for j in range(indices_rank):
+                    indices_size = indices_shape[j] * indices_size
+
+                indices = rng.integers(0, shapes[axis], indices_size, np.int32).reshape(
+                    indices_shape
+                )
+
+                args.append(
+                    [
+                        "_batchdims_{}_axis_{}".format(batch_dims, axis),
+                        [indices, batch_dims, axis],
+                    ]
+                )
+        return args
+
+    def agGatherND(op, shapes, rng):
+        args = []
+
+        for N in range(1, len(shapes) - 1):
+            # Rank includes the N dimension
+            indices_rank = rng.integers(2, 4, size=1)[0]
+            indices_shape = []
+
+            indices_shape = rng.integers(1, 8, size=indices_rank)
+            indices_shape[-1] = N
+
+            indices_count = 1
+            for i in range(indices_rank - 1):
+                indices_count = indices_count * indices_shape[i]
+
+            indices_list = np.zeros(shape=(indices_count, N), dtype=np.int32)
+
+            for i in range(indices_count):
+                for j in range(N):
+                    indices_list[i, j] = rng.integers(0, shapes[j], size=1)[0]
+
+            indices = indices_list.reshape(indices_shape)
+
+            args.append(["_n{}".format(N), [indices]])
+
+        return args
+
+    def agScatterND(op, shapes, rng):
+        args = []
+
+        # ScatterND has to generate a constant shapes tensor, indices
+        # tensor, and a tensor of updates.  Unforunately, the updates
+        # need to be a size that's based on the N generated in this
+        # function and the dtype known only in the TensorGen function,
+        # but not in ArgGen.
+        #
+        # There are many bad ways to solve this and we'll choose the
+        # least of the evils which still gives reasonable coverage of
+        # the possible operand shapes.
+        for N in range(1, len(shapes)):
+            # Rank includes the N dimension
+            indices_rank = rng.integers(2, 4, size=1)[0]
+            indices_shape = []
+
+            indices_shape = rng.integers(1, 8, size=indices_rank)
+            indices_shape[-1] = N
+
+            # Store the Shapes, and the indicies value tensor as arguments.
+            args.append(["_n{}".format(N), [shapes, indices_shape, N, rng]])
+
+        return args
+
+    def agSpaceToBatch(op, shapes, rng):
+        batch_rank = 1
+        channel_rank = 1
+        block_rank = len(shapes) - batch_rank - channel_rank
+
+        # must have at least rank 1 (M) block
+        if block_rank < 1:
+            return []
+
+        args = []
+        block_shape = []
+        padding_shape = []
+
+        for i in range(block_rank):
+            block_size = 2
+            padding_size = block_size - (shapes[i + 1] % block_size)
+            block_shape.append(block_size)
+            padding_shape.append([0, padding_size])
+
+        args.append(["_blockrank_{}".format(block_rank), [block_shape, padding_shape]])
+        return args
+
+    def agBatchToSpace(op, shapes, rng):
+        batch_rank = 1
+        channel_rank = 1
+        block_rank = len(shapes) - batch_rank - channel_rank
+
+        # must have at least rank 1 (M) block
+        if block_rank < 1:
+            return []
+
+        args = []
+        block_shape = []
+        padding_shape = []
+        block_prod = 1
+
+        for i in range(block_rank):
+            block_size = 2
+            block_prod = block_prod * block_size
+            crop_size = 0
+            block_shape.append(block_size)
+            padding_shape.append([0, crop_size])
+
+        # batch / prod(block_shape[i]) must be integer
+        # transpose to swap depth and batch. so shape[-1] would be batch dim
+        if shapes[-1] % block_prod == 0:
+            args.append(
+                ["_blockrank_{}".format(block_rank), [block_shape, padding_shape]]
+            )
+
+        return args
+
+    def agSpaceToDepth(op, shapes, rng):
+        # must be rank 4 input tensor
+        if len(shapes) != 4:
+            return []
+
+        block_size = 2
+
+        # spatial dimension must be divisible by block_size
+        if shapes[1] % block_size != 0 or shapes[2] % block_size != 0:
+            return []
+
+        args = []
+        args.append(["_blocksize_{}".format(block_size), [block_size]])
+
+        return args
+
+    def agDepthToSpace(op, shapes, rng):
+        # must be rank 4 input tensor
+        if len(shapes) != 4:
+            return []
+
+        block_size = 2
+        # depth dimension must be divisible by block_size * block_size
+        if shapes[3] % (block_size * block_size) != 0:
+            return []
+
+        args = []
+        args.append(["_blocksize_{}".format(block_size), [block_size]])
+
+        return args
+
+    def agFakequant(op, shapes, rng):
+        args = []
+        for num_bits in [8, 16]:
+            for narrow in [False, True]:
+                args.append(
+                    ["_bits{}_narrow{}".format(num_bits, narrow), [num_bits, narrow]]
+                )
+
+        return args
+
+    def agShift(op, shapes, rng):
+        args = []
+
+        for shift in rng.integers(0, 32, size=8):
+            args.append(["_shift{}".format(shift), [shift]])
+
+        return args
+
+    def agFloat(op, shapes, rng):
+        args = []
+
+        i = 0
+        for alpha in np.float32(rng.random(size=2)):
+            args.append(["_{}".format(i), [alpha]])
+
+        return args
+
+    # Similar to agAxes, but tf.OneHot only allow axis from [-1, rank(input)]
+    def agOneHot(op, shapes, rng):
+        axes = []
+        for i in range(-1, len(shapes) + 1, 1):
+            if i >= 0:
+                axes.append(["_axis_{}".format(i), [i]])
+            else:
+                axes.append(["_axis_m{}".format(-i), [i]])
+        return axes
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py
new file mode 100644
index 0000000..e57175b
--- /dev/null
+++ b/verif/frameworks/tensor_gen.py
@@ -0,0 +1,264 @@
+# Copyright (c) 2020-2022, ARM Limited.
+# SPDX-License-Identifier: Apache-2.0
+import numpy as np
+import tensorflow as tf
+
+# FIXME: replace hardcoded '* 2' with random integers, where possible
+
+# The scaling factor for random numbers generated in input tensors.  The
+# random numbers are calculated as:
+# (np.random.rand() - RAND_SHIFT_FACTOR) * RAND_SCALE_FACTOR
+# FIXME: improve range here
+RAND_SCALE_FACTOR = 4.0
+# Amount to add to random numbers
+RAND_SHIFT_FACTOR = 0.5
+
+RAND_INT_MIN = -128
+RAND_INT_MAX = 128
+
+
+class TGen:
+    """A collection of functions to build tensor value arguments for an operator"""
+
+    def __init__(self):
+        pass
+
+    @staticmethod
+    def getRand(shape, dtype, rng):
+        if dtype == tf.float32:
+            return np.float32(
+                (rng.random(size=shape) - RAND_SHIFT_FACTOR) * RAND_SCALE_FACTOR
+            )
+        if dtype == tf.float16:
+            return np.float16(
+                (rng.random(size=shape) - RAND_SHIFT_FACTOR) * RAND_SCALE_FACTOR
+            )
+        if dtype == tf.int32:
+            return np.int32(
+                rng.integers(low=RAND_INT_MIN, high=RAND_INT_MAX, size=shape)
+            )
+        if dtype == tf.uint32:
+            return np.uint32(rng.integers(low=0, high=RAND_INT_MAX, size=shape))
+        if dtype == tf.bool:
+            return np.bool_(rng.choice(a=[False, True], size=shape))
+
+        raise Exception("Unsupported type: {}".format(dtype))
+
+    @staticmethod
+    def tgBasic(op, shape, dtype, rng):
+        # Build random tensor placeholder node args of a given shape
+        pl, const = op["operands"]
+
+        tf_placeholders = []
+        tf_consts = []
+
+        for i in range(pl):
+            tf_placeholders.append(
+                ("placeholder_{}".format(i), TGen.getRand(shape, dtype, rng))
+            )
+
+        for i in range(const):
+            tf_consts.append(("const_{}".format(i), TGen.getRand(shape, dtype, rng)))
+
+        return tf_placeholders, tf_consts
+
+    @staticmethod
+    def tgBFuzz(op, shape, dtype, rng):
+        # Build random tensor placeholder node args of a given shape, optionally
+        # fuzzing the arguments with random 1's to force broadcasting
+
+        pl, const = op["operands"]
+
+        assert const == 0
+
+        fuzz_arg = rng.integers(0, pl + const)
+        fuzz_idx = rng.integers(0, len(shape))
+
+        tf_placeholders = []
+        tf_consts = []
+        for i in range(pl):
+            if i == fuzz_arg:
+                # Insert the broadcast in one dimension index
+                s_fuzz = list(shape)
+                s_fuzz[fuzz_idx] = 1
+                s_fuzz = tuple(s_fuzz)
+                i_shape = s_fuzz
+            else:
+                i_shape = shape
+            tf_placeholders.append(
+                ("placeholder_{}".format(i), TGen.getRand(i_shape, dtype, rng))
+            )
+
+        return tf_placeholders, tf_consts
+
+    @staticmethod
+    def tgConv2d(op, ifm_shape, dtype, rng):
+
+        # Take the shape and generate an input and filter
+        tf_placeholders = []
+        tf_consts = []
+
+        # Require rank 4 shape
+        if len(ifm_shape) != 4:
+            return [], []
+
+        filter_h, filter_w = op["filter"]
+
+        # TODO: Hard-code the test by making the OFM depth 2x the IFM depth.
+        # Could randomize this in the future.
+        filter_shape = (filter_h, filter_w, ifm_shape[3], ifm_shape[3] * 2)
+
+        tf_placeholders.append(("placeholder_0", TGen.getRand(ifm_shape, dtype, rng)))
+        tf_consts.append(("const_0", TGen.getRand(filter_shape, dtype, rng)))
+
+        try:
+            bias = op["bias"]
+        except KeyError:
+            bias = False
+
+        if bias:
+            # bias is 1D and size == output channels
+            bias_shape = (ifm_shape[3] * 2,)
+            tf_consts.append(("const_1", TGen.getRand(bias_shape, dtype, rng)))
+
+        return tf_placeholders, tf_consts
+
+    @staticmethod
+    def tgDepthwiseConv2d(op, ifm_shape, dtype, rng):
+
+        # Take the shape and generate an input and filter
+        tf_placeholders = []
+        tf_consts = []
+
+        # Require rank 4 shape
+        if len(ifm_shape) != 4:
+            return [], []
+
+        filter_h, filter_w = op["filter"]
+
+        # TODO: Hard-code the test by making the channel_multiplier=2. Could randomize
+        # this in the future.
+        filter_shape = (filter_h, filter_w, ifm_shape[3], 2)
+
+        tf_placeholders.append(("placeholder_0", TGen.getRand(ifm_shape, dtype, rng)))
+        tf_consts.append(("const_0", TGen.getRand(filter_shape, dtype, rng)))
+
+        try:
+            bias = op["bias"]
+        except KeyError:
+            bias = False
+
+        if bias:
+            # bias is 1D and size == output channels
+            bias_shape = (ifm_shape[3] * 2,)
+            tf_consts.append(("const_1", TGen.getRand(bias_shape, dtype, rng)))
+
+        return tf_placeholders, tf_consts
+
+    @staticmethod
+    def tgTransposeConv2d(op, ifm_shape, dtype, rng):
+
+        # Take the shape and generate an input and filter
+        tf_placeholders = []
+        tf_consts = []
+
+        # Require rank 4 shape
+        if len(ifm_shape) != 4:
+            return [], []
+
+        filter_h, filter_w = op["filter"]
+
+        # TODO: Hard-code the test by making the IFM depth 2x the OFM depth.
+        # Could randomize this in the future.
+        filter_shape = (filter_h, filter_w, ifm_shape[3] * 2, ifm_shape[3])
+
+        tf_placeholders.append(("placeholder_0", TGen.getRand(ifm_shape, dtype, rng)))
+        tf_consts.append(("const_0", TGen.getRand(filter_shape, dtype, rng)))
+
+        try:
+            bias = op["bias"]
+        except KeyError:
+            bias = False
+
+        if bias:
+            # bias is 1D and size == output channels
+            bias_shape = ifm_shape[3] * 2
+            tf_consts.append(("const_1", TGen.getRand(bias_shape, dtype, rng)))
+
+        return tf_placeholders, tf_consts
+
+    @staticmethod
+    def tgPooling(op, shapes, dtype, rng):
+        # Pooling does nothing special except filter out non-rank-4 tensors
+        if len(shapes) != 4:
+            return [], []
+
+        return TGen.tgBasic(op, shapes, dtype, rng)
+
+    @staticmethod
+    def tgMatmul(op, ifm_shape, dtype, rng):
+        # Take the shape and generate an input and filter
+        tf_placeholders = []
+        tf_consts = []
+
+        if len(ifm_shape) < 2:
+            return [], []
+
+        # For ifm_shape = [..., N, K]
+        # Generate rhs tensor with shape [..., K x (2 * N)]
+        tf_placeholders.append(("placeholder_0", TGen.getRand(ifm_shape, dtype, rng)))
+
+        shape_rhs = list(ifm_shape)
+        shape_rhs[-2] = ifm_shape[-1]
+        shape_rhs[-1] = ifm_shape[-2] * 2
+        tf_placeholders.append(
+            (
+                "placeholder_1",
+                TGen.getRand(shape_rhs, dtype, rng),
+            )
+        )
+
+        return tf_placeholders, tf_consts
+
+    @staticmethod
+    def tgOneHot(op, shape, dtype, rng):
+        # Build random tensor placeholder node args of a given shape
+        pl, const = op["operands"]
+
+        assert pl == 3 and const == 1
+
+        tf_placeholders = []
+        tf_consts = []
+
+        # depth
+        depth = np.int32(rng.integers(low=1, high=32, size=None))
+        tf_consts.append(("const_0", depth))
+
+        # indices
+        indices = np.int32(rng.integers(low=0, high=depth, size=shape))
+        tf_placeholders.append(("placeholder_0", indices))
+
+        # on_value
+        tf_placeholders.append(("placeholder_1", TGen.getRand(None, dtype, rng)))
+
+        # off_value
+        tf_placeholders.append(("placeholder_2", TGen.getRand(None, dtype, rng)))
+
+        return tf_placeholders, tf_consts
+
+    @staticmethod
+    def tgSelect(op, shape, dtype, rng):
+        # Build random tensor placeholder node args of a given shape
+        pl, const = op["operands"]
+        assert pl == 3 and const == 0
+
+        tf_placeholders = []
+        tf_consts = []
+
+        # selector
+        tf_placeholders.append(("placeholder_0", TGen.getRand(None, tf.bool, rng)))
+        # inputs
+        tf_placeholders.append(("placeholder_1", TGen.getRand(shape, dtype, rng)))
+        tf_placeholders.append(("placeholder_2", TGen.getRand(shape, dtype, rng)))
+
+        return tf_placeholders, tf_consts
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py
new file mode 100644
index 0000000..a47cf5c
--- /dev/null
+++ b/verif/frameworks/test_builder.py
@@ -0,0 +1,1028 @@
+# Copyright (c) 2020-2022, ARM Limited.
+# SPDX-License-Identifier: Apache-2.0
+import numpy as np
+import tensorflow as tf
+from frameworks.tensor_gen import TGen
+
+
+class TBuilder:
+    """The member functions build the tensorflow operators into small networks
+    for our tests"""
+
+    def __init__(self):
+        pass
+
+    def fake_quant(tensor, tensor_scale, name):
+        """Helper function for quantizing with a scaling parameters structure."""
+        return tf.quantization.fake_quant_with_min_max_args(
+            tensor,
+            min=tensor_scale.min,
+            max=tensor_scale.max,
+            num_bits=tensor_scale.num_bits,
+            narrow_range=tensor_scale.narrow_range,
+            name=name,
+        )
+
+    def fake_quant_params(tensor, min, max, scaling, name):
+        """Helper function for quantizing with individual scaling parameters."""
+        return tf.quantization.fake_quant_with_min_max_args(
+            tensor,
+            min=min,
+            max=max,
+            num_bits=scaling.num_bits,
+            narrow_range=scaling.narrow_range,
+            name=name,
+        )
+
+    class Add:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a, b):
+            return tf.add(a, b, name=self.result_name)
+
+    class Sub:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a, b):
+            return tf.subtract(a, b, name=self.result_name)
+
+    class Mul:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a, b):
+            return tf.multiply(a, b, name=self.result_name)
+
+    class Exp:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.exp(a, name=self.result_name)
+
+    class Rcp:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.math.reciprocal(a, name=self.result_name)
+
+    class Relu:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.nn.relu(a, name=self.result_name)
+
+    class Relu6:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.nn.relu6(a, name=self.result_name)
+
+    class LeakyRelu:
+        def __init__(self, alpha, name):
+            self.alpha = alpha
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.nn.leaky_relu(a, alpha=self.alpha, name=self.result_name)
+
+    class Concat:
+        def __init__(self, axis, name):
+            self.axis = axis
+            self.result_name = name
+
+        def eval(self, a, b):
+            return tf.concat([a, b], self.axis, name=self.result_name)
+
+    class BitwiseAnd:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a, b):
+            return tf.bitwise.bitwise_and(a, b, name=self.result_name)
+
+    class BitwiseOr:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a, b):
+            return tf.bitwise.bitwise_or(a, b, name=self.result_name)
+
+    class BitwiseNot:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.bitwise.invert(a, name=self.result_name)
+
+    class BitwiseXor:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a, b):
+            return tf.bitwise.bitwise_xor(a, b, name=self.result_name)
+
+    class LogicalAnd:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a, b):
+            return tf.math.logical_and(a, b, name=self.result_name)
+
+    class LogicalOr:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a, b):
+            return tf.math.logical_or(a, b, name=self.result_name)
+
+    class LogicalNot:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.math.logical_not(a, name=self.result_name)
+
+    class ReduceAny:
+        def __init__(self, axis_list, keepdims, name):
+            self.axis_list = axis_list
+            self.keepdims = keepdims
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.math.reduce_any(
+                a, self.axis_list, keepdims=self.keepdims, name=self.result_name
+            )
+
+    class ReduceAll:
+        def __init__(self, axis_list, keepdims, name):
+            self.axis_list = axis_list
+            self.keepdims = keepdims
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.math.reduce_all(
+                a, self.axis_list, keepdims=self.keepdims, name=self.result_name
+            )
+
+    class ReduceMin:
+        def __init__(self, axis_list, keepdims, name):
+            self.axis_list = axis_list
+            self.keepdims = keepdims
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.math.reduce_min(
+                a, self.axis_list, keepdims=self.keepdims, name=self.result_name
+            )
+
+    class ReduceMax:
+        def __init__(self, axis_list, keepdims, name):
+            self.axis_list = axis_list
+            self.keepdims = keepdims
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.math.reduce_max(
+                a, self.axis_list, keepdims=self.keepdims, name=self.result_name
+            )
+
+    class ReduceSum:
+        def __init__(self, axis_list, keepdims, name):
+            self.axis_list = axis_list
+            self.keepdims = keepdims
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.math.reduce_sum(
+                a, self.axis_list, keepdims=self.keepdims, name=self.result_name
+            )
+
+    class ReduceMean:
+        def __init__(self, axis_list, keepdims, name):
+            self.axis_list = axis_list
+            self.keepdims = keepdims
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.math.reduce_mean(
+                a, self.axis_list, keepdims=self.keepdims, name=self.result_name
+            )
+
+    class ReduceProduct:
+        def __init__(self, axis_list, keepdims, name):
+            self.axis_list = axis_list
+            self.keepdims = keepdims
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.math.reduce_prod(
+                a, self.axis_list, keepdims=self.keepdims, name=self.result_name
+            )
+
+    class Min:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a, b):
+            return tf.math.minimum(a, b, name=self.result_name)
+
+    class Max:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a, b):
+            return tf.math.maximum(a, b, name=self.result_name)
+
+    class Pow:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a, b):
+            return tf.math.pow(a, b, name=self.result_name)
+
+    class Abs:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.math.abs(a, name=self.result_name)
+
+    class Ceil:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.math.ceil(a, name=self.result_name)
+
+    class Floor:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.math.floor(a, name=self.result_name)
+
+    class Log:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.math.log(a, name=self.result_name)
+
+    class Negate:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.math.negative(a, name=self.result_name)
+
+    class Rsqrt:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.math.rsqrt(a, name=self.result_name)
+
+    class Sigmoid:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.math.sigmoid(a, name=self.result_name)
+
+    class Tanh:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.math.tanh(a, name=self.result_name)
+
+    class Square:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.math.square(a, name=self.result_name)
+
+    class SquaredDifference:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a, b):
+            return tf.math.squared_difference(a, b, name=self.result_name)
+
+    class Equal:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a, b):
+            return tf.math.equal(a, b, name=self.result_name)
+
+    class GreaterEqual:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a, b):
+            return tf.math.greater_equal(a, b, name=self.result_name)
+
+    class Greater:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a, b):
+            return tf.math.greater(a, b, name=self.result_name)
+
+    class Less:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a, b):
+            return tf.math.less(a, b, name=self.result_name)
+
+    class LessEqual:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a, b):
+            return tf.math.less_equal(a, b, name=self.result_name)
+
+    class Conv2d:
+        def __init__(self, weight, strides, padding, dilations, name):
+            self.weight = weight
+            self.strides = strides
+            self.padding = padding
+            self.dilations = dilations
+            self.result_name = name
+
+        def eval(self, input):
+            return tf.nn.conv2d(
+                input,
+                self.weight,
+                self.strides,
+                self.padding,
+                data_format="NHWC",
+                dilations=self.dilations,
+                name=self.result_name,
+            )
+
+    class Conv2dRelu:
+        def __init__(self, weight, name):
+            self.weight = weight
+            self.result_name = name
+
+        def eval(self, input):
+            conv2d = tf.nn.conv2d(
+                input,
+                self.weight,
+                [1, 1, 1, 1],
+                "SAME",
+                data_format="NHWC",
+                dilations=[1, 1, 1, 1],
+                name="conv2d",
+            )
+            return tf.nn.relu(conv2d, name=self.result_name)
+
+    class Conv2dRelu6:
+        def __init__(self, weight, name):
+            self.weight = weight
+            self.result_name = name
+
+        def eval(self, input):
+            conv2d = tf.nn.conv2d(
+                input,
+                self.weight,
+                [1, 1, 1, 1],
+                "SAME",
+                data_format="NHWC",
+                dilations=[1, 1, 1, 1],
+                name="conv2d",
+            )
+            return tf.nn.relu6(conv2d, name=self.result_name)
+
+    class Conv2dReluN1To1:
+        def __init__(self, weight, name):
+            self.weight = weight
+            self.result_name = name
+
+        def eval(self, input):
+            conv2d = tf.nn.conv2d(
+                input,
+                self.weight,
+                [1, 1, 1, 1],
+                "SAME",
+                data_format="NHWC",
+                dilations=[1, 1, 1, 1],
+                name="conv2d",
+            )
+            return tf.clip_by_value(conv2d, -1.0, 1.0, name=self.result_name)
+
+    class Conv2dTanh:
+        def __init__(self, weight, name):
+            self.weight = weight
+            self.result_name = name
+
+        def eval(self, input):
+            conv2d = tf.nn.conv2d(
+                input,
+                self.weight,
+                [1, 1, 1, 1],
+                "SAME",
+                data_format="NHWC",
+                dilations=[1, 1, 1, 1],
+                name="conv2d",
+            )
+            return tf.math.tanh(conv2d, name=self.result_name)
+
+    class Conv2dWithBias:
+        def __init__(self, weight, bias, strides, padding, dilations, name):
+            self.weight = weight
+            self.bias = bias
+            self.strides = strides
+            self.padding = padding
+            self.dilations = dilations
+            self.result_name = name
+
+        def eval(self, input):
+            conv2d_op = tf.nn.conv2d(
+                input,
+                self.weight,
+                self.strides,
+                self.padding,
+                data_format="NHWC",
+                dilations=self.dilations,
+                name="conv2d",
+            )
+            bias_add_op = tf.nn.bias_add(
+                conv2d_op, self.bias, data_format="NHWC", name=self.result_name
+            )
+            return bias_add_op
+
+    class DepthwiseConv2d:
+        def __init__(self, weight, strides, padding, dilations, name):
+            self.weight = weight
+            self.strides = strides
+            self.padding = padding
+            self.dilations = dilations
+            self.result_name = name
+
+        def eval(self, input):
+            dws_conv2d = tf.nn.depthwise_conv2d(
+                input,
+                self.weight,
+                self.strides,
+                self.padding,
+                data_format="NHWC",
+                dilations=self.dilations,
+                name="dws_conv2d",
+            )
+            return tf.identity(dws_conv2d, name=self.result_name)
+
+    class DepthwiseConv2dWithBias:
+        def __init__(self, weight, bias, strides, padding, dilations, name):
+            self.weight = weight
+            self.bias = bias
+            self.strides = strides
+            self.padding = padding
+            self.dilations = dilations
+            self.result_name = name
+
+        def eval(self, input):
+            dws_conv2d = tf.nn.depthwise_conv2d(
+                input,
+                self.weight,
+                self.strides,
+                self.padding,
+                data_format="NHWC",
+                dilations=self.dilations,
+                name="dws_conv2d",
+            )
+            bias_add_op = tf.nn.bias_add(
+                dws_conv2d, self.bias, data_format="NHWC", name=self.result_name
+            )
+            return bias_add_op
+
+    class TransposeConv2d:
+        def __init__(self, weight, output_shape, strides, padding, name):
+            self.weight = weight
+            self.output_shape = output_shape
+            self.strides = strides
+            self.padding = padding
+            self.result_name = name
+
+        def eval(self, input):
+            return tf.nn.conv2d_transpose(
+                input,
+                self.weight,
+                self.output_shape,
+                self.strides,
+                self.padding,
+                data_format="NHWC",
+                name=self.result_name,
+            )
+
+    class Argmax:
+        def __init__(self, axis, name):
+            self.axis = axis
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.argmax(a, self.axis, output_type=tf.int32, name=self.result_name)
+
+    class AvgPool2d:
+        def __init__(self, strides, kernel_size, padding, name):
+            self.strides = strides
+            self.kernel_size = kernel_size
+            self.padding = padding
+            self.result_name = name
+
+        def eval(self, input):
+            return tf.nn.avg_pool2d(
+                input,
+                strides=self.strides,
+                ksize=self.kernel_size,
+                padding=self.padding,
+                data_format="NHWC",
+                name=self.result_name,
+            )
+
+    class MaxPool2d:
+        def __init__(self, strides, kernel_size, padding, name):
+            self.strides = strides
+            self.kernel_size = kernel_size
+            self.padding = padding
+            self.result_name = name
+
+        def eval(self, input):
+            return tf.nn.max_pool2d(
+                input,
+                strides=self.strides,
+                ksize=self.kernel_size,
+                padding=self.padding,
+                data_format="NHWC",
+                name=self.result_name,
+            )
+
+    class Reshape:
+        def __init__(self, shape, name):
+            self.shape = shape
+            self.result_name = name
+
+        def eval(self, a):
+            reshape_op = tf.reshape(a, self.shape)
+            return tf.identity(reshape_op, name=self.result_name)
+
+    class Transpose:
+        def __init__(self, perm, name):
+            self.perm = perm
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.transpose(a, self.perm, name=self.result_name)
+
+    class Slice:
+        def __init__(self, begin, size, name):
+            self.begin = begin
+            self.size = size
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.slice(a, begin=self.begin, size=self.size, name=self.result_name)
+
+    class StridedSlice:
+        def __init__(
+            self,
+            begin,
+            end,
+            strides,
+            begin_mask,
+            end_mask,
+            ellipsis_mask,
+            new_axis_mask,
+            shrink_axis_mask,
+            name,
+        ):
+            self.begin = begin
+            self.end = end
+            self.strides = strides
+            self.begin_mask = begin_mask
+            self.end_mask = end_mask
+            self.ellipsis_mask = ellipsis_mask
+            self.new_axis_mask = new_axis_mask
+            self.shrink_axis_mask = shrink_axis_mask
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.strided_slice(
+                a,
+                begin=self.begin,
+                end=self.end,
+                strides=self.strides,
+                begin_mask=self.begin_mask,
+                end_mask=self.end_mask,
+                ellipsis_mask=self.ellipsis_mask,
+                new_axis_mask=self.new_axis_mask,
+                shrink_axis_mask=self.shrink_axis_mask,
+                name=self.result_name,
+            )
+
+    class Select:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, selector, a, b):
+            return tf.where(condition=selector, x=a, y=b, name=self.result_name)
+
+    class Addn:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a, b, c, d):
+            return tf.add_n([a, b, c, d], name=self.result_name)
+
+    class Concatv2:
+        def __init__(self, axis, name):
+            self.axis = axis
+            self.result_name = name
+
+        def eval(self, a, b, c, d):
+            return tf.concat([a, b, c, d], axis=self.axis, name=self.result_name)
+
+    class Stack:
+        def __init__(self, axis, name):
+            self.axis = axis
+            self.result_name = name
+
+        def eval(self, a, b, c, d):
+            return tf.stack([a, b, c, d], axis=self.axis, name=self.result_name)
+
+    class Unstack:
+        def __init__(self, axis, name):
+            self.axis = axis
+            self.result_name = name
+
+        def eval(self, a):
+            unstack_op = tf.unstack(a, axis=self.axis, name="unstack_op")
+            result_count = a.shape[self.axis]
+
+            if result_count == 1:
+                return tf.identity(unstack_op[0], name=self.result_name)
+
+            sums = []
+            for i in range(result_count):
+                sums.append(
+                    tf.math.reduce_sum(unstack_op[i], name="reduce_{}".format(i))
+                )
+            return tf.stack(sums, 0, name=self.result_name)
+
+    class Pad:
+        def __init__(self, padding, name):
+            self.padding = padding
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.pad(
+                a,
+                self.padding,
+                mode="CONSTANT",
+                constant_values=0,
+                name=self.result_name,
+            )
+
+    class ExpandDims:
+        def __init__(self, axis, name):
+            self.axis = axis
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.expand_dims(a, self.axis, name=self.result_name)
+
+    class Shape:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.shape(a, name=self.result_name)
+
+    class Rank:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.rank(a, name=self.result_name)
+
+    class Fill:
+        def __init__(self, shape, value, name):
+            self.shape = shape
+            self.value = value
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.fill(self.shape, self.value, name=self.result_name)
+
+    class Elu:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.nn.elu(a, name=self.result_name)
+
+    class Softmax:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.nn.softmax(a, name=self.result_name)
+
+    class LogSoftmax:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.nn.log_softmax(a, name=self.result_name)
+
+    class MatMul:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a, b):
+            return tf.linalg.matmul(a, b, name=self.result_name)
+
+    class AddScalar:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.add(a, 1, name=self.result_name)
+
+    class Add1d:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a, b):
+            if len(b.shape) > 1:
+                b_1d = tf.reduce_sum(b, axis=list(range(0, len(b.shape) - 1, 1)))
+            else:
+                b_1d = b
+            return tf.add(a, b_1d, name=self.result_name)
+
+    class Split:
+        def __init__(self, num_splits, axis, name):
+            self.num_splits = num_splits
+            self.axis = axis
+            self.result_name = name
+
+        def eval(self, a):
+            # The split op generates a list of outputs.  Since we have difficulty
+            # serializing a list or array of Numpy arrays, we will reduce each of
+            # the results
+
+            if not isinstance(self.num_splits, list):
+                split_op = tf.split(
+                    a, num_or_size_splits=self.num_splits, axis=self.axis, name="split"
+                )
+                result_count = self.num_splits
+            else:
+                num_split = np.asarray(self.num_splits, dtype=np.int32)
+                split_vec_op = tf.compat.v1.constant(
+                    num_split,
+                    shape=num_split.shape,
+                    dtype=tf.int32,
+                    name="const_split_vec",
+                )
+                split_op = tf.split(
+                    a, num_or_size_splits=split_vec_op, axis=self.axis, name="split"
+                )
+                result_count = num_split.shape[0]
+
+            sums = []
+            for i in range(result_count):
+                sums.append(tf.math.reduce_sum(split_op[i], name="reduce_{}".format(i)))
+            return tf.stack(sums, 0, name=self.result_name)
+
+    class Tile:
+        def __init__(self, multiples, name):
+            self.multiples = multiples
+            self.result_name = name
+
+        def eval(self, a):
+            t = tf.tile(a, self.multiples, name="tile")
+            return tf.identity(t, name=self.result_name)
+
+    class Reverse:
+        def __init__(self, axis, name):
+            self.axis = axis
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.reverse(a, [self.axis], name=self.result_name)
+
+    class Gather:
+        def __init__(self, indices, batch_dims, axis, name):
+            self.indices = indices
+            self.batch_dims = batch_dims
+            self.axis = axis
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.gather(
+                a,
+                self.indices,
+                batch_dims=self.batch_dims,
+                axis=self.axis,
+                name=self.result_name,
+            )
+
+    class GatherNd:
+        def __init__(self, indices, name):
+            self.indices = indices
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.gather_nd(a, self.indices, name=self.result_name)
+
+    class ScatterNd:
+        def __init__(self, shape, indices_shape, N, rng, name):
+            self.shape = shape
+            self.indices_shape = indices_shape
+            self.N = N
+            self.rng = rng
+            self.result_name = name
+
+        def eval(self, a):
+
+            # This operator is special.  The indices and updates tensors really need
+            # to be created together, but in the current structure of this tool there
+            # is no way to do that before now.  The number of updates is determined by
+            # the indices, so we can really only create that after indices; but we
+            # don't know the type at that time.
+            #
+            # Shapes are guaranteed deterministic, but we'll use our rng
+            # copied from the arggen stage.  It's possible that index and
+            # update *values* will be non-deterministic.
+            #
+            # We take the tensor_tensor simply to get the dtype.
+
+            shape_const = tf.constant(self.shape, tf.int32)
+
+            updates_shape = list(self.indices_shape[:-1])
+            updates_shape.extend(self.shape[self.indices_shape[-1] :])
+
+            updates_const = tf.constant(TGen.getRand(updates_shape, a.dtype, self.rng))
+
+            indices = np.zeros(self.indices_shape, dtype=np.int32)
+
+            # We need to generate the random indices tensor based on the
+            # limits of 'shape' for each dimension.  Surely, there is a faster
+            # vectorized way to do this, but the tensors are fairly small so we
+            # will do this one element at a time.  Each element needs to be sized based
+            # on the size of the last dimension.
+            for idx in np.ndindex(indices.shape):
+                indices[idx] = self.rng.integers(0, self.shape[idx[-1]], size=1)[0]
+                # print('{} {}'.format(idx, indices[idx]))
+
+            indices_const = tf.constant(indices, dtype=tf.int32)
+
+            return tf.scatter_nd(
+                indices=indices_const,
+                updates=updates_const,
+                shape=shape_const,
+                name=self.result_name,
+            )
+
+    class SpaceToBatch:
+        def __init__(self, block_shape, padding, name):
+            self.block_shape = block_shape
+            self.padding = padding
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.space_to_batch(
+                a, self.block_shape, self.padding, name=self.result_name
+            )
+
+    class BatchToSpace:
+        def __init__(self, block_shape, cropping, name):
+            self.block_shape = block_shape
+            self.cropping = cropping
+            self.result_name = name
+
+        def eval(self, a):
+            # transpose to swap depth and batch first. this could avoid adding new shape
+            block_rank = len(self.block_shape)
+            perm = [len(a.shape) - 1]
+            for i in range(block_rank):
+                perm.append(i + 1)
+            perm.append(0)
+            transpose_op = tf.transpose(a, perm)
+            return tf.batch_to_space(
+                transpose_op, self.block_shape, self.cropping, name=self.result_name
+            )
+
+    class SpaceToDepth:
+        def __init__(self, block_shape, name):
+            self.block_shape = block_shape
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.nn.space_to_depth(a, self.block_shape, name=self.result_name)
+
+    class DepthToSpace:
+        def __init__(self, block_shape, name):
+            self.block_shape = block_shape
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.nn.depth_to_space(a, self.block_shape, name=self.result_name)
+
+    class OneHot:
+        def __init__(self, depth, axis, name):
+            self.depth = depth
+            self.axis = axis
+            self.result_name = name
+
+        def eval(self, indices, on_value, off_value):
+            return tf.one_hot(
+                indices,
+                self.depth,
+                on_value,
+                off_value,
+                self.axis,
+                on_value.dtype,
+                self.result_name,
+            )
+
+    class Fakequant:
+        def __init__(self, num_bits, narrow_range, name):
+            self.num_bits = num_bits
+            self.narrow_range = narrow_range
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.quantization.fake_quant_with_min_max_args(
+                a,
+                min=-2.0,
+                max=2.0,
+                num_bits=self.num_bits,
+                narrow_range=self.narrow_range,
+                name=self.result_name,
+            )
+
+    class ResizeNearest:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a):
+            out_shape = []
+            out_shape.append(a.shape[1] * 2)
+            out_shape.append(a.shape[2] * 2)
+
+            # tf.image.resize() will overwrite the node name with result_name +
+            # '/BILINEAR' need to add extra identity to force output tensor name to
+            # result_name return tf.image.resize(a, out_shape,
+            # method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, name=result_name)
+            resize = tf.image.resize(
+                a,
+                out_shape,
+                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
+                name="resize",
+            )
+            return tf.identity(resize, name=self.result_name)
+
+    class ResizeBilinear:
+        def __init__(self, name):
+            self.result_name = name
+
+        def eval(self, a):
+            out_shape = []
+            out_shape.append(a.shape[1] * 2)
+            out_shape.append(a.shape[2] * 2)
+
+            # tf.image.resize() will overwrite the node name with result_name +
+            # '/BILINEAR' need to add extra identity to force output tensor name to
+            # result_name return tf.image.resize(a, out_shape,
+            # method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, name=result_name)
+            resize = tf.image.resize(
+                a, out_shape, method=tf.image.ResizeMethod.BILINEAR, name="resize"
+            )
+            return tf.identity(resize, name=self.result_name)
+
+    class LeftShift:
+        def __init__(self, shift, name):
+            self.shift = shift
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.bitwise.left_shift(a, self.shift, name=self.result_name)
+
+    class RightShift:
+        def __init__(self, shift, name):
+            self.shift = shift
+            self.result_name = name
+
+        def eval(self, a):
+            return tf.bitwise.right_shift(a, self.shift, name=self.result_name)
diff --git a/verif/frameworks/test_gen_utils.py b/verif/frameworks/test_gen_utils.py
new file mode 100644
index 0000000..2d8e5d6
--- /dev/null
+++ b/verif/frameworks/test_gen_utils.py
@@ -0,0 +1,76 @@
+# Copyright (c) 2020-2022, ARM Limited.
+# SPDX-License-Identifier: Apache-2.0
+from enum import IntEnum
+from enum import unique
+
+import tensorflow as tf
+
+
+# Get a string name for a given shape
+def get_shape_str(shape, dtype):
+    shape_name = None
+    for dim in shape:
+        shape_name = (shape_name + "x" + str(dim)) if shape_name else str(dim)
+
+    if dtype == tf.float32:
+        shape_name = shape_name + "_f32"
+    elif dtype == tf.float16:
+        shape_name = shape_name + "_f16"
+    elif dtype == tf.int32:
+        shape_name = shape_name + "_i32"
+    elif dtype == tf.uint32:
+        shape_name = shape_name + "_u32"
+    elif dtype == tf.bool:
+        shape_name = shape_name + "_bool"
+    elif dtype == tf.quint8:
+        shape_name = shape_name + "_qu8"
+    elif dtype == tf.qint8:
+        shape_name = shape_name + "_qi8"
+    elif dtype == tf.qint16:
+        shape_name = shape_name + "_qi16"
+    elif dtype == tf.quint16:
+        shape_name = shape_name + "_qu16"
+    else:
+        raise Exception("Unsupported type: {}".format(dtype))
+
+    return shape_name
+
+
+@unique
+class QuantType(IntEnum):
+    UNKNOWN = 0
+    ALL_I8 = 1
+    ALL_U8 = 2
+    ALL_I16 = 3
+    # TODO: support QUINT16
+    CONV_U8_U8 = 4
+    CONV_I8_I8 = 5
+    CONV_I8_I4 = 6
+    CONV_I16_I8 = 7
+
+
+def get_tf_dtype(quantized_inference_dtype):
+    if quantized_inference_dtype == QuantType.ALL_I8:
+        return tf.qint8
+    elif quantized_inference_dtype == QuantType.ALL_U8:
+        return tf.quint8
+    elif quantized_inference_dtype == QuantType.ALL_I16:
+        return tf.qint16
+    elif quantized_inference_dtype == QuantType.CONV_U8_U8:
+        return tf.quint8
+    elif quantized_inference_dtype == QuantType.CONV_I8_I8:
+        return tf.qint8
+    elif quantized_inference_dtype == QuantType.CONV_I8_I4:
+        return tf.qint8
+    elif quantized_inference_dtype == QuantType.CONV_I16_I8:
+        return tf.qint16
+    else:
+        return None
+
+
+class TensorScale:
+    def __init__(self, _min, _max, _num_bits, _narrow_range):
+        self.min = _min
+        self.max = _max
+        self.num_bits = _num_bits
+        self.narrow_range = _narrow_range
diff --git a/verif/frameworks/tosa_verif_framework_compiler_runner.py b/verif/frameworks/tosa_verif_framework_compiler_runner.py
new file mode 100755
index 0000000..337c8a4
--- /dev/null
+++ b/verif/frameworks/tosa_verif_framework_compiler_runner.py
@@ -0,0 +1,804 @@
+#!/usr/bin/env python3
+# Copyright (c) 2020-2022, ARM Limited.
+# SPDX-License-Identifier: Apache-2.0
+import argparse
+import glob
+import json
+import math
+import os
+import queue
+import re
+import sys
+import threading
+import traceback
+from datetime import datetime
+from enum import IntEnum
+from enum import unique
+
+import numpy as np
+from checker.tosa_result_checker import LogColors
+from checker.tosa_result_checker import print_color
+from checker.tosa_result_checker import set_print_in_color
+from runner.run_command import run_sh_command
+from xunit.xunit import xunit_results
+from xunit.xunit import xunit_test
+
+
+def parse_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "-t", "--test", dest="test", type=str, nargs="+", help="Test(s) to run"
+    )
+    parser.add_argument(
+        "-r",
+        "--recursive",
+        dest="recursive_tests",
+        action="store_true",
+        help="Recursively search for tests",
+    )
+    parser.add_argument(
+        "--tf-base-dir",
+        dest="tf_base_dir",
+        type=str,
+        required=True,
+        help="Tensorflow/MLIR base directory",
+    )
+    parser.add_argument(
+        "--tools-base-dir",
+        dest="tools_base_dir",
+        type=str,
+        required=True,
+        help="Reference model base directory",
+    )
+    parser.add_argument(
+        "-v", "--verbose", dest="verbose", action="count", help="Verbose run"
+    )
+    parser.add_argument(
+        "-dref",
+        "--debug-ref-model",
+        dest="debug_ref_model",
+        action="store_true",
+        help="Enable TOSA Reference model debugging",
+    )
+    parser.add_argument(
+        "--tolerance",
+        dest="tolerance",
+        default=1e-3,
+        type=float,
+        help="Comparison tolerance b value",
+    )
+    parser.add_argument(
+        "--no-compiler",
+        dest="no_compiler",
+        action="store_true",
+        help="Do not run TF MLIR/tfopt/TOSA compiler.  Just run TOSA Reference model",
+    )
+    parser.add_argument(
+        "--no-ref-model",
+        dest="no_ref",
+        action="store_true",
+        help="Do not run TOSA reference model, just run TF MLIR/tfopt/TOSA compiler.",
+    )
+    parser.add_argument(
+        "--valgrind",
+        dest="valgrind",
+        action="store_true",
+        help="Enable valgrind on TOSA Reference Model",
+    )
+    parser.add_argument(
+        "-j", "--jobs", dest="jobs", type=int, default=1, help="Number of parallel jobs"
+    )
+    parser.add_argument(
+        "--no-color",
+        "--no-colour",
+        dest="no_color",
+        action="store_true",
+        help="Disable color output",
+    )
+    parser.add_argument(
+        "-f",
+        "--framework",
+        dest="framework",
+        default=[],
+        action="append",
+        help="Frameworks to test (tf, tflite)",
+    )
+    parser.add_argument(
+        "--override-exclusions",
+        dest="override_exclusions",
+        default=False,
+        action="store_true",
+        help="Ignore the framework exclusions listed in the test JSON",
+    )
+    parser.add_argument(
+        "--xunit-file",
+        dest="xunit_file",
+        type=str,
+        default="result.xml",
+        help="XUnit result output file",
+    )
+    parser.add_argument(
+        "--xunit-classname-prefix",
+        dest="xunit_classname_prefix",
+        default="TFUnitTests",
+        help="Prefix for xunit classname",
+    )
+    parser.add_argument(
+        "--hex-bool-hack",
+        dest="hex_bool_hack",
+        default=1,
+        type=int,
+        help=(
+            "Hack around bug in MLIR hex parsing for boolean types"
+            " by disabling hex encoding"
+        ),
+    )
+    parser.add_argument(
+        "--regression-mode",
+        dest="regression_mode",
+        default=False,
+        action="store_true",
+        help="Options to make the script more friendly for jenkins regressions",
+    )
+    parser.add_argument(
+        "--quantize-tolerance",
+        dest="quantize_tolerance",
+        default=0,
+        type=int,
+        help=(
+            "Tolerance when comparing TOSA reference model result"
+            " to TensorFlow Lite reference"
+        ),
+    )
+    parser.add_argument(
+        "--test-dir",
+        dest="test_dir",
+        default="",
+        help="Path to prepend to paths in test.json",
+    )
+
+    parser.add_argument(
+        "-o", "--output", dest="output_file", help="Redirect script output to a file"
+    )
+
+    args = parser.parse_args()
+
+    # No easy way to both do array append and override a default value
+    if not args.framework:
+        args.framework = ["tf", "tflite"]
+
+    # Autodetect CPU count
+    if args.jobs <= 0:
+        args.jobs = os.cpu_count()
+
+    return args
+
+
+@unique
+class TestResult(IntEnum):
+    PASS = 0
+    COMPILER_ERROR = 1
+    REF_MODEL_ERROR = 2
+    REF_MODEL_UNPREDICTABLE = 3
+    REF_MODEL_RUNTIME_ERROR = 4
+    MISMATCH = 5
+    NOT_LOWERED = 6
+    INVALID_MLIR = 7
+    INTERNAL_ERROR = 8
+    SKIPPED = 9
+
+
+TestResultErrorStr = [
+    "",
+    "Compiler error",
+    "Reference model error",
+    "Reference model unpredictable",
+    "Reference model runtime error",
+    "Mismatch",
+    "Not lowered",
+    "Invalid MLIR",
+    "Internal error",
+    "",
+]
+
+
+def parse_compiler_output(compiler_stdout, compiler_stderr):
+    # Look for "has not been lowered yet, skipped" strings in stdout
+    expr = re.compile(".* has not been lowered yet, skipped.*")
+
+    for line in compiler_stdout.splitlines():
+        if expr.match(line):
+            return TestResult.NOT_LOWERED
+
+    return TestResult.PASS
+
+
+def parse_reference_model_output(ref_model_stdout, ref_model_stderr):
+    # Look for "has not been lowered yet, skipped" strings in stdout
+    unpredictable_expr = re.compile(r".*UNPREDICTABLE.*")
+    error_expr = re.compile(".* Graph result: ERROR.*")
+    unknown_expr = re.compile(".* Unknown graph status code.*")
+
+    for line in ref_model_stderr.splitlines():
+        if unpredictable_expr.match(line):
+            return TestResult.REF_MODEL_UNPREDICTABLE
+        elif error_expr.match(line):
+            return TestResult.REF_MODEL_ERROR
+        elif unknown_expr.match(line):
+            return TestResult.REF_MODEL_RUNTIME_ERROR
+
+    return TestResult.PASS
+
+
+# write a self-contained test descriptor in json format
+def write_reference_runner_json(
+    filename,
+    tosa_filename,
+    ifm_name,
+    ifm_file,
+    ofm_name,
+    ofm_file,
+    expected_failure=False,
+):
+    """Write a json test file so that it is fairly easy to pick up the test
+    and generate commands for third party tool"""
+    test_desc = dict()
+
+    test_desc["tosa_file"] = tosa_filename
+    test_desc["ifm_name"] = ifm_name
+    test_desc["ifm_file"] = ifm_file
+    test_desc["ofm_name"] = ofm_name
+    test_desc["ofm_file"] = ofm_file
+    test_desc["expected_failure"] = expected_failure
+
+    with open(filename, "w") as f:
+        json.dump(test_desc, f, indent="  ")
+
+
+def run_test(args, test, framework):
+
+    # parse test_name from test directory path
+    test_path = test.split("/")
+    test_name = None
+    for t in test_path[::-1]:
+        if len(t) != 0:
+            test_name = t
+            break
+    if not test_name:
+        raise Exception("Could not parse test_name from {}".format(test))
+
+    print_color(LogColors.GREEN, "## Running {} test {}".format(framework, test_name))
+
+    msg = ""
+
+    try:
+        with open(os.path.join(test, "test.json"), "r") as f:
+            test_desc = json.load(f)
+    except Exception:
+        raise Exception(
+            "Could not load or parse test from {}".format(
+                os.path.join(test, "test.json")
+            )
+        )
+
+    try:
+        if not args.override_exclusions:
+            for excl in test_desc["framework_exclusions"]:
+                if excl == framework:
+                    print_color(LogColors.GREEN, "Results SKIPPED")
+                    return (TestResult.SKIPPED, 0.0, "")
+    except KeyError:
+        pass
+
+    tf_tools_dir = os.path.abspath(
+        "{}/bazel-bin/tensorflow/compiler/mlir".format(args.tf_base_dir)
+    )
+
+    pre_opt_filename = os.path.join(test, "test_{}.preopt.mlir".format(framework))
+    post_opt_filename = os.path.join(test, "test_{}.postopt.mlir".format(framework))
+    if args.test_dir:
+        test_path_prepend = args.test_dir
+    else:
+        test_path_prepend = test
+
+    # 1. Framework to MLIR translator command
+    if framework == "tf":
+        if test_desc["tf_model_filename"].endswith(".mlir"):
+            pre_opt_filename = test_desc["tf_model_filename"]
+            translate_mlir_cmd = []
+        else:
+            translate_mlir_cmd = [
+                os.path.join(tf_tools_dir, "tf-mlir-translate"),
+                "--graphdef-to-mlir",
+                "--tf-enable-shape-inference-on-import",
+                "--tf-output-arrays={}".format(test_desc["tf_result_name"]),
+                os.path.join(test_path_prepend, test_desc["tf_model_filename"]),
+                "-o",
+                pre_opt_filename,
+            ]
+    elif framework == "tflite":
+        if test_desc["tflite_model_filename"].endswith(".mlir"):
+            pre_opt_filename = test_desc["tflite_model_filename"]
+            translate_mlir_cmd = []
+        else:
+            translate_mlir_cmd = [
+                os.path.join(tf_tools_dir, "lite", "flatbuffer_translate"),
+                "--tflite-flatbuffer-to-mlir",
+                os.path.join(test_path_prepend, test_desc["tflite_model_filename"]),
+                "--output-arrays={}".format(test_desc["tflite_result_name"]),
+                "-o",
+                pre_opt_filename,
+            ]
+    else:
+        raise Exception("Unknown framwork: {}".format(framework))
+
+    # Any additional inputs to the translator?
+    input_tensor_prefix = "TosaInput_"
+    flatbuffer_dir = "flatbuffer-{}".format(framework)
+    mlir_opts = []
+
+    # Temporary hack: MLIR's new hex encoding of large tensors does not work for
+    # boolean types
+    # for TF hash 8e8041d594a888eb67eafa5cc62627d7e9ca8082
+    if test.endswith("_bool") and args.hex_bool_hack:
+        mlir_opts.append("--mlir-print-elementsattrs-with-hex-if-larger=-1")
+
+    try:
+        # specify input tensors if test is generated from .pb
+        if framework == "tf":
+            # Convert the shape to a mlir-friendly string
+            shapes = []
+            for curr_shape in test_desc["ifm_shape"]:
+                shape_str = ""
+                for dim in curr_shape:
+                    shape_str = shape_str + str(dim) + ","
+                shapes.append(shape_str)
+
+            translate_mlir_cmd.extend(
+                ["--tf-input-arrays", ",".join(test_desc["ifm_name"])]
+            )
+            translate_mlir_cmd.extend(["--tf-input-shapes", ":".join(shapes)])
+
+        # Write the hard-coded placeholder input (reshaped as necesary) to
+        # the file that compiler specified.
+        reference_runner_ifm_name = []
+        for i in range(len(test_desc["ifm_file"])):
+
+            ifm_tensor_name = "{}{}".format(input_tensor_prefix, i)
+
+            assert test_desc["ifm_file"][i].endswith(".npy")
+            ifm_np = np.load(os.path.join(test, test_desc["ifm_file"][i]))
+            # Make sure input numpy and input shape from descriptor match
+            assert list(ifm_np.shape) == test_desc["ifm_shape"][i]
+
+            reference_runner_ifm_name.append(ifm_tensor_name)
+
+    except KeyError:
+        # No additional inputs.  Ignore.
+        pass
+
+    tf_opt_cmd = [
+        os.path.join(tf_tools_dir, "tf-opt"),
+        "--tf-executor-to-functional-conversion",
+        "--verify-each",
+        pre_opt_filename,
+        "-o",
+        post_opt_filename,
+    ]
+
+    translate_mlir_cmd.extend(mlir_opts)
+    tf_opt_cmd.extend(mlir_opts)
+
+    compiler_cmd = [os.path.join(tf_tools_dir, "tf-opt")]
+
+    if framework == "tf":
+        compiler_cmd.append("--tf-to-tosa-pipeline")
+    elif framework == "tflite":
+        compiler_cmd.append("--tfl-to-tosa-pipeline")
+        compiler_cmd.append("--tosa-strip-quant-types")
+
+    tosa_mlir_filename = os.path.join(test, "output_{}.tosa.mlir".format(framework))
+
+    flatbuffer_dir_fullpath = os.path.join(test, flatbuffer_dir)
+
+    os.makedirs(flatbuffer_dir_fullpath, exist_ok=True)
+
+    compiler_cmd.extend(
+        [
+            "--verify-each",
+            post_opt_filename,
+            "-o",
+            tosa_mlir_filename,
+            "--tosa-serialize",
+            "--tosa-flatbuffer-filename={}".format(
+                os.path.join(flatbuffer_dir_fullpath, "{}.tosa".format(test_name))
+            ),
+        ]
+    )
+
+    if not args.no_compiler:
+        try:
+            if translate_mlir_cmd:
+                run_sh_command(translate_mlir_cmd, args.verbose, True)
+            if tf_opt_cmd:
+                run_sh_command(tf_opt_cmd, args.verbose, True)
+        except Exception as e:
+            print_color(
+                LogColors.RED, "Results INVALID_MLIR {}: {}".format(test_name, e)
+            )
+            return (TestResult.INVALID_MLIR, 0.0, e)
+
+        try:
+
+            compiler_stdout, compiler_stderr = run_sh_command(
+                compiler_cmd, args.verbose, True
+            )
+            compiler_rc = parse_compiler_output(compiler_stdout, compiler_stderr)
+            if compiler_rc == TestResult.NOT_LOWERED:
+                print_color(
+                    LogColors.RED,
+                    "Results NOT_LOWERED {}, framework {}".format(test_name, framework),
+                )
+                return (TestResult.NOT_LOWERED, 0.0, "")
+
+            pass
+
+        except Exception as e:
+            if "same scale constraint" in str(e):
+                print_color(
+                    LogColors.RED, "Results INVALID_MLIR {}: {}".format(test_name, e)
+                )
+                return (TestResult.INVALID_MLIR, 0.0, e)
+            else:
+                print_color(
+                    LogColors.RED, "Results COMPILER_ERROR {}: {}".format(test_name, e)
+                )
+                return (TestResult.COMPILER_ERROR, 0.0, e)
+
+    if framework == "tf":
+        try:
+            tf_result = np.load(os.path.join(test, test_desc["tf_result_npy_filename"]))
+        except KeyError:
+            assert 0, "fail to load tf result numpy"
+    elif framework == "tflite":
+        try:
+            tf_result = np.load(
+                os.path.join(test, test_desc["tflite_result_npy_filename"])
+            )
+        except KeyError:
+            assert 0, "fail to load tflite result numpy"
+
+    # Generate test descriptor per flatbuffer generation
+    # Input .npy will be shared across different frameworks
+    # Output .npy will be generated in its corresponding flatbuffer
+    reference_runner_ifm_file = [
+        os.path.join("..", ifm_file) for ifm_file in test_desc["ifm_file"]
+    ]
+
+    # Check if there's any operator in output graph.
+    empty_graph = True
+    with open(tosa_mlir_filename, "r") as f:
+        for line in f:
+            if re.search('"tosa.*"', line):
+                empty_graph = False
+
+                break
+
+    # Fast-forward input tensor to output tensor if TOSA graph is empty.
+    if empty_graph:
+        reference_runner_ofm_name = reference_runner_ifm_name
+    else:
+        reference_runner_ofm_name = ["TosaOutput_0"]
+
+    write_reference_runner_json(
+        filename=os.path.join(test, flatbuffer_dir, "desc.json"),
+        tosa_filename="{}.tosa".format(test_name),
+        ifm_name=reference_runner_ifm_name,
+        ifm_file=reference_runner_ifm_file,
+        ofm_name=reference_runner_ofm_name,
+        ofm_file=["ref_model_output_0.npy"],
+    )
+
+    ref_model_cmd = [
+        os.path.join(
+            args.tools_base_dir, "build", "reference_model", "tosa_reference_model"
+        ),
+        "-Ctest_desc={}".format(os.path.join(test, flatbuffer_dir, "desc.json")),
+    ]
+
+    if args.debug_ref_model:
+        ref_model_cmd.extend(["-DALL", "-lhigh"])
+
+    if args.valgrind:
+        ref_model_cmd = [
+            "valgrind",
+            "--show-leak-kinds=all",
+            "--log-fd=1",
+            "-q",
+        ] + ref_model_cmd
+
+    # Clean out any ref_model result first
+    try:
+        os.remove(os.path.join(test, flatbuffer_dir, "ref_model_*.npy"))
+    except FileNotFoundError:
+        pass
+
+    if not args.no_ref:
+        try:
+            ref_model_stdout, ref_model_stderr = run_sh_command(
+                ref_model_cmd, args.verbose, True
+            )
+            ref_model_rc = parse_reference_model_output(
+                ref_model_stdout, ref_model_stderr
+            )
+            if ref_model_rc != TestResult.PASS:
+                return (ref_model_rc, 0.0, "")
+        except Exception as e:
+            ref_model_rc = parse_reference_model_output("", str(e))
+            if ref_model_rc != TestResult.PASS:
+                print_color(
+                    LogColors.RED,
+                    "Results {} {}: {}".format(
+                        TestResultErrorStr[ref_model_rc], test_name, e
+                    ),
+                )
+                return (ref_model_rc, 0.0, "")
+            print_color(
+                LogColors.RED,
+                "Results REF_MODEL_RUNTIME_ERROR {}: {}".format(test_name, e),
+            )
+            return (TestResult.REF_MODEL_RUNTIME_ERROR, 0.0, e)
+
+    if tf_result.dtype == np.float16:
+        tf_result = tf_result.astype(np.float32)
+    elif (
+        tf_result.dtype == np.uint8
+        or tf_result.dtype == np.int8
+        or tf_result.dtype == np.int16
+        or tf_result.dtype == np.int64
+    ):
+        tf_result = tf_result.astype(np.int32)
+
+    # For now, search for the output from ref_model
+    ref_model_result_files = glob.glob(
+        os.path.join(test, flatbuffer_dir, "ref_model_*.npy")
+    )
+    ref_model_result = np.load(ref_model_result_files[0])
+
+    assert (
+        tf_result.dtype == ref_model_result.dtype
+    ), "Numpy type mismatch {} != {} when comparing result".format(
+        tf_result.dtype, ref_model_result.dtype
+    )
+
+    # Size comparison
+    # Size = 1 tensors can be equivalently represented as having rank 0 or rank
+    # >= 0, allow that special case
+    tf_result = np.squeeze(tf_result)
+    ref_model_result = np.squeeze(ref_model_result)
+
+    if np.shape(tf_result) != np.shape(ref_model_result):
+        print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
+        msg = "Shapes mismatch: Reference {} vs {}".format(
+            np.shape(tf_result), np.shape(ref_model_result)
+        )
+        print(msg)
+        return (TestResult.MISMATCH, 0.0, msg)
+
+    # for quantized test, allow +-(args.quantize_tolerance) error
+    if ref_model_result.dtype == np.int32:
+        assert tf_result.dtype == np.int32
+
+        if np.all(np.absolute(ref_model_result - tf_result) <= args.quantize_tolerance):
+            print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
+        else:
+            print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
+
+            tolerance = args.quantize_tolerance + 1
+            while not np.all(
+                np.absolute(ref_model_result - tf_result) <= args.quantize_tolerance
+            ):
+                tolerance = tolerance + 1
+                if tolerance >= 10:
+                    break
+
+            msg = "Result is within {} {}".format(tolerance, test)
+            print(msg)
+
+            np.set_printoptions(threshold=128)
+            print("tf_result: {}\n".format(tf_result.shape))
+            print(tf_result)
+            print("ref_model_result: {}\n".format(ref_model_result.shape))
+            print(ref_model_result)
+            # print(tf_result - ref_model_result)
+            return (TestResult.MISMATCH, tolerance, msg)
+    else:
+        if np.allclose(
+            ref_model_result, tf_result, atol=args.tolerance, equal_nan=True
+        ):
+            print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
+        else:
+            print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
+
+            # Many of these tests would match with a reasonable looser tolerence.
+            # Determine what would have worked.
+            tolerance = args.tolerance * 10.0
+            while not np.allclose(
+                ref_model_result, tf_result, atol=tolerance, equal_nan=True
+            ):
+                tolerance = tolerance * 10.0
+                if tolerance > 1.0e10:
+                    tolerance = math.inf
+                    break
+
+            msg = "Result is within {:.0e} {}".format(tolerance, test_name)
+            print(msg)
+
+            np.set_printoptions(precision=4, threshold=128)
+            print("tf_result: {}\n".format(tf_result.shape))
+            print(tf_result)
+            print("ref_model_result: {}\n".format(ref_model_result.shape))
+            print(ref_model_result)
+            # print(tf_result - ref_model_result)
+            return (TestResult.MISMATCH, tolerance, msg)
+
+    return (TestResult.PASS, args.tolerance, msg)
+
+
+def worker_thread(task_queue, args, result_queue):
+    while True:
+        try:
+            (test, framework) = task_queue.get(block=False)
+        except queue.Empty:
+            break
+
+        if test is None:
+            break
+
+        msg = ""
+        start_time = datetime.now()
+        try:
+            (rc, tolerance, msg) = run_test(args, test, framework)
+        except Exception as e:
+            print("Internal regression error: {}".format(e))
+            print(
+                "".join(
+                    traceback.format_exception(
+                        etype=type(e), value=e, tb=e.__traceback__
+                    )
+                )
+            )
+            rc = TestResult.INTERNAL_ERROR
+            tolerance = 0.0
+
+        end_time = datetime.now()
+
+        result_queue.put((test, framework, rc, tolerance, msg, end_time - start_time))
+        task_queue.task_done()
+
+    return True
+
+
+def getTestsInDir(directory):
+    # Recursively find any tests in this directory
+    if os.path.isfile(os.path.join(directory, "test.json")):
+        return [directory]
+    elif os.path.isdir(directory):
+        test_list = []
+        for d in glob.glob(os.path.join(directory, "*")):
+            test_list.extend(getTestsInDir(d))
+        return test_list
+    else:
+        return []
+
+
+def main():
+    args = parse_args()
+
+    set_print_in_color(not args.no_color)
+
+    if args.output_file:
+        set_print_in_color(False)
+        sys.stdout = open(args.output_file, "w")
+
+    # Disable TF info messages
+    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+
+    task_queue = queue.Queue()
+    result_queue = queue.Queue()
+
+    threads = []
+
+    # Result counters for each of the TestResult return codes
+    results = [0] * len(TestResult)
+
+    for tdir in args.test:
+
+        if args.recursive_tests:
+            tdirList = getTestsInDir(tdir)
+        else:
+            tdirList = [tdir]
+
+        for t in tdirList:
+            for f in args.framework:
+                task_queue.put((t, f))
+
+    for i in range(args.jobs):
+        t = threading.Thread(
+            target=worker_thread, args=(task_queue, args, result_queue)
+        )
+        t.setDaemon(True)
+        t.start()
+        threads.append(t)
+
+    # Run until queue is empty
+    task_queue.join()
+
+    print_color(LogColors.BOLD_WHITE, "Result summary")
+
+    result_list = []
+    while True:
+        try:
+            test, framework, rc, tol, msg, time_delta = result_queue.get(block=False)
+        except queue.Empty:
+            break
+
+        result_list.append((test, framework, rc, tol, msg, time_delta))
+        results[rc] = results[rc] + 1
+
+    xunit_result = xunit_results()
+    xunit_suite = xunit_result.create_suite(args.xunit_classname_prefix)
+
+    # Sort by test name
+    for test, framework, rc, tol, err_msg, time_delta in sorted(
+        result_list, key=lambda tup: tup[0]
+    ):
+
+        test_name = os.path.basename(test)
+        class_name = f"{args.xunit_classname_prefix}.{framework}"
+
+        xt = xunit_test(test_name, class_name)
+
+        msg = TestResultErrorStr[rc]
+
+        xt.time = str(
+            float(time_delta.seconds) + (float(time_delta.microseconds) * 1e-6)
+        )
+
+        if len(msg) > 0:
+            print("{} on {} {}".format(msg, framework, test))
+
+        # Add any more verbose messaging for the xml log
+        if err_msg:
+            msg = "{} {}".format(msg, err_msg)
+
+        if rc == TestResult.PASS:
+            pass
+        elif rc == TestResult.SKIPPED:
+            xt.skipped()
+        else:
+            xt.failed(msg)
+
+        xunit_suite.tests.append(xt)
+
+        result_queue.task_done()
+
+    xunit_result.write_results(args.xunit_file)
+
+    print("Totals: ", end="")
+    for result in TestResult:
+        print("{} {}, ".format(results[result], result.name.lower()), end="")
+    print()
+
+    if not args.regression_mode and (
+        results[TestResult.COMPILER_ERROR] > 0
+        or results[TestResult.REF_MODEL_ERROR] > 0
+        or results[TestResult.MISMATCH] > 0
+    ):
+        return 1
+
+    return 0
+
+
+if __name__ == "__main__":
+    exit(main())
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
new file mode 100755
index 0000000..222376e
--- /dev/null
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -0,0 +1,1426 @@
+#!/usr/bin/env python3
+# Copyright (c) 2020-2022, ARM Limited.
+# SPDX-License-Identifier: Apache-2.0
+import argparse
+import os
+import re
+import traceback
+
+import numpy as np
+
+#  Level | Level for Humans | Level Description
+# -------|------------------|------------------------------------
+#  0     | DEBUG            | [Default] Print all messages
+#  1     | INFO             | Filter out INFO messages
+#  2     | WARNING          | Filter out INFO & WARNING messages
+#  3     | ERROR            | Filter out all messages
+# Filter tensorflow debug message except errors
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
+
+# Flake8 E402 - ignore imports not at top of file to allow os.environ setting
+import tensorflow as tf  # noqa: E402
+from frameworks.write_test_json import write_test_json  # noqa: E402
+from frameworks.arg_gen import ArgGen  # noqa: E402
+from frameworks.tensor_gen import TGen  # noqa: E402
+from frameworks.test_builder import TBuilder  # noqa: E402
+from frameworks.test_gen_utils import (
+    QuantType,
+    get_tf_dtype,
+    get_shape_str,
+)  # noqa: E402
+from tensorflow.lite.python.interpreter import OpResolverType  # noqa: E402
+
+# All of the supported frameworks
+ALL_FRAMEWORKS = ["tf", "tflite"]
+
+# Lists of different data types
+TYPE_F = [tf.float32]
+TYPE_I = [tf.int32]
+TYPE_FI = [tf.float32, tf.int32]
+TYPE_B = [tf.bool]
+TYPE_FIB = [tf.float32, tf.int32, tf.bool]
+TYPE_H = [tf.float16]
+TYPE_FH = [tf.float32, tf.float16]
+TYPE_FHI = [tf.float32, tf.float16, tf.int32]
+TYPE_FHIB = [tf.float32, tf.float16, tf.int32, tf.bool]
+
+# The list of operator tests
+# Each dictionary entry for an op is a dictionary with the following required members:
+#   'operands': tuple (number_of_placeholder_tensors, number_of_constant_tensors)
+#   'build_fcn: tuple (Test builder function, Tensor generator function,
+#                      Argument generator function)
+#   'types': list of Tensorflow types that should be tested for this op
+#               OR
+#            a dictionary of {'framework_name': [type_list] } for cases where only
+#            a subset of the types should be tested in each framework.  This can also
+#            be used to restrict an operator to a particular framework.
+#
+# And optional members:
+#   'template': boolean (indicates that this is a templated op which gets further
+#               processing in createDynamicOpLists)
+#   'bias':     boolean indicating that there is a bias component to be generated
+#   'qtypes':   List of QuantType quantized types to generate for this op
+
+TF_OP_LIST = {
+    "add": {
+        "operands": (2, 0),
+        "build_fcn": (TBuilder.Add, TGen.tgBFuzz, ArgGen.agNone),
+        "types": {
+            "tf": TYPE_FI,
+            "tflite": list(
+                TYPE_FI + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16]
+            ),
+        },
+    },
+    "sub": {
+        "operands": (2, 0),
+        "build_fcn": (TBuilder.Sub, TGen.tgBFuzz, ArgGen.agNone),
+        "types": {
+            "tf": TYPE_FI,
+            "tflite": list(TYPE_FI + [QuantType.ALL_U8, QuantType.ALL_I8]),
+            # QuantType.ALL_I16 fail in TFLite conversion
+        },
+    },
+    "mul": {
+        "operands": (2, 0),
+        "build_fcn": (TBuilder.Mul, TGen.tgBFuzz, ArgGen.agNone),
+        "types": {
+            "tf": TYPE_FI,
+            "tflite": list(
+                TYPE_FI + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16]
+            ),
+        },
+    },
+    "exp": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.Exp, TGen.tgBasic, ArgGen.agNone),
+        "types": TYPE_F,
+    },
+    "rcp": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.Rcp, TGen.tgBasic, ArgGen.agNone),
+        "types": TYPE_F,
+    },
+    "relu": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.Relu, TGen.tgBasic, ArgGen.agNone),
+        "types": {
+            "tf": TYPE_F,
+            "tflite": list(
+                TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16]
+            ),
+        },
+    },
+    "relu6": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.Relu6, TGen.tgBasic, ArgGen.agNone),
+        "types": {
+            "tf": TYPE_F,
+            "tflite": list(
+                TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16]
+            ),
+        },
+    },
+    "leaky_relu": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.LeakyRelu, TGen.tgBasic, ArgGen.agFloat),
+        "types": {
+            "tf": TYPE_F,
+            "tflite": list(
+                TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16]
+            ),
+        },
+    },
+    "concat": {
+        "operands": (2, 0),
+        "build_fcn": (TBuilder.Concat, TGen.tgBasic, ArgGen.agAxes),
+        "types": TYPE_FI,
+    },
+    "bitwise_and": {
+        "operands": (2, 0),
+        "build_fcn": (TBuilder.BitwiseAnd, TGen.tgBFuzz, ArgGen.agNone),
+        "types": {"tf": TYPE_I},  # Not supported in TF Lite
+    },
+    "bitwise_or": {
+        "operands": (2, 0),
+        "build_fcn": (TBuilder.BitwiseOr, TGen.tgBFuzz, ArgGen.agNone),
+        "types": {"tf": TYPE_I},  # Not supported in TF Lite
+    },
+    "bitwise_not": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.BitwiseNot, TGen.tgBFuzz, ArgGen.agNone),
+        "types": {"tf": TYPE_I},  # Not supported in TF Lite
+    },
+    "bitwise_xor": {
+        "operands": (2, 0),
+        "build_fcn": (TBuilder.BitwiseXor, TGen.tgBFuzz, ArgGen.agNone),
+        "types": {"tf": TYPE_I},  # Not supported in TF Lite
+    },
+    "logical_and": {
+        "operands": (2, 0),
+        "build_fcn": (TBuilder.LogicalAnd, TGen.tgBFuzz, ArgGen.agNone),
+        "types": TYPE_B,
+    },
+    "logical_or": {
+        "operands": (2, 0),
+        "build_fcn": (TBuilder.LogicalOr, TGen.tgBFuzz, ArgGen.agNone),
+        "types": TYPE_B,
+    },
+    "logical_not": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.LogicalNot, TGen.tgBFuzz, ArgGen.agNone),
+        "types": TYPE_B,
+    },
+    "reduce_any": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.ReduceAny, TGen.tgBasic, ArgGen.agAxesListKeepdims),
+        "types": TYPE_B,
+    },
+    "reduce_all": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.ReduceAll, TGen.tgBasic, ArgGen.agAxesListKeepdims),
+        "types": {"tf": TYPE_B},
+    },
+    "reduce_min": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.ReduceMin, TGen.tgBasic, ArgGen.agAxesListKeepdims),
+        "types": {
+            "tf": TYPE_FI,
+            "tflite": list(TYPE_FI + [QuantType.ALL_U8, QuantType.ALL_I8]),
+        },
+    },
+    "reduce_max": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.ReduceMax, TGen.tgBasic, ArgGen.agAxesListKeepdims),
+        "types": {
+            "tf": TYPE_FI,
+            "tflite": list(TYPE_FI + [QuantType.ALL_U8, QuantType.ALL_I8]),
+        },
+    },
+    "reduce_sum": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.ReduceSum, TGen.tgBasic, ArgGen.agAxesListKeepdims),
+        "types": {
+            "tf": TYPE_F,
+            # v2 converter doesn't recognize quantized reduce_sum
+            # "tflite": list(TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8]),
+            "tflite": TYPE_F,
+        },
+    },
+    "reduce_mean": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.ReduceMean, TGen.tgBasic, ArgGen.agAxesListKeepdims),
+        "types": {
+            "tf": TYPE_F,
+            "tflite": list(
+                TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16]
+            ),
+        },
+    },
+    "reduce_product": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.ReduceProduct, TGen.tgBasic, ArgGen.agAxesListKeepdims),
+        "types": TYPE_F,
+    },
+    "min": {
+        "operands": (2, 0),
+        "build_fcn": (TBuilder.Min, TGen.tgBFuzz, ArgGen.agNone),
+        "types": TYPE_FI,
+    },
+    "max": {
+        "operands": (2, 0),
+        "build_fcn": (TBuilder.Max, TGen.tgBFuzz, ArgGen.agNone),
+        "types": TYPE_FI,
+    },
+    "pow": {
+        "operands": (2, 0),
+        "build_fcn": (TBuilder.Pow, TGen.tgBFuzz, ArgGen.agNone),
+        # Technically, integer is supported, but only for positive exponents.
+        # Needs a random argument generator.
+        "types": TYPE_F,
+    },
+    "abs": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.Abs, TGen.tgBasic, ArgGen.agNone),
+        "types": TYPE_F,
+    },
+    "ceil": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.Ceil, TGen.tgBasic, ArgGen.agNone),
+        "types": TYPE_F,
+    },
+    "floor": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.Floor, TGen.tgBasic, ArgGen.agNone),
+        "types": TYPE_F,
+    },
+    "log": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.Log, TGen.tgBasic, ArgGen.agNone),
+        "types": TYPE_F,
+    },
+    "negate": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.Negate, TGen.tgBasic, ArgGen.agNone),
+        "types": TYPE_F,
+    },
+    "rsqrt": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.Rsqrt, TGen.tgBasic, ArgGen.agNone),
+        "types": TYPE_F,
+    },
+    "sigmoid": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.Sigmoid, TGen.tgBasic, ArgGen.agNone),
+        "types": {
+            "tf": TYPE_F,
+            "tflite": list(
+                TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16]
+            ),
+        },
+    },
+    "tanh": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.Tanh, TGen.tgBasic, ArgGen.agNone),
+        "types": {
+            "tf": TYPE_F,
+            "tflite": list(
+                TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16]
+            ),
+        },
+    },
+    "square": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.Square, TGen.tgBasic, ArgGen.agNone),
+        "types": TYPE_F,
+    },
+    "squared_difference": {
+        "operands": (2, 0),
+        "build_fcn": (TBuilder.SquaredDifference, TGen.tgBFuzz, ArgGen.agNone),
+        "types": TYPE_F,
+    },
+    "equal": {
+        "operands": (2, 0),
+        "build_fcn": (TBuilder.Equal, TGen.tgBFuzz, ArgGen.agNone),
+        "types": TYPE_FI,
+    },
+    "greater_equal": {
+        "operands": (2, 0),
+        "build_fcn": (TBuilder.GreaterEqual, TGen.tgBFuzz, ArgGen.agNone),
+        "types": TYPE_FI,
+    },
+    "greater": {
+        "operands": (2, 0),
+        "build_fcn": (TBuilder.Greater, TGen.tgBFuzz, ArgGen.agNone),
+        "types": TYPE_FI,
+    },
+    "less": {
+        "operands": (2, 0),
+        "build_fcn": (TBuilder.Less, TGen.tgBFuzz, ArgGen.agNone),
+        "types": TYPE_FI,
+    },
+    "less_equal": {
+        "operands": (2, 0),
+        "build_fcn": (TBuilder.LessEqual, TGen.tgBFuzz, ArgGen.agNone),
+        "types": TYPE_FI,
+    },
+    "conv2d_TEMPLATE": {
+        "operands": (1, 1),
+        "build_fcn": (TBuilder.Conv2d, TGen.tgConv2d, ArgGen.agConv2d),
+        "types": {
+            "tf": [tf.float32],
+            "tflite": [
+                tf.float32,
+                QuantType.CONV_U8_U8,
+                QuantType.CONV_I8_I8,
+                QuantType.CONV_I16_I8,
+            ],
+        },
+        "template": True,
+    },
+    "conv2d_relu_TEMPLATE": {
+        "operands": (1, 2),
+        "build_fcn": (TBuilder.Conv2dRelu, TGen.tgConv2d, ArgGen.agNone),
+        "types": {
+            "tf": [tf.float32],
+            "tflite": [
+                tf.float32,
+                QuantType.CONV_U8_U8,
+                QuantType.CONV_I8_I8,
+                QuantType.CONV_I16_I8,
+            ],
+        },
+        "template": True,
+    },
+    "conv2d_relu6_TEMPLATE": {
+        "operands": (1, 2),
+        "build_fcn": (TBuilder.Conv2dRelu6, TGen.tgConv2d, ArgGen.agNone),
+        "types": {
+            "tf": [tf.float32],
+            "tflite": [
+                tf.float32,
+                QuantType.CONV_U8_U8,
+                QuantType.CONV_I8_I8,
+                QuantType.CONV_I16_I8,
+            ],
+        },
+        "template": True,
+    },
+    "conv2d_relu_n1_to_1_TEMPLATE": {
+        "operands": (1, 2),
+        "build_fcn": (TBuilder.Conv2dReluN1To1, TGen.tgConv2d, ArgGen.agNone),
+        "types": {
+            "tf": [tf.float32],
+            "tflite": [
+                tf.float32,
+                QuantType.CONV_U8_U8,
+                QuantType.CONV_I8_I8,
+                QuantType.CONV_I16_I8,
+            ],
+        },
+        "template": True,
+    },
+    # This test is converted as:
+    # tfl.conv2d(){fused_activation_function="NONE"} + tfl.tanh()
+    # TODO: anyway to generate tfl.conv2d(){fused_activation_function="TANH"}?
+    "conv2d_tanh_TEMPLATE": {
+        "operands": (1, 2),
+        "build_fcn": (TBuilder.Conv2dTanh, TGen.tgConv2d, ArgGen.agNone),
+        "types": {
+            "tf": [tf.float32],
+            "tflite": [
+                tf.float32,
+                QuantType.CONV_U8_U8,
+                QuantType.CONV_I8_I8,
+                QuantType.CONV_I16_I8,
+            ],
+        },
+        "template": True,
+    },
+    "conv2d_bias_TEMPLATE": {
+        "operands": (1, 2),
+        "build_fcn": (TBuilder.Conv2dWithBias, TGen.tgConv2d, ArgGen.agConv2d),
+        "types": {
+            "tf": [tf.float32],
+            "tflite": [
+                tf.float32,
+                QuantType.CONV_U8_U8,
+                QuantType.CONV_I8_I8,
+                QuantType.CONV_I16_I8,
+            ],
+        },
+        "bias": True,
+        "template": True,
+    },
+    "depthwise_conv2d_TEMPLATE": {
+        "operands": (1, 1),
+        "build_fcn": (
+            TBuilder.DepthwiseConv2d,
+            TGen.tgDepthwiseConv2d,
+            ArgGen.agDepthwiseConv2d,
+        ),
+        "types": {
+            "tf": [tf.float32],
+            "tflite": [
+                tf.float32,
+                QuantType.CONV_U8_U8,
+                QuantType.CONV_I8_I8,
+                QuantType.CONV_I16_I8,
+            ],
+        },
+        "template": True,
+    },
+    "depthwise_conv2d_bias_TEMPLATE": {
+        "operands": (1, 2),
+        "build_fcn": (
+            TBuilder.DepthwiseConv2dWithBias,
+            TGen.tgDepthwiseConv2d,
+            ArgGen.agDepthwiseConv2d,
+        ),
+        "types": {
+            "tf": [tf.float32],
+            "tflite": [
+                tf.float32,
+                QuantType.CONV_U8_U8,
+                QuantType.CONV_I8_I8,
+                QuantType.CONV_I16_I8,
+            ],
+        },
+        "bias": True,
+        "template": True,
+    },
+    "transpose_conv2d_TEMPLATE": {
+        "operands": (1, 1),
+        "build_fcn": (
+            TBuilder.TransposeConv2d,
+            TGen.tgTransposeConv2d,
+            ArgGen.agTransposeConv2d,
+        ),
+        "types": {
+            "tf": [tf.float32],
+            "tflite": [
+                tf.float32,
+                QuantType.CONV_U8_U8,
+                QuantType.CONV_I8_I8,
+                QuantType.CONV_I16_I8,
+            ],
+        },
+        "template": True,
+    },
+    "argmax": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.Argmax, TGen.tgBasic, ArgGen.agAxes),
+        "types": {"tf": TYPE_F},
+    },
+    "avg_pool2d": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.AvgPool2d, TGen.tgPooling, ArgGen.agPooling),
+        "types": {
+            "tf": TYPE_F,
+            "tflite": list(
+                TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16]
+            ),
+        },
+    },
+    "max_pool2d": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.MaxPool2d, TGen.tgPooling, ArgGen.agPooling),
+        "types": {
+            "tf": TYPE_F,
+            "tflite": list(TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8]),
+            # ALL_I16 not supported yet
+            # In tensorflow/compiler/mlir/lite/ir/tfl_ops.td,
+            # QI16 is missing from MaxPoolOperandAndResultConstraints
+            # If adding QI16 back this test can run through.
+        },
+    },
+    "reshape": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.Reshape, TGen.tgBasic, ArgGen.agReshape),
+        "types": TYPE_FI,
+    },
+    "transpose": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.Transpose, TGen.tgBasic, ArgGen.agTranspose),
+        "types": TYPE_FI,
+    },
+    "slice": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.Slice, TGen.tgBasic, ArgGen.agSlice),
+        "types": TYPE_FI,
+    },
+    "strided_slice": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.StridedSlice, TGen.tgBasic, ArgGen.agStridedSlice),
+        "types": TYPE_FI,
+    },
+    "select": {
+        "operands": (3, 0),
+        "build_fcn": (TBuilder.Select, TGen.tgSelect, ArgGen.agNone),
+        "types": TYPE_FI,
+    },
+    "addn": {
+        "operands": (4, 0),
+        "build_fcn": (TBuilder.Addn, TGen.tgBasic, ArgGen.agNone),
+        "types": TYPE_FI,
+    },
+    "concatv2": {
+        "operands": (4, 0),
+        "build_fcn": (TBuilder.Concatv2, TGen.tgBasic, ArgGen.agAxes),
+        "types": TYPE_FI,
+    },
+    "stack": {
+        "operands": (4, 0),
+        "build_fcn": (TBuilder.Stack, TGen.tgBasic, ArgGen.agStack),
+        "types": TYPE_FI,
+    },
+    "unstack": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.Unstack, TGen.tgPooling, ArgGen.agAxes),
+        "types": TYPE_F,
+    },
+    "pad": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.Pad, TGen.tgBasic, ArgGen.agPad),
+        "types": TYPE_F,
+    },
+    "expand_dims": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.ExpandDims, TGen.tgBasic, ArgGen.agStack),
+        "types": TYPE_FI,
+    },
+    "shape": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.Shape, TGen.tgBasic, ArgGen.agNone),
+        "types": TYPE_FI,
+    },
+    "rank": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.Rank, TGen.tgBasic, ArgGen.agNone),
+        "types": TYPE_FI,
+    },
+    "fill": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.Fill, TGen.tgBasic, ArgGen.agFill),
+        "types": TYPE_FI,
+    },
+    "elu": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.Elu, TGen.tgBasic, ArgGen.agNone),
+        "types": TYPE_F,
+    },
+    "softmax": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.Softmax, TGen.tgBasic, ArgGen.agNone),
+        "types": {
+            "tf": TYPE_F,
+            "tflite": list(
+                TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16]
+            ),
+        },
+    },
+    "log_softmax": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.LogSoftmax, TGen.tgBasic, ArgGen.agNone),
+        "types": TYPE_F,
+    },
+    "matmul": {
+        "operands": (2, 0),
+        "build_fcn": (TBuilder.MatMul, TGen.tgMatmul, ArgGen.agNone),
+        "types": {
+            "tf": TYPE_F,
+            "tflite": list(
+                TYPE_F
+                + [QuantType.ALL_U8, QuantType.ALL_I8]
+                # 16 bits matmul fail to convert
+            ),
+        },
+    },
+    "add_scalar": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.AddScalar, TGen.tgBasic, ArgGen.agNone),
+        "types": TYPE_F,
+    },
+    "add_1d": {
+        "operands": (2, 0),
+        "build_fcn": (TBuilder.Add1d, TGen.tgBasic, ArgGen.agNone),
+        "types": TYPE_F,
+    },
+    "split": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.Split, TGen.tgBasic, ArgGen.agSplit),
+        "types": TYPE_FI,
+    },
+    "tile": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.Tile, TGen.tgBasic, ArgGen.agTile),
+        "types": TYPE_FI,
+    },
+    "reverse": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.Reverse, TGen.tgBasic, ArgGen.agAxes),
+        "types": {"tf": TYPE_FI},
+    },
+    "gather": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.Gather, TGen.tgBasic, ArgGen.agGather),
+        "types": TYPE_FI,
+    },
+    "gather_nd": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.GatherNd, TGen.tgBasic, ArgGen.agGatherND),
+        "types": TYPE_FI,
+    },
+    "scatter_nd": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.ScatterNd, TGen.tgBasic, ArgGen.agScatterND),
+        "types": TYPE_FI,
+    },
+    "space_to_batch": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.SpaceToBatch, TGen.tgBasic, ArgGen.agSpaceToBatch),
+        "types": TYPE_F,
+    },
+    "batch_to_space": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.BatchToSpace, TGen.tgBasic, ArgGen.agBatchToSpace),
+        "types": TYPE_F,
+    },
+    "space_to_depth": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.SpaceToDepth, TGen.tgBasic, ArgGen.agSpaceToDepth),
+        "types": TYPE_F,
+    },
+    "depth_to_space": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.DepthToSpace, TGen.tgBasic, ArgGen.agDepthToSpace),
+        "types": TYPE_F,
+    },
+    "one_hot": {
+        "operands": (3, 1),
+        "build_fcn": (TBuilder.OneHot, TGen.tgOneHot, ArgGen.agOneHot),
+        "types": TYPE_FI,
+    },
+    "fakequant": {
+        "operands": (1, 0),
+        "build_fcn": (
+            TBuilder.Fakequant,
+            TGen.tgBasic,
+            ArgGen.agFakequant,
+        ),
+        "types": {"tf": TYPE_F},
+    },
+    "resize_nearest": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.ResizeNearest, TGen.tgPooling, ArgGen.agNone),
+        "types": {
+            "tf": TYPE_F,
+            "tflite": list(
+                TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16]
+            ),
+        },
+    },
+    "resize_bilinear": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.ResizeBilinear, TGen.tgPooling, ArgGen.agNone),
+        "types": {
+            "tf": TYPE_F,
+            "tflite": list(
+                TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16]
+            ),
+        },
+    },
+    "left_shift": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.LeftShift, TGen.tgBasic, ArgGen.agShift),
+        "types": {"tf": [tf.int32]},
+    },
+    "right_shift": {
+        "operands": (1, 0),
+        "build_fcn": (TBuilder.RightShift, TGen.tgBasic, ArgGen.agShift),
+        "types": {
+            "tf": [
+                tf.int32,
+            ]
+        },
+    },
+}
+
+# Shapes to be tested; default can be overwritten
+shape_list = [
+    (1,),
+    (64,),
+    (14, 19),
+    (13, 21, 3),
+    (1, 4, 4, 4),
+    (1, 8, 4, 17),
+    (1, 4, 8, 19),
+    (1, 32, 32, 8),
+    (1, 7, 7, 9),
+]
+
+
+def gen_rand_shapes(args):
+    """Overwrite the global shape list with a new list of random shapes"""
+    global shape_list
+
+    rng = np.random.default_rng(args.random_seed)
+
+    # Don't let things get too big... cap the maximum volume, but let
+    # an individual dimension be 1..47
+    max_total_volume = 32 * 32 * 4
+
+    shape_list = []
+    # Only iterate over ranks 2, 3, and 4
+    for rank in range(2, 5):
+        for n in range(args.random_shapes):
+            new_shape = rng.integers(1, 48, size=rank)
+
+            # Set the batch dimension on 4D objects to 1
+            if rank == 4:
+                new_shape[0] = 1
+
+            # Limit the total shape volume and throw out any
+            # shapes that wouldn't leave at least size=2 in some non-batch dimension
+            volume = 1
+            skip_shape = False
+            for i in range(rank):
+
+                volume *= new_shape[i]
+
+                # Reduce the shape, while it's larger than the maximum volume
+                while volume > max_total_volume:
+                    new_shape[i] = new_shape[i] // 2
+                    volume = volume // 2
+
+                    # Now an untenable dimension size?  Skip this one.
+                    if new_shape[i] < 1:
+                        skip_shape = True
+
+            if not skip_shape:
+                shape_list.append(tuple(new_shape))
+
+
+# Construct, run and save a whole tensorflow tf.function to a protobuf file
+# or convert to .tflite if it's quantized unit test
+def run_unit_test(
+    op_name,
+    args,
+    test_dir,
+    curr_shape,
+    addl_args,
+    dtype,
+    excluded_framework_list,
+    quantized_inference_dtype,
+    result_name,
+    seed,
+):
+
+    try:
+        op = TF_OP_LIST[op_name]
+        op_fcn, tensor_gen_fcn, arg_gen_fcn = op["build_fcn"]
+
+        # Get and seed a random number generator for this test
+        rng = np.random.default_rng(seed)
+
+        # return placeholders=(str: name, np.array: value)
+        # consts=(str: name, np.array: value)
+        placeholders, consts = tensor_gen_fcn(op, curr_shape, dtype, rng)
+
+        # if test doesn't have any placeholders/consts, terminated
+        if len(placeholders) == 0 and len(consts) == 0:
+            return True
+
+        if not args.quiet:
+            print("   {}              ".format(test_dir))
+
+        try:
+            os.mkdir(test_dir)
+        except FileExistsError:
+            pass
+
+        const_nodes = [value for name, value in consts]
+
+        num_placeholders = len(placeholders)
+        # if test is quantized, create tensor quantization metadata info for
+        # each input tensor, based on different quantized type
+        if quantized_inference_dtype:
+            is_quantized = True
+            # TODO: support INT8 IFM x INT4 weight later
+            if quantized_inference_dtype == QuantType.ALL_U8:
+                qzero = [128] * num_placeholders
+                numpy_dtype = [np.uint8] * num_placeholders
+                tflite_inference_dtype = tf.uint8
+            elif quantized_inference_dtype == QuantType.ALL_I8:
+                qzero = [0] * num_placeholders
+                numpy_dtype = [np.int8] * num_placeholders
+                tflite_inference_dtype = tf.int8
+            elif quantized_inference_dtype == QuantType.ALL_I16:
+                qzero = [0] * num_placeholders
+                numpy_dtype = [np.int16] * num_placeholders
+                tflite_inference_dtype = tf.int16
+            elif quantized_inference_dtype == QuantType.CONV_U8_U8:
+                assert (
+                    num_placeholders == 1
+                ), "Unsupported number of placeholders for Convolution: {}".format(
+                    num_placeholders
+                )
+                qzero = [128] * num_placeholders
+                if num_placeholders == 2:
+                    numpy_dtype = [np.uint8, np.uint8]
+                else:
+                    numpy_dtype = [np.uint8, np.uint8, np.int32]
+                tflite_inference_dtype = tf.uint8
+            elif quantized_inference_dtype == QuantType.CONV_I8_I8:
+                assert (
+                    num_placeholders == 1
+                ), "Unsupported number of placeholders for Convolution: {}".format(
+                    num_placeholders
+                )
+                qzero = [0] * num_placeholders
+                if num_placeholders == 2:
+                    numpy_dtype = [np.int8, np.int8]
+                else:
+                    numpy_dtype = [np.int8, np.int8, np.int32]
+                tflite_inference_dtype = tf.int8
+            elif quantized_inference_dtype == QuantType.CONV_I16_I8:
+                assert (
+                    num_placeholders == 1
+                ), "Unsupported number of placeholders for Convolution: {}".format(
+                    num_placeholders
+                )
+                if num_placeholders == 2:
+                    qzero = [0, 0]
+                    numpy_dtype = [np.int16, np.int8]
+                else:
+                    qzero = [0, 0, 0]
+                    numpy_dtype = [
+                        np.int16,
+                        np.int8,
+                        np.int64,
+                    ]  # np.int64 to represent 40 bits accumulator
+                tflite_inference_dtype = tf.int16
+            else:
+                raise Exception(
+                    "Unsupported fakequant dtype: {}".format(quantized_inference_dtype)
+                )
+
+        else:
+            is_quantized = False
+
+        tf_model_filename = None
+        tf_result_npy_filename = None
+        tf_result_name = None
+
+        tflite_model_filename = None
+        tflite_result_npy_filename = None
+        tflite_result_name = None
+
+        placeholder_names = []
+        placeholder_vals = []
+        placeholder_signatures = ()
+        placeholder_npy_filenames = []
+        placeholder_shapes = []
+
+        for idx, (name, val) in enumerate(placeholders):
+            placeholder_names.append(name)
+            placeholder_signatures = placeholder_signatures + (
+                tf.TensorSpec(shape=val.shape, dtype=val.dtype, name=name),
+            )
+            placeholder_npy_filenames.append("{}.npy".format(name.split(":")[0]))
+            placeholder_shapes.append(val.shape)
+
+        # Get test builder class
+        fcn_node = op_fcn(*const_nodes, *addl_args, result_name)
+        concrete_function = tf.function(input_signature=placeholder_signatures)(
+            fcn_node.eval
+        ).get_concrete_function()
+
+        if is_quantized:
+
+            assert dtype is tf.float32, "quantized test must come from float32 graph"
+
+            # 1. Quantize float placeholder npy to quantized to feed the graph
+            for idx, (name, val) in enumerate(placeholders):
+
+                # we use np.amin()/np.amax() to determine dynamic range
+                # for quantized test
+                zeropoint = 0
+                scale = 1.0
+                if numpy_dtype[idx] != np.int64:
+                    qmin = np.iinfo(numpy_dtype[idx]).min
+                    qmax = np.iinfo(numpy_dtype[idx]).max
+                    num_bits = np.iinfo(numpy_dtype[idx]).bits
+                # 40 bit is represented as np.int64
+                else:
+                    num_bits = 40
+                    qmin = -(1 << num_bits)
+                    qmax = (1 << num_bits) - 1
+
+                min_val = np.amin(val)
+                max_val = np.amax(val)
+
+                # for single value tensor, we set scale equal to the abs(value),
+                # and fix zeropoint to 128
+                # if val > 0, it'll be represented as 129,
+                #    where val = (129 - 128) * val
+                # if val < 0, it'll be represented as 127,
+                #    where val = (127 - 128) * (-val)
+                # if val == 0, it'll be represted as 128, with range [-128.0, 128.0]
+                # and let quantized 1 represent the value
+                # also adjust effective min/max consequently
+                if max_val == min_val:
+                    if max_val != 0:
+                        scale = abs(max_val)
+                    else:
+                        scale = 1.0
+                    min_val = float(qmin - qzero[idx]) * scale
+                    max_val = float(qmax - qzero[idx]) * scale
+                else:
+                    scale = (max_val - min_val) / float(qmax - qmin)
+                    zeropoint = int(round((-min_val) / scale)) + qmin
+
+                # run through tf.fakequant first to assure quantization error aligned
+                fakequant_val = tf.quantization.fake_quant_with_min_max_args(
+                    val,
+                    min=min_val,
+                    max=max_val,
+                    num_bits=num_bits,
+                    name="gen_quant_npy",
+                )
+
+                quant_val = np.round(fakequant_val / scale).astype(np.int32) + zeropoint
+
+                # very few unit tests after TF hash may/2020, this quantized
+                # value for some reason exceed [0, 255] range
+                saved_val = np.clip(quant_val, qmin, qmax).astype(numpy_dtype[idx])
+
+                # saved all quantized tensor as np.int32
+                # since TOSA numpy Cpp API only supports int32
+                np.save(
+                    os.path.join(test_dir, placeholder_npy_filenames[idx]),
+                    saved_val.astype(np.int32),
+                    False,
+                )
+
+                placeholder_vals.append(tf.convert_to_tensor(saved_val))
+
+            # 2. Convert the model to quantized TFLite flatbuffer
+            module = tf.Module()
+            converter = tf.lite.TFLiteConverter.from_concrete_functions(
+                [concrete_function], module
+            )
+            converter.optimizations = [tf.lite.Optimize.DEFAULT]
+            converter.experimental_new_converter = True
+
+            # use MLIR-based post-quantizer
+            converter.experimental_new_quantizer = True
+
+            flag = (
+                tf.lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8  # noqa: E501
+            )
+            if tflite_inference_dtype == tf.int16:
+                converter.target_spec.supported_ops = [flag]
+
+            def input_stats():
+                for i in range(0, args.num_samples):
+                    a = [
+                        TGen.getRand(shape, tf.float32, rng)
+                        for shape in placeholder_shapes
+                    ]
+                    yield a
+
+            converter.representative_dataset = input_stats
+            converter.inference_input_type = tflite_inference_dtype
+            converter.inference_output_type = tflite_inference_dtype
+
+            tflite_model = converter.convert()
+
+            tflite_model_filename = "model.tflite"
+
+            # Write out converted model to disk
+            with open(os.path.join(test_dir, tflite_model_filename), "wb") as f:
+                f.write(tflite_model)
+
+        else:  # is_quantized is False
+
+            # 1. Saved out numpy array directly
+            for idx, (name, val) in enumerate(placeholders):
+                placeholder_vals.append(tf.convert_to_tensor(val))
+                np.save(
+                    os.path.join(test_dir, placeholder_npy_filenames[idx]), val, False
+                )
+
+            # 2.a Saved out .pb if framework includes tensorflow
+            if "tf" not in excluded_framework_list:
+                # Write out graph as protobuf to disk
+                tf_model_filename = "model.pb"
+                tf.io.write_graph(
+                    concrete_function.graph, test_dir, tf_model_filename, True
+                )
+
+            # 2.b Saved out .tflite if framework includes tflite
+            if "tflite" not in excluded_framework_list:
+                # Convert the model to TFLite flatbuffer
+                module = tf.Module()
+                converter = tf.lite.TFLiteConverter.from_concrete_functions(
+                    [concrete_function], module
+                )
+
+                converter.experimental_new_converter = True
+
+                # Even it's non-quantized int32 test, this needs to be set to tf.float32
+                converter.inference_input_type = tf.float32
+                converter.inference_output_type = tf.float32
+                tflite_model = converter.convert()
+
+                # Write out converted model to disk
+                tflite_model_filename = "model.tflite"
+                with open(os.path.join(test_dir, tflite_model_filename), "wb") as f:
+                    f.write(tflite_model)
+
+        # Get TF reference result if .pb is specified
+        if tf_model_filename:
+            tf_result_npy_filename = "tf_result.npy"
+            tf_result = concrete_function(*placeholder_vals)
+            np.save(os.path.join(test_dir, tf_result_npy_filename), tf_result, False)
+
+            tf_result_name = result_name
+
+        # Get TFLite inference result if .tflite is specified
+        if tflite_model_filename:
+            tflite_result_npy_filename = "tflite_result.npy"
+
+            ops_with_optimized_only_kernel = ["elu", "ceil", "gather"]
+
+            if args.tflite_kernel_mode == "optimized" or (
+                op_name in ops_with_optimized_only_kernel
+            ):
+                interpreter = tf.lite.Interpreter(
+                    model_path=os.path.join(test_dir, tflite_model_filename)
+                )
+            elif args.tflite_kernel_mode == "reference":
+                interpreter = tf.lite.Interpreter(
+                    model_path=os.path.join(test_dir, tflite_model_filename),
+                    experimental_op_resolver_type=OpResolverType.BUILTIN_REF,
+                )
+            else:
+                assert 0, "unknown tflite interpreter mode {}".format(
+                    args.tflite_kernel_mode
+                )
+            interpreter.allocate_tensors()
+
+            input_details = interpreter.get_input_details()
+            output_details = interpreter.get_output_details()
+
+            assert len(input_details) == len(
+                placeholder_vals
+            ), "number of placeholder mismatch"
+
+            for idx, val in enumerate(placeholder_vals):
+                interpreter.set_tensor(input_details[idx]["index"], val.numpy())
+
+            interpreter.invoke()
+            tflite_result = interpreter.get_tensor(output_details[0]["index"])
+
+            np.save(
+                os.path.join(test_dir, tflite_result_npy_filename), tflite_result, False
+            )
+
+            # Result tensor name would change after converting to TFLite flatbuffer
+            # Overwrite the information from TFLite models directly.
+            # Assume single result tensor now
+            tflite_result_name = output_details[0]["name"]
+
+        # Write out test descriptor
+        write_test_json(
+            filename=os.path.join(test_dir, "test.json"),
+            tf_model_filename=tf_model_filename,
+            tf_result_npy_filename=tf_result_npy_filename,
+            tf_result_name=tf_result_name,
+            tflite_model_filename=tflite_model_filename,
+            tflite_result_npy_filename=tflite_result_npy_filename,
+            tflite_result_name=tflite_result_name,
+            ifm_name=placeholder_names,
+            ifm_file=placeholder_npy_filenames,
+            ifm_shape=placeholder_shapes,
+            framework_exclusions=excluded_framework_list,
+            quantized=is_quantized,
+        )
+    except Exception as e:
+        msg = "Error running task: {}".format(e)
+        print(msg)
+        print(
+            "".join(
+                traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__)
+            )
+        )
+        return False
+    return True
+
+
+def build_const_net(
+    args,
+    curr_shape,
+    op_name,
+    dtype,
+    excluded_framework_list,
+    quantized_inference_dtype,
+    result_name,
+    seed,
+    rng,
+    filter,
+    unit_test_args,
+):
+
+    if quantized_inference_dtype:
+        quant_dtype = get_tf_dtype(quantized_inference_dtype)
+        test_dir = "test_{}_{}".format(op_name, get_shape_str(curr_shape, quant_dtype))
+    else:
+        test_dir = "test_{}_{}".format(op_name, get_shape_str(curr_shape, dtype))
+    test_dir = os.path.join(args.output_dir, test_dir)
+
+    # If the operator has an additional function to generate arguments, call it
+    # here and iterate through the argument list that it generates
+    op = TF_OP_LIST[op_name]
+    op_fcn, tensor_gen_fcn, arg_gen_fcn = op["build_fcn"]
+
+    addl_args_tuple = arg_gen_fcn(op, curr_shape, rng)
+    for desc, addl_args in addl_args_tuple:
+        if not filter or filter.search(test_dir + desc):
+            unit_test_args.append(
+                [
+                    op_name,
+                    args,
+                    test_dir + desc,
+                    curr_shape,
+                    addl_args,
+                    dtype,
+                    excluded_framework_list,
+                    quantized_inference_dtype,
+                    result_name,
+                    seed,
+                ]
+            )
+
+
+# python hash is not reproducible, create hash for our purpose
+def op_name_hash(op_name):
+    result = 0xDEADBEEF
+    for ch in op_name:
+        if result & 1:
+            result = (ord(ch) << 24) ^ (result >> 1) ^ 0x82608EDB
+        else:
+            result = (ord(ch) << 24) ^ (result >> 1)
+
+    return result
+
+
+def generate_op_tests(args, op_name, shape_list, result_name, filter, unit_test_args):
+
+    if not args.quiet:
+        print(
+            "Generating tests for {}                                        ".format(
+                op_name
+            )
+        )
+
+    op = TF_OP_LIST[op_name]
+
+    # Seed the RNG so that we get the same random tests for each test each time
+    # If the number of tests for a given generation function changes, the tests
+    # for that operator may also change accordingly, but this will at least keep
+    # down churn across operators.
+
+    bounded_hash_val = (args.random_seed + op_name_hash(op_name)) % np.iinfo(
+        np.int32
+    ).max
+    rng = np.random.default_rng(bounded_hash_val)
+
+    # this is a dictionary with 'tf' and 'tflite' as key
+    # and value being the data types we want to test under these framework
+
+    if isinstance(op["types"], dict):
+        try:
+            tf_dtypes = op["types"]["tf"]
+        except KeyError:
+            tf_dtypes = []
+        try:
+            tflite_dtypes = op["types"]["tflite"]
+        except KeyError:
+            tflite_dtypes = []
+    elif isinstance(op["types"], list):
+        tf_dtypes = op["types"]
+        tflite_dtypes = op["types"]
+
+    tf_nonquantized_dtypes = tf_dtypes  # tf doesn't support quantized data types
+    tflite_quantized_dtypes = []
+    tflite_nonquantized_dtypes = []
+    for dtype in tflite_dtypes:
+        if isinstance(dtype, QuantType):
+            tflite_quantized_dtypes.append(dtype)
+        else:
+            tflite_nonquantized_dtypes.append(dtype)
+
+    nonquantized_dtypes_set = set(tf_nonquantized_dtypes).union(
+        set(tflite_nonquantized_dtypes)
+    )
+    nonquantized_dtypes = list(nonquantized_dtypes_set)
+    quantized_dtypes = tflite_quantized_dtypes
+
+    # populate non quantized unit test arguments
+    for dtype in nonquantized_dtypes:
+
+        excluded_framework_set = set(ALL_FRAMEWORKS)
+        if dtype in tf_nonquantized_dtypes:
+            excluded_framework_set.remove("tf")
+        if dtype in tflite_nonquantized_dtypes:
+            excluded_framework_set.remove("tflite")
+        excluded_framework_list = list(excluded_framework_set)
+
+        for curr_shape in shape_list:
+            build_const_net(
+                args,
+                curr_shape,
+                op_name,
+                dtype,
+                excluded_framework_list,
+                None,
+                result_name,
+                bounded_hash_val,
+                rng,
+                filter,
+                unit_test_args,
+            )
+
+    # populate quantized unit test arguments
+    # must exclude 'tf' and source dtype being tf.float32
+    for dtype in quantized_dtypes:
+        for curr_shape in shape_list:
+            build_const_net(
+                args,
+                curr_shape,
+                op_name,
+                tf.float32,
+                ["tf"],
+                dtype,
+                result_name,
+                bounded_hash_val,
+                rng,
+                filter,
+                unit_test_args,
+            )
+
+    return unit_test_args
+
+
+def createDynamicOpLists():
+    """The templated operators are conv2d-style operators with a number of kernel
+    sizes.  Since the operator is unchanged, we generate the range of kernel
+    sizes here in this loop and remove the original templates from the list.
+
+    This could be expanded to non-conv2d-style operators in the future."""
+
+    # Dynamically create op lists for convolutions with a list of kernel sizes
+    KERNELS = [
+        [1, 1],
+        [3, 3],
+        [5, 5],
+    ]
+
+    TEMPLATE_LIST = [
+        "conv2d",
+        "conv2d_bias",
+        "conv2d_relu",
+        "conv2d_relu6",
+        "conv2d_relu_n1_to_1",
+        "conv2d_tanh",
+        "depthwise_conv2d",
+        "depthwise_conv2d_bias",
+        "transpose_conv2d",
+    ]
+
+    for t in TEMPLATE_LIST:
+        for k in KERNELS:
+            testName = "{}_{}x{}".format(t, k[0], k[1])
+            TF_OP_LIST[testName] = TF_OP_LIST["{}_TEMPLATE".format(t)].copy()
+            TF_OP_LIST[testName]["filter"] = k
+            TF_OP_LIST[testName]["template"] = False
+
+    # Delete any templates after having created any dynamic ops
+    # This is a two-pass operation because it's bad practice to delete
+    # keys from dictionaries while iterating
+    keyList = []
+    for k in TF_OP_LIST:
+        try:
+            if TF_OP_LIST[k]["template"]:
+                keyList.append(k)
+                continue
+        except KeyError:
+            pass
+
+    for k in keyList:
+        del TF_OP_LIST[k]
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--seed", dest="random_seed", default=42, type=int, help="Random seed"
+    )
+    parser.add_argument(
+        "--random-shapes",
+        dest="random_shapes",
+        default=0,
+        type=int,
+        help=(
+            "Use N random shapes of each rank for generating tests,"
+            "seeded with random seed"
+        ),
+    )
+    parser.add_argument(
+        "-o",
+        "--output-dir",
+        dest="output_dir",
+        default=".",
+        type=str,
+        help="Test output directory path prefix",
+    )
+    parser.add_argument(
+        "-q",
+        "--quiet",
+        dest="quiet",
+        default=False,
+        action="store_true",
+        help="Do not print test names",
+    )
+    parser.add_argument(
+        "-j", "--jobs", dest="jobs", type=int, default=1, help="Number of parallel jobs"
+    )
+    parser.add_argument(
+        "-m",
+        "--tflite-kernel-mode",
+        dest="tflite_kernel_mode",
+        type=str,
+        choices=["reference", "optimized"],
+        default="reference",
+        help="TFLite interpreter kernel mode",
+    )
+    parser.add_argument(
+        "--num-samples",
+        dest="num_samples",
+        default=200,
+        type=int,
+        help="Number of input samples for post-training quantization",
+    )
+    parser.add_argument(
+        "--filter",
+        dest="filter",
+        default="",
+        type=str,
+        help="Filter test names by this expression",
+    )
+    args = parser.parse_args()
+
+    # Turn the filter into a re object if present
+    filter = None
+    if args.filter != "":
+        filter = re.compile(args.filter)
+
+    # Autodetect CPU count
+    if args.jobs <= 0:
+        args.jobs = os.cpu_count()
+
+    # Disable TF info messages
+    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+
+    try:
+        os.makedirs(args.output_dir)
+    except FileExistsError:
+        pass
+
+    if args.random_shapes:
+        gen_rand_shapes(args)
+
+    # Build dynamic ops
+    createDynamicOpLists()
+
+    # Generate the test list and arguments to run_unit_test()
+    unit_test_args = []
+
+    for op in TF_OP_LIST:
+        generate_op_tests(args, op, shape_list, "result", filter, unit_test_args)
+
+    errors = 0
+    for t in unit_test_args:
+        if not run_unit_test(*t):
+            errors = errors + 1
+
+    if not args.quiet:
+        print("\nAll tasks done - with {} errors".format(errors))
+
+    return 1 if errors else 0
+
+
+if __name__ == "__main__":
+    exit(main())
diff --git a/verif/frameworks/write_test_json.py b/verif/frameworks/write_test_json.py
new file mode 100644
index 0000000..68dfc8f
--- /dev/null
+++ b/verif/frameworks/write_test_json.py
@@ -0,0 +1,70 @@
+# Copyright (c) 2020-2022, ARM Limited.
+# SPDX-License-Identifier: Apache-2.0
+import json
+
+# Used by basic_test_generator to create test description
+
+
+def write_test_json(
+    filename,
+    tf_model_filename=None,
+    tf_result_npy_filename=None,
+    tf_result_name=None,
+    tflite_model_filename=None,
+    tflite_result_npy_filename=None,
+    tflite_result_name=None,
+    ifm_name=None,
+    ifm_file=None,
+    ifm_shape=None,
+    framework_exclusions=None,
+    quantized=False,
+):
+
+    test_desc = dict()
+
+    if tf_model_filename:
+        test_desc["tf_model_filename"] = tf_model_filename
+
+    if tf_result_npy_filename:
+        test_desc["tf_result_npy_filename"] = tf_result_npy_filename
+
+    if tf_result_name:
+        test_desc["tf_result_name"] = tf_result_name
+
+    if tflite_model_filename:
+        test_desc["tflite_model_filename"] = tflite_model_filename
+
+    if tflite_result_npy_filename:
+        test_desc["tflite_result_npy_filename"] = tflite_result_npy_filename
+
+    if tflite_result_name:
+        test_desc["tflite_result_name"] = tflite_result_name
+
+    if ifm_file:
+        if not isinstance(ifm_file, list):
+            ifm_file = [ifm_file]
+        test_desc["ifm_file"] = ifm_file
+
+    # Make sure these arguments are wrapped as lists
+    if ifm_name:
+        if not isinstance(ifm_name, list):
+            ifm_name = [ifm_name]
+        test_desc["ifm_name"] = ifm_name
+
+    if ifm_shape:
+        if not isinstance(ifm_shape, list):
+            ifm_shape = [ifm_shape]
+        test_desc["ifm_shape"] = ifm_shape
+
+    # Some tests cannot be used with specific frameworks.
+    # This list indicates which tests should be excluded from a given framework.
+    if framework_exclusions:
+        if not isinstance(framework_exclusions, list):
+            framework_exclusions = [framework_exclusions]
+        test_desc["framework_exclusions"] = framework_exclusions
+
+    if quantized:
+        test_desc["quantized"] = 1
+
+    with open(filename, "w") as f:
+        json.dump(test_desc, f, indent="  ")
diff --git a/verif/runner/tosa_refmodel_sut_run.py b/verif/runner/tosa_refmodel_sut_run.py
index b9a9575..2aeb7b1 100644
--- a/verif/runner/tosa_refmodel_sut_run.py
+++ b/verif/runner/tosa_refmodel_sut_run.py
@@ -58,12 +58,8 @@
                 graphResult = TosaTestRunner.TosaGraphResult.TOSA_UNPREDICTABLE
             else:
                 graphResult = TosaTestRunner.TosaGraphResult.OTHER_ERROR
-            if (
-                self.args.verbose
-                or graphResult == TosaTestRunner.TosaGraphResult.OTHER_ERROR
-            ):
-                print(e)
-
+                if not self.args.verbose:
+                    print(e)
         except Exception as e:
             print(e)
             graphMessage = str(e)
diff --git a/verif/runner/tosa_test_runner.py b/verif/runner/tosa_test_runner.py
index 0fd7f13..d653a94 100644
--- a/verif/runner/tosa_test_runner.py
+++ b/verif/runner/tosa_test_runner.py
@@ -7,6 +7,7 @@
 
 from checker.tosa_result_checker import LogColors
 from checker.tosa_result_checker import print_color
+from checker.tosa_result_checker import set_print_in_color
 from checker.tosa_result_checker import test_check
 from json2fbbin import json2fbbin
 
@@ -39,6 +40,8 @@
         self.testDir = testDir
         self.testName = Path(self.testDir).name
 
+        set_print_in_color(not args.no_color)
+
         # Check if we want to run binary and if its already converted
         descFilePath = Path(testDir, "desc.json")
         descBinFilePath = Path(testDir, "desc_binary.json")
@@ -165,9 +168,11 @@
                 result == TosaTestRunner.Result.EXPECTED_FAILURE
                 or result == TosaTestRunner.Result.EXPECTED_PASS
             ):
-                print_color(LogColors.GREEN, "Results PASS {}".format(self.testName))
+                print_color(
+                    LogColors.GREEN, "Result code PASS {}".format(self.testName)
+                )
             else:
-                print_color(LogColors.RED, "Results FAIL {}".format(self.testName))
+                print_color(LogColors.RED, "Result code FAIL {}".format(self.testName))
 
         return result, resultMessage
 
diff --git a/verif/runner/tosa_verif_run_tests.py b/verif/runner/tosa_verif_run_tests.py
index dd86950..b400d76 100644
--- a/verif/runner/tosa_verif_run_tests.py
+++ b/verif/runner/tosa_verif_run_tests.py
@@ -119,6 +119,13 @@
         choices=["positive", "negative", "both"],
         help="Filter tests based on expected failure status (positive, negative or both)",
     )
+    parser.add_argument(
+        "--no-color",
+        "--no-colour",
+        dest="no_color",
+        action="store_true",
+        help="Disable color output",
+    )
 
     args = parser.parse_args(argv)