TatWai Chong | 6a46b25 | 2024-01-12 13:13:22 -0800 | [diff] [blame] | 1 | # Copyright (c) 2020-2024, ARM Limited. |
Jeremy Johnson | 015c355 | 2022-02-23 12:15:03 +0000 | [diff] [blame] | 2 | # SPDX-License-Identifier: Apache-2.0 |
| 3 | import json |
| 4 | |
| 5 | # Used by basic_test_generator to create test description |
| 6 | |
| 7 | |
| 8 | def write_test_json( |
| 9 | filename, |
| 10 | tf_model_filename=None, |
| 11 | tf_result_npy_filename=None, |
| 12 | tf_result_name=None, |
| 13 | tflite_model_filename=None, |
| 14 | tflite_result_npy_filename=None, |
| 15 | tflite_result_name=None, |
| 16 | ifm_name=None, |
| 17 | ifm_file=None, |
| 18 | ifm_shape=None, |
TatWai Chong | 6a46b25 | 2024-01-12 13:13:22 -0800 | [diff] [blame] | 19 | ifm_dynamic=False, |
Jeremy Johnson | 015c355 | 2022-02-23 12:15:03 +0000 | [diff] [blame] | 20 | framework_exclusions=None, |
| 21 | quantized=False, |
Eric Kunze | 97b0027 | 2023-07-20 10:52:56 -0700 | [diff] [blame] | 22 | test_name=None, |
Jerry Ge | d5b1512 | 2024-03-26 20:51:48 +0000 | [diff] [blame] | 23 | num_variables=None, |
Jeremy Johnson | 015c355 | 2022-02-23 12:15:03 +0000 | [diff] [blame] | 24 | ): |
| 25 | |
| 26 | test_desc = dict() |
| 27 | |
Eric Kunze | 97b0027 | 2023-07-20 10:52:56 -0700 | [diff] [blame] | 28 | if test_name: |
| 29 | test_desc["name"] = test_name |
| 30 | |
Jeremy Johnson | 015c355 | 2022-02-23 12:15:03 +0000 | [diff] [blame] | 31 | if tf_model_filename: |
| 32 | test_desc["tf_model_filename"] = tf_model_filename |
| 33 | |
| 34 | if tf_result_npy_filename: |
| 35 | test_desc["tf_result_npy_filename"] = tf_result_npy_filename |
| 36 | |
| 37 | if tf_result_name: |
| 38 | test_desc["tf_result_name"] = tf_result_name |
| 39 | |
| 40 | if tflite_model_filename: |
| 41 | test_desc["tflite_model_filename"] = tflite_model_filename |
| 42 | |
| 43 | if tflite_result_npy_filename: |
| 44 | test_desc["tflite_result_npy_filename"] = tflite_result_npy_filename |
| 45 | |
| 46 | if tflite_result_name: |
| 47 | test_desc["tflite_result_name"] = tflite_result_name |
| 48 | |
| 49 | if ifm_file: |
| 50 | if not isinstance(ifm_file, list): |
| 51 | ifm_file = [ifm_file] |
| 52 | test_desc["ifm_file"] = ifm_file |
| 53 | |
| 54 | # Make sure these arguments are wrapped as lists |
| 55 | if ifm_name: |
| 56 | if not isinstance(ifm_name, list): |
| 57 | ifm_name = [ifm_name] |
| 58 | test_desc["ifm_name"] = ifm_name |
| 59 | |
| 60 | if ifm_shape: |
| 61 | if not isinstance(ifm_shape, list): |
| 62 | ifm_shape = [ifm_shape] |
| 63 | test_desc["ifm_shape"] = ifm_shape |
| 64 | |
TatWai Chong | 6a46b25 | 2024-01-12 13:13:22 -0800 | [diff] [blame] | 65 | if ifm_dynamic: |
| 66 | test_desc["ifm_dynamic"] = True |
| 67 | |
Jeremy Johnson | 015c355 | 2022-02-23 12:15:03 +0000 | [diff] [blame] | 68 | # Some tests cannot be used with specific frameworks. |
| 69 | # This list indicates which tests should be excluded from a given framework. |
| 70 | if framework_exclusions: |
| 71 | if not isinstance(framework_exclusions, list): |
| 72 | framework_exclusions = [framework_exclusions] |
| 73 | test_desc["framework_exclusions"] = framework_exclusions |
| 74 | |
| 75 | if quantized: |
| 76 | test_desc["quantized"] = 1 |
| 77 | |
Jerry Ge | d5b1512 | 2024-03-26 20:51:48 +0000 | [diff] [blame] | 78 | if num_variables: |
| 79 | test_desc["num_variables"] = num_variables |
| 80 | |
Jeremy Johnson | 015c355 | 2022-02-23 12:15:03 +0000 | [diff] [blame] | 81 | with open(filename, "w") as f: |
| 82 | json.dump(test_desc, f, indent=" ") |