blob: 684bea38713f12cfcfb2158b8ff53e0c8e7634c3 [file] [log] [blame]
Jeremy Johnson48df8c72023-09-12 14:52:34 +01001# Copyright (c) 2023, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
3"""Verfier library interface."""
4import ctypes as ct
5import json
6from pathlib import Path
Jeremy Johnson48df8c72023-09-12 14:52:34 +01007
8import numpy as np
9import schemavalidation.schemavalidation as sch
10
Jeremy Johnson48df8c72023-09-12 14:52:34 +010011# Type conversion from numpy to tosa_datatype_t
12# "type" matches enum - see include/types.h
13# "size" is size in bytes per value of this datatype
14NUMPY_DATATYPE_TO_CLIENTTYPE = {
15 # tosa_datatype_int32_t (all integer types are this!)
16 np.dtype("int32"): {"type": 5, "size": 4},
17 # tosa_datatype_int48_t (or SHAPE)
18 np.dtype("int64"): {"type": 6, "size": 8},
19 # tosa_datatype_fp16_t
20 np.dtype("float16"): {"type": 2, "size": 2},
21 # tosa_datatype_fp32_t (bf16 stored as this)
22 np.dtype("float32"): {"type": 3, "size": 4},
23 # tosa_datatype_fp64_t (for precise refmodel data)
24 np.dtype("float64"): {"type": 99, "size": 8},
25 # tosa_datatype_bool_t
26 np.dtype("bool"): {"type": 1, "size": 1},
27}
28
29
30class TosaTensor(ct.Structure):
31 _fields_ = [
32 ("name", ct.c_char_p),
33 ("shape", ct.POINTER(ct.c_int32)),
34 ("num_dims", ct.c_int32),
35 ("data_type", ct.c_int),
36 ("data", ct.POINTER(ct.c_uint8)),
37 ("size", ct.c_size_t),
38 ]
39
40
41class VerifierError(Exception):
42 """Exception raised for errors performing data generation."""
43
44
45class VerifierLibrary:
46 """Python interface to the C verify library."""
47
Jeremy Johnsonf0348ea2023-09-27 16:10:59 +010048 def __init__(self, verify_lib_path):
Jeremy Johnson48df8c72023-09-12 14:52:34 +010049 """Find the library and set up the interface."""
Jeremy Johnsonf0348ea2023-09-27 16:10:59 +010050 self.lib_path = verify_lib_path
51 if not self.lib_path.is_file():
52 raise VerifierError(f"Could not find verify library - {self.lib_path}")
Jeremy Johnson48df8c72023-09-12 14:52:34 +010053
Jeremy Johnson48df8c72023-09-12 14:52:34 +010054 self.lib = ct.cdll.LoadLibrary(self.lib_path)
55
56 self.tvf_verify_data = self.lib.tvf_verify_data
57 self.tvf_verify_data.argtypes = [
58 ct.POINTER(TosaTensor), # ref
59 ct.POINTER(TosaTensor), # ref_bnd
60 ct.POINTER(TosaTensor), # imp
61 ct.c_char_p, # config_json
62 ]
63 self.tvf_verify_data.restype = ct.c_bool
64
65 def _get_tensor_data(self, name, array):
66 """Set up tosa_tensor_t using the given a numpy array."""
67 shape = (ct.c_int32 * len(array.shape))(*array.shape)
68 size_in_bytes = array.size * NUMPY_DATATYPE_TO_CLIENTTYPE[array.dtype]["size"]
69
70 tensor = TosaTensor(
71 ct.c_char_p(bytes(name, "utf8")),
72 ct.cast(shape, ct.POINTER(ct.c_int32)),
73 ct.c_int32(len(array.shape)),
74 ct.c_int(NUMPY_DATATYPE_TO_CLIENTTYPE[array.dtype]["type"]),
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +010075 array.ctypes.data_as(ct.POINTER(ct.c_uint8)),
Jeremy Johnson48df8c72023-09-12 14:52:34 +010076 ct.c_size_t(size_in_bytes),
77 )
78 return tensor
79
80 def verify_data(
81 self,
82 output_name,
83 compliance_json_config,
84 imp_result_array,
85 ref_result_array,
86 bnd_result_array=None,
87 ):
88 """Verify the data using the verification library."""
89 sch.TestDescSchemaValidator().validate_config(
90 compliance_json_config, sch.TD_SCHEMA_COMPLIANCE
91 )
92 jsb = bytes(json.dumps(compliance_json_config), "utf8")
93
94 imp = self._get_tensor_data(output_name, imp_result_array)
95 ref = self._get_tensor_data(output_name, ref_result_array)
96 if bnd_result_array is not None:
97 ref_bnd = self._get_tensor_data(output_name, bnd_result_array)
98 else:
99 ref_bnd = None
100
101 result = self.tvf_verify_data(ref, ref_bnd, imp, ct.c_char_p(jsb))
102
103 return result
104
105
106def main(argv=None):
107 """Simple command line interface for the verifier library."""
108 import argparse
Jeremy Johnsonf0348ea2023-09-27 16:10:59 +0100109 import conformance.model_files as cmf
Jeremy Johnson48df8c72023-09-12 14:52:34 +0100110
111 parser = argparse.ArgumentParser()
Jeremy Johnsonf0348ea2023-09-27 16:10:59 +0100112
Jeremy Johnson48df8c72023-09-12 14:52:34 +0100113 parser.add_argument(
Jeremy Johnsonf0348ea2023-09-27 16:10:59 +0100114 "--verify-lib-path",
Jeremy Johnson48df8c72023-09-12 14:52:34 +0100115 type=Path,
Jeremy Johnsonf0348ea2023-09-27 16:10:59 +0100116 help="Path to TOSA verify lib",
Jeremy Johnson48df8c72023-09-12 14:52:34 +0100117 )
118 parser.add_argument(
119 "--test-desc",
120 type=Path,
121 help="Path to test description file: desc.json",
122 )
123 parser.add_argument(
124 "-n",
125 "--ofm-name",
126 dest="ofm_name",
127 type=str,
128 help="output tensor name to check (defaults to only ofm_name in desc.json)",
129 )
130 parser.add_argument(
131 "--bnd-result-path",
132 type=Path,
133 help="path to the reference bounds result numpy file",
134 )
135
136 parser.add_argument(
137 "ref_result_path", type=Path, help="path to the reference result numpy file"
138 )
139 parser.add_argument(
140 "imp_result_path",
141 type=Path,
142 help="path to the implementation result numpy file",
143 )
144 args = parser.parse_args(argv)
145
Jeremy Johnsonf0348ea2023-09-27 16:10:59 +0100146 if args.verify_lib_path is None:
147 # Try to work out ref model directory and find the verify library
148 # but this default only works for the python developer environment
149 # i.e. when using the scripts/py-dev-env.* scripts
150 # otherwise use the command line option --verify-lib-path to specify path
151 ref_model_dir = Path(__file__).absolute().parents[2]
152 args.verify_lib_path = cmf.find_tosa_file(
153 cmf.TosaFileType.VERIFY_LIBRARY, ref_model_dir, False
154 )
155
Jeremy Johnson48df8c72023-09-12 14:52:34 +0100156 if args.test_desc:
157 json_path = args.test_desc
158 else:
159 # Assume its with the reference file
160 json_path = args.ref_result_path.parent / "desc.json"
161
162 print("Load test description")
163 with json_path.open("r") as fd:
164 test_desc = json.load(fd)
165
166 if args.ofm_name is None:
167 if len(test_desc["ofm_name"]) != 1:
168 print("ERROR: ambiguous output to check, please specify output tensor name")
169 return 2
170 output_name = test_desc["ofm_name"][0]
171 else:
172 output_name = args.ofm_name
173
174 if "meta" not in test_desc or "compliance" not in test_desc["meta"]:
175 print(f"ERROR: no compliance meta-data found in {str(json_path)}")
176 return 2
177
178 print("Load numpy data")
179 paths = [args.imp_result_path, args.ref_result_path, args.bnd_result_path]
180 arrays = [None, None, None]
181 for idx, path in enumerate(paths):
182 if path is not None:
183 array = np.load(path)
184 else:
185 array = None
186 arrays[idx] = array
187
188 print("Load verifier library")
Jeremy Johnsonf0348ea2023-09-27 16:10:59 +0100189 vlib = VerifierLibrary(args.verify_lib_path)
Jeremy Johnson48df8c72023-09-12 14:52:34 +0100190
191 print("Verify data")
192 if vlib.verify_data(output_name, test_desc["meta"]["compliance"], *arrays):
193 print("SUCCESS")
194 return 0
195 else:
196 print("FAILURE - NOT COMPLIANT")
197 return 1
198
199
200if __name__ == "__main__":
201 exit(main())