blob: ef1f31493ebcb066a09faffe1aeb39a20a42b2fc [file] [log] [blame]
Viet-Hoa Dobb1ab052022-12-23 14:48:33 +00001# Copyright (c) 2019-2020 Arm Limited.
SiCong Lie36b5262019-10-01 19:26:00 +01002#
3# SPDX-License-Identifier: MIT
4#
5# Permission is hereby granted, free of charge, to any person obtaining a copy
6# of this software and associated documentation files (the "Software"), to
7# deal in the Software without restriction, including without limitation the
8# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
9# sell copies of the Software, and to permit persons to whom the Software is
10# furnished to do so, subject to the following conditions:
11#
12# The above copyright notice and this permission notice shall be included in all
13# copies or substantial portions of the Software.
14#
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21# SOFTWARE.
22
23#!/usr/bin/python3
24
25import argparse
26import csv
27import json
28import logging
29import math
30import os
31from collections import Counter, defaultdict, deque, namedtuple
32from enum import Enum
33from pathlib import Path
SiCong Li75041a12019-11-05 10:43:06 +000034from typing import Deque, Dict, Generator, List, NamedTuple, Set, Tuple, Union
SiCong Lie36b5262019-10-01 19:26:00 +010035
36################################################################################
37# Types
38################################################################################
39
40# Gemm strategy
41Strategy = Enum("Strategy", ["Native", "ReshapedOnlyRHS", "Reshaped"])
42
43# Gemm parameter
Eren Kopuza0bf9132020-06-24 17:29:38 +010044
45
SiCong Lie36b5262019-10-01 19:26:00 +010046class GEMMParam(NamedTuple):
47 M: int # Number of lhs matrix rows
48 N: int # Number of rhs matrix columns
49 K: int # Number of lhs matrix columns/rhs matrix rows
50 B: int # Batch size
Eren Kopuz6977b372020-07-13 12:37:06 +010051 data_type: str # Data type
SiCong Lie36b5262019-10-01 19:26:00 +010052
SiCong Liac4c0302020-07-28 12:24:45 +010053 @classmethod
54 def parse_from_strs(cls, *M_N_K_B, data_type):
55 return cls(*map(int, M_N_K_B), str(data_type))
SiCong Lie36b5262019-10-01 19:26:00 +010056
SiCong Li75041a12019-11-05 10:43:06 +000057 def __str__(self):
Eren Kopuza0bf9132020-06-24 17:29:38 +010058 return ",".join(map(str, self))
SiCong Li75041a12019-11-05 10:43:06 +000059
SiCong Lie36b5262019-10-01 19:26:00 +010060
61# Gemm configuration for strategy Native
62class NativeGEMMConfig(NamedTuple):
63 m0: int # Number of rows processed by the matrix multiplication
64 n0: int # Number of columns processed by the matrix multiplication
65 k0: int # Number of partial accumulations performed by the matrix multiplication
66
SiCong Liac4c0302020-07-28 12:24:45 +010067 @classmethod
68 def parse_from_strs(cls, *args):
Eren Kopuza0bf9132020-06-24 17:29:38 +010069 (*mnk,) = map(int, args)
SiCong Liac4c0302020-07-28 12:24:45 +010070 return cls(*mnk)
SiCong Lie36b5262019-10-01 19:26:00 +010071
SiCong Li75041a12019-11-05 10:43:06 +000072 def __str__(self):
Eren Kopuza0bf9132020-06-24 17:29:38 +010073 return ",".join(map(str, self))
SiCong Li75041a12019-11-05 10:43:06 +000074
SiCong Lie36b5262019-10-01 19:26:00 +010075
76# Gemm configuration for strategy Reshaped Only RHS
77class ReshapedOnlyRHSGEMMConfig(NamedTuple):
78 m0: int # Number of rows processed by the matrix multiplication
79 n0: int # Number of columns processed by the matrix multiplication
80 k0: int # Number of partial accumulations performed by the matrix multiplication
Eren Kopuza0bf9132020-06-24 17:29:38 +010081 # Number of horizontal blocks of size (k0xn0) stored on the same output row
82 h0: int
83 # Interleave rhs matrix (1) / Do not interleave rhs matrix (0)
84 interleave_rhs: bool
85 # Transpose rhs matrix but not lhs matrix (1) / Do not transpose rhs matrix but do transpose lhs matrix (0)
86 transpose_rhs: bool
SiCong Liac4c0302020-07-28 12:24:45 +010087 # Export rhs matrix to cl_image (1) / Do not export rhs matrix to cl_image (0)
88 export_to_cl_image_rhs: bool
SiCong Lie36b5262019-10-01 19:26:00 +010089
SiCong Liac4c0302020-07-28 12:24:45 +010090 @classmethod
91 def parse_from_strs(cls, *args):
92 (*mnkh, interleave_rhs, transpose_rhs, export_to_cl_image_rhs,) = map(int, args)
SiCong Lie36b5262019-10-01 19:26:00 +010093 interleave_rhs = interleave_rhs == 1
94 transpose_rhs = transpose_rhs == 1
SiCong Liac4c0302020-07-28 12:24:45 +010095 export_to_cl_image_rhs = export_to_cl_image_rhs == 1
96 return cls(*mnkh, interleave_rhs, transpose_rhs, export_to_cl_image_rhs)
SiCong Lie36b5262019-10-01 19:26:00 +010097
SiCong Li75041a12019-11-05 10:43:06 +000098 def __str__(self):
Eren Kopuza0bf9132020-06-24 17:29:38 +010099 return ",".join(map(str, self))
SiCong Li75041a12019-11-05 10:43:06 +0000100
SiCong Lie36b5262019-10-01 19:26:00 +0100101
102# Gemm configuration for strategy Reshaped
103class ReshapedGEMMConfig(NamedTuple):
104 m0: int # Number of rows processed by the matrix multiplication
105 n0: int # Number of columns processed by the matrix multiplication
106 k0: int # Number of partial accumulations performed by the matrix multiplication
Eren Kopuza0bf9132020-06-24 17:29:38 +0100107 # Number of vertical blocks of size (m0xk0) stored on the same output row
108 v0: int
109 # Number of horizontal blocks of size (k0xn0) stored on the same output row
110 h0: int
111 # Interleave lhs matrix (1) / Do not interleave lhs matrix (0)
112 interleave_lhs: bool
113 # Interleave rhs matrix (1) / Do not interleave rhs matrix (0)
114 interleave_rhs: bool
115 # Transpose rhs matrix but not lhs matrix (1) / Do not transpose rhs matrix but do transpose lhs matrix (0)
116 transpose_rhs: bool
SiCong Liac4c0302020-07-28 12:24:45 +0100117 # Export rhs matrix to cl_image (1) / Do not export rhs matrix to cl_image (0)
118 export_to_cl_image_rhs: bool
SiCong Lie36b5262019-10-01 19:26:00 +0100119
SiCong Liac4c0302020-07-28 12:24:45 +0100120 @classmethod
121 def parse_from_strs(cls, *args):
122 (*mnkvh, interleave_lhs, interleave_rhs, transpose_rhs, export_to_cl_image_rhs,) = map(int, args)
SiCong Lie36b5262019-10-01 19:26:00 +0100123 interleave_lhs = interleave_lhs == 1
124 interleave_rhs = interleave_rhs == 1
125 transpose_rhs = transpose_rhs == 1
SiCong Liac4c0302020-07-28 12:24:45 +0100126 export_to_cl_image_rhs = export_to_cl_image_rhs == 1
127 return cls(*mnkvh, interleave_lhs, interleave_rhs, transpose_rhs, export_to_cl_image_rhs)
SiCong Lie36b5262019-10-01 19:26:00 +0100128
SiCong Li75041a12019-11-05 10:43:06 +0000129 def __str__(self):
Eren Kopuza0bf9132020-06-24 17:29:38 +0100130 return ",".join(map(str, self))
SiCong Li75041a12019-11-05 10:43:06 +0000131
SiCong Lie36b5262019-10-01 19:26:00 +0100132
133# Measurement we take from the benchmark result.
134class Measurement(NamedTuple):
Eren Kopuza0bf9132020-06-24 17:29:38 +0100135 opencl_timer_ms_reshape: float
136 opencl_timer_ms_kernel: float
137
138 def get_total_ms(self):
139 return self.opencl_timer_ms_reshape + self.opencl_timer_ms_kernel
SiCong Lie36b5262019-10-01 19:26:00 +0100140
SiCong Li75041a12019-11-05 10:43:06 +0000141 def is_close_to(self, other, tol):
Eren Kopuza0bf9132020-06-24 17:29:38 +0100142 return math.fabs(self.get_total_ms() - other.get_total_ms()) < tol
SiCong Li75041a12019-11-05 10:43:06 +0000143
144 def is_better_than(self, other, tol):
Eren Kopuza0bf9132020-06-24 17:29:38 +0100145 return self.get_total_ms() < other.get_total_ms() and not self.is_close_to(
146 other
147 )
SiCong Lie36b5262019-10-01 19:26:00 +0100148
149 def __add__(self, other):
Eren Kopuza0bf9132020-06-24 17:29:38 +0100150 return Measurement(
151 self.opencl_timer_ms_reshape + other.opencl_timer_ms_reshape,
152 self.opencl_timer_ms_kernel + other.opencl_timer_ms_kernel,
153 )
SiCong Lie36b5262019-10-01 19:26:00 +0100154
155 def __sub__(self, other):
Eren Kopuza0bf9132020-06-24 17:29:38 +0100156 return Measurement(
157 self.opencl_timer_ms_reshape - other.opencl_timer_ms_reshape,
158 self.opencl_timer_ms_kernel - other.opencl_timer_ms_kernel,
159 )
SiCong Lie36b5262019-10-01 19:26:00 +0100160
161 def __mul__(self, other):
Eren Kopuza0bf9132020-06-24 17:29:38 +0100162 return Measurement(
163 self.opencl_timer_ms_reshape * other.opencl_timer_ms_reshape,
164 self.opencl_timer_ms_kernel * other.opencl_timer_ms_kernel,
165 )
SiCong Lie36b5262019-10-01 19:26:00 +0100166
167 def __floordiv__(self, other):
Eren Kopuza0bf9132020-06-24 17:29:38 +0100168 return Measurement(
169 self.opencl_timer_ms_reshape // other.opencl_timer_ms_reshape,
170 self.opencl_timer_ms_kernel // other.opencl_timer_ms_kernel,
171 )
SiCong Lie36b5262019-10-01 19:26:00 +0100172
173 def __truediv__(self, other):
Eren Kopuza0bf9132020-06-24 17:29:38 +0100174 return Measurement(
175 self.opencl_timer_ms_reshape / other.opencl_timer_ms_reshape,
176 self.opencl_timer_ms_kernel / other.opencl_timer_ms_kernel,
177 )
SiCong Lie36b5262019-10-01 19:26:00 +0100178
179 def __pow__(self, power):
Eren Kopuza0bf9132020-06-24 17:29:38 +0100180 return Measurement(
181 self.opencl_timer_ms_reshape ** power, self.opencl_timer_ms_kernel ** power
182 )
183
184 def __str__(self):
185 return ",".join(map(str, self))
SiCong Lie36b5262019-10-01 19:26:00 +0100186
187
188# GEMMConfig Type
Eren Kopuza0bf9132020-06-24 17:29:38 +0100189GEMMConfigT = Union[NativeGEMMConfig,
190 ReshapedOnlyRHSGEMMConfig, ReshapedGEMMConfig]
SiCong Lie36b5262019-10-01 19:26:00 +0100191
192
193# Representation of the benchmark result from a single experiment
194class BenchmarkResult(NamedTuple):
195 gemm_param: GEMMParam
196 strategy: Strategy
197 gemm_config: GEMMConfigT
198 measurement: Measurement
199
200
SiCong Lie36b5262019-10-01 19:26:00 +0100201class GEMMBenchmarkResultRecorder:
202 """ A recorder that records and organises GEMM Benchmark results, and produces various reports on the record.
203 """
204
205 SummaryLevel = Enum("SummaryLevel", ["Short", "Detailed"])
206
SiCong Li75041a12019-11-05 10:43:06 +0000207 def __init__(self, tol=0.01):
SiCong Lie36b5262019-10-01 19:26:00 +0100208 """ Initializer
209 """
SiCong Li75041a12019-11-05 10:43:06 +0000210 self._benchmark_result_record: List[BenchmarkResult] = []
SiCong Lie36b5262019-10-01 19:26:00 +0100211 # Strategies recorded
212 self._strategies = set()
SiCong Li75041a12019-11-05 10:43:06 +0000213 self._tol = tol
SiCong Lie36b5262019-10-01 19:26:00 +0100214
215 def add(self, benchmark_result: BenchmarkResult):
216 """ Add a benchmark result to the record.
SiCong Lie36b5262019-10-01 19:26:00 +0100217 """
218 gemm_param, strategy, gemm_config, measurement = benchmark_result
219 # Update strategies encoutnered
220 self._strategies.add(strategy)
SiCong Li75041a12019-11-05 10:43:06 +0000221
222 self._benchmark_result_record.append(benchmark_result)
223
224 def get_record(self) -> Generator[BenchmarkResult, None, None]:
225 """ Return an iterator that iterates over the record.
226 """
227 yield from self._benchmark_result_record
228
229 def get_best_gemm_configs(self):
230 """ Get the best GEMMConfig set per GEMMParam per Strategy
231 """
232 best_gc_sets: Dict[
233 Tuple[GEMMParam, Strategy], List[Tuple[GEMMConfig, Measurement]]
234 ] = defaultdict(list)
235 for gemm_param, strategy, gemm_config, measurement in self.get_record():
236 best_gc_set = best_gc_sets.setdefault((gemm_param, strategy), [])
237 best_gc_set.append((gemm_config, measurement))
238 # Sort the best config set (list)
Eren Kopuza0bf9132020-06-24 17:29:38 +0100239 best_gc_set = sorted(
240 best_gc_set, key=lambda gc_and_m: gc_and_m[1].get_total_ms()
241 )
SiCong Li75041a12019-11-05 10:43:06 +0000242 # Filter out configs that are beyond tolerance to the best GEMMConfig's measurement
243 best_gc, best_m = best_gc_set[0]
244 best_gc_set_new = [
245 (gemm_config, measurement)
246 for gemm_config, measurement in best_gc_set[1:]
247 if measurement.is_close_to(best_m, self._tol)
248 ]
249 # Add back the best config
250 best_gc_set_new.insert(0, (best_gc, best_m))
251 best_gc_sets[(gemm_param, strategy)] = best_gc_set_new
252
253 return best_gc_sets
254
255 def get_best_gemm_configs_as_sequence(self):
256 """ Get the best GEMMConfig set per GEMMParam per Strategy, and flatten the result into a sequence
257 of BenchmarkResults
258 """
Eren Kopuza0bf9132020-06-24 17:29:38 +0100259 for (
260 (gemm_param, strategy),
261 best_gc_sets,
262 ) in self.get_best_gemm_configs().items():
SiCong Li75041a12019-11-05 10:43:06 +0000263 for best_gemm_config, best_measurement in best_gc_sets:
Eren Kopuza0bf9132020-06-24 17:29:38 +0100264 yield BenchmarkResult(
265 gemm_param, strategy, best_gemm_config, best_measurement
266 )
SiCong Lie36b5262019-10-01 19:26:00 +0100267
268 def get_config_distributions(self):
269 """ Return GEMMConfigDistribution for each strategy
270 """
271 gemm_config_distributions: Dict[Strategy, GEMMConfigDistribution] = defaultdict(
272 GEMMConfigDistribution
273 )
SiCong Li75041a12019-11-05 10:43:06 +0000274 for benchmark_result in self.get_best_gemm_configs_as_sequence():
275 _, strategy, _, _ = benchmark_result
SiCong Lie36b5262019-10-01 19:26:00 +0100276 gemm_config_distributions[strategy].add(benchmark_result)
SiCong Li75041a12019-11-05 10:43:06 +0000277
SiCong Lie36b5262019-10-01 19:26:00 +0100278 return gemm_config_distributions
279
Eren Kopuza0bf9132020-06-24 17:29:38 +0100280 def get_best_gemm_strategies(self):
281 """ Get the best Stratey per GEMMParam
282 """
283 all_results: Dict[GEMMParam, List[Tuple[Strategy, Measurement]]] = defaultdict(
284 list
285 )
286
287 best_strategies: Dict[GEMMParam, Strategy] = {}
288
289 for gemm_param, strategy, gemm_config, measurement in self.get_record():
290 all_results[gemm_param].append((strategy, measurement))
291
292 for gemm_param, results_set in all_results.items():
293 # Sort the best results set (list)
294 results_set = sorted(
295 results_set, key=lambda s_and_m: s_and_m[1].get_total_ms()
296 )
297 # Select best Strategy
298 best_s, best_m = results_set[0]
299 best_strategies[gemm_param] = best_s
300
301 return best_strategies
302
303 def save_to_jsons(self, out_dir, only_best_config=True):
304 """ Save records to an output directory of JSON files.
305 The directory is organized such that each strategy gets its own JSON file.
306 The directory also includes a JSON file to define the best strategy per GEMM Param.
SiCong Lie36b5262019-10-01 19:26:00 +0100307 """
308 if not os.path.exists(out_dir):
Eren Kopuza0bf9132020-06-24 17:29:38 +0100309 logging.info(
310 "Output directory {} does not exist. Creating...".format(
311 out_dir)
SiCong Li75041a12019-11-05 10:43:06 +0000312 )
Eren Kopuza0bf9132020-06-24 17:29:38 +0100313 os.mkdir(out_dir)
314
315 out_json_path = os.path.join(out_dir, "gemm_type_selection.json")
316 if check_out_path(out_json_path):
317 results = self.get_best_gemm_strategies()
318 results = {str(key): value.name for key, value in results.items()}
319 dump_json(out_json_path, results)
320
321 for strategy in self._strategies:
322 out_json_path = os.path.join(
323 out_dir, ("gemm_config_" + strategy.name.lower() + ".json")
324 )
325 if check_out_path(out_json_path):
326 record = (
327 self.get_best_gemm_configs_as_sequence()
328 if only_best_config
329 else self.get_record()
SiCong Lie36b5262019-10-01 19:26:00 +0100330 )
Eren Kopuza0bf9132020-06-24 17:29:38 +0100331 results = defaultdict(list)
332 for res in record:
333 if res.strategy == strategy:
334 results[str(res.gemm_param)].append(
335 {
336 "GEMMConfig": str(res.gemm_config),
337 "OpenCL_Timer_ms_reshape": str(
338 res.measurement.opencl_timer_ms_reshape
339 ),
340 "OpenCL_Timer_ms_kernel": str(
341 res.measurement.opencl_timer_ms_kernel
342 ),
343 }
344 )
345 dump_json(out_json_path, results)
SiCong Lie36b5262019-10-01 19:26:00 +0100346
347 def summary(self, sum_level=SummaryLevel.Short):
348 """ Return the summary string of the record
349 """
SiCong Li75041a12019-11-05 10:43:06 +0000350 num_raw_records = sum(1 for _ in self.get_record())
SiCong Lie36b5262019-10-01 19:26:00 +0100351 gemm_params_per_strategy = defaultdict(list)
SiCong Li75041a12019-11-05 10:43:06 +0000352 for gemm_param, strategy in self.get_best_gemm_configs().keys():
SiCong Lie36b5262019-10-01 19:26:00 +0100353 gemm_params_per_strategy[strategy].append(gemm_param)
354 global_summary = f"""
355=== {self.__class__.__name__} Summary ===
356[Global]
357Strategies recorded: {", ".join(map(lambda s: s.name, self._strategies))}
358Total number of results recorded: {num_raw_records}
359
360[Per strategy]
361 """
362 strategy_summaries = []
363 for strategy in gemm_params_per_strategy:
364 summary = f"""
365Strategy {strategy.name}:
366GEMM parameters:
367 Number of: {len(gemm_params_per_strategy[strategy])}
368 """
369 if sum_level == self.__class__.SummaryLevel.Detailed:
370 summary += f"""
371 Content: {gemm_params_per_strategy[strategy]}
372 """
373 strategy_summaries.append(summary)
374 return global_summary + "".join(strategy_summaries)
375
SiCong Lie36b5262019-10-01 19:26:00 +0100376
377class GEMMConfigDistribution:
378 """ A representation of the GEMM Configuration distribution produced by the GEMMBenchmarkResultRecorder.
379 """
380
381 def __init__(self):
382 """ Initializer
383 """
Eren Kopuza0bf9132020-06-24 17:29:38 +0100384 self._gemm_config_dist: Dict[
385 GEMMConfig, List[Tuple[GEMMParam, Measurement]]
386 ] = defaultdict(list)
SiCong Lie36b5262019-10-01 19:26:00 +0100387 self._gemm_config_freq = Counter()
388
389 def add(self, benchmark_result: BenchmarkResult):
390 """ Add a benchmark result to the distribution
391 """
392 gemm_param, _, gemm_config, measurement = benchmark_result
393 self._gemm_config_dist[gemm_config].append((gemm_param, measurement))
394 self._gemm_config_freq[gemm_config] += 1
395
SiCong Lie36b5262019-10-01 19:26:00 +0100396 def distribution(self):
397 return self._gemm_config_dist
398
399 def frequency(self):
400 """ Get the frequency of each (best) gemm config recorded
401 """
SiCong Li75041a12019-11-05 10:43:06 +0000402 return self._gemm_config_freq.most_common()
SiCong Lie36b5262019-10-01 19:26:00 +0100403
404 def best_config(self):
405 """ Get the overall best config, as voted by all benchmark results.
406 """
407 return self._gemm_config_freq.most_common(1)
408
409 def std(self):
410 """ Get the standard deviation as a measure of dispersion of the distribution. We should aim for higher values
411 as they indicate there is high variation in the distribution. Thus the evidence of the best config is stronger.
412 """
413 freqs = self._gemm_config_freq.values()
414 if len(freqs) == 0:
415 return 0
416 mean_freq = sum(freqs) / len(freqs)
417 return math.sqrt(sum((freq - mean_freq) ** 2 for freq in freqs) / len(freqs))
418
419
420################################################################################
421# Globals
422################################################################################
423
424# Gemm config type factory
425# Produces a GEMMConfig type specific to a Strategy
426GEMM_CONFIG_FACTORY = {
427 Strategy.Native: NativeGEMMConfig,
428 Strategy.ReshapedOnlyRHS: ReshapedOnlyRHSGEMMConfig,
429 Strategy.Reshaped: ReshapedGEMMConfig,
430}
431
432# Mapping from example binary name to Strategy
433# Assume 1-to-1 mapping
434EXAMPLE_FILE_2_STRATEGY = {
435 "benchmark_cl_gemm_native": Strategy.Native,
436 "benchmark_cl_gemm_reshaped_rhs_only": Strategy.ReshapedOnlyRHS,
437 "benchmark_cl_gemm_reshaped": Strategy.Reshaped,
438}
439
440# Gemm example arguments type factory
441# Produces a Gemm_Example_Args type specific to a Strategy
442# Gemm example arguments consist of:
443# GEMMParam + GEMMConfig
444# in that order.
445# For example, the example args of running a reshaped rhs only example could be:
SiCong Liac4c0302020-07-28 12:24:45 +0100446# 100,100,100,1, 4, 4, 4, 1, 1, 1, 0
447# M ,N ,K, B,m0,n0,k0,h0,interleave_rhs,transpose_rhs,export_to_cl_image_rhs
448# <-GEMMParam-><-------------GEMMConfig--------------------------------------->
SiCong Lie36b5262019-10-01 19:26:00 +0100449# Note that the test strategy_name == strategy.name is in place to avoid unwanted enum aliases
450GEMM_EXAMPLE_ARGS_FACTORY = {
Eren Kopuz6977b372020-07-13 12:37:06 +0100451 # We ignore the data type field from GEMMParam as that is extracted separately
SiCong Lie36b5262019-10-01 19:26:00 +0100452 strategy: namedtuple(
453 "{}_Gemm_Example_Args".format(strategy_name),
Eren Kopuz6977b372020-07-13 12:37:06 +0100454 GEMMParam._fields[:-1] + GEMM_CONFIG_FACTORY[strategy]._fields,
SiCong Lie36b5262019-10-01 19:26:00 +0100455 )
456 for strategy_name, strategy in Strategy.__members__.items()
457 if strategy_name == strategy.name
458}
459
460# File extension used for benchmark result json files
461BENCHMARK_RESULT_JSON_EXTENSION = "gemmtuner_benchmark"
462
463################################################################################
464# Functions
465################################################################################
466
467
468def parse_benchmark_commandline(commandline: str) -> Dict[str, str]:
SiCong Liac4c0302020-07-28 12:24:45 +0100469 """ Parse the benchmark example command-line string into a dictionary of command-line arguments
SiCong Lie36b5262019-10-01 19:26:00 +0100470 """
Eren Kopuz6977b372020-07-13 12:37:06 +0100471 # Separate the data type option from the example_args portion of the string
472 commandline = commandline.replace(",--type=", " --type=")
473
SiCong Lie36b5262019-10-01 19:26:00 +0100474 args = commandline.split()
475 # Discard program name
476 args = args[1:]
477 # Split into a list of (argument name, argument value)
478 args = map(lambda arg: arg.split("="), args)
479
480 def transform(_name):
481 # Strip '-'/"--" if it exists
482 _name = _name.lstrip("-")
483 return _name
484
485 return {transform(name): val for name, val in args}
486
487
SiCong Li75041a12019-11-05 10:43:06 +0000488def extract_benchmark_results(
489 json_results: Dict, measurement_method="avg"
490) -> Generator[BenchmarkResult, None, None]:
SiCong Lie36b5262019-10-01 19:26:00 +0100491 """ Parse the benchmark result and extract relevant information, namely:
492 GEMM param,
493 Strategy,
494 GEMM config,
495 Measurements
496 """
497 for json_res in json_results:
498 # Get example test and test data.
499 # There should only be 1 test per run
500 example_tests = list(json_res["tests"].items())
501 assert len(example_tests) == 1
502 example_fn, example_test_data = example_tests[0]
503
504 # Process example file name
505 example_fn = example_fn.split(os.path.sep)[-1]
506
507 # Get strategy
508 strategy = EXAMPLE_FILE_2_STRATEGY[example_fn]
509
510 # Get gemm params + gemm configs from example args
511 benchmark_args = parse_benchmark_commandline(json_res["CommandLine"])
512 Gemm_Example_Args_T = GEMM_EXAMPLE_ARGS_FACTORY[strategy]
Eren Kopuza0bf9132020-06-24 17:29:38 +0100513 example_args = Gemm_Example_Args_T(
514 *(benchmark_args["example_args"].split(",")))
SiCong Lie36b5262019-10-01 19:26:00 +0100515 # Gemm_Example_Arg consists of GEMMParam first and then GEMMConfig (in that order)
Eren Kopuz6977b372020-07-13 12:37:06 +0100516 # However data type option is parsed separately from end of options, hence -1 is applied to fields length
517 gemm_param_fields_len = len(GEMMParam._fields) - 1
Eren Kopuza0bf9132020-06-24 17:29:38 +0100518 gemm_param = GEMMParam.parse_from_strs(
Eren Kopuz6977b372020-07-13 12:37:06 +0100519 *example_args[:gemm_param_fields_len],
520 data_type = benchmark_args["type"])
SiCong Lie36b5262019-10-01 19:26:00 +0100521 GEMMConfig = GEMM_CONFIG_FACTORY[strategy]
Eren Kopuza0bf9132020-06-24 17:29:38 +0100522 gemm_config = GEMMConfig.parse_from_strs(
523 *example_args[gemm_param_fields_len:])
SiCong Lie36b5262019-10-01 19:26:00 +0100524
525 # Get OpenCL_Time_Ms stats
526 measurements = list(example_test_data["measurements"].items())
Eren Kopuza0bf9132020-06-24 17:29:38 +0100527 # For reshaped RHS only we have two measurements (one also for the reshape kernel)
528 # Hence we must parse and sum them
529 measurement_ms_reshape = 0
530 measurement_ms_kernel = 0
531 for single_measurement in measurements:
532 measurement_instrument, data = single_measurement
533 # Get instrument name and assert that it is the one we expect
534 measurement_instrument_name = measurement_instrument.split("/")[0]
535 assert measurement_instrument_name == "OpenCLTimer"
536 # Take either the minimum or the average of the raw data as the measurement value
537 if measurement_method == "min":
538 measurement_val = min(data["raw"])
539 elif measurement_method == "avg":
540 measurement_val = sum(data["raw"]) / len(data["raw"])
541 else:
542 raise ValueError(
543 "Invalid measurement method: {}".format(measurement_method)
544 )
SiCong Li75041a12019-11-05 10:43:06 +0000545
Eren Kopuza0bf9132020-06-24 17:29:38 +0100546 measurement_type = measurement_instrument.split("/")[1]
547 if "reshape" in measurement_type.split("_"):
548 measurement_ms_reshape = measurement_val
549 else:
550 measurement_ms_kernel = measurement_val
551
552 measurement = Measurement(
553 measurement_ms_reshape, measurement_ms_kernel)
SiCong Lie36b5262019-10-01 19:26:00 +0100554
555 yield BenchmarkResult(gemm_param, strategy, gemm_config, measurement)
556
557
558def parse_json(dir_name):
559 """ Glob all benchmark result json files and parse them into json objects (dicts).
560 """
561 for res_fn in Path(dir_name).rglob("*.{}".format(BENCHMARK_RESULT_JSON_EXTENSION)):
562 with open(res_fn) as res_fp:
563 yield json.load(res_fp)
564
565
Eren Kopuza0bf9132020-06-24 17:29:38 +0100566def check_out_path(out_path):
567 if os.path.exists(out_path):
568 overwrite = (
569 input(
570 "Output JSON {} already exists. Overwrite? [Y/N]: ".format(
571 out_path)
572 ).lower()
573 == "y"
574 )
575 if not overwrite:
576 logging.info("Skipping {}".format(out_path))
577 return False
578 logging.info("Saving JSON file to {}".format(out_path))
579 return True
580
581
582def dump_json(out_path, dict):
583 with open(out_path, "w") as f:
584 json.dump(dict, f)
585 logging.info("Saved")
586
587
SiCong Lie36b5262019-10-01 19:26:00 +0100588################################################################################
589# Main
590################################################################################
591
592
593def main(args):
Eren Kopuza0bf9132020-06-24 17:29:38 +0100594 logging.info(
595 "Searching best gemm configurations from {}".format(
596 args.benchmark_results_dir)
597 )
SiCong Lie36b5262019-10-01 19:26:00 +0100598
Eren Kopuza0bf9132020-06-24 17:29:38 +0100599 benchmark_results = extract_benchmark_results(
600 parse_json(args.benchmark_results_dir)
601 )
SiCong Lie36b5262019-10-01 19:26:00 +0100602
603 # Add all benchmark results to the recorder
SiCong Li75041a12019-11-05 10:43:06 +0000604 benchmark_result_recorder = GEMMBenchmarkResultRecorder(tol=args.tolerance)
SiCong Lie36b5262019-10-01 19:26:00 +0100605 for benchmark_result in benchmark_results:
606 benchmark_result_recorder.add(benchmark_result)
607
608 if args.debug:
609 recorder_sum_level = GEMMBenchmarkResultRecorder.SummaryLevel.Detailed
610 else:
611 recorder_sum_level = GEMMBenchmarkResultRecorder.SummaryLevel.Short
612
613 # Print overall summary of the recorded results
Eren Kopuza0bf9132020-06-24 17:29:38 +0100614 logging.info(benchmark_result_recorder.summary(
615 sum_level=recorder_sum_level))
SiCong Lie36b5262019-10-01 19:26:00 +0100616
617 # Get GEMM configuration distributions for each strategy
618 all_config_dists = benchmark_result_recorder.get_config_distributions()
619
SiCong Li75041a12019-11-05 10:43:06 +0000620 logging.info("=== Result ===")
SiCong Lie36b5262019-10-01 19:26:00 +0100621 for strategy, config_dist in all_config_dists.items():
SiCong Li75041a12019-11-05 10:43:06 +0000622 logging.info("Strategy: {}".format(strategy.name))
623 logging.debug("GEMM Config, Votes")
624 for config, freq in config_dist.frequency():
625 logging.debug("{}, {}".format(config, freq))
626 logging.info(
Eren Kopuza0bf9132020-06-24 17:29:38 +0100627 "Best GEMM Config: {} with std: {}".format(
628 config_dist.best_config(), config_dist.std()
629 )
SiCong Lie36b5262019-10-01 19:26:00 +0100630 )
631
Eren Kopuza0bf9132020-06-24 17:29:38 +0100632 # Save the recorded results to JSON files in output directory
SiCong Lie36b5262019-10-01 19:26:00 +0100633 if args.output_dir is not None:
Eren Kopuza0bf9132020-06-24 17:29:38 +0100634 benchmark_result_recorder.save_to_jsons(
635 args.output_dir, only_best_config=(not args.debug)
636 )
SiCong Lie36b5262019-10-01 19:26:00 +0100637
638
639if __name__ == "__main__":
640 parser = argparse.ArgumentParser(description="CL GEMM Tuner")
641 parser.add_argument(
642 "-b",
643 "--benchmark_results",
644 dest="benchmark_results_dir",
645 metavar="PATH",
646 action="store",
647 type=str,
648 help="Path to benchmark result directory, where benchmark result json files have a file \
649 extension of '{}'".format(
650 BENCHMARK_RESULT_JSON_EXTENSION
651 ),
652 required=True,
653 )
654 parser.add_argument(
655 "-o",
656 "--output_dir",
657 dest="output_dir",
658 metavar="PATH",
659 action="store",
660 type=str,
Eren Kopuza0bf9132020-06-24 17:29:38 +0100661 help="Path to directory that holds output JSON files. One for strategy selection and one per strategy for GEMM config selection",
SiCong Lie36b5262019-10-01 19:26:00 +0100662 )
663 parser.add_argument(
SiCong Li75041a12019-11-05 10:43:06 +0000664 "-t",
665 "--tolerance",
666 action="store",
667 type=float,
668 default=0.01,
669 help="For testing if two GEMMConfigs are equivalent in terms of performance. The tolerance is OpenCL timer in\
670 milliseconds. Recommended value: <= 0.1 ms",
671 )
672 parser.add_argument(
Eren Kopuza0bf9132020-06-24 17:29:38 +0100673 "-D",
674 "--debug",
675 dest="debug",
676 action="store_true",
677 help="Enable script debugging output",
SiCong Lie36b5262019-10-01 19:26:00 +0100678 )
679 args = parser.parse_args()
680 logging_level = logging.DEBUG if args.debug else logging.INFO
681 logging.basicConfig(level=logging_level)
682 logging.debug("Arguments: {}".format(args))
683 main(args)