Compliance testing support for MAX_POOL2D & PAD

Added Pseudo Random number generator in generate library.
Enabled MAX_POOL2D, PAD FP32 tests to use new generator and compliance.
Fixed verify library exact mode to expect reference data as FP64.
Simplified tosa_verif_build_tests internal interfaces for new tests.

Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: Icc0ffa924cf38107c3a212efd452c47a650c9d98
diff --git a/verif/conformance/tosa_main_profile_ops_info.json b/verif/conformance/tosa_main_profile_ops_info.json
index 0b6dc79..9c18879 100644
--- a/verif/conformance/tosa_main_profile_ops_info.json
+++ b/verif/conformance/tosa_main_profile_ops_info.json
@@ -1471,6 +1471,7 @@
         "profile": [
             "tosa-mi"
         ],
+        "support_for": [ "lazy_data_gen" ],
         "generation": {
             "standard": {
                 "generator_args": [
@@ -1599,6 +1600,7 @@
         "profile": [
             "tosa-mi"
         ],
+        "support_for": [ "lazy_data_gen" ],
         "generation": {
             "standard": {
                 "generator_args": [
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 475f062..f7837a0 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -635,7 +635,11 @@
         # Variable inputs versus constants
         pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"]
 
-        if error_name is not None or not gtu.dtypeIsSupportedByCompliance(dtypeList[0]):
+        if (
+            error_name is not None
+            or not gtu.dtypeIsSupportedByCompliance(dtypeList[0])
+            or opName in ("avg_pool2d",)
+        ):
             # Fall back to original path when dealing with unsupported types
 
             # First turn off lazy data gen so we always produce data
@@ -678,7 +682,7 @@
             if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
                 info = {}
                 # TODO - generate seed for this generator based on test
-                info["rng_seed"] = -1
+                info["rng_seed"] = 42
                 info["range"] = [
                     str(v)
                     for v in testGen.getDTypeRange(dtypeList[idx], high_inclusive=True)
@@ -1107,7 +1111,7 @@
         pass
 
     @staticmethod
-    def _add_data_generators(testGen, opName, dtype, arg_list, error_name, **kwargs):
+    def _add_data_generators(testGen, opName, dtype, arg_list, error_name):
         """Add extra tests for each type of data generator for this op."""
         if (
             error_name is None
@@ -1125,32 +1129,28 @@
         # Expand arg list with other data generator types
         new_arg_list = []
         for dg_type in dataGenTypesList:
-            for arg_str, arg_attrs in arg_list:
-                arg_dict = arg_attrs[0]
-                arg_dict["dg_type"] = dg_type
-
+            for arg_str, args_dict in arg_list:
+                args_dict["dg_type"] = dg_type
                 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
                     # Default test
-                    new_arg_list.append((arg_str, [arg_dict]))
+                    new_arg_list.append((arg_str, args_dict))
 
                 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
                     # Extra tests for each dot product test set
-                    dot_products = kwargs["dot_products"]
+                    dot_products = args_dict["dot_products"]
                     if dot_products < testGen.TOSA_MI_DOT_PRODUCT_MIN:
                         print(
                             f"Skipping {opName} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}"
                         )
                         continue
-                    arg_dict["ks"] = kwargs["ks"]
-                    for key in gtu.DG_DOT_PRODUCT_OPTIONAL_INFO:
-                        if key in kwargs:
-                            arg_dict[key] = kwargs[key]
+                    # KS is required by all dot product generators
+                    assert "ks" in args_dict
 
                     for s in testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS:
                         new_arg_str = f"{arg_str}_s{s}"
-                        new_arg_dict = arg_dict.copy()
-                        new_arg_dict["s"] = s
-                        new_arg_list.append((new_arg_str, [new_arg_dict]))
+                        new_args_dict = args_dict.copy()
+                        new_args_dict["s"] = s
+                        new_arg_list.append((new_arg_str, new_args_dict))
 
         return new_arg_list
 
@@ -1421,9 +1421,21 @@
             # Pick some potentially correct output dtype if input type is incorrect
             accum_dtypes = [DType.INT32]
 
-        arg_list = [
-            (f"acc{testGen.typeStr(a)}", [{"acc_type": a}]) for a in accum_dtypes
-        ]
+        # Set up compliance info
+        args_dict = {
+            "ks": int(shapeList[0][2]),  # Set KS = C, from input A (N,H,C)
+            # Set dot_products = N*H*W
+            "dot_products": gtu.product(
+                (shapeList[0][0], shapeList[0][1], shapeList[1][2])
+            ),
+        }
+
+        # Create arg tuple of string and dict
+        arg_list = []
+        for a in accum_dtypes:
+            d = args_dict.copy()
+            d["acc_type"] = a
+            arg_list.append((f"acc{testGen.typeStr(a)}", d))
 
         arg_list = TosaArgGen._add_data_generators(
             testGen,
@@ -1431,12 +1443,8 @@
             dtype,
             arg_list,
             error_name,
-            ks=int(shapeList[0][2]),  # Set KS = C, from input A (N,H,C)
-            # Set dot_products = N*H*W
-            dot_products=gtu.product(
-                (shapeList[0][0], shapeList[0][1], shapeList[1][2])
-            ),
         )
+        # Return list of tuples: (arg_str, args_dict)
         return arg_list
 
     @staticmethod
@@ -1574,7 +1582,6 @@
 
     @staticmethod
     def agPad(testGen, opName, shapeList, dtype, error_name=None):
-        arg_list = []
         rank = len(shapeList[0])
 
         # Exhaustively test combinations of padding on each side of each dimension
@@ -1606,6 +1613,8 @@
         else:
             sparsity = 1
 
+        # Build arg list
+        arg_list = []
         for n, paddings in enumerate(list_shape_pad_values):
             paddings = list(paddings)
             args_valid = True
@@ -1625,13 +1634,25 @@
                 for r in range(rank):
                     before, after = paddings[r]
                     name = f"{name}{before}{after}"
-                arg_list.append(
-                    (name, [np.array(paddings), pad_const_int, pad_const_fp])
-                )
+                    args_dict = {
+                        "pad": np.array(paddings),
+                        "pad_const_int": pad_const_int,
+                        "pad_const_fp": pad_const_fp,
+                    }
+                arg_list.append((name, args_dict))
 
         if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
             warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}")
 
+        arg_list = TosaArgGen._add_data_generators(
+            testGen,
+            opName,
+            dtype,
+            arg_list,
+            error_name,
+        )
+
+        # Return list of tuples: (arg_str, args_dict)
         return arg_list
 
     @staticmethod
@@ -1735,9 +1756,9 @@
             else "st{}_kern{}_pad{}"
         )
 
-        def get_arg_list_element(accum, stride, pad, kern):
+        def get_arg_list_element(accum, stride, pad, kern, dot_products=0):
             # Return tuple containing the formatted argument string and
-            # the corresponding argument values
+            # the corresponding argument values in a dictionary
 
             # Support for larger values than 9 needs different delimiter
             delim = "" if max(stride + kern + pad) <= 9 else "x"
@@ -1746,13 +1767,18 @@
                 delim.join([str(x) for x in kern]),
                 delim.join([str(x) for x in pad]),
             ]
-            # Note: different order to string
-            arg_val_elems = [stride, pad, kern]
+            args_dict = {
+                "stride": stride,
+                "pad": pad,
+                "kernel": kern,
+                "dot_products": dot_products,  # Ignored for error tests
+                "ks": gtu.product(kern),  # avg_pool2d: KS = KX*KY
+            }
 
             if accum is not None:
                 arg_str_elems.insert(0, testGen.typeStr(accum))
-                arg_val_elems.insert(0, accum)
-            return (arg_str.format(*arg_str_elems), arg_val_elems)
+                args_dict["acc_type"] = accum
+            return (arg_str.format(*arg_str_elems), args_dict)
 
         n = 0
         for a in accum_dtypes:
@@ -1769,8 +1795,9 @@
                                 testGen, error_name, s, p, k
                             )
                             if None not in [sNew, pNew, kNew] and n % sparsity == 0:
-                                arg_vals = [a, sNew, pNew, kNew]
-                                arg_list.append(get_arg_list_element(*arg_vals))
+                                arg_list.append(
+                                    get_arg_list_element(a, sNew, pNew, kNew)
+                                )
                         elif (
                             n % sparsity == 0
                             # padding must not exceed the kernel size
@@ -1804,10 +1831,23 @@
                                 ):
                                     # Test will consume too much memory - skip it
                                     continue
-                                arg_vals = [a, s, p, k]
-                                arg_list.append(get_arg_list_element(*arg_vals))
+                                # Dot products = N*OH*OW*C
+                                dp = gtu.product(
+                                    (shape[0], output_h, output_w, shape[3])
+                                )
+                                arg_list.append(get_arg_list_element(a, s, p, k, dp))
                         n += 1
 
+        # Now add data generator types
+        arg_list = TosaArgGen._add_data_generators(
+            testGen,
+            opName,
+            dtype,
+            arg_list,
+            error_name,
+        )
+
+        # Return list of tuples: (arg_str, args_dict)
         return arg_list
 
     @staticmethod
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index d490cf2..ed1a941 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -2653,16 +2653,28 @@
 
         args = kwargs["args"]
 
-        # Skip accum_dtype arg (apart from MaxPool2D that doesn't have one)
-        stride_idx, pad_idx = (1, 2) if opName != "max_pool2d" else (0, 1)
+        if isinstance(args, dict):
+            args_dict = args
+        else:
+            # Create args_dict from list elements
+            # TODO - Remove this once all NWHC operators agFunctions have been
+            # converted to args_dict output
+
+            # Skip accum_dtype arg (apart from MaxPool2D that doesn't have one)
+            stride_idx, pad_idx = (1, 2) if opName != "max_pool2d" else (0, 1)
+            args_dict = {"stride": args[stride_idx], "pad": args[pad_idx]}
+            # Alias different info for each op
+            args_dict["kernel"] = args[pad_idx + 1]
+            args_dict["out_shape"] = args[pad_idx + 1]
+            args_dict["dilation"] = args[pad_idx + 1]
 
         # Common info for all ops
-        strides = args[stride_idx]
-        padding = args[pad_idx]
+        strides = args_dict["stride"]
+        padding = args_dict["pad"]
 
         if opName.endswith("pool2d"):
             # avg_pool2d, max_pool2d
-            kernel_shape = args[pad_idx + 1]
+            kernel_shape = args_dict["kernel"]
             h = (
                 input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]
             ) // strides[0]
@@ -2674,7 +2686,7 @@
 
         if opName.startswith("transpose_conv2d"):
             # transpose_conv2d
-            output_shape = args[pad_idx + 1]
+            output_shape = args_dict["out_shape"]
             filter_shape = inputShapes[1]
             kernel_shape = filter_shape[1:-1]
 
@@ -2703,7 +2715,7 @@
 
         if "conv2d" in opName or "conv3d" in opName:
             # conv2d, conv3d, depthwise_conv2d
-            dilations = args[pad_idx + 1]
+            dilations = args_dict["dilation"]
             filter_shape = inputShapes[1]
             kernel_shape = (
                 filter_shape[0:2]
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 8fcea29..17cbd8f 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -658,15 +658,22 @@
     def build_pool2d(
         self,
         op,
-        input,
-        accum_dtype,
-        stride,
-        pad,
-        kernel,
+        inputs,
+        args_dict,
         validator_fcns=None,
         error_name=None,
         qinfo=None,
     ):
+        assert len(inputs) == 1
+        input = inputs[0]
+        # max_pool has no accum_dtype
+        accum_dtype = (
+            args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
+        )
+        stride = args_dict["stride"]
+        pad = args_dict["pad"]
+        kernel = args_dict["kernel"]
+
         result_tens = OutputShaper.pool2dOp(
             self.ser, self.rng, input, kernel, stride, pad, error_name
         )
@@ -720,27 +727,28 @@
     def build_maxpool2d(
         self,
         op,
-        input,
-        stride,
-        pad,
-        kernel,
+        inputs,
+        args_dict,
         validator_fcns=None,
         error_name=None,
         qinfo=None,
     ):
-        # Same as build_pool2d but manually sets accum_dtype value
-        # (maxpool has no accum_dtype)
-        return self.build_pool2d(
+        result_tensor = self.build_pool2d(
             op,
-            input,
-            DType.UNKNOWN,
-            stride,
-            pad,
-            kernel,
+            inputs,
+            args_dict,
             validator_fcns,
             error_name,
             qinfo,
         )
+        if gtu.dtypeIsSupportedByCompliance(inputs[0].dtype):
+            compliance = self.tensorComplianceMetaData(
+                op, args_dict, result_tensor, error_name
+            )
+        else:
+            compliance = None
+
+        return TosaTestGen.BuildInfo(result_tensor, compliance)
 
     def build_conv2d(
         self,
@@ -1070,8 +1078,10 @@
         return result_tens
 
     def build_matmul(
-        self, op, a, b, args_dict, validator_fcns=None, error_name=None, qinfo=None
+        self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
     ):
+        assert len(inputs) == 2
+        a, b = inputs
         accum_dtype = args_dict["acc_type"]
         result_tensor = OutputShaper.matmulOp(
             self.ser, self.rng, a, b, accum_dtype, error_name
@@ -1372,15 +1382,19 @@
     def build_pad(
         self,
         op,
-        a,
-        padding,
-        pad_const_int,
-        pad_const_float,
+        inputs,
+        args_dict,
         validator_fcns=None,
         error_name=None,
         qinfo=None,
     ):
-        result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
+        assert len(inputs) == 1
+        a = inputs[0]
+        padding = args_dict["pad"]
+        pad_const_int = args_dict["pad_const_int"]
+        pad_const_float = args_dict["pad_const_fp"]
+
+        result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
 
         attr = ts.TosaSerializerAttribute()
         attr.PadAttribute(
@@ -1389,7 +1403,7 @@
 
         # Invalidate Input/Output list for error if checks.
         input_list = [a.name]
-        output_list = [result_tens.name]
+        output_list = [result_tensor.name]
         pCount, cCount = op["operands"]
         num_operands = pCount + cCount
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
@@ -1402,12 +1416,12 @@
             error_name,
             op=op,
             input_shape=a.shape,
-            output_shape=result_tens.shape,
+            output_shape=result_tensor.shape,
             input_dtype=a.dtype,
-            output_dtype=result_tens.dtype,
+            output_dtype=result_tensor.dtype,
             pad=padding,
             qinfo=qinfo,
-            result_tensors=[result_tens],
+            result_tensors=[result_tensor],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -1416,7 +1430,15 @@
             return None
 
         self.ser.addOperator(op["op"], input_list, output_list, attr)
-        return result_tens
+
+        if gtu.dtypeIsSupportedByCompliance(a.dtype):
+            compliance = self.tensorComplianceMetaData(
+                op, args_dict, result_tensor, error_name
+            )
+        else:
+            compliance = None
+
+        return TosaTestGen.BuildInfo(result_tensor, compliance)
 
     def build_dim(
         self,
@@ -2609,8 +2631,9 @@
         tensMeta = {}
 
         # Check we are using the new testArgs interface with an argsDict dictionary
-        if len(testArgs) == 1 and isinstance(testArgs[0], dict):
-            argsDict = testArgs[0]
+        if isinstance(testArgs, dict):
+            # New interface with args info in dictionary
+            argsDict = testArgs
             assert "dg_type" in argsDict
             tvgInfo = tvgen_fcn(
                 self, opName, dtypeList, shapeList, argsDict, error_name
@@ -2618,38 +2641,49 @@
             if tvgInfo.dataGenDict:
                 tensMeta["data_gen"] = tvgInfo.dataGenDict
             tens = tvgInfo.tensorList
+
+            result = build_fcn(
+                self,
+                op,
+                tens,
+                argsDict,
+                validator_fcns=error_if_validators,
+                error_name=error_name,
+                qinfo=qinfo,
+            )
         else:
+            # Old interface with args info in a list
             tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
 
-        try:
-            if error_if_validators is None:
-                if qinfo is not None:
-                    result = build_fcn(self, op, *tens, *testArgs, qinfo)
+            try:
+                if error_if_validators is None:
+                    if qinfo is not None:
+                        result = build_fcn(self, op, *tens, *testArgs, qinfo)
+                    else:
+                        result = build_fcn(self, op, *tens, *testArgs)
                 else:
-                    result = build_fcn(self, op, *tens, *testArgs)
-            else:
-                if qinfo is not None:
-                    result = build_fcn(
-                        self,
-                        op,
-                        *tens,
-                        *testArgs,
-                        validator_fcns=error_if_validators,
-                        error_name=error_name,
-                        qinfo=qinfo,
-                    )
-                else:
-                    result = build_fcn(
-                        self,
-                        op,
-                        *tens,
-                        *testArgs,
-                        validator_fcns=error_if_validators,
-                        error_name=error_name,
-                    )
-        except TypeError as e:
-            print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
-            raise e
+                    if qinfo is not None:
+                        result = build_fcn(
+                            self,
+                            op,
+                            *tens,
+                            *testArgs,
+                            validator_fcns=error_if_validators,
+                            error_name=error_name,
+                            qinfo=qinfo,
+                        )
+                    else:
+                        result = build_fcn(
+                            self,
+                            op,
+                            *tens,
+                            *testArgs,
+                            validator_fcns=error_if_validators,
+                            error_name=error_name,
+                        )
+            except TypeError as e:
+                print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
+                raise e
 
         if result:
             # The test is valid, serialize it
@@ -2847,7 +2881,7 @@
             "build_fcn": (
                 build_pool2d,
                 TosaTensorGen.tgNHWC,
-                TosaTensorValuesGen.tvgDefault,
+                TosaTensorValuesGen.tvgLazyGenDefault,
                 TosaArgGen.agPooling,
             ),
             "qgen": TosaQuantGen.qgUnary,
@@ -3004,7 +3038,6 @@
             ),
             "data_gen": {
                 "fp": (gtu.DataGenType.DOT_PRODUCT,),
-                "int": (gtu.DataGenType.PSEUDO_RANDOM,),
             },
         },
         "max_pool2d": {
@@ -3014,7 +3047,7 @@
             "build_fcn": (
                 build_maxpool2d,
                 TosaTensorGen.tgNHWC,
-                TosaTensorValuesGen.tvgDefault,
+                TosaTensorValuesGen.tvgLazyGenDefault,
                 TosaArgGen.agPooling,
             ),
             "types": TYPE_NARROW_INT_FP,
@@ -3032,6 +3065,9 @@
                 TosaErrorValidator.evPoolingOutputShapeMismatch,
                 TosaErrorValidator.evPoolingOutputShapeNonInteger,
             ),
+            "data_gen": {
+                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
+            },
         },
         # Templated operator.  Filled in by createDynamicOpLists
         "transpose_conv2d_TEMPLATE": {
@@ -3909,7 +3945,7 @@
             "build_fcn": (
                 build_pad,
                 TosaTensorGen.tgBasic,
-                TosaTensorValuesGen.tvgDefault,
+                TosaTensorValuesGen.tvgLazyGenDefault,
                 TosaArgGen.agPad,
             ),
             "types": TYPE_FIB,
@@ -3923,6 +3959,9 @@
                 TosaErrorValidator.evRankMismatch,
                 TosaErrorValidator.evWrongRank,
             ),
+            "data_gen": {
+                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
+            },
         },
         "dim": {
             "op": Op.DIM,