Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 1 | # Copyright (c) 2019-2020 ARM Limited. |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 2 | # |
| 3 | # SPDX-License-Identifier: MIT |
| 4 | # |
| 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy |
| 6 | # of this software and associated documentation files (the "Software"), to |
| 7 | # deal in the Software without restriction, including without limitation the |
| 8 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or |
| 9 | # sell copies of the Software, and to permit persons to whom the Software is |
| 10 | # furnished to do so, subject to the following conditions: |
| 11 | # |
| 12 | # The above copyright notice and this permission notice shall be included in all |
| 13 | # copies or substantial portions of the Software. |
| 14 | # |
| 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 21 | # SOFTWARE. |
| 22 | |
| 23 | #!/usr/bin/python3 |
| 24 | |
| 25 | import argparse |
| 26 | import csv |
| 27 | import json |
| 28 | import logging |
| 29 | import math |
| 30 | import os |
| 31 | from collections import Counter, defaultdict, deque, namedtuple |
| 32 | from enum import Enum |
| 33 | from pathlib import Path |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 34 | from typing import Deque, Dict, Generator, List, NamedTuple, Set, Tuple, Union |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 35 | |
| 36 | ################################################################################ |
| 37 | # Types |
| 38 | ################################################################################ |
| 39 | |
| 40 | # Gemm strategy |
| 41 | Strategy = Enum("Strategy", ["Native", "ReshapedOnlyRHS", "Reshaped"]) |
| 42 | |
| 43 | # Gemm parameter |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 44 | |
| 45 | |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 46 | class GEMMParam(NamedTuple): |
| 47 | M: int # Number of lhs matrix rows |
| 48 | N: int # Number of rhs matrix columns |
| 49 | K: int # Number of lhs matrix columns/rhs matrix rows |
| 50 | B: int # Batch size |
Eren Kopuz | 6977b37 | 2020-07-13 12:37:06 +0100 | [diff] [blame] | 51 | data_type: str # Data type |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 52 | |
| 53 | @staticmethod |
Eren Kopuz | 6977b37 | 2020-07-13 12:37:06 +0100 | [diff] [blame] | 54 | def parse_from_strs(*M_N_K_B, data_type): |
| 55 | return GEMMParam(*map(int, M_N_K_B),str(data_type)) |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 56 | |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 57 | def __str__(self): |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 58 | return ",".join(map(str, self)) |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 59 | |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 60 | |
| 61 | # Gemm configuration for strategy Native |
| 62 | class NativeGEMMConfig(NamedTuple): |
| 63 | m0: int # Number of rows processed by the matrix multiplication |
| 64 | n0: int # Number of columns processed by the matrix multiplication |
| 65 | k0: int # Number of partial accumulations performed by the matrix multiplication |
| 66 | |
| 67 | @staticmethod |
| 68 | def parse_from_strs(*args): |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 69 | (*mnk,) = map(int, args) |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 70 | return NativeGEMMConfig(*mnk) |
| 71 | |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 72 | def __str__(self): |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 73 | return ",".join(map(str, self)) |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 74 | |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 75 | |
| 76 | # Gemm configuration for strategy Reshaped Only RHS |
| 77 | class ReshapedOnlyRHSGEMMConfig(NamedTuple): |
| 78 | m0: int # Number of rows processed by the matrix multiplication |
| 79 | n0: int # Number of columns processed by the matrix multiplication |
| 80 | k0: int # Number of partial accumulations performed by the matrix multiplication |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 81 | # Number of horizontal blocks of size (k0xn0) stored on the same output row |
| 82 | h0: int |
| 83 | # Interleave rhs matrix (1) / Do not interleave rhs matrix (0) |
| 84 | interleave_rhs: bool |
| 85 | # Transpose rhs matrix but not lhs matrix (1) / Do not transpose rhs matrix but do transpose lhs matrix (0) |
| 86 | transpose_rhs: bool |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 87 | |
| 88 | @staticmethod |
| 89 | def parse_from_strs(*args): |
| 90 | *mnkh, interleave_rhs, transpose_rhs = map(int, args) |
| 91 | interleave_rhs = interleave_rhs == 1 |
| 92 | transpose_rhs = transpose_rhs == 1 |
| 93 | return ReshapedOnlyRHSGEMMConfig(*mnkh, interleave_rhs, transpose_rhs) |
| 94 | |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 95 | def __str__(self): |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 96 | return ",".join(map(str, self)) |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 97 | |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 98 | |
| 99 | # Gemm configuration for strategy Reshaped |
| 100 | class ReshapedGEMMConfig(NamedTuple): |
| 101 | m0: int # Number of rows processed by the matrix multiplication |
| 102 | n0: int # Number of columns processed by the matrix multiplication |
| 103 | k0: int # Number of partial accumulations performed by the matrix multiplication |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 104 | # Number of vertical blocks of size (m0xk0) stored on the same output row |
| 105 | v0: int |
| 106 | # Number of horizontal blocks of size (k0xn0) stored on the same output row |
| 107 | h0: int |
| 108 | # Interleave lhs matrix (1) / Do not interleave lhs matrix (0) |
| 109 | interleave_lhs: bool |
| 110 | # Interleave rhs matrix (1) / Do not interleave rhs matrix (0) |
| 111 | interleave_rhs: bool |
| 112 | # Transpose rhs matrix but not lhs matrix (1) / Do not transpose rhs matrix but do transpose lhs matrix (0) |
| 113 | transpose_rhs: bool |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 114 | |
| 115 | @staticmethod |
| 116 | def parse_from_strs(*args): |
| 117 | *mnkvh, interleave_lhs, interleave_rhs, transpose_rhs = map(int, args) |
| 118 | interleave_lhs = interleave_lhs == 1 |
| 119 | interleave_rhs = interleave_rhs == 1 |
| 120 | transpose_rhs = transpose_rhs == 1 |
| 121 | return ReshapedGEMMConfig(*mnkvh, interleave_lhs, interleave_rhs, transpose_rhs) |
| 122 | |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 123 | def __str__(self): |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 124 | return ",".join(map(str, self)) |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 125 | |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 126 | |
| 127 | # Measurement we take from the benchmark result. |
| 128 | class Measurement(NamedTuple): |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 129 | opencl_timer_ms_reshape: float |
| 130 | opencl_timer_ms_kernel: float |
| 131 | |
| 132 | def get_total_ms(self): |
| 133 | return self.opencl_timer_ms_reshape + self.opencl_timer_ms_kernel |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 134 | |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 135 | def is_close_to(self, other, tol): |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 136 | return math.fabs(self.get_total_ms() - other.get_total_ms()) < tol |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 137 | |
| 138 | def is_better_than(self, other, tol): |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 139 | return self.get_total_ms() < other.get_total_ms() and not self.is_close_to( |
| 140 | other |
| 141 | ) |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 142 | |
| 143 | def __add__(self, other): |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 144 | return Measurement( |
| 145 | self.opencl_timer_ms_reshape + other.opencl_timer_ms_reshape, |
| 146 | self.opencl_timer_ms_kernel + other.opencl_timer_ms_kernel, |
| 147 | ) |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 148 | |
| 149 | def __sub__(self, other): |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 150 | return Measurement( |
| 151 | self.opencl_timer_ms_reshape - other.opencl_timer_ms_reshape, |
| 152 | self.opencl_timer_ms_kernel - other.opencl_timer_ms_kernel, |
| 153 | ) |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 154 | |
| 155 | def __mul__(self, other): |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 156 | return Measurement( |
| 157 | self.opencl_timer_ms_reshape * other.opencl_timer_ms_reshape, |
| 158 | self.opencl_timer_ms_kernel * other.opencl_timer_ms_kernel, |
| 159 | ) |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 160 | |
| 161 | def __floordiv__(self, other): |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 162 | return Measurement( |
| 163 | self.opencl_timer_ms_reshape // other.opencl_timer_ms_reshape, |
| 164 | self.opencl_timer_ms_kernel // other.opencl_timer_ms_kernel, |
| 165 | ) |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 166 | |
| 167 | def __truediv__(self, other): |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 168 | return Measurement( |
| 169 | self.opencl_timer_ms_reshape / other.opencl_timer_ms_reshape, |
| 170 | self.opencl_timer_ms_kernel / other.opencl_timer_ms_kernel, |
| 171 | ) |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 172 | |
| 173 | def __pow__(self, power): |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 174 | return Measurement( |
| 175 | self.opencl_timer_ms_reshape ** power, self.opencl_timer_ms_kernel ** power |
| 176 | ) |
| 177 | |
| 178 | def __str__(self): |
| 179 | return ",".join(map(str, self)) |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 180 | |
| 181 | |
| 182 | # GEMMConfig Type |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 183 | GEMMConfigT = Union[NativeGEMMConfig, |
| 184 | ReshapedOnlyRHSGEMMConfig, ReshapedGEMMConfig] |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 185 | |
| 186 | |
| 187 | # Representation of the benchmark result from a single experiment |
| 188 | class BenchmarkResult(NamedTuple): |
| 189 | gemm_param: GEMMParam |
| 190 | strategy: Strategy |
| 191 | gemm_config: GEMMConfigT |
| 192 | measurement: Measurement |
| 193 | |
| 194 | |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 195 | class GEMMBenchmarkResultRecorder: |
| 196 | """ A recorder that records and organises GEMM Benchmark results, and produces various reports on the record. |
| 197 | """ |
| 198 | |
| 199 | SummaryLevel = Enum("SummaryLevel", ["Short", "Detailed"]) |
| 200 | |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 201 | def __init__(self, tol=0.01): |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 202 | """ Initializer |
| 203 | """ |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 204 | self._benchmark_result_record: List[BenchmarkResult] = [] |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 205 | # Strategies recorded |
| 206 | self._strategies = set() |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 207 | self._tol = tol |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 208 | |
| 209 | def add(self, benchmark_result: BenchmarkResult): |
| 210 | """ Add a benchmark result to the record. |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 211 | """ |
| 212 | gemm_param, strategy, gemm_config, measurement = benchmark_result |
| 213 | # Update strategies encoutnered |
| 214 | self._strategies.add(strategy) |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 215 | |
| 216 | self._benchmark_result_record.append(benchmark_result) |
| 217 | |
| 218 | def get_record(self) -> Generator[BenchmarkResult, None, None]: |
| 219 | """ Return an iterator that iterates over the record. |
| 220 | """ |
| 221 | yield from self._benchmark_result_record |
| 222 | |
| 223 | def get_best_gemm_configs(self): |
| 224 | """ Get the best GEMMConfig set per GEMMParam per Strategy |
| 225 | """ |
| 226 | best_gc_sets: Dict[ |
| 227 | Tuple[GEMMParam, Strategy], List[Tuple[GEMMConfig, Measurement]] |
| 228 | ] = defaultdict(list) |
| 229 | for gemm_param, strategy, gemm_config, measurement in self.get_record(): |
| 230 | best_gc_set = best_gc_sets.setdefault((gemm_param, strategy), []) |
| 231 | best_gc_set.append((gemm_config, measurement)) |
| 232 | # Sort the best config set (list) |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 233 | best_gc_set = sorted( |
| 234 | best_gc_set, key=lambda gc_and_m: gc_and_m[1].get_total_ms() |
| 235 | ) |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 236 | # Filter out configs that are beyond tolerance to the best GEMMConfig's measurement |
| 237 | best_gc, best_m = best_gc_set[0] |
| 238 | best_gc_set_new = [ |
| 239 | (gemm_config, measurement) |
| 240 | for gemm_config, measurement in best_gc_set[1:] |
| 241 | if measurement.is_close_to(best_m, self._tol) |
| 242 | ] |
| 243 | # Add back the best config |
| 244 | best_gc_set_new.insert(0, (best_gc, best_m)) |
| 245 | best_gc_sets[(gemm_param, strategy)] = best_gc_set_new |
| 246 | |
| 247 | return best_gc_sets |
| 248 | |
| 249 | def get_best_gemm_configs_as_sequence(self): |
| 250 | """ Get the best GEMMConfig set per GEMMParam per Strategy, and flatten the result into a sequence |
| 251 | of BenchmarkResults |
| 252 | """ |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 253 | for ( |
| 254 | (gemm_param, strategy), |
| 255 | best_gc_sets, |
| 256 | ) in self.get_best_gemm_configs().items(): |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 257 | for best_gemm_config, best_measurement in best_gc_sets: |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 258 | yield BenchmarkResult( |
| 259 | gemm_param, strategy, best_gemm_config, best_measurement |
| 260 | ) |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 261 | |
| 262 | def get_config_distributions(self): |
| 263 | """ Return GEMMConfigDistribution for each strategy |
| 264 | """ |
| 265 | gemm_config_distributions: Dict[Strategy, GEMMConfigDistribution] = defaultdict( |
| 266 | GEMMConfigDistribution |
| 267 | ) |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 268 | for benchmark_result in self.get_best_gemm_configs_as_sequence(): |
| 269 | _, strategy, _, _ = benchmark_result |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 270 | gemm_config_distributions[strategy].add(benchmark_result) |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 271 | |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 272 | return gemm_config_distributions |
| 273 | |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 274 | def get_best_gemm_strategies(self): |
| 275 | """ Get the best Stratey per GEMMParam |
| 276 | """ |
| 277 | all_results: Dict[GEMMParam, List[Tuple[Strategy, Measurement]]] = defaultdict( |
| 278 | list |
| 279 | ) |
| 280 | |
| 281 | best_strategies: Dict[GEMMParam, Strategy] = {} |
| 282 | |
| 283 | for gemm_param, strategy, gemm_config, measurement in self.get_record(): |
| 284 | all_results[gemm_param].append((strategy, measurement)) |
| 285 | |
| 286 | for gemm_param, results_set in all_results.items(): |
| 287 | # Sort the best results set (list) |
| 288 | results_set = sorted( |
| 289 | results_set, key=lambda s_and_m: s_and_m[1].get_total_ms() |
| 290 | ) |
| 291 | # Select best Strategy |
| 292 | best_s, best_m = results_set[0] |
| 293 | best_strategies[gemm_param] = best_s |
| 294 | |
| 295 | return best_strategies |
| 296 | |
| 297 | def save_to_jsons(self, out_dir, only_best_config=True): |
| 298 | """ Save records to an output directory of JSON files. |
| 299 | The directory is organized such that each strategy gets its own JSON file. |
| 300 | The directory also includes a JSON file to define the best strategy per GEMM Param. |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 301 | """ |
| 302 | if not os.path.exists(out_dir): |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 303 | logging.info( |
| 304 | "Output directory {} does not exist. Creating...".format( |
| 305 | out_dir) |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 306 | ) |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 307 | os.mkdir(out_dir) |
| 308 | |
| 309 | out_json_path = os.path.join(out_dir, "gemm_type_selection.json") |
| 310 | if check_out_path(out_json_path): |
| 311 | results = self.get_best_gemm_strategies() |
| 312 | results = {str(key): value.name for key, value in results.items()} |
| 313 | dump_json(out_json_path, results) |
| 314 | |
| 315 | for strategy in self._strategies: |
| 316 | out_json_path = os.path.join( |
| 317 | out_dir, ("gemm_config_" + strategy.name.lower() + ".json") |
| 318 | ) |
| 319 | if check_out_path(out_json_path): |
| 320 | record = ( |
| 321 | self.get_best_gemm_configs_as_sequence() |
| 322 | if only_best_config |
| 323 | else self.get_record() |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 324 | ) |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 325 | results = defaultdict(list) |
| 326 | for res in record: |
| 327 | if res.strategy == strategy: |
| 328 | results[str(res.gemm_param)].append( |
| 329 | { |
| 330 | "GEMMConfig": str(res.gemm_config), |
| 331 | "OpenCL_Timer_ms_reshape": str( |
| 332 | res.measurement.opencl_timer_ms_reshape |
| 333 | ), |
| 334 | "OpenCL_Timer_ms_kernel": str( |
| 335 | res.measurement.opencl_timer_ms_kernel |
| 336 | ), |
| 337 | } |
| 338 | ) |
| 339 | dump_json(out_json_path, results) |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 340 | |
| 341 | def summary(self, sum_level=SummaryLevel.Short): |
| 342 | """ Return the summary string of the record |
| 343 | """ |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 344 | num_raw_records = sum(1 for _ in self.get_record()) |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 345 | gemm_params_per_strategy = defaultdict(list) |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 346 | for gemm_param, strategy in self.get_best_gemm_configs().keys(): |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 347 | gemm_params_per_strategy[strategy].append(gemm_param) |
| 348 | global_summary = f""" |
| 349 | === {self.__class__.__name__} Summary === |
| 350 | [Global] |
| 351 | Strategies recorded: {", ".join(map(lambda s: s.name, self._strategies))} |
| 352 | Total number of results recorded: {num_raw_records} |
| 353 | |
| 354 | [Per strategy] |
| 355 | """ |
| 356 | strategy_summaries = [] |
| 357 | for strategy in gemm_params_per_strategy: |
| 358 | summary = f""" |
| 359 | Strategy {strategy.name}: |
| 360 | GEMM parameters: |
| 361 | Number of: {len(gemm_params_per_strategy[strategy])} |
| 362 | """ |
| 363 | if sum_level == self.__class__.SummaryLevel.Detailed: |
| 364 | summary += f""" |
| 365 | Content: {gemm_params_per_strategy[strategy]} |
| 366 | """ |
| 367 | strategy_summaries.append(summary) |
| 368 | return global_summary + "".join(strategy_summaries) |
| 369 | |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 370 | |
| 371 | class GEMMConfigDistribution: |
| 372 | """ A representation of the GEMM Configuration distribution produced by the GEMMBenchmarkResultRecorder. |
| 373 | """ |
| 374 | |
| 375 | def __init__(self): |
| 376 | """ Initializer |
| 377 | """ |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 378 | self._gemm_config_dist: Dict[ |
| 379 | GEMMConfig, List[Tuple[GEMMParam, Measurement]] |
| 380 | ] = defaultdict(list) |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 381 | self._gemm_config_freq = Counter() |
| 382 | |
| 383 | def add(self, benchmark_result: BenchmarkResult): |
| 384 | """ Add a benchmark result to the distribution |
| 385 | """ |
| 386 | gemm_param, _, gemm_config, measurement = benchmark_result |
| 387 | self._gemm_config_dist[gemm_config].append((gemm_param, measurement)) |
| 388 | self._gemm_config_freq[gemm_config] += 1 |
| 389 | |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 390 | def distribution(self): |
| 391 | return self._gemm_config_dist |
| 392 | |
| 393 | def frequency(self): |
| 394 | """ Get the frequency of each (best) gemm config recorded |
| 395 | """ |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 396 | return self._gemm_config_freq.most_common() |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 397 | |
| 398 | def best_config(self): |
| 399 | """ Get the overall best config, as voted by all benchmark results. |
| 400 | """ |
| 401 | return self._gemm_config_freq.most_common(1) |
| 402 | |
| 403 | def std(self): |
| 404 | """ Get the standard deviation as a measure of dispersion of the distribution. We should aim for higher values |
| 405 | as they indicate there is high variation in the distribution. Thus the evidence of the best config is stronger. |
| 406 | """ |
| 407 | freqs = self._gemm_config_freq.values() |
| 408 | if len(freqs) == 0: |
| 409 | return 0 |
| 410 | mean_freq = sum(freqs) / len(freqs) |
| 411 | return math.sqrt(sum((freq - mean_freq) ** 2 for freq in freqs) / len(freqs)) |
| 412 | |
| 413 | |
| 414 | ################################################################################ |
| 415 | # Globals |
| 416 | ################################################################################ |
| 417 | |
| 418 | # Gemm config type factory |
| 419 | # Produces a GEMMConfig type specific to a Strategy |
| 420 | GEMM_CONFIG_FACTORY = { |
| 421 | Strategy.Native: NativeGEMMConfig, |
| 422 | Strategy.ReshapedOnlyRHS: ReshapedOnlyRHSGEMMConfig, |
| 423 | Strategy.Reshaped: ReshapedGEMMConfig, |
| 424 | } |
| 425 | |
| 426 | # Mapping from example binary name to Strategy |
| 427 | # Assume 1-to-1 mapping |
| 428 | EXAMPLE_FILE_2_STRATEGY = { |
| 429 | "benchmark_cl_gemm_native": Strategy.Native, |
| 430 | "benchmark_cl_gemm_reshaped_rhs_only": Strategy.ReshapedOnlyRHS, |
| 431 | "benchmark_cl_gemm_reshaped": Strategy.Reshaped, |
| 432 | } |
| 433 | |
| 434 | # Gemm example arguments type factory |
| 435 | # Produces a Gemm_Example_Args type specific to a Strategy |
| 436 | # Gemm example arguments consist of: |
| 437 | # GEMMParam + GEMMConfig |
| 438 | # in that order. |
| 439 | # For example, the example args of running a reshaped rhs only example could be: |
| 440 | # 100,100,100,1, 4, 4, 4, 1, 1, 1 |
| 441 | # M ,N ,K, B,m0,n0,k0,h0,interleave_rhs,transpose_rhs |
| 442 | # <-GEMMParam-><-------------GEMMConfig--------------> |
| 443 | # Note that the test strategy_name == strategy.name is in place to avoid unwanted enum aliases |
| 444 | GEMM_EXAMPLE_ARGS_FACTORY = { |
Eren Kopuz | 6977b37 | 2020-07-13 12:37:06 +0100 | [diff] [blame] | 445 | # We ignore the data type field from GEMMParam as that is extracted separately |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 446 | strategy: namedtuple( |
| 447 | "{}_Gemm_Example_Args".format(strategy_name), |
Eren Kopuz | 6977b37 | 2020-07-13 12:37:06 +0100 | [diff] [blame] | 448 | GEMMParam._fields[:-1] + GEMM_CONFIG_FACTORY[strategy]._fields, |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 449 | ) |
| 450 | for strategy_name, strategy in Strategy.__members__.items() |
| 451 | if strategy_name == strategy.name |
| 452 | } |
| 453 | |
| 454 | # File extension used for benchmark result json files |
| 455 | BENCHMARK_RESULT_JSON_EXTENSION = "gemmtuner_benchmark" |
| 456 | |
| 457 | ################################################################################ |
| 458 | # Functions |
| 459 | ################################################################################ |
| 460 | |
| 461 | |
| 462 | def parse_benchmark_commandline(commandline: str) -> Dict[str, str]: |
| 463 | """ Parse the benchmark example command-line string into a dictionary of command-line agruments |
| 464 | """ |
Eren Kopuz | 6977b37 | 2020-07-13 12:37:06 +0100 | [diff] [blame] | 465 | # Separate the data type option from the example_args portion of the string |
| 466 | commandline = commandline.replace(",--type=", " --type=") |
| 467 | |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 468 | args = commandline.split() |
| 469 | # Discard program name |
| 470 | args = args[1:] |
| 471 | # Split into a list of (argument name, argument value) |
| 472 | args = map(lambda arg: arg.split("="), args) |
| 473 | |
| 474 | def transform(_name): |
| 475 | # Strip '-'/"--" if it exists |
| 476 | _name = _name.lstrip("-") |
| 477 | return _name |
| 478 | |
| 479 | return {transform(name): val for name, val in args} |
| 480 | |
| 481 | |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 482 | def extract_benchmark_results( |
| 483 | json_results: Dict, measurement_method="avg" |
| 484 | ) -> Generator[BenchmarkResult, None, None]: |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 485 | """ Parse the benchmark result and extract relevant information, namely: |
| 486 | GEMM param, |
| 487 | Strategy, |
| 488 | GEMM config, |
| 489 | Measurements |
| 490 | """ |
| 491 | for json_res in json_results: |
| 492 | # Get example test and test data. |
| 493 | # There should only be 1 test per run |
| 494 | example_tests = list(json_res["tests"].items()) |
| 495 | assert len(example_tests) == 1 |
| 496 | example_fn, example_test_data = example_tests[0] |
| 497 | |
| 498 | # Process example file name |
| 499 | example_fn = example_fn.split(os.path.sep)[-1] |
| 500 | |
| 501 | # Get strategy |
| 502 | strategy = EXAMPLE_FILE_2_STRATEGY[example_fn] |
| 503 | |
| 504 | # Get gemm params + gemm configs from example args |
| 505 | benchmark_args = parse_benchmark_commandline(json_res["CommandLine"]) |
| 506 | Gemm_Example_Args_T = GEMM_EXAMPLE_ARGS_FACTORY[strategy] |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 507 | example_args = Gemm_Example_Args_T( |
| 508 | *(benchmark_args["example_args"].split(","))) |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 509 | # Gemm_Example_Arg consists of GEMMParam first and then GEMMConfig (in that order) |
Eren Kopuz | 6977b37 | 2020-07-13 12:37:06 +0100 | [diff] [blame] | 510 | # However data type option is parsed separately from end of options, hence -1 is applied to fields length |
| 511 | gemm_param_fields_len = len(GEMMParam._fields) - 1 |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 512 | gemm_param = GEMMParam.parse_from_strs( |
Eren Kopuz | 6977b37 | 2020-07-13 12:37:06 +0100 | [diff] [blame] | 513 | *example_args[:gemm_param_fields_len], |
| 514 | data_type = benchmark_args["type"]) |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 515 | GEMMConfig = GEMM_CONFIG_FACTORY[strategy] |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 516 | gemm_config = GEMMConfig.parse_from_strs( |
| 517 | *example_args[gemm_param_fields_len:]) |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 518 | |
| 519 | # Get OpenCL_Time_Ms stats |
| 520 | measurements = list(example_test_data["measurements"].items()) |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 521 | # For reshaped RHS only we have two measurements (one also for the reshape kernel) |
| 522 | # Hence we must parse and sum them |
| 523 | measurement_ms_reshape = 0 |
| 524 | measurement_ms_kernel = 0 |
| 525 | for single_measurement in measurements: |
| 526 | measurement_instrument, data = single_measurement |
| 527 | # Get instrument name and assert that it is the one we expect |
| 528 | measurement_instrument_name = measurement_instrument.split("/")[0] |
| 529 | assert measurement_instrument_name == "OpenCLTimer" |
| 530 | # Take either the minimum or the average of the raw data as the measurement value |
| 531 | if measurement_method == "min": |
| 532 | measurement_val = min(data["raw"]) |
| 533 | elif measurement_method == "avg": |
| 534 | measurement_val = sum(data["raw"]) / len(data["raw"]) |
| 535 | else: |
| 536 | raise ValueError( |
| 537 | "Invalid measurement method: {}".format(measurement_method) |
| 538 | ) |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 539 | |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 540 | measurement_type = measurement_instrument.split("/")[1] |
| 541 | if "reshape" in measurement_type.split("_"): |
| 542 | measurement_ms_reshape = measurement_val |
| 543 | else: |
| 544 | measurement_ms_kernel = measurement_val |
| 545 | |
| 546 | measurement = Measurement( |
| 547 | measurement_ms_reshape, measurement_ms_kernel) |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 548 | |
| 549 | yield BenchmarkResult(gemm_param, strategy, gemm_config, measurement) |
| 550 | |
| 551 | |
| 552 | def parse_json(dir_name): |
| 553 | """ Glob all benchmark result json files and parse them into json objects (dicts). |
| 554 | """ |
| 555 | for res_fn in Path(dir_name).rglob("*.{}".format(BENCHMARK_RESULT_JSON_EXTENSION)): |
| 556 | with open(res_fn) as res_fp: |
| 557 | yield json.load(res_fp) |
| 558 | |
| 559 | |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 560 | def check_out_path(out_path): |
| 561 | if os.path.exists(out_path): |
| 562 | overwrite = ( |
| 563 | input( |
| 564 | "Output JSON {} already exists. Overwrite? [Y/N]: ".format( |
| 565 | out_path) |
| 566 | ).lower() |
| 567 | == "y" |
| 568 | ) |
| 569 | if not overwrite: |
| 570 | logging.info("Skipping {}".format(out_path)) |
| 571 | return False |
| 572 | logging.info("Saving JSON file to {}".format(out_path)) |
| 573 | return True |
| 574 | |
| 575 | |
| 576 | def dump_json(out_path, dict): |
| 577 | with open(out_path, "w") as f: |
| 578 | json.dump(dict, f) |
| 579 | logging.info("Saved") |
| 580 | |
| 581 | |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 582 | ################################################################################ |
| 583 | # Main |
| 584 | ################################################################################ |
| 585 | |
| 586 | |
| 587 | def main(args): |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 588 | logging.info( |
| 589 | "Searching best gemm configurations from {}".format( |
| 590 | args.benchmark_results_dir) |
| 591 | ) |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 592 | |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 593 | benchmark_results = extract_benchmark_results( |
| 594 | parse_json(args.benchmark_results_dir) |
| 595 | ) |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 596 | |
| 597 | # Add all benchmark results to the recorder |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 598 | benchmark_result_recorder = GEMMBenchmarkResultRecorder(tol=args.tolerance) |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 599 | for benchmark_result in benchmark_results: |
| 600 | benchmark_result_recorder.add(benchmark_result) |
| 601 | |
| 602 | if args.debug: |
| 603 | recorder_sum_level = GEMMBenchmarkResultRecorder.SummaryLevel.Detailed |
| 604 | else: |
| 605 | recorder_sum_level = GEMMBenchmarkResultRecorder.SummaryLevel.Short |
| 606 | |
| 607 | # Print overall summary of the recorded results |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 608 | logging.info(benchmark_result_recorder.summary( |
| 609 | sum_level=recorder_sum_level)) |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 610 | |
| 611 | # Get GEMM configuration distributions for each strategy |
| 612 | all_config_dists = benchmark_result_recorder.get_config_distributions() |
| 613 | |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 614 | logging.info("=== Result ===") |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 615 | for strategy, config_dist in all_config_dists.items(): |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 616 | logging.info("Strategy: {}".format(strategy.name)) |
| 617 | logging.debug("GEMM Config, Votes") |
| 618 | for config, freq in config_dist.frequency(): |
| 619 | logging.debug("{}, {}".format(config, freq)) |
| 620 | logging.info( |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 621 | "Best GEMM Config: {} with std: {}".format( |
| 622 | config_dist.best_config(), config_dist.std() |
| 623 | ) |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 624 | ) |
| 625 | |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 626 | # Save the recorded results to JSON files in output directory |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 627 | if args.output_dir is not None: |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 628 | benchmark_result_recorder.save_to_jsons( |
| 629 | args.output_dir, only_best_config=(not args.debug) |
| 630 | ) |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 631 | |
| 632 | |
| 633 | if __name__ == "__main__": |
| 634 | parser = argparse.ArgumentParser(description="CL GEMM Tuner") |
| 635 | parser.add_argument( |
| 636 | "-b", |
| 637 | "--benchmark_results", |
| 638 | dest="benchmark_results_dir", |
| 639 | metavar="PATH", |
| 640 | action="store", |
| 641 | type=str, |
| 642 | help="Path to benchmark result directory, where benchmark result json files have a file \ |
| 643 | extension of '{}'".format( |
| 644 | BENCHMARK_RESULT_JSON_EXTENSION |
| 645 | ), |
| 646 | required=True, |
| 647 | ) |
| 648 | parser.add_argument( |
| 649 | "-o", |
| 650 | "--output_dir", |
| 651 | dest="output_dir", |
| 652 | metavar="PATH", |
| 653 | action="store", |
| 654 | type=str, |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 655 | help="Path to directory that holds output JSON files. One for strategy selection and one per strategy for GEMM config selection", |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 656 | ) |
| 657 | parser.add_argument( |
SiCong Li | 75041a1 | 2019-11-05 10:43:06 +0000 | [diff] [blame] | 658 | "-t", |
| 659 | "--tolerance", |
| 660 | action="store", |
| 661 | type=float, |
| 662 | default=0.01, |
| 663 | help="For testing if two GEMMConfigs are equivalent in terms of performance. The tolerance is OpenCL timer in\ |
| 664 | milliseconds. Recommended value: <= 0.1 ms", |
| 665 | ) |
| 666 | parser.add_argument( |
Eren Kopuz | a0bf913 | 2020-06-24 17:29:38 +0100 | [diff] [blame] | 667 | "-D", |
| 668 | "--debug", |
| 669 | dest="debug", |
| 670 | action="store_true", |
| 671 | help="Enable script debugging output", |
SiCong Li | e36b526 | 2019-10-01 19:26:00 +0100 | [diff] [blame] | 672 | ) |
| 673 | args = parser.parse_args() |
| 674 | logging_level = logging.DEBUG if args.debug else logging.INFO |
| 675 | logging.basicConfig(level=logging_level) |
| 676 | logging.debug("Arguments: {}".format(args)) |
| 677 | main(args) |