# Copyright (c) 2020-2022, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import argparse
import re
import sys

from generator.tosa_test_gen import TosaTestGen
from serializer.tosa_serializer import dtype_str_to_val
from serializer.tosa_serializer import DTypeNames

OPTION_FP_VALUES_RANGE = "--fp-values-range"


# Used for parsing a comma-separated list of integers in a string
# to an actual list of integers
def str_to_list(in_s, is_float=False):
    """Converts a comma-separated list of string integers to a python list of ints"""
    lst = in_s.split(",")
    out_list = []
    for i in lst:
        val = float(i) if is_float else int(i)
        out_list.append(val)
    return out_list


def auto_int(x):
    """Converts hex/dec argument values to an int"""
    return int(x, 0)


def parseArgs(argv):
    """Parse the command line arguments."""
    if argv is None:
        argv = sys.argv[1:]

    if OPTION_FP_VALUES_RANGE in argv:
        # Argparse fix for hyphen (minus values) in argument values
        # convert "ARG VAL" into "ARG=VAL"
        # Example --fp-values-range -2.0,2.0 -> --fp-values-range=-2.0,2.0
        new_argv = []
        idx = 0
        while idx < len(argv):
            arg = argv[idx]
            if arg == OPTION_FP_VALUES_RANGE and idx + 1 < len(argv):
                val = argv[idx + 1]
                if val.startswith("-"):
                    arg = f"{arg}={val}"
                    idx += 1
            new_argv.append(arg)
            idx += 1
        argv = new_argv

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-o", dest="output_dir", type=str, default="vtest", help="Test output directory"
    )

    parser.add_argument(
        "--seed",
        dest="random_seed",
        default=42,
        type=int,
        help="Random seed for test generation",
    )

    parser.add_argument(
        "--filter",
        dest="filter",
        default="",
        type=str,
        help="Filter operator test names by this expression",
    )

    parser.add_argument(
        "-v", "--verbose", dest="verbose", action="count", help="Verbose operation"
    )

    # Constraints on tests
    parser.add_argument(
        "--tensor-dim-range",
        dest="tensor_shape_range",
        default="1,64",
        type=lambda x: str_to_list(x),
        help="Min,Max range of tensor shapes",
    )

    parser.add_argument(
        OPTION_FP_VALUES_RANGE,
        dest="tensor_fp_value_range",
        default="0.0,1.0",
        type=lambda x: str_to_list(x, is_float=True),
        help="Min,Max range of floating point tensor values",
    )

    parser.add_argument(
        "--max-batch-size",
        dest="max_batch_size",
        default=1,
        type=int,
        help="Maximum batch size for NHWC tests",
    )

    parser.add_argument(
        "--max-conv-padding",
        dest="max_conv_padding",
        default=1,
        type=int,
        help="Maximum padding for Conv tests",
    )

    parser.add_argument(
        "--max-conv-dilation",
        dest="max_conv_dilation",
        default=2,
        type=int,
        help="Maximum dilation for Conv tests",
    )

    parser.add_argument(
        "--max-conv-stride",
        dest="max_conv_stride",
        default=2,
        type=int,
        help="Maximum stride for Conv tests",
    )

    parser.add_argument(
        "--max-pooling-padding",
        dest="max_pooling_padding",
        default=1,
        type=int,
        help="Maximum padding for pooling tests",
    )

    parser.add_argument(
        "--max-pooling-stride",
        dest="max_pooling_stride",
        default=2,
        type=int,
        help="Maximum stride for pooling tests",
    )

    parser.add_argument(
        "--max-pooling-kernel",
        dest="max_pooling_kernel",
        default=3,
        type=int,
        help="Maximum kernel for pooling tests",
    )

    parser.add_argument(
        "--num-rand-permutations",
        dest="num_rand_permutations",
        default=6,
        type=int,
        help="Number of random permutations for a given shape/rank for randomly-sampled parameter spaces",
    )

    parser.add_argument(
        "--max-resize-output-dim",
        dest="max_resize_output_dim",
        default=1000,
        type=int,
        help="Upper limit on width and height output dimensions for `resize` op. Default: 1000",
    )

    # Targeting a specific shape/rank/dtype
    parser.add_argument(
        "--target-shape",
        dest="target_shapes",
        action="append",
        default=[],
        type=lambda x: str_to_list(x),
        help="Create tests with a particular input tensor shape, e.g., 1,4,4,8 (may be repeated for tests that require multiple input shapes)",
    )

    parser.add_argument(
        "--target-rank",
        dest="target_ranks",
        action="append",
        default=None,
        type=lambda x: auto_int(x),
        help="Create tests with a particular input tensor rank",
    )

    # Used for parsing a comma-separated list of integers in a string
    parser.add_argument(
        "--target-dtype",
        dest="target_dtypes",
        action="append",
        default=None,
        type=lambda x: dtype_str_to_val(x),
        help=f"Create test with a particular DType: [{', '.join([d.lower() for d in DTypeNames[1:]])}] (may be repeated)",
    )

    parser.add_argument(
        "--num-const-inputs-concat",
        dest="num_const_inputs_concat",
        default=0,
        choices=[0, 1, 2, 3],
        type=int,
        help="Allow constant input tensors for concat operator",
    )

    parser.add_argument(
        "--test-type",
        dest="test_type",
        choices=["positive", "negative", "both"],
        default="positive",
        type=str,
        help="type of tests produced, positive, negative, or both",
    )

    parser.add_argument(
        "--allow-pooling-and-conv-oversizes",
        dest="oversize",
        action="store_true",
        help="allow oversize padding, stride and kernel tests",
    )

    parser.add_argument(
        "--zero-point",
        dest="zeropoint",
        default=None,
        type=int,
        help="set a particular zero point for all valid positive tests",
    )

    parser.add_argument(
        "--dump-const-tensors",
        dest="dump_consts",
        action="store_true",
        help="output const tensors as numpy files for inspection",
    )

    args = parser.parse_args(argv)

    return args


def main(argv=None):

    args = parseArgs(argv)

    ttg = TosaTestGen(args)

    if args.test_type == "both":
        testType = ["positive", "negative"]
    else:
        testType = [args.test_type]
    results = []
    for test_type in testType:
        testList = []
        for op in ttg.TOSA_OP_LIST:
            if re.match(args.filter + ".*", op):
                testList.extend(
                    ttg.genOpTestList(
                        op,
                        shapeFilter=args.target_shapes,
                        rankFilter=args.target_ranks,
                        dtypeFilter=args.target_dtypes,
                        testType=test_type,
                    )
                )

        print("{} matching {} tests".format(len(testList), test_type))

        testStrings = []
        for opName, testStr, dtype, error, shapeList, testArgs in testList:
            # Check for and skip duplicate tests
            if testStr in testStrings:
                print(f"Skipping duplicate test: {testStr}")
                continue
            else:
                testStrings.append(testStr)

            results.append(
                ttg.serializeTest(opName, testStr, dtype, error, shapeList, testArgs)
            )

    print(f"Done creating {len(results)} tests")


if __name__ == "__main__":
    exit(main())
