[reference_model] Support StatefulOps and the tests for CallOnceOp
Signed-off-by: Jerry Ge <jerry.ge@arm.com>
Change-Id: I03cb878736ccd7e1f5e1f780d7171949a19a9de2
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index ec009c6..ffe373b 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -28,6 +28,7 @@
get_tf_dtype,
get_shape_str,
) # noqa: E402
+
from tensorflow.lite.python.interpreter import OpResolverType # noqa: E402
# All of the supported frameworks
@@ -829,6 +830,15 @@
]
},
},
+ "lstm_stateful": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.SLSTM, TGen.tgRecurrent, ArgGen.agNone),
+ "types": {
+ "tflite": [
+ tf.float32,
+ ]
+ },
+ },
"gru": {
"operands": (1, 0),
"build_fcn": (TBuilder.GRU, TGen.tgRecurrent, ArgGen.agNone),
@@ -848,6 +858,17 @@
]
},
},
+ "callonce": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.CallOnce, TGen.tgBasic, ArgGen.agNone),
+ "types": {
+ "tflite": [tf.float32],
+ },
+ "custom_shapes": {
+ "custom_shape_only": True,
+ "shape_list": [(1,)],
+ },
+ },
"rfft2d": {
"operands": (1, 0),
"build_fcn": (TBuilder.RFFT2d, TGen.tgRFFT2d, ArgGen.agRFFT2d),
@@ -1219,9 +1240,15 @@
if "tflite" not in excluded_framework_list:
# Convert the model to TFLite flatbuffer
module = tf.Module()
- converter = tf.lite.TFLiteConverter.from_concrete_functions(
- [concrete_function], module
- )
+
+ if op_name == "callonce" or op_name == "lstm_stateful":
+ converter = tf.lite.TFLiteConverter.from_concrete_functions(
+ [concrete_function], fcn_node
+ )
+ else:
+ converter = tf.lite.TFLiteConverter.from_concrete_functions(
+ [concrete_function], module
+ )
converter.experimental_new_converter = True