COMPMID-2567: Create a python script to parse the CLGEMM benchmarks and
return the optimal configuration

* Fix GEMM Reshaped example by adding the reshape lhs kernel.
* Extend runner shell script to save result files with a defined
file extension
* Extend runner shell script to print out progress and time
* Add python script
* Update README

Change-Id: I484ec8945aded4341743bc1024820523392b8ce3
Signed-off-by: SiCong Li <sicong.li@arm.com>
Reviewed-on: https://review.mlplatform.org/c/2122
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/examples/gemm_tuner/GemmTuner.py b/examples/gemm_tuner/GemmTuner.py
new file mode 100644
index 0000000..8093ad0
--- /dev/null
+++ b/examples/gemm_tuner/GemmTuner.py
@@ -0,0 +1,522 @@
+# Copyright (c) 2019 ARM Limited.
+#
+# SPDX-License-Identifier: MIT
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to
+# deal in the Software without restriction, including without limitation the
+# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+# sell copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+#!/usr/bin/python3
+
+import argparse
+import csv
+import json
+import logging
+import math
+import os
+from collections import Counter, defaultdict, deque, namedtuple
+from enum import Enum
+from pathlib import Path
+from typing import Deque, Dict, Generator, List, NamedTuple, Tuple, Union
+
+################################################################################
+# Types
+################################################################################
+
+# Gemm strategy
+Strategy = Enum("Strategy", ["Native", "ReshapedOnlyRHS", "Reshaped"])
+
+# Gemm parameter
+class GEMMParam(NamedTuple):
+    M: int  # Number of lhs matrix rows
+    N: int  # Number of rhs matrix columns
+    K: int  # Number of lhs matrix columns/rhs matrix rows
+    B: int  # Batch size
+
+    @staticmethod
+    def parse_from_strs(*args):
+        return GEMMParam(*map(int, args))
+
+
+# Gemm configuration for strategy Native
+class NativeGEMMConfig(NamedTuple):
+    m0: int  # Number of rows processed by the matrix multiplication
+    n0: int  # Number of columns processed by the matrix multiplication
+    k0: int  # Number of partial accumulations performed by the matrix multiplication
+
+    @staticmethod
+    def parse_from_strs(*args):
+        *mnk, = map(int, args)
+        return NativeGEMMConfig(*mnk)
+
+
+# Gemm configuration for strategy Reshaped Only RHS
+class ReshapedOnlyRHSGEMMConfig(NamedTuple):
+    m0: int  # Number of rows processed by the matrix multiplication
+    n0: int  # Number of columns processed by the matrix multiplication
+    k0: int  # Number of partial accumulations performed by the matrix multiplication
+    h0: int  # Number of horizontal blocks of size (k0xn0) stored on the same output row
+    interleave_rhs: bool  # Interleave rhs matrix (1) / Do not interleave rhs matrix (0)
+    transpose_rhs: bool  # Transpose rhs matrix but not lhs matrix (1) / Do not transpose rhs matrix but do transpose lhs matrix (0)
+
+    @staticmethod
+    def parse_from_strs(*args):
+        *mnkh, interleave_rhs, transpose_rhs = map(int, args)
+        interleave_rhs = interleave_rhs == 1
+        transpose_rhs = transpose_rhs == 1
+        return ReshapedOnlyRHSGEMMConfig(*mnkh, interleave_rhs, transpose_rhs)
+
+
+# Gemm configuration for strategy Reshaped
+class ReshapedGEMMConfig(NamedTuple):
+    m0: int  # Number of rows processed by the matrix multiplication
+    n0: int  # Number of columns processed by the matrix multiplication
+    k0: int  # Number of partial accumulations performed by the matrix multiplication
+    v0: int  # Number of vertical blocks of size (m0xk0) stored on the same output row
+    h0: int  # Number of horizontal blocks of size (k0xn0) stored on the same output row
+    interleave_lhs: bool  # Interleave lhs matrix (1) / Do not interleave lhs matrix (0)
+    interleave_rhs: bool  # Interleave rhs matrix (1) / Do not interleave rhs matrix (0)
+    transpose_rhs: bool  # Transpose rhs matrix but not lhs matrix (1) / Do not transpose rhs matrix but do transpose lhs matrix (0)
+
+    @staticmethod
+    def parse_from_strs(*args):
+        *mnkvh, interleave_lhs, interleave_rhs, transpose_rhs = map(int, args)
+        interleave_lhs = interleave_lhs == 1
+        interleave_rhs = interleave_rhs == 1
+        transpose_rhs = transpose_rhs == 1
+        return ReshapedGEMMConfig(*mnkvh, interleave_lhs, interleave_rhs, transpose_rhs)
+
+
+# Measurement we take from the benchmark result.
+class Measurement(NamedTuple):
+    opencl_timer_ms: float
+
+    def is_better_than(self, other):
+        return self < other
+
+    def __add__(self, other):
+        return Measurement(self.opencl_timer_ms + other.opencl_timer_ms)
+
+    def __sub__(self, other):
+        return Measurement(self.opencl_timer_ms - other.opencl_timer_ms)
+
+    def __mul__(self, other):
+        return Measurement(self.opencl_timer_ms * other.opencl_timer_ms)
+
+    def __floordiv__(self, other):
+        return Measurement(self.opencl_timer_ms // other.opencl_timer_ms)
+
+    def __truediv__(self, other):
+        return Measurement(self.opencl_timer_ms / other.opencl_timer_ms)
+
+    def __pow__(self, power):
+        return Measurement(self.opencl_timer_ms ** power)
+
+
+# GEMMConfig Type
+GEMMConfigT = Union[NativeGEMMConfig, ReshapedOnlyRHSGEMMConfig, ReshapedGEMMConfig]
+
+
+# Representation of the benchmark result from a single experiment
+class BenchmarkResult(NamedTuple):
+    gemm_param: GEMMParam
+    strategy: Strategy
+    gemm_config: GEMMConfigT
+    measurement: Measurement
+
+
+# Representation of a single row of BenchmarkResult in CSV
+# NOTE: In the CSV representation, we merge all fields of Gemm Config into a single field "GEMMConfig", but keep the
+# fields of GEMMParam and Measurement
+# The example entry including header would look like:
+# M   , N , K  , B, Strategy         , GEMMConfig       , OpenCLTimer_MS
+# 1225, 32, 192, 1, Reshaped         , 4-4-4-3-1-1-1-0  , 0.3309
+BenchmarkResultCSVRow = namedtuple(
+    "BenchmarkResultCSVRow", GEMMParam._fields + ("Strategy", "GEMMConfig") + Measurement._fields
+)
+
+
+def benchmark_result_2_csv_row(result: BenchmarkResult) -> BenchmarkResultCSVRow:
+    """ Convert a BenchmarkResult into its CSV row form """
+    return BenchmarkResultCSVRow(
+        *(result.gemm_param + (result.strategy.name, str(result.gemm_config)) + result.measurement)
+    )
+
+
+class GEMMBenchmarkResultRecorder:
+    """ A recorder that records and organises GEMM Benchmark results, and produces various reports on the record.
+    """
+
+    SummaryLevel = Enum("SummaryLevel", ["Short", "Detailed"])
+
+    def __init__(self):
+        """ Initializer
+        """
+        # Record that holds all recorded benchmark results.
+        # Indexed by (GEMMParam, Strategy) and each such pair maps to a deque of (GEMMConfig, Measurements),
+        # with the best one always at the front (index 0) of the deque
+        self._benchmark_result_record: Dict[
+            Tuple[GEMMParam, Strategy], Deque[Tuple[GEMMConfig, Measurements]]
+        ] = {}
+        # Strategies recorded
+        self._strategies = set()
+
+    def add(self, benchmark_result: BenchmarkResult):
+        """ Add a benchmark result to the record.
+        Keep the best gemm config at the front of the deque.
+        """
+        gemm_param, strategy, gemm_config, measurement = benchmark_result
+        # Update strategies encoutnered
+        self._strategies.add(strategy)
+        # Update the best configuration of the given gemm param
+        configs_with_measurements = self._benchmark_result_record.setdefault(
+            (gemm_param, strategy), deque([])
+        )
+        if len(configs_with_measurements) == 0:
+            configs_with_measurements.append((gemm_config, measurement))
+        else:
+            best_config, best_measurement = configs_with_measurements[0]
+            if measurement.is_better_than(best_measurement):
+                configs_with_measurements.appendleft((gemm_config, measurement))
+            else:
+                configs_with_measurements.append((gemm_config, measurement))
+
+    def get_config_distributions(self):
+        """ Return GEMMConfigDistribution for each strategy
+        """
+        gemm_config_distributions: Dict[Strategy, GEMMConfigDistribution] = defaultdict(
+            GEMMConfigDistribution
+        )
+        for benchmark_result in self.get_record(only_best_config=True):
+            gemm_param, strategy, gemm_config, measurement = benchmark_result
+            gemm_config_distributions[strategy].add(benchmark_result)
+        return gemm_config_distributions
+
+    def save_to_csvs(self, out_dir, only_best_config=True):
+        """ Save records to an output directory of csv files.
+        The directory is organized such that each strategy gets its own CSV file.
+        """
+        if not os.path.exists(out_dir):
+            logging.info("Output directory {} does not exist. Creating...".format(out_dir))
+            os.mkdir(out_dir)
+        for strategy in self._strategies:
+            out_csv_path = os.path.join(out_dir, strategy.name)
+            if os.path.exists(out_csv_path):
+                overwrite = (
+                    input(
+                        "Output CSV {} already exists. Overwrite? [Y/N]: ".format(out_csv_path)
+                    ).lower()
+                    == "y"
+                )
+                if not overwrite:
+                    logging.info("Skipping {}".format(out_csv_path))
+                    continue
+            logging.info("Saving csv file to {}".format(out_csv_path))
+            with open(out_csv_path, "w") as f:
+                csv_writer = csv.DictWriter(f, fieldnames=BenchmarkResultCSVRow._fields)
+                csv_writer.writeheader()
+                csv_writer.writerows(
+                    benchmark_result_2_csv_row(res)._asdict()
+                    for res in self.get_record(only_best_config)
+                    if res.strategy == strategy
+                )
+            logging.info("Saved")
+
+    def summary(self, sum_level=SummaryLevel.Short):
+        """ Return the summary string of the record
+        """
+        num_raw_records = sum(1 for _ in self.get_record(only_best_config=False))
+        gemm_params_per_strategy = defaultdict(list)
+        for gemm_param, strategy, _, _ in self.get_record(only_best_config=True):
+            gemm_params_per_strategy[strategy].append(gemm_param)
+        global_summary = f"""
+=== {self.__class__.__name__} Summary ===
+[Global]
+Strategies recorded: {", ".join(map(lambda s: s.name, self._strategies))}
+Total number of results recorded: {num_raw_records}
+
+[Per strategy]
+        """
+        strategy_summaries = []
+        for strategy in gemm_params_per_strategy:
+            summary = f"""
+Strategy {strategy.name}:
+GEMM parameters:
+    Number of: {len(gemm_params_per_strategy[strategy])}
+            """
+            if sum_level == self.__class__.SummaryLevel.Detailed:
+                summary += f"""
+    Content: {gemm_params_per_strategy[strategy]}
+                """
+            strategy_summaries.append(summary)
+        return global_summary + "".join(strategy_summaries)
+
+    def get_record(self, only_best_config=True) -> Generator[BenchmarkResult, None, None]:
+        """ Return an iterator that iterates over the record.
+        """
+        for (
+            (gemm_param, strategy),
+            configs_with_measurements,
+        ) in self._benchmark_result_record.items():
+            if only_best_config:
+                best_gemm_config, best_measurement = configs_with_measurements[0]
+                yield BenchmarkResult(gemm_param, strategy, best_gemm_config, best_measurement)
+            else:
+                for gemm_config, measurement in configs_with_measurements:
+                    yield BenchmarkResult(gemm_param, strategy, gemm_config, measurement)
+
+
+class GEMMConfigDistribution:
+    """ A representation of the GEMM Configuration distribution produced by the GEMMBenchmarkResultRecorder.
+    """
+
+    def __init__(self):
+        """ Initializer
+        """
+        self._gemm_config_dist: Dict[GEMMConfig, List[Tuple[GEMMParam, Measurement]]] = defaultdict(
+            list
+        )
+        self._gemm_config_freq = Counter()
+
+    def add(self, benchmark_result: BenchmarkResult):
+        """ Add a benchmark result to the distribution
+        """
+        gemm_param, _, gemm_config, measurement = benchmark_result
+        self._gemm_config_dist[gemm_config].append((gemm_param, measurement))
+        self._gemm_config_freq[gemm_config] += 1
+
+    def get_measurement(self, gemm_config, measure=min):
+        """ Get measurement of a gemm_config
+        """
+        return measure(list(zip(*self._gemm_config_dist[gemm_config]))[1])
+
+    def distribution(self):
+        return self._gemm_config_dist
+
+    def frequency(self):
+        """ Get the frequency of each (best) gemm config recorded
+        """
+        return self._gemm_config_freq.copy()
+
+    def best_config(self):
+        """ Get the overall best config, as voted by all benchmark results.
+        """
+        return self._gemm_config_freq.most_common(1)
+
+    def std(self):
+        """ Get the standard deviation as a measure of dispersion of the distribution. We should aim for higher values
+        as they indicate there is high variation in the distribution. Thus the evidence of the best config is stronger.
+        """
+        freqs = self._gemm_config_freq.values()
+        if len(freqs) == 0:
+            return 0
+        mean_freq = sum(freqs) / len(freqs)
+        return math.sqrt(sum((freq - mean_freq) ** 2 for freq in freqs) / len(freqs))
+
+
+################################################################################
+# Globals
+################################################################################
+
+# Gemm config type factory
+# Produces a GEMMConfig type specific to a Strategy
+GEMM_CONFIG_FACTORY = {
+    Strategy.Native: NativeGEMMConfig,
+    Strategy.ReshapedOnlyRHS: ReshapedOnlyRHSGEMMConfig,
+    Strategy.Reshaped: ReshapedGEMMConfig,
+}
+
+# Mapping from example binary name to Strategy
+# Assume 1-to-1 mapping
+EXAMPLE_FILE_2_STRATEGY = {
+    "benchmark_cl_gemm_native": Strategy.Native,
+    "benchmark_cl_gemm_reshaped_rhs_only": Strategy.ReshapedOnlyRHS,
+    "benchmark_cl_gemm_reshaped": Strategy.Reshaped,
+}
+
+# Gemm example arguments type factory
+# Produces a Gemm_Example_Args type specific to a Strategy
+# Gemm example arguments consist of:
+#           GEMMParam + GEMMConfig
+#   in that order.
+# For example, the example args of running a reshaped rhs only example could be:
+#   100,100,100,1, 4, 4, 4, 1,             1,            1
+#   M  ,N  ,K,  B,m0,n0,k0,h0,interleave_rhs,transpose_rhs
+#   <-GEMMParam-><-------------GEMMConfig-------------->
+# Note that the test strategy_name == strategy.name is in place to avoid unwanted enum aliases
+GEMM_EXAMPLE_ARGS_FACTORY = {
+    strategy: namedtuple(
+        "{}_Gemm_Example_Args".format(strategy_name),
+        GEMMParam._fields + GEMM_CONFIG_FACTORY[strategy]._fields,
+    )
+    for strategy_name, strategy in Strategy.__members__.items()
+    if strategy_name == strategy.name
+}
+
+# File extension used for benchmark result json files
+BENCHMARK_RESULT_JSON_EXTENSION = "gemmtuner_benchmark"
+
+################################################################################
+# Functions
+################################################################################
+
+
+def parse_benchmark_commandline(commandline: str) -> Dict[str, str]:
+    """ Parse the benchmark example command-line string into a dictionary of command-line agruments
+    """
+    args = commandline.split()
+    # Discard program name
+    args = args[1:]
+    # Split into a list of (argument name, argument value)
+    args = map(lambda arg: arg.split("="), args)
+
+    def transform(_name):
+        # Strip '-'/"--" if it exists
+        _name = _name.lstrip("-")
+        return _name
+
+    return {transform(name): val for name, val in args}
+
+
+def extract_benchmark_results(json_results: Dict) -> Generator[BenchmarkResult, None, None]:
+    """ Parse the benchmark result and extract relevant information, namely:
+        GEMM param,
+        Strategy,
+        GEMM config,
+        Measurements
+    """
+    for json_res in json_results:
+        # Get example test and test data.
+        # There should only be 1 test per run
+        example_tests = list(json_res["tests"].items())
+        assert len(example_tests) == 1
+        example_fn, example_test_data = example_tests[0]
+
+        # Process example file name
+        example_fn = example_fn.split(os.path.sep)[-1]
+
+        # Get strategy
+        strategy = EXAMPLE_FILE_2_STRATEGY[example_fn]
+
+        # Get gemm params + gemm configs from example args
+        benchmark_args = parse_benchmark_commandline(json_res["CommandLine"])
+        Gemm_Example_Args_T = GEMM_EXAMPLE_ARGS_FACTORY[strategy]
+        example_args = Gemm_Example_Args_T(*(benchmark_args["example_args"].split(",")))
+        # Gemm_Example_Arg consists of GEMMParam first and then GEMMConfig (in that order)
+        gemm_param_fields_len = len(GEMMParam._fields)
+        gemm_param = GEMMParam.parse_from_strs(*example_args[:gemm_param_fields_len])
+        GEMMConfig = GEMM_CONFIG_FACTORY[strategy]
+        gemm_config = GEMMConfig.parse_from_strs(*example_args[gemm_param_fields_len:])
+
+        # Get OpenCL_Time_Ms stats
+        measurements = list(example_test_data["measurements"].items())
+        # There should only be 1 instrument per run
+        assert len(measurements) == 1
+        measurement_instrument, data = measurements.pop()
+        # Get instrument name and assert that it is the one we expect
+        measurement_instrument_name = measurement_instrument.split("/")[0]
+        assert measurement_instrument_name == "OpenCLTimer"
+        # Take the MINIMUM of the raw data as the measurement value
+        measurement_val = min(data["raw"])
+        measurement = Measurement(measurement_val)
+
+        yield BenchmarkResult(gemm_param, strategy, gemm_config, measurement)
+
+
+def parse_json(dir_name):
+    """ Glob all benchmark result json files and parse them into json objects (dicts).
+    """
+    for res_fn in Path(dir_name).rglob("*.{}".format(BENCHMARK_RESULT_JSON_EXTENSION)):
+        with open(res_fn) as res_fp:
+            yield json.load(res_fp)
+
+
+################################################################################
+# Main
+################################################################################
+
+
+def main(args):
+    logging.info("Searching best gemm configurations from {}".format(args.benchmark_results_dir))
+
+    benchmark_results = extract_benchmark_results(parse_json(args.benchmark_results_dir))
+
+    # Add all benchmark results to the recorder
+    benchmark_result_recorder = GEMMBenchmarkResultRecorder()
+    for benchmark_result in benchmark_results:
+        benchmark_result_recorder.add(benchmark_result)
+
+    if args.debug:
+        recorder_sum_level = GEMMBenchmarkResultRecorder.SummaryLevel.Detailed
+    else:
+        recorder_sum_level = GEMMBenchmarkResultRecorder.SummaryLevel.Short
+
+    # Print overall summary of the recorded results
+    print(benchmark_result_recorder.summary(sum_level=recorder_sum_level))
+
+    # Get GEMM configuration distributions for each strategy
+    all_config_dists = benchmark_result_recorder.get_config_distributions()
+
+    print("=== Result ===")
+    for strategy, config_dist in all_config_dists.items():
+        print("Strategy: {}".format(strategy.name))
+        print("GEMM Config votes: ")
+        print("GEMM Config: Best measurement, Vote")
+        for config, freq in config_dist.frequency().items():
+            print(config, end=": ")
+            print(config_dist.get_measurement(config), freq, sep=",")
+        print(
+            "Best GEMM Config: {} with std: {}".format(config_dist.best_config(), config_dist.std())
+        )
+
+    # Save the recorded results to csv files in output directory
+    if args.output_dir is not None:
+        benchmark_result_recorder.save_to_csvs(args.output_dir, only_best_config=(not args.debug))
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="CL GEMM Tuner")
+    parser.add_argument(
+        "-b",
+        "--benchmark_results",
+        dest="benchmark_results_dir",
+        metavar="PATH",
+        action="store",
+        type=str,
+        help="Path to benchmark result directory, where benchmark result json files have a file \
+                                extension of '{}'".format(
+            BENCHMARK_RESULT_JSON_EXTENSION
+        ),
+        required=True,
+    )
+    parser.add_argument(
+        "-o",
+        "--output_dir",
+        dest="output_dir",
+        metavar="PATH",
+        action="store",
+        type=str,
+        help="Path to directory that holds output csv files. One per strategy",
+    )
+    parser.add_argument(
+        "-D", "--debug", dest="debug", action="store_true", help="Enable script debugging output"
+    )
+    args = parser.parse_args()
+    logging_level = logging.DEBUG if args.debug else logging.INFO
+    logging.basicConfig(level=logging_level)
+    logging.debug("Arguments: {}".format(args))
+    main(args)
diff --git a/examples/gemm_tuner/README.md b/examples/gemm_tuner/README.md
index 789dc2b..3d5b637 100644
--- a/examples/gemm_tuner/README.md
+++ b/examples/gemm_tuner/README.md
@@ -2,10 +2,27 @@
 
 ## Pre-requisite
 (Preferably) bash shell
-benchmark examples
+Built benchmark examples
+python >= 3.6
 
 ## Usage
-Run gemm examples of a selected strategy, over all pre-defined tunable configurations, on a set of gemm shapes provided
-by the user. Save the benchmark results to json files in an output directory.
+The tuning consists of 2 steps:
 
-[$SHELL] ./benchmark_gemm_examples.sh -e \<example_binary_dir\> -s \<strategy\> -g \<gemm_shape_file\> -c \<gemm_config_file\> [-o \<out_dir\>, [-i \<iteration\>]]
+1. Run benchmarks: Run the runner shell script (benchmark_gemm_examples.sh) on
+your target device. Note that all the built benchmark examples have to be
+present on your target device prior to running. The script will run the selected
+strategy, over all pre-defined tunable configurations, on a set of gemm shapes
+provided by the user, and then save the benchmark results to json files in an
+output directory.
+
+[$SHELL] ./benchmark_gemm_examples.sh -s \<strategy\> -e \<example_binary_dir\>
+    -g \<gemm_shape_file\> -c \<gemm_config_file\> [-o \<out_dir\>]
+
+2. Run analyser: Run the python script (GemmTuner.py) on your host device.
+You'll need to transfer all the benchmark result json files generated from the
+previous step to your host machine beforehand. Note that this requires
+python >= 3.6. The script will output the best configuration, along with some
+analysis statistics for each strategy, and optionally save the parsed benchmark
+results into csv files (one for each strategy) for further analysis.
+
+python GemmTuner.py -b \<benchmark_results_dir\> [-o \<out_dir\>]
diff --git a/examples/gemm_tuner/benchmark_gemm_examples.sh b/examples/gemm_tuner/benchmark_gemm_examples.sh
index fd5f71d..95bb367 100755
--- a/examples/gemm_tuner/benchmark_gemm_examples.sh
+++ b/examples/gemm_tuner/benchmark_gemm_examples.sh
@@ -299,9 +299,10 @@
 }
 
 #######################################
-# Run all tunable configurations and all input configurations
+# Run a single example with all tunable gemm configurations on all gemm parameters
 # Globals:
 #   OUT_DIR
+#   OUT_EXTENSION
 #   EXAMPLE_BIN_DIR
 #   NUM_ITERATION
 #   GEMM_CONFIGS_FILE
@@ -315,20 +316,73 @@
   local example_bin=$1
   echo "Running all configs for ${example_bin}" 1>&2
   local example_args
-  local test_id=1
+  local expr_count=1
+  # Total number of experiment runs scheduled for this session
+  local total_num_experiment
+  local num_params
+  local num_configs
+  num_params=$( wc -l ${GEMM_SHAPES_FILE} | cut -d " " -f 1)
+  num_configs=$( wc -l ${GEMM_CONFIGS_FILE} | cut -d " " -f 1 )
+  (( total_num_experiment=${num_params} * ${num_configs} ))
+  # Time elapsed since the beginning in seconds
+  local time_elapsed_s
+  # Time estimated to finish in seconds
+  local time_est_s
+  echo "Running a total number of ${total_num_experiment} experiments" 1>&2
+
   while read gemm_shape
   do
     while read gemm_config
     do
+      echo "Running..." 1>&2
       example_args="${gemm_shape},${gemm_config}"
-      ${EXAMPLE_BIN_DIR}/${example_bin} --example_args=${example_args} --iterations=${NUM_ITERATION} --json-file=${OUT_DIR}/${test_id} --instruments=OPENCL_TIMER_MS
-      (( test_id++ ))
+      # Run experiment
+      ${EXAMPLE_BIN_DIR}/${example_bin} --example_args=${example_args} --iterations=${NUM_ITERATION} --json-file=${OUT_DIR}/${expr_count}.${OUT_EXTENSION} --instruments=OPENCL_TIMER_MS
+      # Print progress
+      print_progress ${expr_count} ${total_num_experiment}
+      # Print time statistics
+      time_elapsed_s=$SECONDS
+      echo "Time elapsed since beginning: $(( $time_elapsed_s / 60 ))m $(( $time_elapsed_s % 60 ))s" 1>&2
+      (( time_est_s=(${total_num_experiment} - ${expr_count}) * ${time_elapsed_s} / ${expr_count} ))
+      echo "Time estimated to finish: $(( $time_est_s / 60 ))m $(( $time_est_s % 60 ))s" 1>&2
+      (( expr_count++ ))
+      echo "Done." 1>&2
     done < "${GEMM_CONFIGS_FILE}"
   done < "${GEMM_SHAPES_FILE}"
   echo "Finished running all configs for ${example_bin}" 1>&2
   echo "All results saved to ${OUT_DIR}" 1>&2
 }
 
+#######################################
+# Print the progress of the current session
+# Globals:
+#   None
+# Arguments:
+#   current   Current number of items
+#   total     Total number of items
+# Returns:
+#   None
+#######################################
+function print_progress() {
+  local current
+  local total
+  current=$1
+  total=$2
+  # Width of progress bar
+  local width
+  width=20
+  (( current_width= $width * current / total ))
+  echo -n -e "Progress [" 1>&2
+  for i in $(seq 1 ${width}); do
+    if [[ $i -le ${current_width} ]]; then
+      echo -n "#" 1>&2
+    else
+      echo -n " " 1>&2
+    fi
+  done
+  echo  "] $current / $total Experiments" 1>&2
+}
+
 # Functions }}}
 
 # Main: Main script {{{
@@ -341,6 +395,8 @@
 STRATEGY_OPTION=""
 # Path to output directory
 OUT_DIR=${DEFAULT_OUT_DIR}
+# Output benchmark result file extension
+OUT_EXTENSION="gemmtuner_benchmark"
 # Toggle help
 HELP=false
 
@@ -394,6 +450,8 @@
 mkdir ${OUT_DIR}
 
 # Run selected strategy with all configurations
+# Restart the built-in timer
+SECONDS=0
 [ "${STRATEGY_OPTION}" == "native" ] && run $EXAMPLE_BIN_NATIVE
 [ "${STRATEGY_OPTION}" == "reshaped_rhs_only" ] && run $EXAMPLE_BIN_RESHAPED_RHS_ONLY
 [ "${STRATEGY_OPTION}" == "reshaped" ] && run $EXAMPLE_BIN_RESHAPED
diff --git a/examples/gemm_tuner/cl_gemm_reshaped.cpp b/examples/gemm_tuner/cl_gemm_reshaped.cpp
index 6445592..e579ed7 100644
--- a/examples/gemm_tuner/cl_gemm_reshaped.cpp
+++ b/examples/gemm_tuner/cl_gemm_reshaped.cpp
@@ -27,6 +27,7 @@
 
 #include "CommonGemmExampleOptions.h"
 #include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h"
+#include "arm_compute/core/CL/kernels/CLGEMMReshapeLHSMatrixKernel.h"
 #include "arm_compute/core/Helpers.h"
 #include "arm_compute/core/KernelDescriptors.h"
 #include "arm_compute/core/Types.h"
@@ -165,6 +166,8 @@
 }
 
 } // namespace
+// Create function for CLGEMMReshapeLHSMatrixKernel
+using CLGEMMReshapeLHSMatrix = test::CLSynthetizeFunction<CLGEMMReshapeLHSMatrixKernel>;
 // Create function for CLGEMMMatrixMultiplyReshapedKernel
 using CLGEMMMatrixMultiplyReshaped = test::CLSynthetizeFunction<CLGEMMMatrixMultiplyReshapedKernel>;
 
@@ -249,6 +252,9 @@
         // Initialise rhs_reshaped tensor info
         auto_init_if_empty(*rhs_reshaped.info(), rhs.info()->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*rhs.info(), rhs_info)));
 
+        // Configure reshape lhs function
+        reshape_lhs.configure(&lhs, &lhs_reshaped, lhs_info);
+
         // Configure function
         gemm.configure(&lhs_reshaped, &rhs_reshaped, &bias, &dst, alpha, beta, lhs_info, rhs_info, kernel_info);
 
@@ -265,6 +271,7 @@
     void do_run() override
     {
         // Execute the function
+        reshape_lhs.run();
         gemm.run();
 
         // Make sure all the OpenCL jobs are done executing:
@@ -283,6 +290,7 @@
     CLTensor                     bias{};
     CLTensor                     dst{};
     CLTuner                      tuner{};
+    CLGEMMReshapeLHSMatrix       reshape_lhs{};
     CLGEMMMatrixMultiplyReshaped gemm{};
 };