COMPMID-3453: Modify GemmTuner.py to output JSON

- Added functionality to parse multiple OpenCL_ms measurements and sum
- Results are now outputted in 4 JSON files as specified in ticket COMPMID-3453

Change-Id: I9241e9eeaee7c4979e877f87736deaee3cfd38e4
Signed-off-by: Eren Kopuz <eren.kopuz@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3492
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: SiCong Li <sicong.li@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/examples/gemm_tuner/GemmTuner.py b/examples/gemm_tuner/GemmTuner.py
index 8bc0d3a..aab2d55 100644
--- a/examples/gemm_tuner/GemmTuner.py
+++ b/examples/gemm_tuner/GemmTuner.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2019 Arm Limited.
+# Copyright (c) 2019-2020 ARM Limited.
 #
 # SPDX-License-Identifier: MIT
 #
@@ -41,6 +41,8 @@
 Strategy = Enum("Strategy", ["Native", "ReshapedOnlyRHS", "Reshaped"])
 
 # Gemm parameter
+
+
 class GEMMParam(NamedTuple):
     M: int  # Number of lhs matrix rows
     N: int  # Number of rhs matrix columns
@@ -52,7 +54,7 @@
         return GEMMParam(*map(int, args))
 
     def __str__(self):
-        return "-".join(map(str, self))
+        return ",".join(map(str, self))
 
 
 # Gemm configuration for strategy Native
@@ -63,11 +65,11 @@
 
     @staticmethod
     def parse_from_strs(*args):
-        *mnk, = map(int, args)
+        (*mnk,) = map(int, args)
         return NativeGEMMConfig(*mnk)
 
     def __str__(self):
-        return "-".join(map(str, self))
+        return ",".join(map(str, self))
 
 
 # Gemm configuration for strategy Reshaped Only RHS
@@ -75,9 +77,12 @@
     m0: int  # Number of rows processed by the matrix multiplication
     n0: int  # Number of columns processed by the matrix multiplication
     k0: int  # Number of partial accumulations performed by the matrix multiplication
-    h0: int  # Number of horizontal blocks of size (k0xn0) stored on the same output row
-    interleave_rhs: bool  # Interleave rhs matrix (1) / Do not interleave rhs matrix (0)
-    transpose_rhs: bool  # Transpose rhs matrix but not lhs matrix (1) / Do not transpose rhs matrix but do transpose lhs matrix (0)
+    # Number of horizontal blocks of size (k0xn0) stored on the same output row
+    h0: int
+    # Interleave rhs matrix (1) / Do not interleave rhs matrix (0)
+    interleave_rhs: bool
+    # Transpose rhs matrix but not lhs matrix (1) / Do not transpose rhs matrix but do transpose lhs matrix (0)
+    transpose_rhs: bool
 
     @staticmethod
     def parse_from_strs(*args):
@@ -87,7 +92,7 @@
         return ReshapedOnlyRHSGEMMConfig(*mnkh, interleave_rhs, transpose_rhs)
 
     def __str__(self):
-        return "-".join(map(str, self))
+        return ",".join(map(str, self))
 
 
 # Gemm configuration for strategy Reshaped
@@ -95,11 +100,16 @@
     m0: int  # Number of rows processed by the matrix multiplication
     n0: int  # Number of columns processed by the matrix multiplication
     k0: int  # Number of partial accumulations performed by the matrix multiplication
-    v0: int  # Number of vertical blocks of size (m0xk0) stored on the same output row
-    h0: int  # Number of horizontal blocks of size (k0xn0) stored on the same output row
-    interleave_lhs: bool  # Interleave lhs matrix (1) / Do not interleave lhs matrix (0)
-    interleave_rhs: bool  # Interleave rhs matrix (1) / Do not interleave rhs matrix (0)
-    transpose_rhs: bool  # Transpose rhs matrix but not lhs matrix (1) / Do not transpose rhs matrix but do transpose lhs matrix (0)
+    # Number of vertical blocks of size (m0xk0) stored on the same output row
+    v0: int
+    # Number of horizontal blocks of size (k0xn0) stored on the same output row
+    h0: int
+    # Interleave lhs matrix (1) / Do not interleave lhs matrix (0)
+    interleave_lhs: bool
+    # Interleave rhs matrix (1) / Do not interleave rhs matrix (0)
+    interleave_rhs: bool
+    # Transpose rhs matrix but not lhs matrix (1) / Do not transpose rhs matrix but do transpose lhs matrix (0)
+    transpose_rhs: bool
 
     @staticmethod
     def parse_from_strs(*args):
@@ -110,40 +120,67 @@
         return ReshapedGEMMConfig(*mnkvh, interleave_lhs, interleave_rhs, transpose_rhs)
 
     def __str__(self):
-        return "-".join(map(str, self))
+        return ",".join(map(str, self))
 
 
 # Measurement we take from the benchmark result.
 class Measurement(NamedTuple):
-    opencl_timer_ms: float
+    opencl_timer_ms_reshape: float
+    opencl_timer_ms_kernel: float
+
+    def get_total_ms(self):
+        return self.opencl_timer_ms_reshape + self.opencl_timer_ms_kernel
 
     def is_close_to(self, other, tol):
-        return math.fabs(self.opencl_timer_ms - other.opencl_timer_ms) < tol
+        return math.fabs(self.get_total_ms() - other.get_total_ms()) < tol
 
     def is_better_than(self, other, tol):
-        return self < other and not self.is_close_to(other)
+        return self.get_total_ms() < other.get_total_ms() and not self.is_close_to(
+            other
+        )
 
     def __add__(self, other):
-        return Measurement(self.opencl_timer_ms + other.opencl_timer_ms)
+        return Measurement(
+            self.opencl_timer_ms_reshape + other.opencl_timer_ms_reshape,
+            self.opencl_timer_ms_kernel + other.opencl_timer_ms_kernel,
+        )
 
     def __sub__(self, other):
-        return Measurement(self.opencl_timer_ms - other.opencl_timer_ms)
+        return Measurement(
+            self.opencl_timer_ms_reshape - other.opencl_timer_ms_reshape,
+            self.opencl_timer_ms_kernel - other.opencl_timer_ms_kernel,
+        )
 
     def __mul__(self, other):
-        return Measurement(self.opencl_timer_ms * other.opencl_timer_ms)
+        return Measurement(
+            self.opencl_timer_ms_reshape * other.opencl_timer_ms_reshape,
+            self.opencl_timer_ms_kernel * other.opencl_timer_ms_kernel,
+        )
 
     def __floordiv__(self, other):
-        return Measurement(self.opencl_timer_ms // other.opencl_timer_ms)
+        return Measurement(
+            self.opencl_timer_ms_reshape // other.opencl_timer_ms_reshape,
+            self.opencl_timer_ms_kernel // other.opencl_timer_ms_kernel,
+        )
 
     def __truediv__(self, other):
-        return Measurement(self.opencl_timer_ms / other.opencl_timer_ms)
+        return Measurement(
+            self.opencl_timer_ms_reshape / other.opencl_timer_ms_reshape,
+            self.opencl_timer_ms_kernel / other.opencl_timer_ms_kernel,
+        )
 
     def __pow__(self, power):
-        return Measurement(self.opencl_timer_ms ** power)
+        return Measurement(
+            self.opencl_timer_ms_reshape ** power, self.opencl_timer_ms_kernel ** power
+        )
+
+    def __str__(self):
+        return ",".join(map(str, self))
 
 
 # GEMMConfig Type
-GEMMConfigT = Union[NativeGEMMConfig, ReshapedOnlyRHSGEMMConfig, ReshapedGEMMConfig]
+GEMMConfigT = Union[NativeGEMMConfig,
+                    ReshapedOnlyRHSGEMMConfig, ReshapedGEMMConfig]
 
 
 # Representation of the benchmark result from a single experiment
@@ -154,24 +191,6 @@
     measurement: Measurement
 
 
-# Representation of a single row of BenchmarkResult in CSV
-# NOTE: In the CSV representation, we merge all fields of Gemm Config into a single field "GEMMConfig", but keep the
-# fields of GEMMParam and Measurement
-# The example entry including header would look like:
-# M   , N , K  , B, Strategy         , GEMMConfig       , OpenCLTimer_MS
-# 1225, 32, 192, 1, Reshaped         , 4-4-4-3-1-1-1-0  , 0.3309
-BenchmarkResultCSVRow = namedtuple(
-    "BenchmarkResultCSVRow", GEMMParam._fields + ("Strategy", "GEMMConfig") + Measurement._fields
-)
-
-
-def benchmark_result_2_csv_row(result: BenchmarkResult) -> BenchmarkResultCSVRow:
-    """ Convert a BenchmarkResult into its CSV row form """
-    return BenchmarkResultCSVRow(
-        *(result.gemm_param + (result.strategy.name, str(result.gemm_config)) + result.measurement)
-    )
-
-
 class GEMMBenchmarkResultRecorder:
     """ A recorder that records and organises GEMM Benchmark results, and produces various reports on the record.
     """
@@ -210,7 +229,9 @@
             best_gc_set = best_gc_sets.setdefault((gemm_param, strategy), [])
             best_gc_set.append((gemm_config, measurement))
             # Sort the best config set (list)
-            best_gc_set = sorted(best_gc_set, key=lambda gc_and_m: gc_and_m[1])
+            best_gc_set = sorted(
+                best_gc_set, key=lambda gc_and_m: gc_and_m[1].get_total_ms()
+            )
             # Filter out configs that are beyond tolerance to the best GEMMConfig's measurement
             best_gc, best_m = best_gc_set[0]
             best_gc_set_new = [
@@ -228,9 +249,14 @@
         """ Get the best GEMMConfig set per GEMMParam per Strategy, and flatten the result into a sequence
         of BenchmarkResults
         """
-        for (gemm_param, strategy), best_gc_sets in self.get_best_gemm_configs().items():
+        for (
+            (gemm_param, strategy),
+            best_gc_sets,
+        ) in self.get_best_gemm_configs().items():
             for best_gemm_config, best_measurement in best_gc_sets:
-                yield BenchmarkResult(gemm_param, strategy, best_gemm_config, best_measurement)
+                yield BenchmarkResult(
+                    gemm_param, strategy, best_gemm_config, best_measurement
+                )
 
     def get_config_distributions(self):
         """ Return GEMMConfigDistribution for each strategy
@@ -244,38 +270,72 @@
 
         return gemm_config_distributions
 
-    def save_to_csvs(self, out_dir, only_best_config=True):
-        """ Save records to an output directory of csv files.
-        The directory is organized such that each strategy gets its own CSV file.
+    def get_best_gemm_strategies(self):
+        """ Get the best Stratey per GEMMParam
+        """
+        all_results: Dict[GEMMParam, List[Tuple[Strategy, Measurement]]] = defaultdict(
+            list
+        )
+
+        best_strategies: Dict[GEMMParam, Strategy] = {}
+
+        for gemm_param, strategy, gemm_config, measurement in self.get_record():
+            all_results[gemm_param].append((strategy, measurement))
+
+        for gemm_param, results_set in all_results.items():
+            # Sort the best results set (list)
+            results_set = sorted(
+                results_set, key=lambda s_and_m: s_and_m[1].get_total_ms()
+            )
+            # Select best Strategy
+            best_s, best_m = results_set[0]
+            best_strategies[gemm_param] = best_s
+
+        return best_strategies
+
+    def save_to_jsons(self, out_dir, only_best_config=True):
+        """ Save records to an output directory of JSON files.
+        The directory is organized such that each strategy gets its own JSON file.
+        The directory also includes a JSON file to define the best strategy per GEMM Param.
         """
         if not os.path.exists(out_dir):
-            logging.info("Output directory {} does not exist. Creating...".format(out_dir))
-            os.mkdir(out_dir)
-        for strategy in self._strategies:
-            out_csv_path = os.path.join(out_dir, strategy.name)
-            if os.path.exists(out_csv_path):
-                overwrite = (
-                    input(
-                        "Output CSV {} already exists. Overwrite? [Y/N]: ".format(out_csv_path)
-                    ).lower()
-                    == "y"
-                )
-                if not overwrite:
-                    logging.info("Skipping {}".format(out_csv_path))
-                    continue
-            logging.info("Saving csv file to {}".format(out_csv_path))
-            record = (
-                self.get_best_gemm_configs_as_sequence() if only_best_config else self.get_record()
+            logging.info(
+                "Output directory {} does not exist. Creating...".format(
+                    out_dir)
             )
-            with open(out_csv_path, "w") as f:
-                csv_writer = csv.DictWriter(f, fieldnames=BenchmarkResultCSVRow._fields)
-                csv_writer.writeheader()
-                csv_writer.writerows(
-                    benchmark_result_2_csv_row(res)._asdict()
-                    for res in record
-                    if res.strategy == strategy
+            os.mkdir(out_dir)
+
+        out_json_path = os.path.join(out_dir, "gemm_type_selection.json")
+        if check_out_path(out_json_path):
+            results = self.get_best_gemm_strategies()
+            results = {str(key): value.name for key, value in results.items()}
+            dump_json(out_json_path, results)
+
+        for strategy in self._strategies:
+            out_json_path = os.path.join(
+                out_dir, ("gemm_config_" + strategy.name.lower() + ".json")
+            )
+            if check_out_path(out_json_path):
+                record = (
+                    self.get_best_gemm_configs_as_sequence()
+                    if only_best_config
+                    else self.get_record()
                 )
-            logging.info("Saved")
+                results = defaultdict(list)
+                for res in record:
+                    if res.strategy == strategy:
+                        results[str(res.gemm_param)].append(
+                            {
+                                "GEMMConfig": str(res.gemm_config),
+                                "OpenCL_Timer_ms_reshape": str(
+                                    res.measurement.opencl_timer_ms_reshape
+                                ),
+                                "OpenCL_Timer_ms_kernel": str(
+                                    res.measurement.opencl_timer_ms_kernel
+                                ),
+                            }
+                        )
+                dump_json(out_json_path, results)
 
     def summary(self, sum_level=SummaryLevel.Short):
         """ Return the summary string of the record
@@ -314,9 +374,9 @@
     def __init__(self):
         """ Initializer
         """
-        self._gemm_config_dist: Dict[GEMMConfig, List[Tuple[GEMMParam, Measurement]]] = defaultdict(
-            list
-        )
+        self._gemm_config_dist: Dict[
+            GEMMConfig, List[Tuple[GEMMParam, Measurement]]
+        ] = defaultdict(list)
         self._gemm_config_freq = Counter()
 
     def add(self, benchmark_result: BenchmarkResult):
@@ -439,30 +499,45 @@
         # Get gemm params + gemm configs from example args
         benchmark_args = parse_benchmark_commandline(json_res["CommandLine"])
         Gemm_Example_Args_T = GEMM_EXAMPLE_ARGS_FACTORY[strategy]
-        example_args = Gemm_Example_Args_T(*(benchmark_args["example_args"].split(",")))
+        example_args = Gemm_Example_Args_T(
+            *(benchmark_args["example_args"].split(",")))
         # Gemm_Example_Arg consists of GEMMParam first and then GEMMConfig (in that order)
         gemm_param_fields_len = len(GEMMParam._fields)
-        gemm_param = GEMMParam.parse_from_strs(*example_args[:gemm_param_fields_len])
+        gemm_param = GEMMParam.parse_from_strs(
+            *example_args[:gemm_param_fields_len])
         GEMMConfig = GEMM_CONFIG_FACTORY[strategy]
-        gemm_config = GEMMConfig.parse_from_strs(*example_args[gemm_param_fields_len:])
+        gemm_config = GEMMConfig.parse_from_strs(
+            *example_args[gemm_param_fields_len:])
 
         # Get OpenCL_Time_Ms stats
         measurements = list(example_test_data["measurements"].items())
-        # There should only be 1 instrument per run
-        assert len(measurements) == 1
-        measurement_instrument, data = measurements.pop()
-        # Get instrument name and assert that it is the one we expect
-        measurement_instrument_name = measurement_instrument.split("/")[0]
-        assert measurement_instrument_name == "OpenCLTimer"
-        # Take either the minimum or the average of the raw data as the measurement value
-        if measurement_method == "min":
-            measurement_val = min(data["raw"])
-        elif measurement_method == "avg":
-            measurement_val = sum(data["raw"]) / len(data["raw"])
-        else:
-            raise ValueError("Invalid measurement method: {}".format(measurement_method))
+        # For reshaped RHS only we have two measurements (one also for the reshape kernel)
+        # Hence we must parse and sum them
+        measurement_ms_reshape = 0
+        measurement_ms_kernel = 0
+        for single_measurement in measurements:
+            measurement_instrument, data = single_measurement
+            # Get instrument name and assert that it is the one we expect
+            measurement_instrument_name = measurement_instrument.split("/")[0]
+            assert measurement_instrument_name == "OpenCLTimer"
+            # Take either the minimum or the average of the raw data as the measurement value
+            if measurement_method == "min":
+                measurement_val = min(data["raw"])
+            elif measurement_method == "avg":
+                measurement_val = sum(data["raw"]) / len(data["raw"])
+            else:
+                raise ValueError(
+                    "Invalid measurement method: {}".format(measurement_method)
+                )
 
-        measurement = Measurement(measurement_val)
+            measurement_type = measurement_instrument.split("/")[1]
+            if "reshape" in measurement_type.split("_"):
+                measurement_ms_reshape = measurement_val
+            else:
+                measurement_ms_kernel = measurement_val
+
+        measurement = Measurement(
+            measurement_ms_reshape, measurement_ms_kernel)
 
         yield BenchmarkResult(gemm_param, strategy, gemm_config, measurement)
 
@@ -475,15 +550,42 @@
             yield json.load(res_fp)
 
 
+def check_out_path(out_path):
+    if os.path.exists(out_path):
+        overwrite = (
+            input(
+                "Output JSON {} already exists. Overwrite? [Y/N]: ".format(
+                    out_path)
+            ).lower()
+            == "y"
+        )
+        if not overwrite:
+            logging.info("Skipping {}".format(out_path))
+            return False
+    logging.info("Saving JSON file to {}".format(out_path))
+    return True
+
+
+def dump_json(out_path, dict):
+    with open(out_path, "w") as f:
+        json.dump(dict, f)
+    logging.info("Saved")
+
+
 ################################################################################
 # Main
 ################################################################################
 
 
 def main(args):
-    logging.info("Searching best gemm configurations from {}".format(args.benchmark_results_dir))
+    logging.info(
+        "Searching best gemm configurations from {}".format(
+            args.benchmark_results_dir)
+    )
 
-    benchmark_results = extract_benchmark_results(parse_json(args.benchmark_results_dir))
+    benchmark_results = extract_benchmark_results(
+        parse_json(args.benchmark_results_dir)
+    )
 
     # Add all benchmark results to the recorder
     benchmark_result_recorder = GEMMBenchmarkResultRecorder(tol=args.tolerance)
@@ -496,7 +598,8 @@
         recorder_sum_level = GEMMBenchmarkResultRecorder.SummaryLevel.Short
 
     # Print overall summary of the recorded results
-    logging.info(benchmark_result_recorder.summary(sum_level=recorder_sum_level))
+    logging.info(benchmark_result_recorder.summary(
+        sum_level=recorder_sum_level))
 
     # Get GEMM configuration distributions for each strategy
     all_config_dists = benchmark_result_recorder.get_config_distributions()
@@ -508,12 +611,16 @@
         for config, freq in config_dist.frequency():
             logging.debug("{}, {}".format(config, freq))
         logging.info(
-            "Best GEMM Config: {} with std: {}".format(config_dist.best_config(), config_dist.std())
+            "Best GEMM Config: {} with std: {}".format(
+                config_dist.best_config(), config_dist.std()
+            )
         )
 
-    # Save the recorded results to csv files in output directory
+    # Save the recorded results to JSON files in output directory
     if args.output_dir is not None:
-        benchmark_result_recorder.save_to_csvs(args.output_dir, only_best_config=(not args.debug))
+        benchmark_result_recorder.save_to_jsons(
+            args.output_dir, only_best_config=(not args.debug)
+        )
 
 
 if __name__ == "__main__":
@@ -538,7 +645,7 @@
         metavar="PATH",
         action="store",
         type=str,
-        help="Path to directory that holds output csv files. One per strategy",
+        help="Path to directory that holds output JSON files. One for strategy selection and one per strategy for GEMM config selection",
     )
     parser.add_argument(
         "-t",
@@ -550,7 +657,11 @@
         milliseconds. Recommended value: <= 0.1 ms",
     )
     parser.add_argument(
-        "-D", "--debug", dest="debug", action="store_true", help="Enable script debugging output"
+        "-D",
+        "--debug",
+        dest="debug",
+        action="store_true",
+        help="Enable script debugging output",
     )
     args = parser.parse_args()
     logging_level = logging.DEBUG if args.debug else logging.INFO