Jeremy Johnson | 65ba809 | 2023-10-09 16:31:13 +0100 | [diff] [blame] | 1 | """Tests for the python interface to the data generator library.""" |
| 2 | # Copyright (c) 2023, ARM Limited. |
| 3 | # SPDX-License-Identifier: Apache-2.0 |
| 4 | from pathlib import Path |
| 5 | |
| 6 | import numpy as np |
| 7 | import pytest |
| 8 | from generator.datagenerator import GenerateError |
| 9 | from generator.datagenerator import GenerateLibrary |
| 10 | |
| 11 | # NOTE: These tests are marked as POST COMMIT |
| 12 | # To run them, please build the reference_model in a local "build" directory |
| 13 | # (as per the README) and run them using: pytest -m "postcommit" |
| 14 | |
| 15 | # Location of reference model binaries |
| 16 | REF_MODEL_BUILD_PATH = Path(__file__).resolve().parents[2] / "build" / "reference_model" |
| 17 | GENERATE_LIB = "libtosa_reference_generate_lib.so" |
| 18 | GENERATE_LIB_PATH = REF_MODEL_BUILD_PATH / GENERATE_LIB |
| 19 | |
| 20 | TEST_DIR = Path(__file__).parent |
| 21 | |
| 22 | |
| 23 | @pytest.mark.postcommit |
| 24 | def test_generate_lib_built(): |
| 25 | """First test to check the library has been built.""" |
| 26 | assert GENERATE_LIB_PATH.is_file() |
| 27 | |
| 28 | |
| 29 | @pytest.mark.postcommit |
| 30 | def test_checker_generate_load_fail(): |
| 31 | with pytest.raises(GenerateError) as excinfo: |
| 32 | GenerateLibrary(Path("/place-that-does-not-exist")) |
| 33 | assert str(excinfo.value).startswith("Could not find generate library") |
| 34 | |
| 35 | |
| 36 | @pytest.mark.postcommit |
| 37 | def test_checker_generate_load(): |
| 38 | glib = GenerateLibrary(GENERATE_LIB_PATH) |
| 39 | assert glib |
| 40 | |
| 41 | |
| 42 | JSON_DATAGEN_DOT_PRODUCT = { |
| 43 | "tosa_file": "test.json", |
| 44 | "ifm_name": ["input-0", "input-1"], |
| 45 | "ifm_file": ["input-0.npy", "input-1.npy"], |
| 46 | "ofm_name": ["result-0"], |
| 47 | "ofm_file": ["result-0.npy"], |
| 48 | "meta": { |
| 49 | "data_gen": { |
| 50 | "version": "0.1", |
| 51 | "tensors": { |
| 52 | "input-0": { |
| 53 | "generator": "DOT_PRODUCT", |
| 54 | "data_type": "FP32", |
| 55 | "input_type": "VARIABLE", |
| 56 | "shape": [3, 5, 4], |
| 57 | "input_pos": 0, |
| 58 | "op": "MATMUL", |
| 59 | "dot_product_info": {"s": 0, "ks": 4, "acc_type": "FP32"}, |
| 60 | }, |
| 61 | "input-1": { |
| 62 | "generator": "DOT_PRODUCT", |
| 63 | "data_type": "FP32", |
| 64 | "input_type": "VARIABLE", |
| 65 | "shape": [3, 4, 6], |
| 66 | "input_pos": 1, |
| 67 | "op": "MATMUL", |
| 68 | "dot_product_info": {"s": 0, "ks": 4, "acc_type": "FP32"}, |
| 69 | }, |
| 70 | }, |
| 71 | } |
| 72 | }, |
| 73 | } |
| 74 | |
| 75 | |
| 76 | @pytest.mark.postcommit |
| 77 | def test_generate_dot_product_check(): |
| 78 | glib = GenerateLibrary(GENERATE_LIB_PATH) |
| 79 | assert glib |
| 80 | |
| 81 | json_config = JSON_DATAGEN_DOT_PRODUCT |
| 82 | glib.set_config(json_config) |
| 83 | |
| 84 | glib.write_numpy_files(TEST_DIR) |
| 85 | |
| 86 | # Test the files exist and are the expected numpy files |
| 87 | for f, n in zip(json_config["ifm_file"], json_config["ifm_name"]): |
| 88 | file = TEST_DIR / f |
| 89 | assert file.is_file() |
| 90 | arr = np.load(file) |
| 91 | assert arr.shape == tuple( |
| 92 | json_config["meta"]["data_gen"]["tensors"][n]["shape"] |
| 93 | ) |
| 94 | assert arr.dtype == np.float32 |
| 95 | file.unlink() |
| 96 | |
| 97 | |
| 98 | @pytest.mark.postcommit |
| 99 | def test_generate_dot_product_check_fail_names(): |
| 100 | glib = GenerateLibrary(GENERATE_LIB_PATH) |
| 101 | assert glib |
| 102 | |
| 103 | # Fix up the JSON to have the wrong names |
| 104 | json_config = JSON_DATAGEN_DOT_PRODUCT.copy() |
| 105 | json_config["ifm_name"] = ["not-input0", "not-input1"] |
| 106 | glib.set_config(json_config) |
| 107 | |
| 108 | with pytest.raises(GenerateError) as excinfo: |
| 109 | glib.write_numpy_files(TEST_DIR) |
| 110 | info = str(excinfo.value).split("\n") |
| 111 | for i, n in enumerate(json_config["ifm_name"]): |
| 112 | assert info[i].startswith(f"ERROR: Failed to create data for tensor {n}") |
| 113 | |
| 114 | for f in json_config["ifm_file"]: |
| 115 | file = TEST_DIR / f |
| 116 | assert not file.is_file() |
Jeremy Johnson | d1a08ce | 2023-10-18 17:22:21 +0100 | [diff] [blame] | 117 | |
| 118 | |
| 119 | @pytest.mark.postcommit |
| 120 | def test_generate_tensor_data_check(): |
| 121 | glib = GenerateLibrary(GENERATE_LIB_PATH) |
| 122 | assert glib |
| 123 | |
| 124 | json_config = JSON_DATAGEN_DOT_PRODUCT["meta"]["data_gen"] |
| 125 | |
| 126 | for n in JSON_DATAGEN_DOT_PRODUCT["ifm_name"]: |
| 127 | arr = glib.get_tensor_data(n, json_config) |
| 128 | |
| 129 | assert arr.shape == tuple(json_config["tensors"][n]["shape"]) |
| 130 | assert arr.dtype == np.float32 |