blob: 1361f9092976825f46467588d752bd1cd7f2a8d8 [file] [log] [blame]
Eren Kopuza0bf9132020-06-24 17:29:38 +01001# Copyright (c) 2019-2020 ARM Limited.
SiCong Lie36b5262019-10-01 19:26:00 +01002#
3# SPDX-License-Identifier: MIT
4#
5# Permission is hereby granted, free of charge, to any person obtaining a copy
6# of this software and associated documentation files (the "Software"), to
7# deal in the Software without restriction, including without limitation the
8# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
9# sell copies of the Software, and to permit persons to whom the Software is
10# furnished to do so, subject to the following conditions:
11#
12# The above copyright notice and this permission notice shall be included in all
13# copies or substantial portions of the Software.
14#
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21# SOFTWARE.
22
23#!/usr/bin/python3
24
25import argparse
26import csv
27import json
28import logging
29import math
30import os
31from collections import Counter, defaultdict, deque, namedtuple
32from enum import Enum
33from pathlib import Path
SiCong Li75041a12019-11-05 10:43:06 +000034from typing import Deque, Dict, Generator, List, NamedTuple, Set, Tuple, Union
SiCong Lie36b5262019-10-01 19:26:00 +010035
36################################################################################
37# Types
38################################################################################
39
40# Gemm strategy
41Strategy = Enum("Strategy", ["Native", "ReshapedOnlyRHS", "Reshaped"])
42
43# Gemm parameter
Eren Kopuza0bf9132020-06-24 17:29:38 +010044
45
SiCong Lie36b5262019-10-01 19:26:00 +010046class GEMMParam(NamedTuple):
47 M: int # Number of lhs matrix rows
48 N: int # Number of rhs matrix columns
49 K: int # Number of lhs matrix columns/rhs matrix rows
50 B: int # Batch size
Eren Kopuz6977b372020-07-13 12:37:06 +010051 data_type: str # Data type
SiCong Lie36b5262019-10-01 19:26:00 +010052
53 @staticmethod
Eren Kopuz6977b372020-07-13 12:37:06 +010054 def parse_from_strs(*M_N_K_B, data_type):
55 return GEMMParam(*map(int, M_N_K_B),str(data_type))
SiCong Lie36b5262019-10-01 19:26:00 +010056
SiCong Li75041a12019-11-05 10:43:06 +000057 def __str__(self):
Eren Kopuza0bf9132020-06-24 17:29:38 +010058 return ",".join(map(str, self))
SiCong Li75041a12019-11-05 10:43:06 +000059
SiCong Lie36b5262019-10-01 19:26:00 +010060
61# Gemm configuration for strategy Native
62class NativeGEMMConfig(NamedTuple):
63 m0: int # Number of rows processed by the matrix multiplication
64 n0: int # Number of columns processed by the matrix multiplication
65 k0: int # Number of partial accumulations performed by the matrix multiplication
66
67 @staticmethod
68 def parse_from_strs(*args):
Eren Kopuza0bf9132020-06-24 17:29:38 +010069 (*mnk,) = map(int, args)
SiCong Lie36b5262019-10-01 19:26:00 +010070 return NativeGEMMConfig(*mnk)
71
SiCong Li75041a12019-11-05 10:43:06 +000072 def __str__(self):
Eren Kopuza0bf9132020-06-24 17:29:38 +010073 return ",".join(map(str, self))
SiCong Li75041a12019-11-05 10:43:06 +000074
SiCong Lie36b5262019-10-01 19:26:00 +010075
76# Gemm configuration for strategy Reshaped Only RHS
77class ReshapedOnlyRHSGEMMConfig(NamedTuple):
78 m0: int # Number of rows processed by the matrix multiplication
79 n0: int # Number of columns processed by the matrix multiplication
80 k0: int # Number of partial accumulations performed by the matrix multiplication
Eren Kopuza0bf9132020-06-24 17:29:38 +010081 # Number of horizontal blocks of size (k0xn0) stored on the same output row
82 h0: int
83 # Interleave rhs matrix (1) / Do not interleave rhs matrix (0)
84 interleave_rhs: bool
85 # Transpose rhs matrix but not lhs matrix (1) / Do not transpose rhs matrix but do transpose lhs matrix (0)
86 transpose_rhs: bool
SiCong Lie36b5262019-10-01 19:26:00 +010087
88 @staticmethod
89 def parse_from_strs(*args):
90 *mnkh, interleave_rhs, transpose_rhs = map(int, args)
91 interleave_rhs = interleave_rhs == 1
92 transpose_rhs = transpose_rhs == 1
93 return ReshapedOnlyRHSGEMMConfig(*mnkh, interleave_rhs, transpose_rhs)
94
SiCong Li75041a12019-11-05 10:43:06 +000095 def __str__(self):
Eren Kopuza0bf9132020-06-24 17:29:38 +010096 return ",".join(map(str, self))
SiCong Li75041a12019-11-05 10:43:06 +000097
SiCong Lie36b5262019-10-01 19:26:00 +010098
99# Gemm configuration for strategy Reshaped
100class ReshapedGEMMConfig(NamedTuple):
101 m0: int # Number of rows processed by the matrix multiplication
102 n0: int # Number of columns processed by the matrix multiplication
103 k0: int # Number of partial accumulations performed by the matrix multiplication
Eren Kopuza0bf9132020-06-24 17:29:38 +0100104 # Number of vertical blocks of size (m0xk0) stored on the same output row
105 v0: int
106 # Number of horizontal blocks of size (k0xn0) stored on the same output row
107 h0: int
108 # Interleave lhs matrix (1) / Do not interleave lhs matrix (0)
109 interleave_lhs: bool
110 # Interleave rhs matrix (1) / Do not interleave rhs matrix (0)
111 interleave_rhs: bool
112 # Transpose rhs matrix but not lhs matrix (1) / Do not transpose rhs matrix but do transpose lhs matrix (0)
113 transpose_rhs: bool
SiCong Lie36b5262019-10-01 19:26:00 +0100114
115 @staticmethod
116 def parse_from_strs(*args):
117 *mnkvh, interleave_lhs, interleave_rhs, transpose_rhs = map(int, args)
118 interleave_lhs = interleave_lhs == 1
119 interleave_rhs = interleave_rhs == 1
120 transpose_rhs = transpose_rhs == 1
121 return ReshapedGEMMConfig(*mnkvh, interleave_lhs, interleave_rhs, transpose_rhs)
122
SiCong Li75041a12019-11-05 10:43:06 +0000123 def __str__(self):
Eren Kopuza0bf9132020-06-24 17:29:38 +0100124 return ",".join(map(str, self))
SiCong Li75041a12019-11-05 10:43:06 +0000125
SiCong Lie36b5262019-10-01 19:26:00 +0100126
127# Measurement we take from the benchmark result.
128class Measurement(NamedTuple):
Eren Kopuza0bf9132020-06-24 17:29:38 +0100129 opencl_timer_ms_reshape: float
130 opencl_timer_ms_kernel: float
131
132 def get_total_ms(self):
133 return self.opencl_timer_ms_reshape + self.opencl_timer_ms_kernel
SiCong Lie36b5262019-10-01 19:26:00 +0100134
SiCong Li75041a12019-11-05 10:43:06 +0000135 def is_close_to(self, other, tol):
Eren Kopuza0bf9132020-06-24 17:29:38 +0100136 return math.fabs(self.get_total_ms() - other.get_total_ms()) < tol
SiCong Li75041a12019-11-05 10:43:06 +0000137
138 def is_better_than(self, other, tol):
Eren Kopuza0bf9132020-06-24 17:29:38 +0100139 return self.get_total_ms() < other.get_total_ms() and not self.is_close_to(
140 other
141 )
SiCong Lie36b5262019-10-01 19:26:00 +0100142
143 def __add__(self, other):
Eren Kopuza0bf9132020-06-24 17:29:38 +0100144 return Measurement(
145 self.opencl_timer_ms_reshape + other.opencl_timer_ms_reshape,
146 self.opencl_timer_ms_kernel + other.opencl_timer_ms_kernel,
147 )
SiCong Lie36b5262019-10-01 19:26:00 +0100148
149 def __sub__(self, other):
Eren Kopuza0bf9132020-06-24 17:29:38 +0100150 return Measurement(
151 self.opencl_timer_ms_reshape - other.opencl_timer_ms_reshape,
152 self.opencl_timer_ms_kernel - other.opencl_timer_ms_kernel,
153 )
SiCong Lie36b5262019-10-01 19:26:00 +0100154
155 def __mul__(self, other):
Eren Kopuza0bf9132020-06-24 17:29:38 +0100156 return Measurement(
157 self.opencl_timer_ms_reshape * other.opencl_timer_ms_reshape,
158 self.opencl_timer_ms_kernel * other.opencl_timer_ms_kernel,
159 )
SiCong Lie36b5262019-10-01 19:26:00 +0100160
161 def __floordiv__(self, other):
Eren Kopuza0bf9132020-06-24 17:29:38 +0100162 return Measurement(
163 self.opencl_timer_ms_reshape // other.opencl_timer_ms_reshape,
164 self.opencl_timer_ms_kernel // other.opencl_timer_ms_kernel,
165 )
SiCong Lie36b5262019-10-01 19:26:00 +0100166
167 def __truediv__(self, other):
Eren Kopuza0bf9132020-06-24 17:29:38 +0100168 return Measurement(
169 self.opencl_timer_ms_reshape / other.opencl_timer_ms_reshape,
170 self.opencl_timer_ms_kernel / other.opencl_timer_ms_kernel,
171 )
SiCong Lie36b5262019-10-01 19:26:00 +0100172
173 def __pow__(self, power):
Eren Kopuza0bf9132020-06-24 17:29:38 +0100174 return Measurement(
175 self.opencl_timer_ms_reshape ** power, self.opencl_timer_ms_kernel ** power
176 )
177
178 def __str__(self):
179 return ",".join(map(str, self))
SiCong Lie36b5262019-10-01 19:26:00 +0100180
181
182# GEMMConfig Type
Eren Kopuza0bf9132020-06-24 17:29:38 +0100183GEMMConfigT = Union[NativeGEMMConfig,
184 ReshapedOnlyRHSGEMMConfig, ReshapedGEMMConfig]
SiCong Lie36b5262019-10-01 19:26:00 +0100185
186
187# Representation of the benchmark result from a single experiment
188class BenchmarkResult(NamedTuple):
189 gemm_param: GEMMParam
190 strategy: Strategy
191 gemm_config: GEMMConfigT
192 measurement: Measurement
193
194
SiCong Lie36b5262019-10-01 19:26:00 +0100195class GEMMBenchmarkResultRecorder:
196 """ A recorder that records and organises GEMM Benchmark results, and produces various reports on the record.
197 """
198
199 SummaryLevel = Enum("SummaryLevel", ["Short", "Detailed"])
200
SiCong Li75041a12019-11-05 10:43:06 +0000201 def __init__(self, tol=0.01):
SiCong Lie36b5262019-10-01 19:26:00 +0100202 """ Initializer
203 """
SiCong Li75041a12019-11-05 10:43:06 +0000204 self._benchmark_result_record: List[BenchmarkResult] = []
SiCong Lie36b5262019-10-01 19:26:00 +0100205 # Strategies recorded
206 self._strategies = set()
SiCong Li75041a12019-11-05 10:43:06 +0000207 self._tol = tol
SiCong Lie36b5262019-10-01 19:26:00 +0100208
209 def add(self, benchmark_result: BenchmarkResult):
210 """ Add a benchmark result to the record.
SiCong Lie36b5262019-10-01 19:26:00 +0100211 """
212 gemm_param, strategy, gemm_config, measurement = benchmark_result
213 # Update strategies encoutnered
214 self._strategies.add(strategy)
SiCong Li75041a12019-11-05 10:43:06 +0000215
216 self._benchmark_result_record.append(benchmark_result)
217
218 def get_record(self) -> Generator[BenchmarkResult, None, None]:
219 """ Return an iterator that iterates over the record.
220 """
221 yield from self._benchmark_result_record
222
223 def get_best_gemm_configs(self):
224 """ Get the best GEMMConfig set per GEMMParam per Strategy
225 """
226 best_gc_sets: Dict[
227 Tuple[GEMMParam, Strategy], List[Tuple[GEMMConfig, Measurement]]
228 ] = defaultdict(list)
229 for gemm_param, strategy, gemm_config, measurement in self.get_record():
230 best_gc_set = best_gc_sets.setdefault((gemm_param, strategy), [])
231 best_gc_set.append((gemm_config, measurement))
232 # Sort the best config set (list)
Eren Kopuza0bf9132020-06-24 17:29:38 +0100233 best_gc_set = sorted(
234 best_gc_set, key=lambda gc_and_m: gc_and_m[1].get_total_ms()
235 )
SiCong Li75041a12019-11-05 10:43:06 +0000236 # Filter out configs that are beyond tolerance to the best GEMMConfig's measurement
237 best_gc, best_m = best_gc_set[0]
238 best_gc_set_new = [
239 (gemm_config, measurement)
240 for gemm_config, measurement in best_gc_set[1:]
241 if measurement.is_close_to(best_m, self._tol)
242 ]
243 # Add back the best config
244 best_gc_set_new.insert(0, (best_gc, best_m))
245 best_gc_sets[(gemm_param, strategy)] = best_gc_set_new
246
247 return best_gc_sets
248
249 def get_best_gemm_configs_as_sequence(self):
250 """ Get the best GEMMConfig set per GEMMParam per Strategy, and flatten the result into a sequence
251 of BenchmarkResults
252 """
Eren Kopuza0bf9132020-06-24 17:29:38 +0100253 for (
254 (gemm_param, strategy),
255 best_gc_sets,
256 ) in self.get_best_gemm_configs().items():
SiCong Li75041a12019-11-05 10:43:06 +0000257 for best_gemm_config, best_measurement in best_gc_sets:
Eren Kopuza0bf9132020-06-24 17:29:38 +0100258 yield BenchmarkResult(
259 gemm_param, strategy, best_gemm_config, best_measurement
260 )
SiCong Lie36b5262019-10-01 19:26:00 +0100261
262 def get_config_distributions(self):
263 """ Return GEMMConfigDistribution for each strategy
264 """
265 gemm_config_distributions: Dict[Strategy, GEMMConfigDistribution] = defaultdict(
266 GEMMConfigDistribution
267 )
SiCong Li75041a12019-11-05 10:43:06 +0000268 for benchmark_result in self.get_best_gemm_configs_as_sequence():
269 _, strategy, _, _ = benchmark_result
SiCong Lie36b5262019-10-01 19:26:00 +0100270 gemm_config_distributions[strategy].add(benchmark_result)
SiCong Li75041a12019-11-05 10:43:06 +0000271
SiCong Lie36b5262019-10-01 19:26:00 +0100272 return gemm_config_distributions
273
Eren Kopuza0bf9132020-06-24 17:29:38 +0100274 def get_best_gemm_strategies(self):
275 """ Get the best Stratey per GEMMParam
276 """
277 all_results: Dict[GEMMParam, List[Tuple[Strategy, Measurement]]] = defaultdict(
278 list
279 )
280
281 best_strategies: Dict[GEMMParam, Strategy] = {}
282
283 for gemm_param, strategy, gemm_config, measurement in self.get_record():
284 all_results[gemm_param].append((strategy, measurement))
285
286 for gemm_param, results_set in all_results.items():
287 # Sort the best results set (list)
288 results_set = sorted(
289 results_set, key=lambda s_and_m: s_and_m[1].get_total_ms()
290 )
291 # Select best Strategy
292 best_s, best_m = results_set[0]
293 best_strategies[gemm_param] = best_s
294
295 return best_strategies
296
297 def save_to_jsons(self, out_dir, only_best_config=True):
298 """ Save records to an output directory of JSON files.
299 The directory is organized such that each strategy gets its own JSON file.
300 The directory also includes a JSON file to define the best strategy per GEMM Param.
SiCong Lie36b5262019-10-01 19:26:00 +0100301 """
302 if not os.path.exists(out_dir):
Eren Kopuza0bf9132020-06-24 17:29:38 +0100303 logging.info(
304 "Output directory {} does not exist. Creating...".format(
305 out_dir)
SiCong Li75041a12019-11-05 10:43:06 +0000306 )
Eren Kopuza0bf9132020-06-24 17:29:38 +0100307 os.mkdir(out_dir)
308
309 out_json_path = os.path.join(out_dir, "gemm_type_selection.json")
310 if check_out_path(out_json_path):
311 results = self.get_best_gemm_strategies()
312 results = {str(key): value.name for key, value in results.items()}
313 dump_json(out_json_path, results)
314
315 for strategy in self._strategies:
316 out_json_path = os.path.join(
317 out_dir, ("gemm_config_" + strategy.name.lower() + ".json")
318 )
319 if check_out_path(out_json_path):
320 record = (
321 self.get_best_gemm_configs_as_sequence()
322 if only_best_config
323 else self.get_record()
SiCong Lie36b5262019-10-01 19:26:00 +0100324 )
Eren Kopuza0bf9132020-06-24 17:29:38 +0100325 results = defaultdict(list)
326 for res in record:
327 if res.strategy == strategy:
328 results[str(res.gemm_param)].append(
329 {
330 "GEMMConfig": str(res.gemm_config),
331 "OpenCL_Timer_ms_reshape": str(
332 res.measurement.opencl_timer_ms_reshape
333 ),
334 "OpenCL_Timer_ms_kernel": str(
335 res.measurement.opencl_timer_ms_kernel
336 ),
337 }
338 )
339 dump_json(out_json_path, results)
SiCong Lie36b5262019-10-01 19:26:00 +0100340
341 def summary(self, sum_level=SummaryLevel.Short):
342 """ Return the summary string of the record
343 """
SiCong Li75041a12019-11-05 10:43:06 +0000344 num_raw_records = sum(1 for _ in self.get_record())
SiCong Lie36b5262019-10-01 19:26:00 +0100345 gemm_params_per_strategy = defaultdict(list)
SiCong Li75041a12019-11-05 10:43:06 +0000346 for gemm_param, strategy in self.get_best_gemm_configs().keys():
SiCong Lie36b5262019-10-01 19:26:00 +0100347 gemm_params_per_strategy[strategy].append(gemm_param)
348 global_summary = f"""
349=== {self.__class__.__name__} Summary ===
350[Global]
351Strategies recorded: {", ".join(map(lambda s: s.name, self._strategies))}
352Total number of results recorded: {num_raw_records}
353
354[Per strategy]
355 """
356 strategy_summaries = []
357 for strategy in gemm_params_per_strategy:
358 summary = f"""
359Strategy {strategy.name}:
360GEMM parameters:
361 Number of: {len(gemm_params_per_strategy[strategy])}
362 """
363 if sum_level == self.__class__.SummaryLevel.Detailed:
364 summary += f"""
365 Content: {gemm_params_per_strategy[strategy]}
366 """
367 strategy_summaries.append(summary)
368 return global_summary + "".join(strategy_summaries)
369
SiCong Lie36b5262019-10-01 19:26:00 +0100370
371class GEMMConfigDistribution:
372 """ A representation of the GEMM Configuration distribution produced by the GEMMBenchmarkResultRecorder.
373 """
374
375 def __init__(self):
376 """ Initializer
377 """
Eren Kopuza0bf9132020-06-24 17:29:38 +0100378 self._gemm_config_dist: Dict[
379 GEMMConfig, List[Tuple[GEMMParam, Measurement]]
380 ] = defaultdict(list)
SiCong Lie36b5262019-10-01 19:26:00 +0100381 self._gemm_config_freq = Counter()
382
383 def add(self, benchmark_result: BenchmarkResult):
384 """ Add a benchmark result to the distribution
385 """
386 gemm_param, _, gemm_config, measurement = benchmark_result
387 self._gemm_config_dist[gemm_config].append((gemm_param, measurement))
388 self._gemm_config_freq[gemm_config] += 1
389
SiCong Lie36b5262019-10-01 19:26:00 +0100390 def distribution(self):
391 return self._gemm_config_dist
392
393 def frequency(self):
394 """ Get the frequency of each (best) gemm config recorded
395 """
SiCong Li75041a12019-11-05 10:43:06 +0000396 return self._gemm_config_freq.most_common()
SiCong Lie36b5262019-10-01 19:26:00 +0100397
398 def best_config(self):
399 """ Get the overall best config, as voted by all benchmark results.
400 """
401 return self._gemm_config_freq.most_common(1)
402
403 def std(self):
404 """ Get the standard deviation as a measure of dispersion of the distribution. We should aim for higher values
405 as they indicate there is high variation in the distribution. Thus the evidence of the best config is stronger.
406 """
407 freqs = self._gemm_config_freq.values()
408 if len(freqs) == 0:
409 return 0
410 mean_freq = sum(freqs) / len(freqs)
411 return math.sqrt(sum((freq - mean_freq) ** 2 for freq in freqs) / len(freqs))
412
413
414################################################################################
415# Globals
416################################################################################
417
418# Gemm config type factory
419# Produces a GEMMConfig type specific to a Strategy
420GEMM_CONFIG_FACTORY = {
421 Strategy.Native: NativeGEMMConfig,
422 Strategy.ReshapedOnlyRHS: ReshapedOnlyRHSGEMMConfig,
423 Strategy.Reshaped: ReshapedGEMMConfig,
424}
425
426# Mapping from example binary name to Strategy
427# Assume 1-to-1 mapping
428EXAMPLE_FILE_2_STRATEGY = {
429 "benchmark_cl_gemm_native": Strategy.Native,
430 "benchmark_cl_gemm_reshaped_rhs_only": Strategy.ReshapedOnlyRHS,
431 "benchmark_cl_gemm_reshaped": Strategy.Reshaped,
432}
433
434# Gemm example arguments type factory
435# Produces a Gemm_Example_Args type specific to a Strategy
436# Gemm example arguments consist of:
437# GEMMParam + GEMMConfig
438# in that order.
439# For example, the example args of running a reshaped rhs only example could be:
440# 100,100,100,1, 4, 4, 4, 1, 1, 1
441# M ,N ,K, B,m0,n0,k0,h0,interleave_rhs,transpose_rhs
442# <-GEMMParam-><-------------GEMMConfig-------------->
443# Note that the test strategy_name == strategy.name is in place to avoid unwanted enum aliases
444GEMM_EXAMPLE_ARGS_FACTORY = {
Eren Kopuz6977b372020-07-13 12:37:06 +0100445 # We ignore the data type field from GEMMParam as that is extracted separately
SiCong Lie36b5262019-10-01 19:26:00 +0100446 strategy: namedtuple(
447 "{}_Gemm_Example_Args".format(strategy_name),
Eren Kopuz6977b372020-07-13 12:37:06 +0100448 GEMMParam._fields[:-1] + GEMM_CONFIG_FACTORY[strategy]._fields,
SiCong Lie36b5262019-10-01 19:26:00 +0100449 )
450 for strategy_name, strategy in Strategy.__members__.items()
451 if strategy_name == strategy.name
452}
453
454# File extension used for benchmark result json files
455BENCHMARK_RESULT_JSON_EXTENSION = "gemmtuner_benchmark"
456
457################################################################################
458# Functions
459################################################################################
460
461
462def parse_benchmark_commandline(commandline: str) -> Dict[str, str]:
463 """ Parse the benchmark example command-line string into a dictionary of command-line agruments
464 """
Eren Kopuz6977b372020-07-13 12:37:06 +0100465 # Separate the data type option from the example_args portion of the string
466 commandline = commandline.replace(",--type=", " --type=")
467
SiCong Lie36b5262019-10-01 19:26:00 +0100468 args = commandline.split()
469 # Discard program name
470 args = args[1:]
471 # Split into a list of (argument name, argument value)
472 args = map(lambda arg: arg.split("="), args)
473
474 def transform(_name):
475 # Strip '-'/"--" if it exists
476 _name = _name.lstrip("-")
477 return _name
478
479 return {transform(name): val for name, val in args}
480
481
SiCong Li75041a12019-11-05 10:43:06 +0000482def extract_benchmark_results(
483 json_results: Dict, measurement_method="avg"
484) -> Generator[BenchmarkResult, None, None]:
SiCong Lie36b5262019-10-01 19:26:00 +0100485 """ Parse the benchmark result and extract relevant information, namely:
486 GEMM param,
487 Strategy,
488 GEMM config,
489 Measurements
490 """
491 for json_res in json_results:
492 # Get example test and test data.
493 # There should only be 1 test per run
494 example_tests = list(json_res["tests"].items())
495 assert len(example_tests) == 1
496 example_fn, example_test_data = example_tests[0]
497
498 # Process example file name
499 example_fn = example_fn.split(os.path.sep)[-1]
500
501 # Get strategy
502 strategy = EXAMPLE_FILE_2_STRATEGY[example_fn]
503
504 # Get gemm params + gemm configs from example args
505 benchmark_args = parse_benchmark_commandline(json_res["CommandLine"])
506 Gemm_Example_Args_T = GEMM_EXAMPLE_ARGS_FACTORY[strategy]
Eren Kopuza0bf9132020-06-24 17:29:38 +0100507 example_args = Gemm_Example_Args_T(
508 *(benchmark_args["example_args"].split(",")))
SiCong Lie36b5262019-10-01 19:26:00 +0100509 # Gemm_Example_Arg consists of GEMMParam first and then GEMMConfig (in that order)
Eren Kopuz6977b372020-07-13 12:37:06 +0100510 # However data type option is parsed separately from end of options, hence -1 is applied to fields length
511 gemm_param_fields_len = len(GEMMParam._fields) - 1
Eren Kopuza0bf9132020-06-24 17:29:38 +0100512 gemm_param = GEMMParam.parse_from_strs(
Eren Kopuz6977b372020-07-13 12:37:06 +0100513 *example_args[:gemm_param_fields_len],
514 data_type = benchmark_args["type"])
SiCong Lie36b5262019-10-01 19:26:00 +0100515 GEMMConfig = GEMM_CONFIG_FACTORY[strategy]
Eren Kopuza0bf9132020-06-24 17:29:38 +0100516 gemm_config = GEMMConfig.parse_from_strs(
517 *example_args[gemm_param_fields_len:])
SiCong Lie36b5262019-10-01 19:26:00 +0100518
519 # Get OpenCL_Time_Ms stats
520 measurements = list(example_test_data["measurements"].items())
Eren Kopuza0bf9132020-06-24 17:29:38 +0100521 # For reshaped RHS only we have two measurements (one also for the reshape kernel)
522 # Hence we must parse and sum them
523 measurement_ms_reshape = 0
524 measurement_ms_kernel = 0
525 for single_measurement in measurements:
526 measurement_instrument, data = single_measurement
527 # Get instrument name and assert that it is the one we expect
528 measurement_instrument_name = measurement_instrument.split("/")[0]
529 assert measurement_instrument_name == "OpenCLTimer"
530 # Take either the minimum or the average of the raw data as the measurement value
531 if measurement_method == "min":
532 measurement_val = min(data["raw"])
533 elif measurement_method == "avg":
534 measurement_val = sum(data["raw"]) / len(data["raw"])
535 else:
536 raise ValueError(
537 "Invalid measurement method: {}".format(measurement_method)
538 )
SiCong Li75041a12019-11-05 10:43:06 +0000539
Eren Kopuza0bf9132020-06-24 17:29:38 +0100540 measurement_type = measurement_instrument.split("/")[1]
541 if "reshape" in measurement_type.split("_"):
542 measurement_ms_reshape = measurement_val
543 else:
544 measurement_ms_kernel = measurement_val
545
546 measurement = Measurement(
547 measurement_ms_reshape, measurement_ms_kernel)
SiCong Lie36b5262019-10-01 19:26:00 +0100548
549 yield BenchmarkResult(gemm_param, strategy, gemm_config, measurement)
550
551
552def parse_json(dir_name):
553 """ Glob all benchmark result json files and parse them into json objects (dicts).
554 """
555 for res_fn in Path(dir_name).rglob("*.{}".format(BENCHMARK_RESULT_JSON_EXTENSION)):
556 with open(res_fn) as res_fp:
557 yield json.load(res_fp)
558
559
Eren Kopuza0bf9132020-06-24 17:29:38 +0100560def check_out_path(out_path):
561 if os.path.exists(out_path):
562 overwrite = (
563 input(
564 "Output JSON {} already exists. Overwrite? [Y/N]: ".format(
565 out_path)
566 ).lower()
567 == "y"
568 )
569 if not overwrite:
570 logging.info("Skipping {}".format(out_path))
571 return False
572 logging.info("Saving JSON file to {}".format(out_path))
573 return True
574
575
576def dump_json(out_path, dict):
577 with open(out_path, "w") as f:
578 json.dump(dict, f)
579 logging.info("Saved")
580
581
SiCong Lie36b5262019-10-01 19:26:00 +0100582################################################################################
583# Main
584################################################################################
585
586
587def main(args):
Eren Kopuza0bf9132020-06-24 17:29:38 +0100588 logging.info(
589 "Searching best gemm configurations from {}".format(
590 args.benchmark_results_dir)
591 )
SiCong Lie36b5262019-10-01 19:26:00 +0100592
Eren Kopuza0bf9132020-06-24 17:29:38 +0100593 benchmark_results = extract_benchmark_results(
594 parse_json(args.benchmark_results_dir)
595 )
SiCong Lie36b5262019-10-01 19:26:00 +0100596
597 # Add all benchmark results to the recorder
SiCong Li75041a12019-11-05 10:43:06 +0000598 benchmark_result_recorder = GEMMBenchmarkResultRecorder(tol=args.tolerance)
SiCong Lie36b5262019-10-01 19:26:00 +0100599 for benchmark_result in benchmark_results:
600 benchmark_result_recorder.add(benchmark_result)
601
602 if args.debug:
603 recorder_sum_level = GEMMBenchmarkResultRecorder.SummaryLevel.Detailed
604 else:
605 recorder_sum_level = GEMMBenchmarkResultRecorder.SummaryLevel.Short
606
607 # Print overall summary of the recorded results
Eren Kopuza0bf9132020-06-24 17:29:38 +0100608 logging.info(benchmark_result_recorder.summary(
609 sum_level=recorder_sum_level))
SiCong Lie36b5262019-10-01 19:26:00 +0100610
611 # Get GEMM configuration distributions for each strategy
612 all_config_dists = benchmark_result_recorder.get_config_distributions()
613
SiCong Li75041a12019-11-05 10:43:06 +0000614 logging.info("=== Result ===")
SiCong Lie36b5262019-10-01 19:26:00 +0100615 for strategy, config_dist in all_config_dists.items():
SiCong Li75041a12019-11-05 10:43:06 +0000616 logging.info("Strategy: {}".format(strategy.name))
617 logging.debug("GEMM Config, Votes")
618 for config, freq in config_dist.frequency():
619 logging.debug("{}, {}".format(config, freq))
620 logging.info(
Eren Kopuza0bf9132020-06-24 17:29:38 +0100621 "Best GEMM Config: {} with std: {}".format(
622 config_dist.best_config(), config_dist.std()
623 )
SiCong Lie36b5262019-10-01 19:26:00 +0100624 )
625
Eren Kopuza0bf9132020-06-24 17:29:38 +0100626 # Save the recorded results to JSON files in output directory
SiCong Lie36b5262019-10-01 19:26:00 +0100627 if args.output_dir is not None:
Eren Kopuza0bf9132020-06-24 17:29:38 +0100628 benchmark_result_recorder.save_to_jsons(
629 args.output_dir, only_best_config=(not args.debug)
630 )
SiCong Lie36b5262019-10-01 19:26:00 +0100631
632
633if __name__ == "__main__":
634 parser = argparse.ArgumentParser(description="CL GEMM Tuner")
635 parser.add_argument(
636 "-b",
637 "--benchmark_results",
638 dest="benchmark_results_dir",
639 metavar="PATH",
640 action="store",
641 type=str,
642 help="Path to benchmark result directory, where benchmark result json files have a file \
643 extension of '{}'".format(
644 BENCHMARK_RESULT_JSON_EXTENSION
645 ),
646 required=True,
647 )
648 parser.add_argument(
649 "-o",
650 "--output_dir",
651 dest="output_dir",
652 metavar="PATH",
653 action="store",
654 type=str,
Eren Kopuza0bf9132020-06-24 17:29:38 +0100655 help="Path to directory that holds output JSON files. One for strategy selection and one per strategy for GEMM config selection",
SiCong Lie36b5262019-10-01 19:26:00 +0100656 )
657 parser.add_argument(
SiCong Li75041a12019-11-05 10:43:06 +0000658 "-t",
659 "--tolerance",
660 action="store",
661 type=float,
662 default=0.01,
663 help="For testing if two GEMMConfigs are equivalent in terms of performance. The tolerance is OpenCL timer in\
664 milliseconds. Recommended value: <= 0.1 ms",
665 )
666 parser.add_argument(
Eren Kopuza0bf9132020-06-24 17:29:38 +0100667 "-D",
668 "--debug",
669 dest="debug",
670 action="store_true",
671 help="Enable script debugging output",
SiCong Lie36b5262019-10-01 19:26:00 +0100672 )
673 args = parser.parse_args()
674 logging_level = logging.DEBUG if args.debug else logging.INFO
675 logging.basicConfig(level=logging_level)
676 logging.debug("Arguments: {}".format(args))
677 main(args)