blob: 4f3d7fdad0f5921f2089fcbf78624a606440f4f3 [file] [log] [blame]
Jeremy Johnson65ba8092023-10-09 16:31:13 +01001"""Tests for the python interface to the data generator library."""
2# Copyright (c) 2023, ARM Limited.
3# SPDX-License-Identifier: Apache-2.0
4from pathlib import Path
5
6import numpy as np
7import pytest
8from generator.datagenerator import GenerateError
9from generator.datagenerator import GenerateLibrary
10
11# NOTE: These tests are marked as POST COMMIT
12# To run them, please build the reference_model in a local "build" directory
13# (as per the README) and run them using: pytest -m "postcommit"
14
15# Location of reference model binaries
16REF_MODEL_BUILD_PATH = Path(__file__).resolve().parents[2] / "build" / "reference_model"
17GENERATE_LIB = "libtosa_reference_generate_lib.so"
18GENERATE_LIB_PATH = REF_MODEL_BUILD_PATH / GENERATE_LIB
19
20TEST_DIR = Path(__file__).parent
21
22
23@pytest.mark.postcommit
24def test_generate_lib_built():
25 """First test to check the library has been built."""
26 assert GENERATE_LIB_PATH.is_file()
27
28
29@pytest.mark.postcommit
30def test_checker_generate_load_fail():
31 with pytest.raises(GenerateError) as excinfo:
32 GenerateLibrary(Path("/place-that-does-not-exist"))
33 assert str(excinfo.value).startswith("Could not find generate library")
34
35
36@pytest.mark.postcommit
37def test_checker_generate_load():
38 glib = GenerateLibrary(GENERATE_LIB_PATH)
39 assert glib
40
41
42JSON_DATAGEN_DOT_PRODUCT = {
43 "tosa_file": "test.json",
44 "ifm_name": ["input-0", "input-1"],
45 "ifm_file": ["input-0.npy", "input-1.npy"],
46 "ofm_name": ["result-0"],
47 "ofm_file": ["result-0.npy"],
48 "meta": {
49 "data_gen": {
50 "version": "0.1",
51 "tensors": {
52 "input-0": {
53 "generator": "DOT_PRODUCT",
54 "data_type": "FP32",
55 "input_type": "VARIABLE",
56 "shape": [3, 5, 4],
57 "input_pos": 0,
58 "op": "MATMUL",
59 "dot_product_info": {"s": 0, "ks": 4, "acc_type": "FP32"},
60 },
61 "input-1": {
62 "generator": "DOT_PRODUCT",
63 "data_type": "FP32",
64 "input_type": "VARIABLE",
65 "shape": [3, 4, 6],
66 "input_pos": 1,
67 "op": "MATMUL",
68 "dot_product_info": {"s": 0, "ks": 4, "acc_type": "FP32"},
69 },
70 },
71 }
72 },
73}
74
75
76@pytest.mark.postcommit
77def test_generate_dot_product_check():
78 glib = GenerateLibrary(GENERATE_LIB_PATH)
79 assert glib
80
81 json_config = JSON_DATAGEN_DOT_PRODUCT
82 glib.set_config(json_config)
83
84 glib.write_numpy_files(TEST_DIR)
85
86 # Test the files exist and are the expected numpy files
87 for f, n in zip(json_config["ifm_file"], json_config["ifm_name"]):
88 file = TEST_DIR / f
89 assert file.is_file()
90 arr = np.load(file)
91 assert arr.shape == tuple(
92 json_config["meta"]["data_gen"]["tensors"][n]["shape"]
93 )
94 assert arr.dtype == np.float32
95 file.unlink()
96
97
98@pytest.mark.postcommit
99def test_generate_dot_product_check_fail_names():
100 glib = GenerateLibrary(GENERATE_LIB_PATH)
101 assert glib
102
103 # Fix up the JSON to have the wrong names
104 json_config = JSON_DATAGEN_DOT_PRODUCT.copy()
105 json_config["ifm_name"] = ["not-input0", "not-input1"]
106 glib.set_config(json_config)
107
108 with pytest.raises(GenerateError) as excinfo:
109 glib.write_numpy_files(TEST_DIR)
110 info = str(excinfo.value).split("\n")
111 for i, n in enumerate(json_config["ifm_name"]):
112 assert info[i].startswith(f"ERROR: Failed to create data for tensor {n}")
113
114 for f in json_config["ifm_file"]:
115 file = TEST_DIR / f
116 assert not file.is_file()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100117
118
119@pytest.mark.postcommit
120def test_generate_tensor_data_check():
121 glib = GenerateLibrary(GENERATE_LIB_PATH)
122 assert glib
123
124 json_config = JSON_DATAGEN_DOT_PRODUCT["meta"]["data_gen"]
125
126 for n in JSON_DATAGEN_DOT_PRODUCT["ifm_name"]:
127 arr = glib.get_tensor_data(n, json_config)
128
129 assert arr.shape == tuple(json_config["tensors"][n]["shape"])
130 assert arr.dtype == np.float32