Add positive/negative random number generator for Rsqrt

Signed-off-by: Jerry Ge <jerry.ge@arm.com>
Change-Id: I1e9e97ead447295e1252785106931b261df7bcea
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py
index c534a58..f8d50a8 100644
--- a/verif/frameworks/tensor_gen.py
+++ b/verif/frameworks/tensor_gen.py
@@ -1,5 +1,7 @@
 # Copyright (c) 2020-2023, ARM Limited.
 # SPDX-License-Identifier: Apache-2.0
+import enum
+
 import numpy as np
 import tensorflow as tf
 
@@ -17,6 +19,12 @@
 RAND_INT_MAX = 128
 
 
+class ElemSignedness(enum.Enum):
+    ALL_RANGE = 1
+    POSITIVE = 2
+    NEGATIVE = 3
+
+
 class TGen:
     """A collection of functions to build tensor value arguments for an operator"""
 
@@ -24,7 +32,14 @@
         pass
 
     @staticmethod
-    def getRand(shape, dtype, rng):
+    def getRand(shape, dtype, rng, elem_signedness=ElemSignedness.ALL_RANGE):
+        if elem_signedness == ElemSignedness.POSITIVE:
+            RAND_SHIFT_FACTOR = 0
+        elif elem_signedness == ElemSignedness.NEGATIVE:
+            RAND_SHIFT_FACTOR = 1
+        else:
+            RAND_SHIFT_FACTOR = 0.5
+
         if dtype == tf.float32:
             return np.float32(
                 (rng.random(size=shape) - RAND_SHIFT_FACTOR) * RAND_SCALE_FACTOR
@@ -45,7 +60,11 @@
         raise Exception("Unsupported type: {}".format(dtype))
 
     @staticmethod
-    def tgBasic(op, shape, dtype, rng):
+    def tgBasicPositive(op, shape, dtype, rng, elem_signedness=ElemSignedness.POSITIVE):
+        return TGen.tgBasic(op, shape, dtype, rng, elem_signedness)
+
+    @staticmethod
+    def tgBasic(op, shape, dtype, rng, elem_signedness=ElemSignedness.ALL_RANGE):
         # Build random tensor placeholder node args of a given shape
         pl, const = op["operands"]
 
@@ -54,11 +73,16 @@
 
         for i in range(pl):
             tf_placeholders.append(
-                ("placeholder_{}".format(i), TGen.getRand(shape, dtype, rng))
+                (
+                    "placeholder_{}".format(i),
+                    TGen.getRand(shape, dtype, rng, elem_signedness),
+                )
             )
 
         for i in range(const):
-            tf_consts.append(("const_{}".format(i), TGen.getRand(shape, dtype, rng)))
+            tf_consts.append(
+                ("const_{}".format(i), TGen.getRand(shape, dtype, rng, elem_signedness))
+            )
 
         return tf_placeholders, tf_consts
 
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index 93bdfe0..0741686 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -22,6 +22,7 @@
 from frameworks.write_test_json import write_test_json  # noqa: E402
 from frameworks.arg_gen import ArgGen  # noqa: E402
 from frameworks.tensor_gen import TGen  # noqa: E402
+from frameworks.tensor_gen import ElemSignedness  # noqa: E402
 from frameworks.test_builder import TBuilder  # noqa: E402
 from frameworks.test_gen_utils import (  # noqa: E402
     QuantType,
@@ -303,8 +304,11 @@
     },
     "rsqrt": {
         "operands": (1, 0),
-        "build_fcn": (TBuilder.Rsqrt, TGen.tgBasic, ArgGen.agNone),
-        "types": TYPE_F,
+        "build_fcn": (TBuilder.Rsqrt, TGen.tgBasicPositive, ArgGen.agNone),
+        "types": {
+            "tf": TYPE_F,
+            "tflite": list(TYPE_F + [QuantType.ALL_I8]),
+        },
     },
     "sign": {
         "operands": (1, 0),
@@ -1121,9 +1125,14 @@
                 converter.target_spec.supported_ops = [flag]
 
             def input_stats():
+                ## Rsqrt can only handle positive numbers
+                elem_signedness = ElemSignedness.ALL_RANGE
+                if op_name == "rsqrt":
+                    elem_signedness = ElemSignedness.POSITIVE
+
                 for i in range(0, args.num_samples):
                     a = [
-                        TGen.getRand(shape, tf.float32, rng)
+                        TGen.getRand(shape, tf.float32, rng, elem_signedness)
                         for shape in placeholder_shapes
                     ]
                     yield a