blob: fe177eb87a4a7278a2b36ec70772a6229b2e67e2 [file] [log] [blame]
# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for the CLI options."""
from __future__ import annotations
import argparse
from pathlib import Path
from typing import Any
from typing import Callable
from typing import Sequence
from mlia.backend.corstone import is_corstone_backend
from mlia.backend.manager import get_available_backends
from mlia.core.common import AdviceCategory
from mlia.core.typing import OutputFormat
from mlia.target.registry import builtin_profile_names
from mlia.target.registry import registry as target_registry
DEFAULT_PRUNING_TARGET = 0.5
DEFAULT_CLUSTERING_TARGET = 32
def add_check_category_options(parser: argparse.ArgumentParser) -> None:
"""Add check category type options."""
parser.add_argument(
"--performance", action="store_true", help="Perform performance checks."
)
parser.add_argument(
"--compatibility",
action="store_true",
help="Perform compatibility checks. (default)",
)
def add_target_options(
parser: argparse.ArgumentParser,
supported_advice: Sequence[AdviceCategory] | None = None,
required: bool = True,
) -> None:
"""Add target specific options."""
target_profiles = builtin_profile_names()
if supported_advice:
def is_advice_supported(profile: str, advice: Sequence[AdviceCategory]) -> bool:
"""
Collect all target profiles that support the advice.
This means target profiles that...
- have the right target prefix, e.g. "ethos-u55..." to avoid loading
all target profiles
- support any of the required advice
"""
for target, info in target_registry.items.items():
if profile.startswith(target):
return any(info.is_supported(adv) for adv in advice)
return False
target_profiles = [
profile
for profile in target_profiles
if is_advice_supported(profile, supported_advice)
]
target_group = parser.add_argument_group("target options")
target_group.add_argument(
"-t",
"--target-profile",
required=required,
help="Built-in target profile or path to the custom target profile. "
f"Built-in target profiles are {', '.join(target_profiles)}. "
"Target profile that will set the target options "
"such as target, mac value, memory mode, etc. "
"For the values associated with each target profile "
"please refer to the documentation. ",
)
def add_multi_optimization_options(parser: argparse.ArgumentParser) -> None:
"""Add optimization specific options."""
multi_optimization_group = parser.add_argument_group("optimization options")
multi_optimization_group.add_argument(
"--pruning", action="store_true", help="Apply pruning optimization."
)
multi_optimization_group.add_argument(
"--clustering", action="store_true", help="Apply clustering optimization."
)
multi_optimization_group.add_argument(
"--pruning-target",
type=float,
help="Sparsity to be reached during optimization "
f"(default: {DEFAULT_PRUNING_TARGET})",
)
multi_optimization_group.add_argument(
"--clustering-target",
type=int,
help="Number of clusters to reach during optimization "
f"(default: {DEFAULT_CLUSTERING_TARGET})",
)
def add_model_options(parser: argparse.ArgumentParser) -> None:
"""Add model specific options."""
parser.add_argument("model", help="TensorFlow Lite model or Keras model")
def add_output_options(parser: argparse.ArgumentParser) -> None:
"""Add output specific options."""
output_group = parser.add_argument_group("output options")
output_group.add_argument(
"--json",
action="store_true",
help=("Print the output in JSON format."),
)
def add_debug_options(parser: argparse.ArgumentParser) -> None:
"""Add debug options."""
debug_group = parser.add_argument_group("debug options")
debug_group.add_argument(
"-d",
"--debug",
default=False,
action="store_true",
help="Produce verbose output",
)
def add_keras_model_options(parser: argparse.ArgumentParser) -> None:
"""Add model specific options."""
model_group = parser.add_argument_group("Keras model options")
model_group.add_argument("model", help="Keras model")
def add_backend_install_options(parser: argparse.ArgumentParser) -> None:
"""Add options for the backends configuration."""
def valid_directory(param: str) -> Path:
"""Check if passed string is a valid directory path."""
if not (dir_path := Path(param)).is_dir():
parser.error(f"Invalid directory path {param}")
return dir_path
parser.add_argument(
"--path", type=valid_directory, help="Path to the installed backend"
)
parser.add_argument(
"--i-agree-to-the-contained-eula",
default=False,
action="store_true",
help=argparse.SUPPRESS,
)
parser.add_argument(
"--force",
default=False,
action="store_true",
help="Force reinstalling backend in the specified path",
)
parser.add_argument(
"--noninteractive",
default=False,
action="store_true",
help="Non interactive mode with automatic confirmation of every action",
)
parser.add_argument(
"name",
help="Name of the backend to install",
)
def add_backend_uninstall_options(parser: argparse.ArgumentParser) -> None:
"""Add options for the backends configuration."""
parser.add_argument(
"name",
help="Name of the installed backend",
)
def add_backend_options(
parser: argparse.ArgumentParser, backends_to_skip: list[str] | None = None
) -> None:
"""Add evaluation options."""
available_backends = get_available_backends()
def only_one_corstone_checker() -> Callable:
"""
Return a callable to check that only one Corstone backend is passed.
Raises an exception when more than one Corstone backend is passed.
"""
num_corstones = 0
def check(backend: str) -> str:
"""Count Corstone backends and raise an exception if more than one."""
nonlocal num_corstones
if is_corstone_backend(backend):
num_corstones = num_corstones + 1
if num_corstones > 1:
raise argparse.ArgumentTypeError(
"There must be only one Corstone backend in the argument list."
)
return backend
return check
# Remove backends to skip
if backends_to_skip:
available_backends = [
x for x in available_backends if x not in backends_to_skip
]
evaluation_group = parser.add_argument_group("backend options")
evaluation_group.add_argument(
"-b",
"--backend",
help="Backends to use for evaluation.",
action="append",
choices=available_backends,
type=only_one_corstone_checker(),
)
def add_output_directory(parser: argparse.ArgumentParser) -> None:
"""Add parameter for the output directory."""
parser.add_argument(
"--output-dir",
type=Path,
help="Path to the directory where MLIA will create "
"output directory 'mlia-output' "
"for storing artifacts, e.g. logs, target profiles and model files. "
"If not specified then 'mlia-output' directory will be created "
"in the current working directory.",
)
def parse_optimization_parameters(
pruning: bool = False,
clustering: bool = False,
pruning_target: float | None = None,
clustering_target: int | None = None,
layers_to_optimize: list[str] | None = None,
) -> list[dict[str, Any]]:
"""Parse provided optimization parameters."""
opt_types = []
opt_targets = []
if clustering_target and not clustering:
raise argparse.ArgumentError(
None,
"To enable clustering optimization you need to include the "
"`--clustering` flag in your command.",
)
if not pruning_target:
pruning_target = DEFAULT_PRUNING_TARGET
if not clustering_target:
clustering_target = DEFAULT_CLUSTERING_TARGET
if (pruning is False and clustering is False) or pruning:
opt_types.append("pruning")
opt_targets.append(pruning_target)
if clustering:
opt_types.append("clustering")
opt_targets.append(clustering_target)
optimizer_params = [
{
"optimization_type": opt_type.strip(),
"optimization_target": float(opt_target),
"layers_to_optimize": layers_to_optimize,
}
for opt_type, opt_target in zip(opt_types, opt_targets)
]
return optimizer_params
def get_target_profile_opts(target_args: dict | None) -> list[str]:
"""Get non default values passed as parameters for the target profile."""
if not target_args:
return []
parser = argparse.ArgumentParser()
add_target_options(parser, required=False)
args = parser.parse_args([])
params_name = {
action.dest: param_name
for param_name, action in parser._option_string_actions.items() # pylint: disable=protected-access
}
non_default = [
arg_name
for arg_name, arg_value in target_args.items()
if arg_name in args and vars(args)[arg_name] != arg_value
]
def construct_param(name: str, value: Any) -> list[str]:
"""Construct parameter."""
if isinstance(value, list):
return [str(item) for v in value for item in [name, v]]
return [name, str(value)]
return [
item
for name in non_default
for item in construct_param(params_name[name], target_args[name])
]
def get_output_format(args: argparse.Namespace) -> OutputFormat:
"""Return the OutputFormat depending on the CLI flags."""
output_format: OutputFormat = "plain_text"
if "json" in args and args.json:
output_format = "json"
return output_format