Main compliance testing support for MUL

Update verify ULP mode to allow fractions (e.g. 0.5).
Update pseudo generator to accept ranges.
Fix up pseudo random distribution based on ranges.

Change-Id: I9168c5f7d37722678c0f1f9e906953c8cec367b1
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
diff --git a/reference_model/src/generate/generate_pseudo_random.cc b/reference_model/src/generate/generate_pseudo_random.cc
index 858a4b2..f234796 100644
--- a/reference_model/src/generate/generate_pseudo_random.cc
+++ b/reference_model/src/generate/generate_pseudo_random.cc
@@ -40,40 +40,76 @@
         constexpr auto min = std::numeric_limits<FP>::lowest() / 2;
         constexpr auto max = std::numeric_limits<FP>::max() / 2;
         static_assert(max <= std::numeric_limits<FP>::max() + min);
-        _unidis = std::uniform_real_distribution<FP>(min, max);
 
-        // Piecewise Constant distribution
-        const std::array<double, 7> intervals{ min, min + 1000, -1000.0, 0.0, 1000.0, max - 1000, max };
-        const std::array<double, 7> weights{ 1.0, 0.1, 1.0, 2.0, 1.0, 0.1, 1.0 };
-        _pwcdis = std::piecewise_constant_distribution<FP>(intervals.begin(), intervals.end(), weights.begin());
+        setDistribution(min, max);
     }
 
-    FP getRandomUniformFloat()
+    PseudoRandomGeneratorFloat(uint64_t seed, FP min, FP max)
+        : _gen(seed)
     {
-        return _unidis(_gen);
+        setDistribution(min, max);
     }
 
-    FP getRandomPWCFloat()
+    FP getRandomFloat()
     {
-        return _pwcdis(_gen);
+        if (_useUniform)
+            return _unidis(_gen);
+        else
+            return _pwcdis(_gen);
     }
 
 private:
+    void setDistribution(FP min, FP max)
+    {
+        _unidis = std::uniform_real_distribution<FP>(min, max);
+
+        // Piecewise Constant distribution for larger ranges
+        double range = std::abs(max - min);
+        double mid;
+        if (max == -min)
+            mid = 0.f;
+        else
+            mid = (range / 2) + min;
+        double segment = std::min<double>(1000.0, range / 5);
+
+        const std::array<double, 7> intervals{
+            min, min + segment, mid - segment, mid, mid + segment, max - segment, max
+        };
+        const std::array<double, 7> weights{ 1.0, 0.1, 1.0, 2.0, 1.0, 0.1, 1.0 };
+        _pwcdis = std::piecewise_constant_distribution<FP>(intervals.begin(), intervals.end(), weights.begin());
+
+        // Uniform distribution works well on smaller ranges
+        _useUniform = (range < 2000.0);
+    }
+
     std::mt19937 _gen;
     std::uniform_real_distribution<FP> _unidis;
     std::piecewise_constant_distribution<FP> _pwcdis;
+    bool _useUniform;
 };
 
 bool generateFP32(const TosaReference::GenerateConfig& cfg, void* data, size_t size)
 {
     const TosaReference::PseudoRandomInfo& prinfo = cfg.pseudoRandomInfo;
-    PseudoRandomGeneratorFloat<float> generator(prinfo.rngSeed);
+
+    PseudoRandomGeneratorFloat<float>* generator;
+
+    if (prinfo.range.size() == 2)
+    {
+        const float min = std::stof(prinfo.range[0]);
+        const float max = std::stof(prinfo.range[1]);
+        generator       = new PseudoRandomGeneratorFloat<float>(prinfo.rngSeed, min, max);
+    }
+    else
+    {
+        generator = new PseudoRandomGeneratorFloat<float>(prinfo.rngSeed);
+    }
 
     float* a     = reinterpret_cast<float*>(data);
     const auto T = TosaReference::numElementsFromShape(cfg.shape);
     for (auto t = 0; t < T; ++t)
     {
-        a[t] = generator.getRandomPWCFloat();
+        a[t] = generator->getRandomFloat();
     }
     return true;
 }
@@ -90,6 +126,10 @@
         WARNING("[Generator][PR] Unknown operator.");
         return false;
     }
+    if (cfg.pseudoRandomInfo.range.size() != 0 || cfg.pseudoRandomInfo.range.size() != 2)
+    {
+        WARNING("[Generator][PR] Invalid range.");
+    }
 
     switch (cfg.dataType)
     {
diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc
index d3bb076..ae6dfcb 100644
--- a/reference_model/src/generate/generate_utils.cc
+++ b/reference_model/src/generate/generate_utils.cc
@@ -38,10 +38,11 @@
 NLOHMANN_JSON_SERIALIZE_ENUM(Op,
                              {
                                  { Op::Op_UNKNOWN, "UNKNOWN" },
+                                 { Op::Op_CONV2D, "CONV2D" },
                                  { Op::Op_MATMUL, "MATMUL" },
                                  { Op::Op_MAX_POOL2D, "MAX_POOL2D" },
+                                 { Op::Op_MUL, "MUL" },
                                  { Op::Op_PAD, "PAD" },
-                                 { Op::Op_CONV2D, "CONV2D" },
                              })
 
 }    // namespace tosa
@@ -84,6 +85,10 @@
 void from_json(const nlohmann::json& j, PseudoRandomInfo& pseudoRandomInfo)
 {
     j.at("rng_seed").get_to(pseudoRandomInfo.rngSeed);
+    if (j.contains("range"))
+    {
+        j.at("range").get_to(pseudoRandomInfo.range);
+    }
 }
 
 void from_json(const nlohmann::json& j, GenerateConfig& cfg)
diff --git a/reference_model/src/generate/generate_utils.h b/reference_model/src/generate/generate_utils.h
index 7c55f1d..8d0f654 100644
--- a/reference_model/src/generate/generate_utils.h
+++ b/reference_model/src/generate/generate_utils.h
@@ -61,7 +61,7 @@
     PseudoRandomInfo() = default;
 
     int64_t rngSeed;
-    // TODO: Add range support
+    std::vector<std::string> range;
 };
 
 /// \brief Generator configuration
diff --git a/reference_model/src/verify/verifiers.h b/reference_model/src/verify/verifiers.h
index dd97122..fcfb3b3 100644
--- a/reference_model/src/verify/verifiers.h
+++ b/reference_model/src/verify/verifiers.h
@@ -58,7 +58,7 @@
 /// \param ulp    The ULP tolerence for the comparison of the two tensors
 ///
 /// \return True if compliant else false
-bool verifyULP(const CTensor* referenceTensor, const CTensor* implementationTensor, uint64_t ulp);
+bool verifyULP(const CTensor* referenceTensor, const CTensor* implementationTensor, const UlpInfo& ulpInfo);
 
 };    // namespace TosaReference
 
diff --git a/reference_model/src/verify/verify_dot_product.cc b/reference_model/src/verify/verify_dot_product.cc
index 233c072..863640f 100644
--- a/reference_model/src/verify/verify_dot_product.cc
+++ b/reference_model/src/verify/verify_dot_product.cc
@@ -14,7 +14,6 @@
 
 #include "func_debug.h"
 #include "verifiers.h"
-#include "verify_utils.h"
 
 #include <cmath>
 #include <numeric>
diff --git a/reference_model/src/verify/verify_entry.cc b/reference_model/src/verify/verify_entry.cc
index 67eb7df..4da3bde 100644
--- a/reference_model/src/verify/verify_entry.cc
+++ b/reference_model/src/verify/verify_entry.cc
@@ -38,7 +38,7 @@
             return verifyReduceProduct(ref, imp, cfg.reduceProductInfo.m, cfg.reduceProductInfo.n);
         }
         case VerifyMode::Ulp: {
-            return verifyULP(ref, imp, cfg.ulpInfo.ulp);
+            return verifyULP(ref, imp, cfg.ulpInfo);
         }
         default: {
             WARNING("[Verifier] Unsupported verification mode.");
diff --git a/reference_model/src/verify/verify_ulp.cc b/reference_model/src/verify/verify_ulp.cc
index 486c0ff..8c27191 100644
--- a/reference_model/src/verify/verify_ulp.cc
+++ b/reference_model/src/verify/verify_ulp.cc
@@ -31,7 +31,7 @@
               "TOSA Reference Model has not been built with standard IEE574 64-bit float support; ULP based "
               "verifcation is invalid");
 
-bool tosaCheckULP(float testValue, double referenceValue, int64_t ulpCount)
+bool tosaCheckULP(double referenceValue, float testValue, double ulpNum)
 {
 
     // Start by sanitizing the input.
@@ -71,57 +71,55 @@
     else
     {
         // Find the exponent of the reference value.
-        int referenceExponent;
-        std::frexp(referenceValue, &referenceExponent);
+        int32_t referenceExponent = ilog2(referenceValue);
 
         // Work out the values magnitude - by raising 2 to the power of the
         // exponent and taking the normalized minimum for denormal values
-        const double referencePower2 =
-            std::max(std::ldexp(1.0, referenceExponent), static_cast<double>(std::numeric_limits<float>::min()));
+        const double referencePower2 = std::max(exp2(referenceExponent), AccPrecision<float>::normal_min);
         // Get the value of changing the last bit - by shifting the least significant bit to this magnitude
         // i.e. the ULP.
-        double ulpValue = referencePower2 * std::ldexp(1.0, -23);
-
-        // It is possible that within one ULP we cross a boundary where we need to change the exponent,
-        // if this happens we will take the ULP for the larger exponent.
-        if (referenceValue + ulpValue > 2 * referencePower2)
-        {
-            ulpValue = 2 * ulpValue;
-        }
+        double ulpValue = referencePower2 * exp2(-AccPrecision<float>::normal_frac);
 
         // Scale by the number of ULPs requested by the user.
-        referenceMax = referenceValue + ulpValue * ulpCount;
-        referenceMin = referenceValue - ulpValue * ulpCount;
+        referenceMax = referenceValue + ulpValue * ulpNum;
+        referenceMin = referenceValue - ulpValue * ulpNum;
 
         // Handle the overflow cases.
-        if (referenceMax > std::numeric_limits<float>::max())
+        if (referenceMax > AccPrecision<float>::normal_max)
         {
             referenceMax = std::numeric_limits<float>::infinity();
         }
 
-        if (referenceMin > std::numeric_limits<float>::max())
+        if (referenceMin > AccPrecision<float>::normal_max)
         {
             referenceMin = std::numeric_limits<float>::infinity();
         }
 
         // And the underflow cases.
-        if (referenceMax < std::numeric_limits<float>::min())
+        if (referenceMax < AccPrecision<float>::normal_min)
         {
-            referenceMax = std::numeric_limits<float>::min();
+            referenceMax = AccPrecision<float>::normal_min;
         }
 
-        if (referenceMin < std::numeric_limits<float>::min())
+        if (referenceMin < AccPrecision<float>::normal_min)
         {
-            referenceMin = 0;
+            referenceMin = 0.0;
         }
     }
 
     // And finally... Do the comparison.
-    return static_cast<double>(testValue) >= referenceMin && static_cast<double>(testValue) <= referenceMax;
+    double testValue64 = static_cast<double>(testValue);
+    bool withinUlp     = testValue64 >= referenceMin && testValue64 <= referenceMax;
+    if (!withinUlp)
+    {
+        WARNING("[Verfier][ULP] value (%10f) is not in ULP %g range (%10f <= ref (%10f) <= %10f).", testValue64, ulpNum,
+                referenceMin, referenceValue, referenceMax);
+    }
+    return withinUlp;
 }
 }    // namespace
 
-bool verifyULP(const CTensor* referenceTensor, const CTensor* implementationTensor, uint64_t ulp)
+bool verifyULP(const CTensor* referenceTensor, const CTensor* implementationTensor, const UlpInfo& ulpInfo)
 {
     // Validate that tensors are provided
     TOSA_REF_REQUIRE(referenceTensor != nullptr, "[ULP] Reference tensor is missing");
@@ -132,10 +130,11 @@
         numElements(std::vector<int32_t>(referenceTensor->shape, referenceTensor->shape + referenceTensor->num_dims));
     TOSA_REF_REQUIRE(elementCount > 0, "[ULP] Invalid shape for reference tensor");
 
+    const double ulp = ulpInfo.ulp;
     switch (implementationTensor->data_type)
     {
         case tosa_datatype_fp32_t: {
-            const auto* refData = reinterpret_cast<const float*>(referenceTensor->data);
+            const auto* refData = reinterpret_cast<const double*>(referenceTensor->data);
             TOSA_REF_REQUIRE(refData != nullptr, "[ULP] Missing data for reference");
             const auto* impData = reinterpret_cast<const float*>(implementationTensor->data);
             TOSA_REF_REQUIRE(impData != nullptr, "[ULP] Missing data for implementation");
diff --git a/reference_model/src/verify/verify_utils.cc b/reference_model/src/verify/verify_utils.cc
index 43ecbe7..99cb0c1 100644
--- a/reference_model/src/verify/verify_utils.cc
+++ b/reference_model/src/verify/verify_utils.cc
@@ -50,7 +50,6 @@
                                  { VerifyMode::DotProduct, "DOT_PRODUCT" },
                                  { VerifyMode::ReduceProduct, "REDUCE_PRODUCT" },
                                  { VerifyMode::FpSpecial, "FP_SPECIAL" },
-                                 { VerifyMode::Round, "ROUND" },
                              })
 
 void from_json(const nlohmann::json& j, UlpInfo& ulpInfo)
@@ -144,7 +143,24 @@
 // Like const_exp2 but for use during runtime
 double exp2(int32_t n)
 {
-    TOSA_REF_REQUIRE(-1022 <= n && n <= 1023, " Invalid exponent value (%d)", n);
+    TOSA_REF_REQUIRE(-1022 <= n && n <= 1023, " Invalid exponent value (%d) in exp2", n);
     return const_exp2(n);
 }
+
+int32_t ilog2(double v)
+{
+    TOSA_REF_REQUIRE(0.0 < v && v < std::numeric_limits<double>::infinity(), " Value out of range (%g) in ilog2", v);
+    int32_t n = 0;
+    while (v >= 2.0)
+    {
+        v = v / 2.0;
+        n++;
+    }
+    while (v < 1.0)
+    {
+        v = v * 2.0;
+        n--;
+    }
+    return n;
+}
 }    // namespace TosaReference
diff --git a/reference_model/src/verify/verify_utils.h b/reference_model/src/verify/verify_utils.h
index 486ce19..15d7ba5 100644
--- a/reference_model/src/verify/verify_utils.h
+++ b/reference_model/src/verify/verify_utils.h
@@ -44,8 +44,7 @@
     Ulp,
     DotProduct,
     ReduceProduct,
-    FpSpecial,
-    Round
+    FpSpecial
 };
 
 /// \brief ULP verification meta-data
@@ -53,7 +52,7 @@
 {
     UlpInfo() = default;
 
-    uint64_t ulp;
+    double ulp;
 };
 
 /// \brief Dot-product verification meta-data
@@ -95,7 +94,7 @@
 /// \brief Map API data-type to DType
 DType mapToDType(tosa_datatype_t dataType);
 
-/// \brief Raise a value by the power of N or -N
+/// \brief Return 2 to the power of N or -N
 // For use during compile time - as no range check
 constexpr double const_exp2(int32_t n)
 {
@@ -116,6 +115,9 @@
 /// \brief Same as const_exp2 but with runtime range check of N
 double exp2(int32_t n);
 
+/// \brief Return the base-2 exponent of V
+int32_t ilog2(double v);
+
 /// \brief Accuracy precision information
 template <typename T>
 struct AccPrecision;
diff --git a/reference_model/test/verify_tests.cpp b/reference_model/test/verify_tests.cpp
index 369a8cd..e7d6c4e 100644
--- a/reference_model/test/verify_tests.cpp
+++ b/reference_model/test/verify_tests.cpp
@@ -392,29 +392,37 @@
     const auto elementCount = std::accumulate(std::begin(shape), std::end(shape), 1, std::multiplies<>());
 
     // Generate some random floats using the full range of fp32.
-    auto data = generateRandomTensorData<float>(elementCount, false);
+    auto data_fp32 = generateRandomTensorData<float>(elementCount, false);
+    std::vector<double> data_fp64(data_fp32.begin(), data_fp32.end());
+
     SUBCASE("same")
     {
         // Generate some data that meets the ULP requirements of the result.
-        auto otherData = data;
-        std::for_each(std::begin(otherData), std::end(otherData), [](auto& value) { value = increment(value, 5); });
+        auto otherData_fp32 = data_fp32;
+        std::for_each(std::begin(otherData_fp32), std::end(otherData_fp32), [](auto& value) {
+            if (std::abs(value) != 0.0 && !std::isinf(value))
+                value = increment(value, 5);
+        });
         const auto referenceTensor =
-            TosaTensor("out1", tosa_datatype_fp64_t, shape, reinterpret_cast<uint8_t*>(data.data()));
+            TosaTensor("out1", tosa_datatype_fp64_t, shape, reinterpret_cast<uint8_t*>(data_fp64.data()));
         const auto implementationTensor =
-            TosaTensor("out1", tosa_datatype_fp32_t, shape, reinterpret_cast<uint8_t*>(otherData.data()));
+            TosaTensor("out1", tosa_datatype_fp32_t, shape, reinterpret_cast<uint8_t*>(otherData_fp32.data()));
         REQUIRE(tvf_verify_data(referenceTensor.cTensor(), nullptr, implementationTensor.cTensor(), jsonCfg.c_str()));
     }
 
     SUBCASE("different")
     {
         // Generate some data that exceeds a specified number of ULP for each value in the tensor.
-        auto otherData = std::vector<float>(elementCount);
-        std::for_each(std::begin(otherData), std::end(otherData), [](auto& value) { value = increment(value, 6); });
+        auto otherData_fp32 = data_fp32;
+        std::for_each(std::begin(otherData_fp32), std::end(otherData_fp32), [](auto& value) {
+            if (std::abs(value) != 0.0 && !std::isinf(value))
+                value = increment(value, 6);
+        });
 
         const auto referenceTensor =
-            TosaTensor("out1", tosa_datatype_fp64_t, shape, reinterpret_cast<uint8_t*>(data.data()));
+            TosaTensor("out1", tosa_datatype_fp64_t, shape, reinterpret_cast<uint8_t*>(data_fp64.data()));
         const auto implementationTensor =
-            TosaTensor("out1", tosa_datatype_fp32_t, shape, reinterpret_cast<uint8_t*>(otherData.data()));
+            TosaTensor("out1", tosa_datatype_fp32_t, shape, reinterpret_cast<uint8_t*>(otherData_fp32.data()));
         REQUIRE_FALSE(
             tvf_verify_data(referenceTensor.cTensor(), nullptr, implementationTensor.cTensor(), jsonCfg.c_str()));
     }
diff --git a/scripts/schemavalidation/compliance-config.schema.json b/scripts/schemavalidation/compliance-config.schema.json
index 570c88f..e78d385 100644
--- a/scripts/schemavalidation/compliance-config.schema.json
+++ b/scripts/schemavalidation/compliance-config.schema.json
@@ -35,8 +35,8 @@
                             "properties":
                             {
                                 "ulp": {
-                                    "description": "ulp range limit - positive number",
-                                    "type": "integer",
+                                    "description": "ulp range limit - positive float",
+                                    "type": "number",
                                     "minimum": 0
                                 }
                             },
diff --git a/verif/conformance/tosa_main_profile_ops_info.json b/verif/conformance/tosa_main_profile_ops_info.json
index a090479..4256bfb 100644
--- a/verif/conformance/tosa_main_profile_ops_info.json
+++ b/verif/conformance/tosa_main_profile_ops_info.json
@@ -1484,7 +1484,7 @@
                         "--target-dtype",
                         "bf16",
                         "--fp-values-range",
-                        "-2.0,2.0",
+                        "-max,max",
                         "--target-shape",
                         "1,47,37,25",
                         "--target-shape",
@@ -1495,7 +1495,7 @@
                         "--target-dtype",
                         "fp32",
                         "--fp-values-range",
-                        "-2.0,2.0",
+                        "-max,max",
                         "--target-shape",
                         "1,65534,4,1",
                         "--target-shape",
@@ -1613,7 +1613,7 @@
                         "--target-dtype",
                         "bf16",
                         "--fp-values-range",
-                        "-2.0,2.0",
+                        "-max,max",
                         "--tensor-dim-range",
                         "1,65",
                         "--target-rank",
@@ -1627,7 +1627,7 @@
                         "--target-dtype",
                         "fp16",
                         "--fp-values-range",
-                        "-2.0,2.0",
+                        "-max,max",
                         "--tensor-dim-range",
                         "1,17",
                         "--target-rank",
@@ -1637,7 +1637,7 @@
                         "--target-dtype",
                         "bf16",
                         "--fp-values-range",
-                        "-2.0,2.0",
+                        "-max,max",
                         "--tensor-dim-range",
                         "1,16",
                         "--target-rank",
@@ -1647,7 +1647,7 @@
                         "--target-dtype",
                         "fp32",
                         "--fp-values-range",
-                        "-2.0,2.0",
+                        "-max,max",
                         "--target-shape",
                         "1,1,65539,1"
                     ]
@@ -2312,6 +2312,7 @@
         "profile": [
             "tosa-mi"
         ],
+        "support_for": [ "lazy_data_gen" ],
         "generation": {
             "standard": {
                 "negative_dim_range": "1,10",
@@ -2324,7 +2325,7 @@
                         "--target-dtype",
                         "bf16",
                         "--fp-values-range",
-                        "-2.0,2.0",
+                        "-max,max",
                         "--tensor-dim-range",
                         "16,64",
                         "--target-rank",
@@ -2338,7 +2339,7 @@
                         "--target-dtype",
                         "fp16",
                         "--fp-values-range",
-                        "-2.0,2.0",
+                        "-max,max",
                         "--tensor-dim-range",
                         "1,16",
                         "--target-rank",
@@ -2350,7 +2351,7 @@
                         "--target-dtype",
                         "bf16",
                         "--fp-values-range",
-                        "-2.0,2.0",
+                        "-max,max",
                         "--target-shape",
                         "1,1,3,65534",
                         "--target-shape",
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 32f4341..94b7172 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -628,6 +628,13 @@
 
         return tens
 
+    # Default high value for random numbers
+    TVG_FLOAT_HIGH_VALUE = {
+        DType.FP32: (1 << 128) - (1 << (127 - 23)),
+        DType.FP16: (1 << 16) - (1 << (15 - 10)),
+        DType.BF16: (1 << 128) - (1 << (127 - 7)),
+    }
+
     @staticmethod
     def tvgLazyGenDefault(
         testGen, opName, dtypeList, shapeList, argsDict, error_name=None
@@ -684,10 +691,13 @@
                 info = {}
                 # TODO - generate seed for this generator based on test
                 info["rng_seed"] = 42
-                info["range"] = [
-                    str(v)
-                    for v in testGen.getDTypeRange(dtypeList[idx], high_inclusive=True)
-                ]
+                if "data_range" in argsDict:
+                    data_range = argsDict["data_range"]
+                else:
+                    data_range = testGen.getDTypeRange(
+                        dtypeList[idx], high_inclusive=True
+                    )
+                info["range"] = [str(v) for v in data_range]
                 tens_meta["pseudo_random_info"] = info
             elif dg_type == gtu.DataGenType.DOT_PRODUCT:
                 info = {}
@@ -950,79 +960,96 @@
                 testGen, op, dtypeList, shapeList, testArgs, error_name
             )
 
+    # Set the data range to the square root of the largest value
+    TVG_FLOAT_HIGH_VALUE_MUL = {
+        DType.FP32: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
+        DType.FP16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
+        DType.BF16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
+    }
+
     @staticmethod
-    def tvgMul(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
-        if error_name is None:
+    def tvgMul(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
+        if error_name is not None or dtypeList[0] in (
+            DType.FP16,
+            DType.BF16,
+            DType.FP32,
+        ):
+            # ERROR_IF or floating point test
+            if dtypeList[0] in TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL:
+                data_range = testGen.getDTypeRange(dtypeList[0], high_inclusive=True)
+                high_val = TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL[dtypeList[0]]
+                # Set the values to something that won't produce infinity whilst
+                # respecting the default ranges if less than the high value
+                argsDict["data_range"] = [
+                    max(-high_val, data_range[0]),
+                    min(high_val, data_range[1]),
+                ]
+            return TosaTensorValuesGen.tvgLazyGenDefault(
+                testGen, opName, dtypeList, shapeList, argsDict, error_name
+            )
+        else:
+            # Integer test
+            op = testGen.TOSA_OP_LIST[opName]
             pCount, cCount = op["operands"]
             assert (
                 pCount == 2 and cCount == 0
             ), "Op.MUL must have 2 placeholders, 0 consts"
 
-            tens = []
-            if dtypeList[0] in (DType.FP16, DType.BF16, DType.FP32):
-                tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
+            tens_ser_list = []
+
+            # Make sure multiply result in int32 range
+            shift = argsDict["shift"]
+            if dtypeList[0] == DType.INT8:
+                num_bits = 8
+            elif dtypeList[0] == DType.INT16:
+                num_bits = 16
+            elif dtypeList[0] == DType.INT32:
+                num_bits = 32
+            elif error_name == ErrorIf.WrongInputType:
+                num_bits = 8
             else:
-                placeholders = []
+                raise Exception("OpMul: invalid input dtype")
 
-                # Make sure multiply result in int32 range
-                shift = testArgs[0]
-                if dtypeList[0] == DType.INT8:
-                    num_bits = 8
-                elif dtypeList[0] == DType.INT16:
-                    num_bits = 16
-                elif dtypeList[0] == DType.INT32:
-                    num_bits = 32
-                elif error_name == ErrorIf.WrongInputType:
-                    num_bits = 8
+            for idx, shape in enumerate(shapeList[:]):
+                low = -(2 ** (num_bits - 1))
+                high = (2 ** (num_bits - 1)) - 1
+
+                a_arr = np.int32(
+                    testGen.rng.integers(low=low, high=high, size=shapeList[0])
+                )
+                b_arr = np.int32(
+                    testGen.rng.integers(low=low, high=high, size=shapeList[1])
+                )
+
+            i = 0
+            while True:
+
+                a_arr_64 = a_arr.astype(np.int64)
+                b_arr_64 = b_arr.astype(np.int64)
+
+                if shift > 0:
+                    rounding = 1 << (shift - 1)
+                    result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
                 else:
-                    raise Exception("OpMul: invalid input dtype")
+                    result_arr = a_arr_64 * b_arr_64
 
-                for idx, shape in enumerate(shapeList[:]):
-                    low = -(2 ** (num_bits - 1))
-                    high = (2 ** (num_bits - 1)) - 1
+                if (result_arr > -(2**31)).all() and (
+                    result_arr <= ((2**31) - 1)
+                ).all():
+                    break
 
-                    a_arr = np.int32(
-                        testGen.rng.integers(low=low, high=high, size=shapeList[0])
-                    )
-                    b_arr = np.int32(
-                        testGen.rng.integers(low=low, high=high, size=shapeList[1])
-                    )
+                i = i + 1
+                a_arr = a_arr // 2
+                b_arr = b_arr // 2
 
-                i = 0
-                while True:
-
-                    a_arr_64 = a_arr.astype(np.int64)
-                    b_arr_64 = b_arr.astype(np.int64)
-
-                    if shift > 0:
-                        rounding = 1 << (shift - 1)
-                        result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
-                    else:
-                        result_arr = a_arr_64 * b_arr_64
-
-                    if (result_arr > -(2**31)).all() and (
-                        result_arr <= ((2**31) - 1)
-                    ).all():
-                        break
-
-                    i = i + 1
-                    a_arr = a_arr // 2
-                    b_arr = b_arr // 2
-
-                placeholders.append(
-                    testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
-                )
-                placeholders.append(
-                    testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
-                )
-
-                tens.extend(placeholders)
-
-            return tens
-        else:
-            return TosaTensorValuesGen.tvgDefault(
-                testGen, op, dtypeList, shapeList, testArgs, error_name
+            tens_ser_list.append(
+                testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
             )
+            tens_ser_list.append(
+                testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
+            )
+
+            return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
 
     @staticmethod
     def tvgConcat(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
@@ -2076,11 +2103,18 @@
             for p in range(testGen.args.num_rand_permutations):
 
                 shift = testGen.randInt(0, 32)
-
-                arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
+                arg_list.append(("perm{}_shift{}".format(p, shift), {"shift": shift}))
         else:
-            arg_list.append(("perm0_shift0", [0]))
+            arg_list.append(("perm0_shift0", {"shift": 0}))
 
+        arg_list = TosaArgGen._add_data_generators(
+            testGen,
+            opName,
+            dtype,
+            arg_list,
+            error_name,
+        )
+        # Return list of tuples: (arg_str, args_dict)
         return arg_list
 
     @staticmethod
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 54b624e..1995cbc 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -51,15 +51,31 @@
         self.quantGen = TosaQuantGen()
         # Force makeShape to do a specific starting shape
         self.targetted_shape = None
-        # Work out floating point range
-        self.random_fp_low = min(args.tensor_fp_value_range)
-        self.random_fp_high = max(args.tensor_fp_value_range)
         # JSON schema validation
         self.descSchemaValidator = TestDescSchemaValidator()
         # Data generator library is sometimes needed for compliance set up
         # even if we are generating the data later (lazy_data_generation)
         self.dgl = GenerateLibrary(args.generate_lib_path)
 
+        # Work out floating point range
+        def convertFPRange(rangeFP, maxFP):
+            # Converts program arguments of max/-max to FP max
+            vals = []
+            for v in rangeFP:
+                if v == "max":
+                    v = maxFP
+                elif v == "-max":
+                    v = -maxFP
+                vals.append(v)
+            return tuple(sorted(vals))
+
+        self.random_float_range = {}
+        for dtype in (DType.FP32, DType.FP16, DType.BF16):
+            self.random_float_range[dtype] = convertFPRange(
+                args.tensor_fp_value_range,
+                TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
+            )
+
     def createSerializer(self, opName, testPath):
         self.testPath = os.path.join(opName, testPath)
 
@@ -130,9 +146,8 @@
         # Returns dtype value range boundaries (low, high)
         # The high boundary is excluded in the range
         # unless high_inclusive is True
-
         if dtype in (DType.FP32, DType.FP16, DType.BF16):
-            return (self.random_fp_low, self.random_fp_high)
+            return self.random_float_range[dtype]
         elif dtype == DType.BOOL:
             rng = (0, 2)
         elif dtype == DType.UINT8:
@@ -318,8 +333,6 @@
             compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
         elif op["op"] == Op.REDUCE_PRODUCT:
             mode = gtu.ComplianceMode.REDUCE_PRODUCT
-        elif op["op"] in (Op.ADD, Op.MUL, Op.SUB, Op.CEIL, Op.FLOOR, Op.CAST):
-            mode = gtu.ComplianceMode.ROUND
         else:
             mode = gtu.ComplianceMode.EXACT
         compliance_tens["mode"] = gtu.ComplianceMode(mode).name
@@ -466,23 +479,29 @@
         self.ser.addOperator(op["op"], input_list, output_list, attr)
         return result_tens
 
-    def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
-        result_tens = OutputShaper.binaryBroadcastOp(
+    def build_mul(
+        self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
+    ):
+        assert len(inputs) == 2
+        a, b = inputs
+        shift = args_dict["shift"]
+
+        result_tensor = OutputShaper.binaryBroadcastOp(
             self.ser, self.rng, a, b, error_name
         )
 
-        # Special for multiply:
-        # Force the result to INT32 for INT types
+        # Special for multiply: Force the result to INT32 for INT types
         if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
-            result_tens.setDtype(DType.INT32)
+            result_tensor.setDtype(DType.INT32)
+
         if error_name == ErrorIf.WrongOutputType:
             all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
             outputDType = self.rng.choice(all_dtypes)
-            result_tens.setDtype(outputDType)
+            result_tensor.setDtype(outputDType)
 
         # Invalidate Input/Output list for error if checks.
         input_list = [a.name, b.name]
-        output_list = [result_tens.name]
+        output_list = [result_tensor.name]
         pCount, cCount = op["operands"]
         num_operands = pCount + cCount
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
@@ -497,8 +516,8 @@
             input1=a,
             input2=b,
             input_dtype=a.dtype,
-            output_dtype=result_tens.dtype,
-            result_tensors=[result_tens],
+            output_dtype=result_tensor.dtype,
+            result_tensors=[result_tensor],
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
@@ -509,7 +528,12 @@
         attr.MulAttribute(shift)
 
         self.ser.addOperator(op["op"], input_list, output_list, attr)
-        return result_tens
+
+        compliance = self.tensorComplianceMetaData(
+            op, a.dtype, args_dict, result_tensor, error_name
+        )
+
+        return TosaTestGen.BuildInfo(result_tensor, compliance)
 
     def build_table(self, op, a, table, validator_fcns=None, error_name=None):
         result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
@@ -3456,6 +3480,10 @@
                 TosaErrorValidator.evDimensionMismatch,
                 TosaErrorValidator.evBroadcastShapesMismatch,
             ),
+            "data_gen": {
+                "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
+            },
+            "compliance": {"ulp": 0.5},
         },
         "pow": {
             "op": Op.POW,
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index 7fc5b52..3b487de 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -38,7 +38,6 @@
     ULP = 2
     FP_SPECIAL = 3
     REDUCE_PRODUCT = 4
-    ROUND = 5
 
 
 class DataGenType(IntEnum):
diff --git a/verif/generator/tosa_verif_build_tests.py b/verif/generator/tosa_verif_build_tests.py
index 954c6e9..d6598fb 100644
--- a/verif/generator/tosa_verif_build_tests.py
+++ b/verif/generator/tosa_verif_build_tests.py
@@ -13,14 +13,18 @@
 OPTION_FP_VALUES_RANGE = "--fp-values-range"
 
 
-# Used for parsing a comma-separated list of integers in a string
-# to an actual list of integers
+# Used for parsing a comma-separated list of integers/floats in a string
+# to an actual list of integers/floats with special case max
 def str_to_list(in_s, is_float=False):
-    """Converts a comma-separated list of string integers to a python list of ints"""
+    """Converts a comma-separated list string to a python list of numbers."""
     lst = in_s.split(",")
     out_list = []
     for i in lst:
-        val = float(i) if is_float else int(i)
+        # Special case for allowing maximum FP numbers
+        if is_float and i in ("-max", "max"):
+            val = i
+        else:
+            val = float(i) if is_float else int(i)
         out_list.append(val)
     return out_list