blob: 272c12458337e5a43945b0a2f691697a3202352f [file] [log] [blame]
Eric Kunze2364dcd2021-04-26 11:06:57 -07001#!/usr/bin/env python3
2
3# Copyright (c) 2021, ARM Limited.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17""" Simple test script which tests numpy file read/write"""
18
19import argparse
20import random
21import shlex
22import subprocess
23from datetime import datetime
24from enum import IntEnum, unique
25from pathlib import Path
26from xunit.xunit import xunit_results, xunit_test
27
28
29@unique
30class TestResult(IntEnum):
31 PASS = 0
32 COMMAND_ERROR = 1
33 MISMATCH = 2
34 SKIPPED = 3
35
36
37def parseArgs():
38 baseDir = (Path(__file__).parent / "../..").resolve()
39 buildDir = (baseDir / "build").resolve()
40 parser = argparse.ArgumentParser()
41
42 parser.add_argument(
43 "-c",
44 "--cmd",
45 default=str(buildDir / "serialization_npy_test"),
46 help="Command to write/read test file",
47 )
48 parser.add_argument("-s", "--seed", default=1, help="Random number seed")
49 parser.add_argument(
50 "-v", "--verbose", action="store_true", help="verbose", default=False
51 )
52 parser.add_argument(
53 "--xunit-file", default="npy-result.xml", help="xunit result output file"
54 )
55 args = parser.parse_args()
56
57 # check that required files exist
58 if not Path(args.cmd).exists():
59 print("command not found at location " + args.cmd)
60 parser.print_help()
61 exit(1)
62 return args
63
64
65def run_sh_command(full_cmd, verbose=False, capture_output=False):
66 """Utility function to run an external command. Optionally return captured
67 stdout/stderr"""
68
69 # Quote the command line for printing
70 full_cmd_esc = [shlex.quote(x) for x in full_cmd]
71
72 if verbose:
73 print("### Running {}".format(" ".join(full_cmd_esc)))
74
75 if capture_output:
76 rc = subprocess.run(full_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
77 if rc.returncode != 0:
78 print(rc.stdout.decode("utf-8"))
79 print(rc.stderr.decode("utf-8"))
80 raise Exception(
81 "Error running command: {}.\n{}".format(
82 " ".join(full_cmd_esc), rc.stderr.decode("utf-8")
83 )
84 )
85 return (rc.stdout, rc.stderr)
86 else:
87 rc = subprocess.run(full_cmd)
88 if rc.returncode != 0:
89 raise Exception("Error running command: {}".format(" ".join(full_cmd_esc)))
90
91
92def runTest(args, dtype, shape):
93 start_time = datetime.now()
94 result = TestResult.PASS
95 message = ""
96
97 target = Path(f"npytest-{random.randint(0,10000)}.npy")
98 shape_str = ",".join(shape)
99 # Remove any previous files
100 if target.exists():
101 target.unlink()
102
103 try:
104 cmd = [args.cmd, "-d", dtype, "-f", str(target), "-t", shape_str]
105 run_sh_command(cmd, args.verbose)
106 target.unlink()
107
108 except Exception as e:
109 message = str(e)
110 result = TestResult.COMMAND_ERROR
111 end_time = datetime.now()
112 return result, message, end_time - start_time
113
114
115def main():
116 args = parseArgs()
117
118 suitename = "basic_serialization"
119 classname = "npy_test"
120
121 xunit_result = xunit_results()
122 xunit_suite = xunit_result.create_suite("basic_serialization")
123
124 max_size = 128
Tai Ly3ef34fb2023-04-04 20:34:05 +0000125 datatypes = ["int32", "int64", "float", "bool", "double"]
Eric Kunze2364dcd2021-04-26 11:06:57 -0700126 random.seed(args.seed)
127
128 failed = 0
129 count = 0
130 for test in datatypes:
131 count = count + 1
132 shape = []
133 for i in range(4):
134 shape.append(str(random.randint(1, max_size)))
135 (result, message, time_delta) = runTest(args, test, shape)
136 xt = xunit_test(str(test), f"{suitename}.{classname}")
137 xt.time = str(
138 float(time_delta.seconds) + (float(time_delta.microseconds) * 1e-6)
139 )
140 if result == TestResult.PASS:
141 pass
142 else:
143 xt.failed(message)
144 failed = failed + 1
145 xunit_suite.tests.append(xt)
146
147 xunit_result.write_results(args.xunit_file)
148 print(f"Total tests run: {count} failures: {failed}")
149
150
151if __name__ == "__main__":
152 exit(main())