blob: 895b4b51d0a76052a72c6109320659f1b9ecce66 [file] [log] [blame]
SiCong Li7061eb22021-01-08 15:16:02 +00001/*
2 * Copyright (c) 2021 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "src/runtime/CL/mlgo/MLGOHeuristics.h"
25#include "src/runtime/CL/mlgo/Utils.h"
26#include "tests/framework/Asserts.h"
27#include "tests/framework/Macros.h"
28
29using namespace arm_compute::mlgo;
30
31namespace arm_compute
32{
33namespace test
34{
35namespace validation
36{
37TEST_SUITE(CL)
38TEST_SUITE(UNIT)
39TEST_SUITE(MLGOHeuristics)
40TEST_CASE(CorrectDotMLGOShouldLoadCorrectly, framework::DatasetMode::ALL)
41{
42 std::string mlgo_str = R"_(
43 <header>
44 gemm-version, [1,2,1]
45 ip-type,gpu
46 </header>
47 <heuristics-table>
48 0, g76 , 8, f32, best-performance, static, gemm-type, [m,n,k,n]
49 1, g71 , 8, f16, best-performance, static, gemm-config-reshaped-only-rhs, [m,n,k,n]
50 2, g76 , 8, f16, best-performance, static, gemm-config-reshaped, [m,n,k,n]
51 </heuristics-table>
52 <heuristic, 0>
53 b , 0, var, m, ==, num, 10., 1, 2
54 l , 1, gemm-type, reshaped
55 b , 2, var, r_mn, >=, num, 2., 3, 6
56 b , 3, var, n, >=, num, 200., 4, 5
57 l , 4, gemm-type, reshaped-only-rhs
58 l , 5, gemm-type, reshaped
59 l , 6, gemm-type, reshaped-only-rhs
60 </heuristic>
61 <heuristic, 1>
62 b ,0,var, n, >, num, 100., 1, 4
63 b ,1,var, r_mnk, <=, num, 20., 2, 3
64 l ,2,gemm-config-reshaped-only-rhs, [4, 4,4,2,1,0,1]
65 l ,3,gemm-config-reshaped-only-rhs,[ 2, 2,4,2,1,1, 1 ]
66 b ,4,var, n, >=, num, 199.12, 5, 6
67 l ,5,gemm-config-reshaped-only-rhs, [1, 4,3,4,0,0,0]
68 l ,6,gemm-config-reshaped-only-rhs, [5, 4,4,5,1,1,0]
69 </heuristic>
70 <heuristic, 2>
71 l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
72 </heuristic>
73 )_";
74 std::stringstream ss(mlgo_str);
75 MLGOHeuristics heuristics;
76 heuristics.reload_from_stream(ss);
77
78 ARM_COMPUTE_EXPECT(heuristics.query_gemm_type(Query{ "g76", DataType::F32, 10, 1024, 20, 1 }).second == GEMMType::RESHAPED, framework::LogLevel::ERRORS);
79 ARM_COMPUTE_EXPECT(heuristics.query_gemm_type(Query{ "g76", DataType::F32, 400, 201, 5, 1 }).second == GEMMType::RESHAPED_ONLY_RHS, framework::LogLevel::ERRORS);
80 ARM_COMPUTE_EXPECT(heuristics.query_gemm_type(Query{ "g76", DataType::F32, 400, 200, 199, 16 }).second == GEMMType::RESHAPED_ONLY_RHS, framework::LogLevel::ERRORS);
81 ARM_COMPUTE_EXPECT(heuristics.query_gemm_type(Query{ "g76", DataType::F32, 400, 199, 512, 4 }).second == GEMMType::RESHAPED, framework::LogLevel::ERRORS);
82
83 ARM_COMPUTE_EXPECT((heuristics.query_gemm_config_reshaped_only_rhs(Query{ "g71", DataType::F16, 100, 1024, 20, 32 }).second == GEMMConfigReshapedOnlyRHS{ 4, 4, 4, 2, true, false, true }),
84 framework::LogLevel::ERRORS);
85 ARM_COMPUTE_EXPECT((heuristics.query_gemm_config_reshaped_only_rhs(Query{ "g71", DataType::F16, 100, 1024, 20, 32 }).second == GEMMConfigReshapedOnlyRHS{ 4, 4, 4, 2, true, false, true }),
86 framework::LogLevel::ERRORS);
87 ARM_COMPUTE_EXPECT((heuristics.query_gemm_config_reshaped_only_rhs(Query{ "g71", DataType::F16, 128, 101, 20, 1 }).second == GEMMConfigReshapedOnlyRHS{ 2, 2, 4, 2, true, true, true }),
88 framework::LogLevel::ERRORS);
89 ARM_COMPUTE_EXPECT((heuristics.query_gemm_config_reshaped_only_rhs(Query{ "g71", DataType::F16, 400, 100, 512, 1 }).second == GEMMConfigReshapedOnlyRHS{ 5, 4, 4, 5, true, true, false }),
90 framework::LogLevel::ERRORS);
91 ARM_COMPUTE_EXPECT((heuristics.query_gemm_config_reshaped_only_rhs(Query{ "g71", DataType::F16, 400, 100, 512, 1 }).second == GEMMConfigReshapedOnlyRHS{ 5, 4, 4, 5, true, true, false }),
92 framework::LogLevel::ERRORS);
93
94 ARM_COMPUTE_EXPECT((heuristics.query_gemm_config_reshaped(Query{ "g76", DataType::F16, 100, 100, 20, 32 }).second == GEMMConfigReshaped{ 4, 2, 4, 2, 8, true, false, true, false }),
95 framework::LogLevel::ERRORS);
96 ARM_COMPUTE_EXPECT((heuristics.query_gemm_config_reshaped(Query{ "g76", DataType::F16, 128, 512, 1024, 1 }).second == GEMMConfigReshaped{ 4, 2, 4, 2, 8, true, false, true, false }),
97 framework::LogLevel::ERRORS);
98}
99
100TEST_CASE(InvalidDotmlgoSyntaxShouldReturnInvalidStatus, framework::DatasetMode::ALL)
101{
102 std::string mlgo_str = R"_(
103 <header>
104 gemm-version, [1,2,1]
105 ip-type,pu
106 </header>
107 <heuristics-table>
108 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
109 </heurist
110 <heuristic, 0>
111 l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
112 </heuristic>
113 )_";
114 std::stringstream ss(mlgo_str);
115 MLGOHeuristics heuristics;
116 ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
117}
118
119TEST_SUITE(InvalidDotmlgoSemanticsShouldReturnInvalidStatus)
120// If the semantics errors are local to some trees instead of the entire heuristics, an alternative is to simply
121// ignore/remove those invalid trees. However the reason why we choose to throw, thus invalidating the entire
122// heuristics is that if there are some invalid trees, the quality of the dotmlgo is called into question even if
123// the rest of the trees are semantically valid, and they could severely degrade the performance of GEMM. Therefore
124// this "all or nothing" approach when it comes to dotmlgo correctness is safer and more defensive.
125
126// Also note that the semantic error of the tree only refers to those that obstruct its evaluation and thus query,
127// (e.g. invalid tree structure, unsupported features etc.) instead of those affecting the desired outcome
128// (usually in terms of final GEMM performance, e.g. the effectiveness of the decision tree)
129
130// In the future we might want to check the content of the exceptions as well. But right now it suffices to only
131// know that it throws exactly when it needs to.
132TEST_CASE(MismatchesBetweenHeuristicsTableEntriesAndHeuristicTrees, framework::DatasetMode::ALL)
133{
134 {
135 // Mismatching number of entries 1
136 std::string mlgo_str = R"_(
137 <header>
138 gemm-version, [1,2,1]
139 ip-type,gpu
140 </header>
141 <heuristics-table>
142 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
143 </heuristics-table>
144 )_";
145 std::stringstream ss(mlgo_str);
146 MLGOHeuristics heuristics;
147 // NOTE: This case might throw an internal error as the tree inserted by the heuristics-table cannot not be checked
148 ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
149 }
150
151 {
152 // Mismatching number of entries 2
153 std::string mlgo_str = R"_(
154 <header>
155 gemm-version, [1,2,1]
156 ip-type,gpu
157 </header>
158 <heuristics-table>
159 </heuristics-table>
160 <heuristic, 1>
161 l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
162 </heuristic>
163 )_";
164 std::stringstream ss(mlgo_str);
165 MLGOHeuristics heuristics;
166 ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
167 }
168
169 {
170 // Mismatching info
171 std::string mlgo_str = R"_(
172 <header>
173 gemm-version, [1,2,1]
174 ip-type,gpu
175 </header>
176 <heuristics-table>
177 0, g76 , 8, f32, best-performance, static, gemm-type, [m,n,k,n]
178 </heuristics-table>
179 <heuristic, 0>
180 l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
181 </heuristic>
182 )_";
183 std::stringstream ss(mlgo_str);
184 MLGOHeuristics heuristics;
185 ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
186 }
187}
188
189TEST_CASE(RepeatedHeuristicsTableEntriesId, framework::DatasetMode::ALL)
190{
191 std::string mlgo_str = R"_(
192 <header>
193 gemm-version, [1,2,1]
194 ip-type,gpu
195 </header>
196 <heuristics-table>
197 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
198 0, g71 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
199 </heuristics-table>
200 <heuristic, 0>
201 l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
202 </heuristic>
203 <heuristic, 1>
204 l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
205 </heuristic>
206 )_";
207 std::stringstream ss(mlgo_str);
208 MLGOHeuristics heuristics;
209 ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
210}
211
212TEST_CASE(RepeatedHeuristicsTableEntriesIndex, framework::DatasetMode::ALL)
213{
214 std::string mlgo_str = R"_(
215 <header>
216 gemm-version, [1,2,1]
217 ip-type,gpu
218 </header>
219 <heuristics-table>
220 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
221 1, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
222 </heuristics-table>
223 <heuristic, 0>
224 l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
225 </heuristic>
226 <heuristic, 1>
227 l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
228 </heuristic>
229 )_";
230 std::stringstream ss(mlgo_str);
231 MLGOHeuristics heuristics;
232 ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
233}
234
235TEST_CASE(RepeatedHeuristicTreesId, framework::DatasetMode::ALL)
236{
237 std::string mlgo_str = R"_(
238 <header>
239 gemm-version, [1,2,1]
240 ip-type,gpu
241 </header>
242 <heuristics-table>
243 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
244 1, g71 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
245 </heuristics-table>
246 <heuristic, 0>
247 l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
248 </heuristic>
249 <heuristic, 0>
250 l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
251 </heuristic>
252 )_";
253 std::stringstream ss(mlgo_str);
254 MLGOHeuristics heuristics;
255 ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
256}
257TEST_CASE(EmptyTree, framework::DatasetMode::ALL)
258{
259 std::string mlgo_str = R"_(
260 <header>
261 gemm-version, [1,2,1]
262 ip-type,gpu
263 </header>
264 <heuristics-table>
265 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
266 </heuristics-table>
267 <heuristic, 0>
268 </heuristic>
269 )_";
270 std::stringstream ss(mlgo_str);
271 MLGOHeuristics heuristics;
272 ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
273}
274
275TEST_CASE(InvalidTreeMissingRoot, framework::DatasetMode::ALL)
276{
277 std::string mlgo_str = R"_(
278 <header>
279 gemm-version, [1,2,1]
280 ip-type,gpu
281 </header>
282 <heuristics-table>
283 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
284 </heuristics-table>
285 <heuristic, 0>
286 b ,2, var, m, ==, num, 10., 3, 4
287 l ,3,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
288 l ,4,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
289 </heuristic>
290 )_";
291 std::stringstream ss(mlgo_str);
292 MLGOHeuristics heuristics;
293 ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
294}
295TEST_CASE(InvalidTreeMissingNodes, framework::DatasetMode::ALL)
296{
297 std::string mlgo_str = R"_(
298 <header>
299 gemm-version, [1,2,1]
300 ip-type,gpu
301 </header>
302 <heuristics-table>
303 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
304 </heuristics-table>
305 <heuristic, 0>
306 b ,0, var, m, ==, num, 10., 1, 2
307 l ,1,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
308 </heuristic>
309 )_";
310 std::stringstream ss(mlgo_str);
311 MLGOHeuristics heuristics;
312 ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
313}
314TEST_CASE(InvalidTreeRepeatedNodeIds, framework::DatasetMode::ALL)
315{
316 std::string mlgo_str = R"_(
317 <header>
318 gemm-version, [1,2,1]
319 ip-type,gpu
320 </header>
321 <heuristics-table>
322 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
323 </heuristics-table>
324 <heuristic, 0>
325 b ,0, var, m, ==, num, 10., 1, 2
326 l ,1,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
327 l ,1,gemm-config-reshaped,[1,2,4,2,8,1,0,1,0]
328 l ,2,gemm-config-reshaped,[2,2,4,2,8,1,0,1,0]
329 </heuristic>
330 )_";
331 std::stringstream ss(mlgo_str);
332 MLGOHeuristics heuristics;
333 ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
334}
335TEST_CASE(InvalidTreeDisjointNodes, framework::DatasetMode::ALL)
336{
337 std::string mlgo_str = R"_(
338 <header>
339 gemm-version, [1,2,1]
340 ip-type,gpu
341 </header>
342 <heuristics-table>
343 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
344 </heuristics-table>
345 <heuristic, 0>
346 b ,0, var, m, ==, num, 10., 1, 2
347 l ,1,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
348 l ,2,gemm-config-reshaped,[2,2,4,2,8,1,0,1,0]
349
350 b ,4, var, n, ==, num, 10., 5, 6
351 l ,5,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
352 l ,6,gemm-config-reshaped,[2,2,4,2,8,1,0,1,0]
353
354 l ,7,gemm-config-reshaped,[2,2,4,2,8,1,0,1,0]
355 </heuristic>
356 )_";
357 std::stringstream ss(mlgo_str);
358 MLGOHeuristics heuristics;
359 ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
360}
361TEST_CASE(InvalidTreeLoop, framework::DatasetMode::ALL)
362{
363 std::string mlgo_str = R"_(
364 <header>
365 gemm-version, [1,2,1]
366 ip-type,gpu
367 </header>
368 <heuristics-table>
369 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
370 </heuristics-table>
371 <heuristic, 0>
372 b ,0, var, m, ==, num, 10., 0, 1
373 l ,1,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
374 </heuristic>
375 )_";
376 std::stringstream ss(mlgo_str);
377 MLGOHeuristics heuristics;
378 ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
379}
380TEST_CASE(InvalidTreeCycle, framework::DatasetMode::ALL)
381{
382 std::string mlgo_str = R"_(
383 <header>
384 gemm-version, [1,2,1]
385 ip-type,gpu
386 </header>
387 <heuristics-table>
388 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
389 </heuristics-table>
390 <heuristic, 0>
391 b ,0, var, m, ==, num, 10., 1, 5
392 b ,1, var, n, ==, num, 10., 2, 3
393 l ,2,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
394 b ,3, var, k, ==, num, 10., 0, 4
395 l ,4,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
396 l ,5,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
397 </heuristic>
398 )_";
399 std::stringstream ss(mlgo_str);
400 MLGOHeuristics heuristics;
401 ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
402}
403TEST_CASE(InvalidTreeInvalidFeatures, framework::DatasetMode::ALL)
404{
405 std::string mlgo_str = R"_(
406 <header>
407 gemm-version, [1,2,1]
408 ip-type,gpu
409 </header>
410 <heuristics-table>
411 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
412 </heuristics-table>
413 <heuristic, 0>
414 b ,0, var, magic_feature, ==, num, 10., 1, 2
415 l ,1,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
416 l ,2,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
417 </heuristic>
418 )_";
419 std::stringstream ss(mlgo_str);
420 MLGOHeuristics heuristics;
421 ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
422}
423TEST_SUITE_END() // InvalidDotmlgoSemanticsShouldReturnInvalidStatus
424
425TEST_CASE(InvalidUsageOfHeuristicsShouldReturnInvalidStatus, framework::DatasetMode::ALL)
426{
427 std::string mlgo_str = R"_(
428 <header>
429 gemm-version, [1,2,1]
430 ip-type,gpu
431 </header>
432 <heuristics-table>
433 0, g76 , 8, f32, best-performance, static, gemm-type, [m,n,k,n]
434 </heuristics-table>
435 <heuristic, 0>
436 b , 0, var, m, ==, num, 10., 1, 2
437 l , 1, gemm-type, reshaped
438 b , 2, var, r_mn, >=, num, 2., 3, 6
439 b , 3, var, n, >=, num, 200., 4, 5
440 l , 4, gemm-type, reshaped-only-rhs
441 l , 5, gemm-type, reshaped
442 l , 6, gemm-type, reshaped-only-rhs
443 </heuristic>
444 )_";
445 std::stringstream ss(mlgo_str);
446 MLGOHeuristics heuristics;
447 ARM_COMPUTE_EXPECT(heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
448
449 // Querying unavailable heuristic type should return invalid Status
450 ARM_COMPUTE_EXPECT(!heuristics.query_gemm_config_reshaped(Query{ "g76", DataType::F32, 1024, 1024, 100, 3 }).first, framework::LogLevel::ERRORS);
451 // Querying unavailable ip target should return invalid Status
452 ARM_COMPUTE_EXPECT(!heuristics.query_gemm_type(Query{ "g77", DataType::F32, 1024, 1024, 100, 3 }).first, framework::LogLevel::ERRORS);
453 // Querying unavailable data type should return invalid Status
454 ARM_COMPUTE_EXPECT(!heuristics.query_gemm_config_reshaped_only_rhs(Query{ "g76", DataType::QASYMM8, 1024, 1024, 100, 3 }).first, framework::LogLevel::ERRORS);
455}
456TEST_SUITE_END() // MLGOHeuristics
457TEST_SUITE_END() // UNIT
458TEST_SUITE_END() // CL
459} // namespace validation
460} // namespace test
461} // namespace arm_compute