blob: d52a777820d3161accbc403bb0a8bfc1dafd7a01 [file] [log] [blame]
Jeremy Johnson015c3552022-02-23 12:15:03 +00001# Copyright (c) 2020-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
3import json
4
5# Used by basic_test_generator to create test description
6
7
8def write_test_json(
9 filename,
10 tf_model_filename=None,
11 tf_result_npy_filename=None,
12 tf_result_name=None,
13 tflite_model_filename=None,
14 tflite_result_npy_filename=None,
15 tflite_result_name=None,
16 ifm_name=None,
17 ifm_file=None,
18 ifm_shape=None,
19 framework_exclusions=None,
20 quantized=False,
Eric Kunze97b00272023-07-20 10:52:56 -070021 test_name=None,
Jeremy Johnson015c3552022-02-23 12:15:03 +000022):
23
24 test_desc = dict()
25
Eric Kunze97b00272023-07-20 10:52:56 -070026 if test_name:
27 test_desc["name"] = test_name
28
Jeremy Johnson015c3552022-02-23 12:15:03 +000029 if tf_model_filename:
30 test_desc["tf_model_filename"] = tf_model_filename
31
32 if tf_result_npy_filename:
33 test_desc["tf_result_npy_filename"] = tf_result_npy_filename
34
35 if tf_result_name:
36 test_desc["tf_result_name"] = tf_result_name
37
38 if tflite_model_filename:
39 test_desc["tflite_model_filename"] = tflite_model_filename
40
41 if tflite_result_npy_filename:
42 test_desc["tflite_result_npy_filename"] = tflite_result_npy_filename
43
44 if tflite_result_name:
45 test_desc["tflite_result_name"] = tflite_result_name
46
47 if ifm_file:
48 if not isinstance(ifm_file, list):
49 ifm_file = [ifm_file]
50 test_desc["ifm_file"] = ifm_file
51
52 # Make sure these arguments are wrapped as lists
53 if ifm_name:
54 if not isinstance(ifm_name, list):
55 ifm_name = [ifm_name]
56 test_desc["ifm_name"] = ifm_name
57
58 if ifm_shape:
59 if not isinstance(ifm_shape, list):
60 ifm_shape = [ifm_shape]
61 test_desc["ifm_shape"] = ifm_shape
62
63 # Some tests cannot be used with specific frameworks.
64 # This list indicates which tests should be excluded from a given framework.
65 if framework_exclusions:
66 if not isinstance(framework_exclusions, list):
67 framework_exclusions = [framework_exclusions]
68 test_desc["framework_exclusions"] = framework_exclusions
69
70 if quantized:
71 test_desc["quantized"] = 1
72
73 with open(filename, "w") as f:
74 json.dump(test_desc, f, indent=" ")