blob: cd42198b66ffa9e97c0d4e6a57872482ab4879be [file] [log] [blame]
# Copyright (c) 2020-2024, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import json
# Used by basic_test_generator to create test description
def write_test_json(
filename,
tf_model_filename=None,
tf_result_npy_filename=None,
tf_result_name=None,
tflite_model_filename=None,
tflite_result_npy_filename=None,
tflite_result_name=None,
ifm_name=None,
ifm_file=None,
ifm_shape=None,
ifm_dynamic=False,
framework_exclusions=None,
quantized=False,
test_name=None,
num_variables=None,
):
test_desc = dict()
if test_name:
test_desc["name"] = test_name
if tf_model_filename:
test_desc["tf_model_filename"] = tf_model_filename
if tf_result_npy_filename:
test_desc["tf_result_npy_filename"] = tf_result_npy_filename
if tf_result_name:
test_desc["tf_result_name"] = tf_result_name
if tflite_model_filename:
test_desc["tflite_model_filename"] = tflite_model_filename
if tflite_result_npy_filename:
test_desc["tflite_result_npy_filename"] = tflite_result_npy_filename
if tflite_result_name:
test_desc["tflite_result_name"] = tflite_result_name
if ifm_file:
if not isinstance(ifm_file, list):
ifm_file = [ifm_file]
test_desc["ifm_file"] = ifm_file
# Make sure these arguments are wrapped as lists
if ifm_name:
if not isinstance(ifm_name, list):
ifm_name = [ifm_name]
test_desc["ifm_name"] = ifm_name
if ifm_shape:
if not isinstance(ifm_shape, list):
ifm_shape = [ifm_shape]
test_desc["ifm_shape"] = ifm_shape
if ifm_dynamic:
test_desc["ifm_dynamic"] = True
# Some tests cannot be used with specific frameworks.
# This list indicates which tests should be excluded from a given framework.
if framework_exclusions:
if not isinstance(framework_exclusions, list):
framework_exclusions = [framework_exclusions]
test_desc["framework_exclusions"] = framework_exclusions
if quantized:
test_desc["quantized"] = 1
if num_variables:
test_desc["num_variables"] = num_variables
with open(filename, "w") as f:
json.dump(test_desc, f, indent=" ")