Eric Kunze | 2364dcd | 2021-04-26 11:06:57 -0700 | [diff] [blame] | 1 | #!/usr/bin/env python3 |
| 2 | |
| 3 | # Copyright (c) 2021, ARM Limited. |
| 4 | # |
| 5 | # Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | # you may not use this file except in compliance with the License. |
| 7 | # You may obtain a copy of the License at |
| 8 | # |
| 9 | # http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | # |
| 11 | # Unless required by applicable law or agreed to in writing, software |
| 12 | # distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | # See the License for the specific language governing permissions and |
| 15 | # limitations under the License. |
| 16 | |
| 17 | """ Simple test script which tests numpy file read/write""" |
| 18 | |
| 19 | import argparse |
| 20 | import random |
| 21 | import shlex |
| 22 | import subprocess |
| 23 | from datetime import datetime |
| 24 | from enum import IntEnum, unique |
| 25 | from pathlib import Path |
| 26 | from xunit.xunit import xunit_results, xunit_test |
| 27 | |
| 28 | |
| 29 | @unique |
| 30 | class TestResult(IntEnum): |
| 31 | PASS = 0 |
| 32 | COMMAND_ERROR = 1 |
| 33 | MISMATCH = 2 |
| 34 | SKIPPED = 3 |
| 35 | |
| 36 | |
| 37 | def parseArgs(): |
| 38 | baseDir = (Path(__file__).parent / "../..").resolve() |
| 39 | buildDir = (baseDir / "build").resolve() |
| 40 | parser = argparse.ArgumentParser() |
| 41 | |
| 42 | parser.add_argument( |
| 43 | "-c", |
| 44 | "--cmd", |
| 45 | default=str(buildDir / "serialization_npy_test"), |
| 46 | help="Command to write/read test file", |
| 47 | ) |
| 48 | parser.add_argument("-s", "--seed", default=1, help="Random number seed") |
| 49 | parser.add_argument( |
| 50 | "-v", "--verbose", action="store_true", help="verbose", default=False |
| 51 | ) |
| 52 | parser.add_argument( |
| 53 | "--xunit-file", default="npy-result.xml", help="xunit result output file" |
| 54 | ) |
| 55 | args = parser.parse_args() |
| 56 | |
| 57 | # check that required files exist |
| 58 | if not Path(args.cmd).exists(): |
| 59 | print("command not found at location " + args.cmd) |
| 60 | parser.print_help() |
| 61 | exit(1) |
| 62 | return args |
| 63 | |
| 64 | |
| 65 | def run_sh_command(full_cmd, verbose=False, capture_output=False): |
| 66 | """Utility function to run an external command. Optionally return captured |
| 67 | stdout/stderr""" |
| 68 | |
| 69 | # Quote the command line for printing |
| 70 | full_cmd_esc = [shlex.quote(x) for x in full_cmd] |
| 71 | |
| 72 | if verbose: |
| 73 | print("### Running {}".format(" ".join(full_cmd_esc))) |
| 74 | |
| 75 | if capture_output: |
| 76 | rc = subprocess.run(full_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| 77 | if rc.returncode != 0: |
| 78 | print(rc.stdout.decode("utf-8")) |
| 79 | print(rc.stderr.decode("utf-8")) |
| 80 | raise Exception( |
| 81 | "Error running command: {}.\n{}".format( |
| 82 | " ".join(full_cmd_esc), rc.stderr.decode("utf-8") |
| 83 | ) |
| 84 | ) |
| 85 | return (rc.stdout, rc.stderr) |
| 86 | else: |
| 87 | rc = subprocess.run(full_cmd) |
| 88 | if rc.returncode != 0: |
| 89 | raise Exception("Error running command: {}".format(" ".join(full_cmd_esc))) |
| 90 | |
| 91 | |
| 92 | def runTest(args, dtype, shape): |
| 93 | start_time = datetime.now() |
| 94 | result = TestResult.PASS |
| 95 | message = "" |
| 96 | |
| 97 | target = Path(f"npytest-{random.randint(0,10000)}.npy") |
| 98 | shape_str = ",".join(shape) |
| 99 | # Remove any previous files |
| 100 | if target.exists(): |
| 101 | target.unlink() |
| 102 | |
| 103 | try: |
| 104 | cmd = [args.cmd, "-d", dtype, "-f", str(target), "-t", shape_str] |
| 105 | run_sh_command(cmd, args.verbose) |
| 106 | target.unlink() |
| 107 | |
| 108 | except Exception as e: |
| 109 | message = str(e) |
| 110 | result = TestResult.COMMAND_ERROR |
| 111 | end_time = datetime.now() |
| 112 | return result, message, end_time - start_time |
| 113 | |
| 114 | |
| 115 | def main(): |
| 116 | args = parseArgs() |
| 117 | |
| 118 | suitename = "basic_serialization" |
| 119 | classname = "npy_test" |
| 120 | |
| 121 | xunit_result = xunit_results() |
| 122 | xunit_suite = xunit_result.create_suite("basic_serialization") |
| 123 | |
| 124 | max_size = 128 |
Tai Ly | 3ef34fb | 2023-04-04 20:34:05 +0000 | [diff] [blame] | 125 | datatypes = ["int32", "int64", "float", "bool", "double"] |
Eric Kunze | 2364dcd | 2021-04-26 11:06:57 -0700 | [diff] [blame] | 126 | random.seed(args.seed) |
| 127 | |
| 128 | failed = 0 |
| 129 | count = 0 |
| 130 | for test in datatypes: |
| 131 | count = count + 1 |
| 132 | shape = [] |
| 133 | for i in range(4): |
| 134 | shape.append(str(random.randint(1, max_size))) |
| 135 | (result, message, time_delta) = runTest(args, test, shape) |
| 136 | xt = xunit_test(str(test), f"{suitename}.{classname}") |
| 137 | xt.time = str( |
| 138 | float(time_delta.seconds) + (float(time_delta.microseconds) * 1e-6) |
| 139 | ) |
| 140 | if result == TestResult.PASS: |
| 141 | pass |
| 142 | else: |
| 143 | xt.failed(message) |
| 144 | failed = failed + 1 |
| 145 | xunit_suite.tests.append(xt) |
| 146 | |
| 147 | xunit_result.write_results(args.xunit_file) |
| 148 | print(f"Total tests run: {count} failures: {failed}") |
| 149 | |
| 150 | |
| 151 | if __name__ == "__main__": |
| 152 | exit(main()) |