Fix MAXIMUM/MINIMUM handling of NaNs and zeroes

Change FP_SPECIAL testing to be used for DOT_PRODUCT cases only.
Use default EXACT matching - where zeroes of different signs will
be ignored when testing for equality

Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: I0461c42258611cae597f693507075b3ef15fbe19
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index 8cc1319..d4a9f2f 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -411,6 +411,22 @@
         case TOSA_REF_TYPE_BF16:
         case TOSA_REF_TYPE_FP32:
         case TOSA_REF_TYPE_FP64:
+            this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType {
+                if (isnan(a))
+                {
+                    return a;
+                }
+                else if (isnan(b))
+                {
+                    return b;
+                }
+                else
+                {
+                    return a > b ? a : b;
+                }
+            };
+            break;
+
         case TOSA_REF_TYPE_INT32:
             this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; };
             break;
@@ -430,6 +446,21 @@
         case TOSA_REF_TYPE_BF16:
         case TOSA_REF_TYPE_FP32:
         case TOSA_REF_TYPE_FP64:
+            this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType {
+                if (isnan(a))
+                {
+                    return a;
+                }
+                else if (isnan(b))
+                {
+                    return b;
+                }
+                else
+                {
+                    return a < b ? a : b;
+                }
+            };
+            break;
         case TOSA_REF_TYPE_INT32:
             this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; };
             break;
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 88dd17a..cbac081 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -260,6 +260,11 @@
             # Data type is needed for all FP runs, as refmodel precise mode produces FP64
             "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
         }
+
+        op_compliance = op.get("compliance", {})
+        mode = None
+
+        # Check what data generation we have done
         if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
             mode = gtu.ComplianceMode.DOT_PRODUCT
             compliance_tens["dot_product_info"] = {
@@ -268,12 +273,10 @@
                     int(argsDict["ksb"]) if "ksb" in argsDict else int(argsDict["ks"])
                 ),
             }
-        elif argsDict["dg_type"] == gtu.DataGenType.FP_SPECIAL:
-            mode = gtu.ComplianceMode.FP_SPECIAL
-        elif "compliance" in op and "ulp" in op["compliance"]:
+        elif "ulp" in op_compliance:
             mode = gtu.ComplianceMode.ULP
             compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
-        elif "compliance" in op and "relative" in op["compliance"]:
+        elif "relative" in op_compliance:
             mode = gtu.ComplianceMode.RELATIVE
             compliance_tens["relative_info"] = {
                 "max": argsDict["max_abs_value"],
@@ -284,26 +287,30 @@
             compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
         elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
             mode = gtu.ComplianceMode.ABS_ERROR
-            if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
+            if "abs_error_lower_bound" in op_compliance:
                 compliance_tens["abs_error_info"] = {
                     "lower_bound": op["compliance"]["abs_error_lower_bound"]
                 }
         elif op["op"] in (Op.SIN, Op.COS):
             mode = gtu.ComplianceMode.ABS_ERROR
-            if "compliance" in op:
-                normal_divisor = op["compliance"].get("abs_error_normal_divisor", 1)
-                bound_addition = op["compliance"].get("abs_error_bound_addition", 0)
-            else:
-                normal_divisor = 1
-                bound_addition = 0
+            normal_divisor = op_compliance.get("abs_error_normal_divisor", 1)
+            bound_addition = op_compliance.get("abs_error_bound_addition", 0)
 
             compliance_tens["abs_error_info"] = {
                 "normal_divisor": normal_divisor,
                 "bound_as_magnitude": True,
                 "bound_addition": bound_addition,
             }
+        elif argsDict["dg_type"] == gtu.DataGenType.FP_SPECIAL:
+            if gtu.ComplianceMode.DOT_PRODUCT in op["data_gen"][inputType]:
+                # Use special mode that only checks for matching inf/nan/zeroes
+                # as normal values need statistical analysis
+                mode = gtu.ComplianceMode.FP_SPECIAL
+            else:
+                mode = gtu.ComplianceMode.EXACT
         else:
             mode = gtu.ComplianceMode.EXACT
+
         compliance_tens["mode"] = gtu.ComplianceMode(mode).name
 
         return compliance_tens