blob: 503ecfeb650b0838fa97fe4c3ba985d3f53aa994 [file] [log] [blame]
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +01001// Copyright (c) 2023, ARM Limited.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14#include "generate.h"
15
16#include <doctest.h>
17
18#include <array>
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +010019#include <sstream>
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010020#include <string>
21#include <vector>
22
Jeremy Johnson59b307d2023-10-04 14:17:26 +010023namespace
24{
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +010025void update_json_template(std::string& str, const std::string& find, const std::string& change)
Jeremy Johnson59b307d2023-10-04 14:17:26 +010026{
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +010027 // Update the 'str' by looking for instances of 'find' and replacing them with 'change'
28 auto pos = str.find(find);
Jeremy Johnson59b307d2023-10-04 14:17:26 +010029 while (pos != std::string::npos)
30 {
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +010031 str.replace(pos, find.length(), change);
Jeremy Johnson59b307d2023-10-04 14:17:26 +010032 pos = str.find(find);
33 }
34}
35
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +010036void check_value(bool match, uint32_t result, uint32_t expected, uint32_t idx)
37{
38 std::stringstream msg;
39 msg << "index: " << idx << " expected: " << std::hex << expected << " got: " << result;
40 if (match)
41 {
42 REQUIRE_MESSAGE(expected == result, msg.str());
43 }
44 else
45 {
46 REQUIRE_MESSAGE(expected != result, msg.str());
47 }
48}
49
Jeremy Johnson59b307d2023-10-04 14:17:26 +010050template <typename T>
51void check_output(const std::vector<T>& results, const std::vector<uint32_t>& expected)
52{
53 for (size_t idx = 0; idx < expected.size(); ++idx)
54 {
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +010055 check_value(true, *(uint32_t*)&results[idx], expected[idx], idx);
Jeremy Johnson59b307d2023-10-04 14:17:26 +010056 }
57}
58
59} // namespace
60
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010061TEST_SUITE_BEGIN("generate");
62
63TEST_CASE("negative - api")
64{
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +010065 std::string templateJsonCfg = R"({
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010066 "tensors" : {
67 "in1" : {
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +010068 "generator": "_GENERATOR_",
69 "data_type": "_TYPE_",
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010070 "input_type": "VARIABLE",
71 "shape" : [ 4, 8, 8 ],
72 "input_pos": 0,
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +010073 "op" : "_OP_",
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010074 "dot_product_info": {
75 "s": 0,
Jeremy Johnson59b307d2023-10-04 14:17:26 +010076 "ks": 8,
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +010077 "acc_type": "_TYPE_"
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010078 }
79 }
80 }
81 })";
82
83 const std::string tosaName = "in1";
84 const size_t tosaElements = 4 * 8 * 8;
85 const size_t tosaSize = tosaElements * 4;
86
87 SUBCASE("missing input")
88 {
89 REQUIRE_FALSE(tgd_generate_data(NULL, NULL, NULL, 0));
90 }
91 SUBCASE("invalid json")
92 {
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +010093 std::string invalidJsonCfg = R"({
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +010094 "tensors" : {
95 "in1" : {
96 "generator": DOT_PRODUCT,
97 },
98 }
99 })";
100
101 std::vector<float> buffer(tosaElements);
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100102 REQUIRE_FALSE(tgd_generate_data(invalidJsonCfg.c_str(), tosaName.c_str(), (void*)buffer.data(), tosaSize));
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100103 }
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100104 SUBCASE("unknown generator")
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100105 {
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100106 std::string jsonCfg = templateJsonCfg;
107 update_json_template(jsonCfg, "_GENERATOR_", "SOLAR");
108 update_json_template(jsonCfg, "_TYPE_", "FP32");
109 update_json_template(jsonCfg, "_OP_", "MATMUL");
110 std::vector<float> buffer(tosaElements);
111 REQUIRE_FALSE(tgd_generate_data(jsonCfg.c_str(), tosaName.c_str(), (void*)buffer.data(), tosaSize));
112 }
113 SUBCASE("unknown op")
114 {
115 std::string jsonCfg = templateJsonCfg;
116 update_json_template(jsonCfg, "_GENERATOR_", "DOT_PRODUCT");
117 update_json_template(jsonCfg, "_TYPE_", "FP32");
118 update_json_template(jsonCfg, "_OP_", "GREEN");
119
120 std::vector<float> buffer(tosaElements);
121 REQUIRE_FALSE(tgd_generate_data(jsonCfg.c_str(), tosaName.c_str(), (void*)buffer.data(), tosaSize));
122 }
123 SUBCASE("unknown type")
124 {
125 std::string jsonCfg = templateJsonCfg;
126 update_json_template(jsonCfg, "_GENERATOR_", "DOT_PRODUCT");
127 update_json_template(jsonCfg, "_TYPE_", "WATT");
128 update_json_template(jsonCfg, "_OP_", "MATMUL");
129
130 std::vector<float> buffer(tosaElements);
131 REQUIRE_FALSE(tgd_generate_data(jsonCfg.c_str(), tosaName.c_str(), (void*)buffer.data(), tosaSize));
132 }
133 SUBCASE("mismatching name")
134 {
135 std::string jsonCfg = templateJsonCfg;
136 update_json_template(jsonCfg, "_GENERATOR_", "DOT_PRODUCT");
137 update_json_template(jsonCfg, "_TYPE_", "FP32");
138 update_json_template(jsonCfg, "_OP_", "MATMUL");
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100139 std::string invalidName = "notFound1";
140
141 std::vector<float> buffer(tosaElements);
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100142 REQUIRE_FALSE(tgd_generate_data(jsonCfg.c_str(), invalidName.c_str(), (void*)buffer.data(), tosaSize));
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100143 }
144 SUBCASE("mismatching size")
145 {
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100146 std::string jsonCfg = templateJsonCfg;
147 update_json_template(jsonCfg, "_GENERATOR_", "DOT_PRODUCT");
148 update_json_template(jsonCfg, "_TYPE_", "FP32");
149 update_json_template(jsonCfg, "_OP_", "MATMUL");
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100150 size_t smallElements = 4 * 8 * 7;
151 size_t smallSize = smallElements * 4;
152
153 std::vector<float> buffer(smallElements);
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100154 REQUIRE_FALSE(tgd_generate_data(jsonCfg.c_str(), tosaName.c_str(), (void*)buffer.data(), smallSize));
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100155 }
156}
157
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100158void matmul_test_FP32(const std::string tosaName[2],
159 const size_t tosaElements[2],
160 const std::string templateJsonCfg,
161 const std::string setStr,
162 int32_t param,
163 const std::vector<uint32_t> expected)
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100164{
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100165 std::string jsonCfg = templateJsonCfg;
166 update_json_template(jsonCfg, "_SET_", setStr);
167 std::vector<float> buffer(tosaElements[param]);
168 REQUIRE(tgd_generate_data(jsonCfg.c_str(), tosaName[param].c_str(), (void*)buffer.data(), tosaElements[param] * 4));
169 check_output<float>(buffer, expected);
170}
171
172TEST_CASE("positive - FP32 matmul dot product (first 3 values)")
173{
174 std::string templateJsonCfg = R"({
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100175 "tensors" : {
176 "in1" : {
177 "generator": "DOT_PRODUCT",
178 "data_type": "FP32",
179 "input_type": "VARIABLE",
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100180 "shape" : [ 4, 8, 2 ],
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100181 "input_pos": 0,
182 "op" : "MATMUL",
183 "dot_product_info": {
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100184 "s": _SET_,
185 "ks": 2,
186 "acc_type": "FP32"
187 }
188 },
189 "in2" : {
190 "generator": "DOT_PRODUCT",
191 "data_type": "FP32",
192 "input_type": "VARIABLE",
193 "shape" : [ 4, 2, 5 ],
194 "input_pos": 1,
195 "op" : "MATMUL",
196 "dot_product_info": {
197 "s": _SET_,
198 "ks": 2,
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100199 "acc_type": "FP32"
200 }
201 }
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100202
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100203 }
204 })";
205
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100206 const std::string tosaName[2] = { "in1", "in2" };
207 const size_t tosaElements[2] = { (4 * 8 * 2), (4 * 2 * 5) };
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100208
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100209 SUBCASE("matmul, set 0, param 0")
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100210 {
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100211 std::vector<uint32_t> expected = { 0xbf665aa4, 0xbf736bd3, 0x0 };
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100212 matmul_test_FP32(tosaName, tosaElements, templateJsonCfg, "0", 0, expected);
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100213 }
214 SUBCASE("matmul, set 0, param 1")
215 {
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100216 std::vector<uint32_t> expected = { 0x0, 0x0, 0x3f34f2dd };
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100217 matmul_test_FP32(tosaName, tosaElements, templateJsonCfg, "0", 1, expected);
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100218 }
219 SUBCASE("matmul, set 1, param 0")
220 {
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100221 std::vector<uint32_t> expected = { 0x5e97f1b0, 0x5ea6a18e, 0x5eb811af };
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100222 matmul_test_FP32(tosaName, tosaElements, templateJsonCfg, "1", 0, expected);
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100223 }
224 SUBCASE("matmul, set 1, param 1")
225 {
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100226 std::vector<uint32_t> expected = { 0x5f128bb1, 0x5ef54579, 0x5ebd65b8 };
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100227 matmul_test_FP32(tosaName, tosaElements, templateJsonCfg, "1", 1, expected);
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100228 }
229 SUBCASE("matmul, set 2, param 0")
230 {
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100231 std::vector<uint32_t> expected = { 0x3f800000, 0x3e66ed53, 0x3f800000 };
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100232 matmul_test_FP32(tosaName, tosaElements, templateJsonCfg, "2", 0, expected);
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100233 }
234 SUBCASE("matmul, set 2, param 1")
235 {
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100236 std::vector<uint32_t> expected = { 0x3f800000, 0x3f800000, 0x3f800000 };
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100237 matmul_test_FP32(tosaName, tosaElements, templateJsonCfg, "2", 1, expected);
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100238 }
239 SUBCASE("matmul, set 3, param 0")
240 {
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100241 // NOTE: Python test script produced 0xbf256686 - so off by 1
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100242 std::vector<uint32_t> expected = { 0x41800000, 0xbf256685, 0x41800000 };
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100243 matmul_test_FP32(tosaName, tosaElements, templateJsonCfg, "3", 0, expected);
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100244 }
245 SUBCASE("matmul, set 3, param 1")
246 {
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100247 std::vector<uint32_t> expected = { 0x41800000, 0x41800000, 0x41800000 };
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100248 matmul_test_FP32(tosaName, tosaElements, templateJsonCfg, "3", 1, expected);
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100249 }
250 SUBCASE("matmul, set 4, param 0")
251 {
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100252 std::vector<uint32_t> expected = { 0x0, 0x3f000000, 0x5f14e80c };
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100253 matmul_test_FP32(tosaName, tosaElements, templateJsonCfg, "4", 0, expected);
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100254 }
255 SUBCASE("matmul, set 4, param 1")
256 {
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100257 std::vector<uint32_t> expected = { 0x5d5d0db2, 0xdf2c82a8, 0x0 };
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100258 matmul_test_FP32(tosaName, tosaElements, templateJsonCfg, "4", 1, expected);
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100259 }
260 SUBCASE("matmul, set 5, param 0")
261 {
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100262 std::vector<uint32_t> expected = { 0x5df6c4b3, 0x5e6b4088, 0x5ed0fe71 };
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100263 matmul_test_FP32(tosaName, tosaElements, templateJsonCfg, "5", 0, expected);
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100264 }
265 SUBCASE("matmul, set 5, param 1")
266 {
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100267 std::vector<uint32_t> expected = { 0xde086d85, 0x5e630878, 0x5eba5c7b };
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100268 matmul_test_FP32(tosaName, tosaElements, templateJsonCfg, "5", 1, expected);
Jeremy Johnsonb20b0c92023-10-04 14:17:55 +0100269 }
270}
Jeremy Johnson59b307d2023-10-04 14:17:26 +0100271TEST_SUITE_END(); // generate