SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 1 | # Copyright (c) 2019 ARM Limited. |
| 2 | # |
| 3 | # SPDX-License-Identifier: MIT |
| 4 | # |
| 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy |
| 6 | # of this software and associated documentation files (the "Software"), to |
| 7 | # deal in the Software without restriction, including without limitation the |
| 8 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or |
| 9 | # sell copies of the Software, and to permit persons to whom the Software is |
| 10 | # furnished to do so, subject to the following conditions: |
| 11 | # |
| 12 | # The above copyright notice and this permission notice shall be included in all |
| 13 | # copies or substantial portions of the Software. |
| 14 | # |
| 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 21 | # SOFTWARE. |
| 22 | |
| 23 | #!/usr/bin/python3 |
| 24 | |
| 25 | import argparse |
| 26 | import csv |
| 27 | import json |
| 28 | import logging |
| 29 | import math |
| 30 | import os |
| 31 | from collections import Counter, defaultdict, deque, namedtuple |
| 32 | from enum import Enum |
| 33 | from pathlib import Path |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 34 | from typing import Deque, Dict, Generator, List, NamedTuple, Set, Tuple, Union |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 35 | |
| 36 | ################################################################################ |
| 37 | # Types |
| 38 | ################################################################################ |
| 39 | |
| 40 | # Gemm strategy |
| 41 | Strategy = Enum("Strategy", ["Native", "ReshapedOnlyRHS", "Reshaped"]) |
| 42 | |
| 43 | # Gemm parameter |
| 44 | class GEMMParam(NamedTuple): |
| 45 | M: int # Number of lhs matrix rows |
| 46 | N: int # Number of rhs matrix columns |
| 47 | K: int # Number of lhs matrix columns/rhs matrix rows |
| 48 | B: int # Batch size |
| 49 | |
| 50 | @staticmethod |
| 51 | def parse_from_strs(*args): |
| 52 | return GEMMParam(*map(int, args)) |
| 53 | |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 54 | def __str__(self): |
| 55 | return "-".join(map(str, self)) |
| 56 | |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 57 | |
| 58 | # Gemm configuration for strategy Native |
| 59 | class NativeGEMMConfig(NamedTuple): |
| 60 | m0: int # Number of rows processed by the matrix multiplication |
| 61 | n0: int # Number of columns processed by the matrix multiplication |
| 62 | k0: int # Number of partial accumulations performed by the matrix multiplication |
| 63 | |
| 64 | @staticmethod |
| 65 | def parse_from_strs(*args): |
| 66 | *mnk, = map(int, args) |
| 67 | return NativeGEMMConfig(*mnk) |
| 68 | |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 69 | def __str__(self): |
| 70 | return "-".join(map(str, self)) |
| 71 | |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 72 | |
| 73 | # Gemm configuration for strategy Reshaped Only RHS |
| 74 | class ReshapedOnlyRHSGEMMConfig(NamedTuple): |
| 75 | m0: int # Number of rows processed by the matrix multiplication |
| 76 | n0: int # Number of columns processed by the matrix multiplication |
| 77 | k0: int # Number of partial accumulations performed by the matrix multiplication |
| 78 | h0: int # Number of horizontal blocks of size (k0xn0) stored on the same output row |
| 79 | interleave_rhs: bool # Interleave rhs matrix (1) / Do not interleave rhs matrix (0) |
| 80 | transpose_rhs: bool # Transpose rhs matrix but not lhs matrix (1) / Do not transpose rhs matrix but do transpose lhs matrix (0) |
| 81 | |
| 82 | @staticmethod |
| 83 | def parse_from_strs(*args): |
| 84 | *mnkh, interleave_rhs, transpose_rhs = map(int, args) |
| 85 | interleave_rhs = interleave_rhs == 1 |
| 86 | transpose_rhs = transpose_rhs == 1 |
| 87 | return ReshapedOnlyRHSGEMMConfig(*mnkh, interleave_rhs, transpose_rhs) |
| 88 | |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 89 | def __str__(self): |
| 90 | return "-".join(map(str, self)) |
| 91 | |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 92 | |
| 93 | # Gemm configuration for strategy Reshaped |
| 94 | class ReshapedGEMMConfig(NamedTuple): |
| 95 | m0: int # Number of rows processed by the matrix multiplication |
| 96 | n0: int # Number of columns processed by the matrix multiplication |
| 97 | k0: int # Number of partial accumulations performed by the matrix multiplication |
| 98 | v0: int # Number of vertical blocks of size (m0xk0) stored on the same output row |
| 99 | h0: int # Number of horizontal blocks of size (k0xn0) stored on the same output row |
| 100 | interleave_lhs: bool # Interleave lhs matrix (1) / Do not interleave lhs matrix (0) |
| 101 | interleave_rhs: bool # Interleave rhs matrix (1) / Do not interleave rhs matrix (0) |
| 102 | transpose_rhs: bool # Transpose rhs matrix but not lhs matrix (1) / Do not transpose rhs matrix but do transpose lhs matrix (0) |
| 103 | |
| 104 | @staticmethod |
| 105 | def parse_from_strs(*args): |
| 106 | *mnkvh, interleave_lhs, interleave_rhs, transpose_rhs = map(int, args) |
| 107 | interleave_lhs = interleave_lhs == 1 |
| 108 | interleave_rhs = interleave_rhs == 1 |
| 109 | transpose_rhs = transpose_rhs == 1 |
| 110 | return ReshapedGEMMConfig(*mnkvh, interleave_lhs, interleave_rhs, transpose_rhs) |
| 111 | |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 112 | def __str__(self): |
| 113 | return "-".join(map(str, self)) |
| 114 | |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 115 | |
| 116 | # Measurement we take from the benchmark result. |
| 117 | class Measurement(NamedTuple): |
| 118 | opencl_timer_ms: float |
| 119 | |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 120 | def is_close_to(self, other, tol): |
| 121 | return math.fabs(self.opencl_timer_ms - other.opencl_timer_ms) < tol |
| 122 | |
| 123 | def is_better_than(self, other, tol): |
| 124 | return self < other and not self.is_close_to(other) |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 125 | |
| 126 | def __add__(self, other): |
| 127 | return Measurement(self.opencl_timer_ms + other.opencl_timer_ms) |
| 128 | |
| 129 | def __sub__(self, other): |
| 130 | return Measurement(self.opencl_timer_ms - other.opencl_timer_ms) |
| 131 | |
| 132 | def __mul__(self, other): |
| 133 | return Measurement(self.opencl_timer_ms * other.opencl_timer_ms) |
| 134 | |
| 135 | def __floordiv__(self, other): |
| 136 | return Measurement(self.opencl_timer_ms // other.opencl_timer_ms) |
| 137 | |
| 138 | def __truediv__(self, other): |
| 139 | return Measurement(self.opencl_timer_ms / other.opencl_timer_ms) |
| 140 | |
| 141 | def __pow__(self, power): |
| 142 | return Measurement(self.opencl_timer_ms ** power) |
| 143 | |
| 144 | |
| 145 | # GEMMConfig Type |
| 146 | GEMMConfigT = Union[NativeGEMMConfig, ReshapedOnlyRHSGEMMConfig, ReshapedGEMMConfig] |
| 147 | |
| 148 | |
| 149 | # Representation of the benchmark result from a single experiment |
| 150 | class BenchmarkResult(NamedTuple): |
| 151 | gemm_param: GEMMParam |
| 152 | strategy: Strategy |
| 153 | gemm_config: GEMMConfigT |
| 154 | measurement: Measurement |
| 155 | |
| 156 | |
| 157 | # Representation of a single row of BenchmarkResult in CSV |
| 158 | # NOTE: In the CSV representation, we merge all fields of Gemm Config into a single field "GEMMConfig", but keep the |
| 159 | # fields of GEMMParam and Measurement |
| 160 | # The example entry including header would look like: |
| 161 | # M , N , K , B, Strategy , GEMMConfig , OpenCLTimer_MS |
| 162 | # 1225, 32, 192, 1, Reshaped , 4-4-4-3-1-1-1-0 , 0.3309 |
| 163 | BenchmarkResultCSVRow = namedtuple( |
| 164 | "BenchmarkResultCSVRow", GEMMParam._fields + ("Strategy", "GEMMConfig") + Measurement._fields |
| 165 | ) |
| 166 | |
| 167 | |
| 168 | def benchmark_result_2_csv_row(result: BenchmarkResult) -> BenchmarkResultCSVRow: |
| 169 | """ Convert a BenchmarkResult into its CSV row form """ |
| 170 | return BenchmarkResultCSVRow( |
| 171 | *(result.gemm_param + (result.strategy.name, str(result.gemm_config)) + result.measurement) |
| 172 | ) |
| 173 | |
| 174 | |
| 175 | class GEMMBenchmarkResultRecorder: |
| 176 | """ A recorder that records and organises GEMM Benchmark results, and produces various reports on the record. |
| 177 | """ |
| 178 | |
| 179 | SummaryLevel = Enum("SummaryLevel", ["Short", "Detailed"]) |
| 180 | |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 181 | def __init__(self, tol=0.01): |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 182 | """ Initializer |
| 183 | """ |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 184 | self._benchmark_result_record: List[BenchmarkResult] = [] |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 185 | # Strategies recorded |
| 186 | self._strategies = set() |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 187 | self._tol = tol |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 188 | |
| 189 | def add(self, benchmark_result: BenchmarkResult): |
| 190 | """ Add a benchmark result to the record. |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 191 | """ |
| 192 | gemm_param, strategy, gemm_config, measurement = benchmark_result |
| 193 | # Update strategies encoutnered |
| 194 | self._strategies.add(strategy) |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 195 | |
| 196 | self._benchmark_result_record.append(benchmark_result) |
| 197 | |
| 198 | def get_record(self) -> Generator[BenchmarkResult, None, None]: |
| 199 | """ Return an iterator that iterates over the record. |
| 200 | """ |
| 201 | yield from self._benchmark_result_record |
| 202 | |
| 203 | def get_best_gemm_configs(self): |
| 204 | """ Get the best GEMMConfig set per GEMMParam per Strategy |
| 205 | """ |
| 206 | best_gc_sets: Dict[ |
| 207 | Tuple[GEMMParam, Strategy], List[Tuple[GEMMConfig, Measurement]] |
| 208 | ] = defaultdict(list) |
| 209 | for gemm_param, strategy, gemm_config, measurement in self.get_record(): |
| 210 | best_gc_set = best_gc_sets.setdefault((gemm_param, strategy), []) |
| 211 | best_gc_set.append((gemm_config, measurement)) |
| 212 | # Sort the best config set (list) |
| 213 | best_gc_set = sorted(best_gc_set, key=lambda gc_and_m: gc_and_m[1]) |
| 214 | # Filter out configs that are beyond tolerance to the best GEMMConfig's measurement |
| 215 | best_gc, best_m = best_gc_set[0] |
| 216 | best_gc_set_new = [ |
| 217 | (gemm_config, measurement) |
| 218 | for gemm_config, measurement in best_gc_set[1:] |
| 219 | if measurement.is_close_to(best_m, self._tol) |
| 220 | ] |
| 221 | # Add back the best config |
| 222 | best_gc_set_new.insert(0, (best_gc, best_m)) |
| 223 | best_gc_sets[(gemm_param, strategy)] = best_gc_set_new |
| 224 | |
| 225 | return best_gc_sets |
| 226 | |
| 227 | def get_best_gemm_configs_as_sequence(self): |
| 228 | """ Get the best GEMMConfig set per GEMMParam per Strategy, and flatten the result into a sequence |
| 229 | of BenchmarkResults |
| 230 | """ |
| 231 | for (gemm_param, strategy), best_gc_sets in self.get_best_gemm_configs().items(): |
| 232 | for best_gemm_config, best_measurement in best_gc_sets: |
| 233 | yield BenchmarkResult(gemm_param, strategy, best_gemm_config, best_measurement) |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 234 | |
| 235 | def get_config_distributions(self): |
| 236 | """ Return GEMMConfigDistribution for each strategy |
| 237 | """ |
| 238 | gemm_config_distributions: Dict[Strategy, GEMMConfigDistribution] = defaultdict( |
| 239 | GEMMConfigDistribution |
| 240 | ) |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 241 | for benchmark_result in self.get_best_gemm_configs_as_sequence(): |
| 242 | _, strategy, _, _ = benchmark_result |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 243 | gemm_config_distributions[strategy].add(benchmark_result) |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 244 | |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 245 | return gemm_config_distributions |
| 246 | |
| 247 | def save_to_csvs(self, out_dir, only_best_config=True): |
| 248 | """ Save records to an output directory of csv files. |
| 249 | The directory is organized such that each strategy gets its own CSV file. |
| 250 | """ |
| 251 | if not os.path.exists(out_dir): |
| 252 | logging.info("Output directory {} does not exist. Creating...".format(out_dir)) |
| 253 | os.mkdir(out_dir) |
| 254 | for strategy in self._strategies: |
| 255 | out_csv_path = os.path.join(out_dir, strategy.name) |
| 256 | if os.path.exists(out_csv_path): |
| 257 | overwrite = ( |
| 258 | input( |
| 259 | "Output CSV {} already exists. Overwrite? [Y/N]: ".format(out_csv_path) |
| 260 | ).lower() |
| 261 | == "y" |
| 262 | ) |
| 263 | if not overwrite: |
| 264 | logging.info("Skipping {}".format(out_csv_path)) |
| 265 | continue |
| 266 | logging.info("Saving csv file to {}".format(out_csv_path)) |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 267 | record = ( |
| 268 | self.get_best_gemm_configs_as_sequence() if only_best_config else self.get_record() |
| 269 | ) |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 270 | with open(out_csv_path, "w") as f: |
| 271 | csv_writer = csv.DictWriter(f, fieldnames=BenchmarkResultCSVRow._fields) |
| 272 | csv_writer.writeheader() |
| 273 | csv_writer.writerows( |
| 274 | benchmark_result_2_csv_row(res)._asdict() |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 275 | for res in record |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 276 | if res.strategy == strategy |
| 277 | ) |
| 278 | logging.info("Saved") |
| 279 | |
| 280 | def summary(self, sum_level=SummaryLevel.Short): |
| 281 | """ Return the summary string of the record |
| 282 | """ |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 283 | num_raw_records = sum(1 for _ in self.get_record()) |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 284 | gemm_params_per_strategy = defaultdict(list) |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 285 | for gemm_param, strategy in self.get_best_gemm_configs().keys(): |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 286 | gemm_params_per_strategy[strategy].append(gemm_param) |
| 287 | global_summary = f""" |
| 288 | === {self.__class__.__name__} Summary === |
| 289 | [Global] |
| 290 | Strategies recorded: {", ".join(map(lambda s: s.name, self._strategies))} |
| 291 | Total number of results recorded: {num_raw_records} |
| 292 | |
| 293 | [Per strategy] |
| 294 | """ |
| 295 | strategy_summaries = [] |
| 296 | for strategy in gemm_params_per_strategy: |
| 297 | summary = f""" |
| 298 | Strategy {strategy.name}: |
| 299 | GEMM parameters: |
| 300 | Number of: {len(gemm_params_per_strategy[strategy])} |
| 301 | """ |
| 302 | if sum_level == self.__class__.SummaryLevel.Detailed: |
| 303 | summary += f""" |
| 304 | Content: {gemm_params_per_strategy[strategy]} |
| 305 | """ |
| 306 | strategy_summaries.append(summary) |
| 307 | return global_summary + "".join(strategy_summaries) |
| 308 | |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 309 | |
| 310 | class GEMMConfigDistribution: |
| 311 | """ A representation of the GEMM Configuration distribution produced by the GEMMBenchmarkResultRecorder. |
| 312 | """ |
| 313 | |
| 314 | def __init__(self): |
| 315 | """ Initializer |
| 316 | """ |
| 317 | self._gemm_config_dist: Dict[GEMMConfig, List[Tuple[GEMMParam, Measurement]]] = defaultdict( |
| 318 | list |
| 319 | ) |
| 320 | self._gemm_config_freq = Counter() |
| 321 | |
| 322 | def add(self, benchmark_result: BenchmarkResult): |
| 323 | """ Add a benchmark result to the distribution |
| 324 | """ |
| 325 | gemm_param, _, gemm_config, measurement = benchmark_result |
| 326 | self._gemm_config_dist[gemm_config].append((gemm_param, measurement)) |
| 327 | self._gemm_config_freq[gemm_config] += 1 |
| 328 | |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 329 | def distribution(self): |
| 330 | return self._gemm_config_dist |
| 331 | |
| 332 | def frequency(self): |
| 333 | """ Get the frequency of each (best) gemm config recorded |
| 334 | """ |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 335 | return self._gemm_config_freq.most_common() |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 336 | |
| 337 | def best_config(self): |
| 338 | """ Get the overall best config, as voted by all benchmark results. |
| 339 | """ |
| 340 | return self._gemm_config_freq.most_common(1) |
| 341 | |
| 342 | def std(self): |
| 343 | """ Get the standard deviation as a measure of dispersion of the distribution. We should aim for higher values |
| 344 | as they indicate there is high variation in the distribution. Thus the evidence of the best config is stronger. |
| 345 | """ |
| 346 | freqs = self._gemm_config_freq.values() |
| 347 | if len(freqs) == 0: |
| 348 | return 0 |
| 349 | mean_freq = sum(freqs) / len(freqs) |
| 350 | return math.sqrt(sum((freq - mean_freq) ** 2 for freq in freqs) / len(freqs)) |
| 351 | |
| 352 | |
| 353 | ################################################################################ |
| 354 | # Globals |
| 355 | ################################################################################ |
| 356 | |
| 357 | # Gemm config type factory |
| 358 | # Produces a GEMMConfig type specific to a Strategy |
| 359 | GEMM_CONFIG_FACTORY = { |
| 360 | Strategy.Native: NativeGEMMConfig, |
| 361 | Strategy.ReshapedOnlyRHS: ReshapedOnlyRHSGEMMConfig, |
| 362 | Strategy.Reshaped: ReshapedGEMMConfig, |
| 363 | } |
| 364 | |
| 365 | # Mapping from example binary name to Strategy |
| 366 | # Assume 1-to-1 mapping |
| 367 | EXAMPLE_FILE_2_STRATEGY = { |
| 368 | "benchmark_cl_gemm_native": Strategy.Native, |
| 369 | "benchmark_cl_gemm_reshaped_rhs_only": Strategy.ReshapedOnlyRHS, |
| 370 | "benchmark_cl_gemm_reshaped": Strategy.Reshaped, |
| 371 | } |
| 372 | |
| 373 | # Gemm example arguments type factory |
| 374 | # Produces a Gemm_Example_Args type specific to a Strategy |
| 375 | # Gemm example arguments consist of: |
| 376 | # GEMMParam + GEMMConfig |
| 377 | # in that order. |
| 378 | # For example, the example args of running a reshaped rhs only example could be: |
| 379 | # 100,100,100,1, 4, 4, 4, 1, 1, 1 |
| 380 | # M ,N ,K, B,m0,n0,k0,h0,interleave_rhs,transpose_rhs |
| 381 | # <-GEMMParam-><-------------GEMMConfig--------------> |
| 382 | # Note that the test strategy_name == strategy.name is in place to avoid unwanted enum aliases |
| 383 | GEMM_EXAMPLE_ARGS_FACTORY = { |
| 384 | strategy: namedtuple( |
| 385 | "{}_Gemm_Example_Args".format(strategy_name), |
| 386 | GEMMParam._fields + GEMM_CONFIG_FACTORY[strategy]._fields, |
| 387 | ) |
| 388 | for strategy_name, strategy in Strategy.__members__.items() |
| 389 | if strategy_name == strategy.name |
| 390 | } |
| 391 | |
| 392 | # File extension used for benchmark result json files |
| 393 | BENCHMARK_RESULT_JSON_EXTENSION = "gemmtuner_benchmark" |
| 394 | |
| 395 | ################################################################################ |
| 396 | # Functions |
| 397 | ################################################################################ |
| 398 | |
| 399 | |
| 400 | def parse_benchmark_commandline(commandline: str) -> Dict[str, str]: |
| 401 | """ Parse the benchmark example command-line string into a dictionary of command-line agruments |
| 402 | """ |
| 403 | args = commandline.split() |
| 404 | # Discard program name |
| 405 | args = args[1:] |
| 406 | # Split into a list of (argument name, argument value) |
| 407 | args = map(lambda arg: arg.split("="), args) |
| 408 | |
| 409 | def transform(_name): |
| 410 | # Strip '-'/"--" if it exists |
| 411 | _name = _name.lstrip("-") |
| 412 | return _name |
| 413 | |
| 414 | return {transform(name): val for name, val in args} |
| 415 | |
| 416 | |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 417 | def extract_benchmark_results( |
| 418 | json_results: Dict, measurement_method="avg" |
| 419 | ) -> Generator[BenchmarkResult, None, None]: |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 420 | """ Parse the benchmark result and extract relevant information, namely: |
| 421 | GEMM param, |
| 422 | Strategy, |
| 423 | GEMM config, |
| 424 | Measurements |
| 425 | """ |
| 426 | for json_res in json_results: |
| 427 | # Get example test and test data. |
| 428 | # There should only be 1 test per run |
| 429 | example_tests = list(json_res["tests"].items()) |
| 430 | assert len(example_tests) == 1 |
| 431 | example_fn, example_test_data = example_tests[0] |
| 432 | |
| 433 | # Process example file name |
| 434 | example_fn = example_fn.split(os.path.sep)[-1] |
| 435 | |
| 436 | # Get strategy |
| 437 | strategy = EXAMPLE_FILE_2_STRATEGY[example_fn] |
| 438 | |
| 439 | # Get gemm params + gemm configs from example args |
| 440 | benchmark_args = parse_benchmark_commandline(json_res["CommandLine"]) |
| 441 | Gemm_Example_Args_T = GEMM_EXAMPLE_ARGS_FACTORY[strategy] |
| 442 | example_args = Gemm_Example_Args_T(*(benchmark_args["example_args"].split(","))) |
| 443 | # Gemm_Example_Arg consists of GEMMParam first and then GEMMConfig (in that order) |
| 444 | gemm_param_fields_len = len(GEMMParam._fields) |
| 445 | gemm_param = GEMMParam.parse_from_strs(*example_args[:gemm_param_fields_len]) |
| 446 | GEMMConfig = GEMM_CONFIG_FACTORY[strategy] |
| 447 | gemm_config = GEMMConfig.parse_from_strs(*example_args[gemm_param_fields_len:]) |
| 448 | |
| 449 | # Get OpenCL_Time_Ms stats |
| 450 | measurements = list(example_test_data["measurements"].items()) |
| 451 | # There should only be 1 instrument per run |
| 452 | assert len(measurements) == 1 |
| 453 | measurement_instrument, data = measurements.pop() |
| 454 | # Get instrument name and assert that it is the one we expect |
| 455 | measurement_instrument_name = measurement_instrument.split("/")[0] |
| 456 | assert measurement_instrument_name == "OpenCLTimer" |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 457 | # Take either the minimum or the average of the raw data as the measurement value |
| 458 | if measurement_method == "min": |
| 459 | measurement_val = min(data["raw"]) |
| 460 | elif measurement_method == "avg": |
| 461 | measurement_val = sum(data["raw"]) / len(data["raw"]) |
| 462 | else: |
| 463 | raise ValueError("Invalid measurement method: {}".format(measurement_method)) |
| 464 | |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 465 | measurement = Measurement(measurement_val) |
| 466 | |
| 467 | yield BenchmarkResult(gemm_param, strategy, gemm_config, measurement) |
| 468 | |
| 469 | |
| 470 | def parse_json(dir_name): |
| 471 | """ Glob all benchmark result json files and parse them into json objects (dicts). |
| 472 | """ |
| 473 | for res_fn in Path(dir_name).rglob("*.{}".format(BENCHMARK_RESULT_JSON_EXTENSION)): |
| 474 | with open(res_fn) as res_fp: |
| 475 | yield json.load(res_fp) |
| 476 | |
| 477 | |
| 478 | ################################################################################ |
| 479 | # Main |
| 480 | ################################################################################ |
| 481 | |
| 482 | |
| 483 | def main(args): |
| 484 | logging.info("Searching best gemm configurations from {}".format(args.benchmark_results_dir)) |
| 485 | |
| 486 | benchmark_results = extract_benchmark_results(parse_json(args.benchmark_results_dir)) |
| 487 | |
| 488 | # Add all benchmark results to the recorder |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 489 | benchmark_result_recorder = GEMMBenchmarkResultRecorder(tol=args.tolerance) |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 490 | for benchmark_result in benchmark_results: |
| 491 | benchmark_result_recorder.add(benchmark_result) |
| 492 | |
| 493 | if args.debug: |
| 494 | recorder_sum_level = GEMMBenchmarkResultRecorder.SummaryLevel.Detailed |
| 495 | else: |
| 496 | recorder_sum_level = GEMMBenchmarkResultRecorder.SummaryLevel.Short |
| 497 | |
| 498 | # Print overall summary of the recorded results |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 499 | logging.info(benchmark_result_recorder.summary(sum_level=recorder_sum_level)) |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 500 | |
| 501 | # Get GEMM configuration distributions for each strategy |
| 502 | all_config_dists = benchmark_result_recorder.get_config_distributions() |
| 503 | |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 504 | logging.info("=== Result ===") |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 505 | for strategy, config_dist in all_config_dists.items(): |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 506 | logging.info("Strategy: {}".format(strategy.name)) |
| 507 | logging.debug("GEMM Config, Votes") |
| 508 | for config, freq in config_dist.frequency(): |
| 509 | logging.debug("{}, {}".format(config, freq)) |
| 510 | logging.info( |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 511 | "Best GEMM Config: {} with std: {}".format(config_dist.best_config(), config_dist.std()) |
| 512 | ) |
| 513 | |
| 514 | # Save the recorded results to csv files in output directory |
| 515 | if args.output_dir is not None: |
| 516 | benchmark_result_recorder.save_to_csvs(args.output_dir, only_best_config=(not args.debug)) |
| 517 | |
| 518 | |
| 519 | if __name__ == "__main__": |
| 520 | parser = argparse.ArgumentParser(description="CL GEMM Tuner") |
| 521 | parser.add_argument( |
| 522 | "-b", |
| 523 | "--benchmark_results", |
| 524 | dest="benchmark_results_dir", |
| 525 | metavar="PATH", |
| 526 | action="store", |
| 527 | type=str, |
| 528 | help="Path to benchmark result directory, where benchmark result json files have a file \ |
| 529 | extension of '{}'".format( |
| 530 | BENCHMARK_RESULT_JSON_EXTENSION |
| 531 | ), |
| 532 | required=True, |
| 533 | ) |
| 534 | parser.add_argument( |
| 535 | "-o", |
| 536 | "--output_dir", |
| 537 | dest="output_dir", |
| 538 | metavar="PATH", |
| 539 | action="store", |
| 540 | type=str, |
| 541 | help="Path to directory that holds output csv files. One per strategy", |
| 542 | ) |
| 543 | parser.add_argument( |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 544 | "-t", |
| 545 | "--tolerance", |
| 546 | action="store", |
| 547 | type=float, |
| 548 | default=0.01, |
| 549 | help="For testing if two GEMMConfigs are equivalent in terms of performance. The tolerance is OpenCL timer in\ |
| 550 | milliseconds. Recommended value: <= 0.1 ms", |
| 551 | ) |
| 552 | parser.add_argument( |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 553 | "-D", "--debug", dest="debug", action="store_true", help="Enable script debugging output" |
| 554 | ) |
| 555 | args = parser.parse_args() |
| 556 | logging_level = logging.DEBUG if args.debug else logging.INFO |
| 557 | logging.basicConfig(level=logging_level) |
| 558 | logging.debug("Arguments: {}".format(args)) |
| 559 | main(args) |