# Copyright (c) 2024, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import copy
import logging

logging.basicConfig()
logger = logging.getLogger("tosa_verif_build_tests")


class Test:
    """Test container to allow group and permute selection."""

    def __init__(
        self, opName, testStr, dtype, error, shapeList, argsDict, testOpName=None
    ):
        self.opName = opName
        self.testStr = testStr
        self.dtype = dtype
        self.error = error
        self.shapeList = shapeList
        self.argsDict = argsDict
        # Given test op name used for look up in TOSA_OP_LIST for "conv2d_1x1" for example
        self.testOpName = testOpName if testOpName is not None else opName

        self.key = None
        self.groupKey = None
        self.mark = False

    def __str__(self):
        return self.testStr

    def __lt__(self, other):
        return self.testStr < str(other)

    def getArg(self, param):
        # Get parameter values (arguments) for this test
        if param == "rank":
            return len(self.shapeList[0])
        elif param == "dtype":
            if isinstance(self.dtype, list):
                return tuple(self.dtype)
            return self.dtype
        elif param == "shape" and "shape" not in self.argsDict:
            return str(self.shapeList[0])

        if param in self.argsDict:
            # Turn other args into hashable string without newlines
            val = str(self.argsDict[param])
            return ",".join(str(val).splitlines())
        else:
            return None

    def setKey(self, keyParams):
        if self.error is None:
            # Create the main key based on primary parameters
            key = [self.getArg(param) for param in keyParams]
            self.key = tuple(key)
        else:
            # Use the error as the key
            self.key = self.error
        return self.key

    def getKey(self):
        return self.key

    def setGroupKey(self, groupParams):
        # Create the group key based on arguments that do not define the group
        # Therefore this test will match other tests that have the same arguments
        # that are NOT the group arguments (group arguments like test set number)
        paramsList = sorted(["shape", "dtype"] + list(self.argsDict.keys()))
        key = []
        for param in paramsList:
            if param in groupParams:
                continue
            key.append(self.getArg(param))
        self.groupKey = tuple(key)
        return self.groupKey

    def getGroupKey(self):
        return self.groupKey

    def inGroup(self, groupKey):
        return self.groupKey == groupKey

    def setMark(self):
        # Marks the test as important
        self.mark = True

    def getMark(self):
        return self.mark

    def isError(self):
        return self.error is not None


def _get_selection_info_from_op(op, selectionCriteria, item, default):
    # Get selection info from the op
    if (
        "selection" in op
        and selectionCriteria in op["selection"]
        and item in op["selection"][selectionCriteria]
    ):
        return op["selection"][selectionCriteria][item]
    else:
        return default


def _get_tests_by_group(tests):
    # Create simple structures to record the tests in groups
    groups = []
    group_tests = {}

    for test in tests:
        key = test.getGroupKey()
        if key in group_tests:
            group_tests[key].append(test)
        else:
            group_tests[key] = [test]
            groups.append(key)

    # Return list of test groups (group keys) and a dictionary with a list of tests
    # associated with each group key
    return groups, group_tests


def _get_specific_op_info(opName, opSelectionInfo, testOpName):
    # Get the op specific section from the selection config
    name = opName if opName in opSelectionInfo else testOpName
    if name not in opSelectionInfo:
        logger.info(f"No op entry found for {opName} in test selection config")
        return {}
    return opSelectionInfo[name]


class TestOpList:
    """All the tests for one op grouped by permutations."""

    def __init__(self, opName, opSelectionInfo, selectionCriteria, testOpName):
        self.opName = opName
        self.testOpName = testOpName
        op = _get_specific_op_info(opName, opSelectionInfo, testOpName)

        # See verif/conformance/README.md for more information on
        # these selection arguments
        self.permuteArgs = _get_selection_info_from_op(
            op, selectionCriteria, "permutes", ["rank", "dtype"]
        )
        self.paramArgs = _get_selection_info_from_op(
            op, selectionCriteria, "full_params", []
        )
        self.specificArgs = _get_selection_info_from_op(
            op, selectionCriteria, "specifics", {}
        )
        self.groupArgs = _get_selection_info_from_op(
            op, selectionCriteria, "groups", ["s"]
        )
        self.maximumPerPermute = _get_selection_info_from_op(
            op, selectionCriteria, "maximum", None
        )
        self.numErrorIfs = _get_selection_info_from_op(
            op, selectionCriteria, "num_errorifs", 1
        )
        self.selectAll = _get_selection_info_from_op(
            op, selectionCriteria, "all", False
        )

        if self.paramArgs and self.maximumPerPermute > 1:
            logger.warning(f"Unsupported - selection params AND maximum for {opName}")

        self.tests = []
        self.testStrings = set()
        self.shapes = set()

        self.permutes = set()
        self.testsPerPermute = {}
        self.paramsPerPermute = {}
        self.specificsPerPermute = {}

        self.selectionDone = False

    def __len__(self):
        return len(self.tests)

    def add(self, test):
        # Add a test to this op group and set up the permutations/group for it
        assert test.opName.startswith(self.opName)
        if str(test) in self.testStrings:
            logger.debug(f"Skipping duplicate test: {str(test)}")
            return

        self.tests.append(test)
        self.testStrings.add(str(test))

        self.shapes.add(test.getArg("shape"))

        # Work out the permutation key for this test
        permute = test.setKey(self.permuteArgs)
        # Set up the group key for the test (for pulling out groups during selection)
        test.setGroupKey(self.groupArgs)

        if permute not in self.permutes:
            # New permutation
            self.permutes.add(permute)
            # Set up area to record the selected tests
            self.testsPerPermute[permute] = []
            if self.paramArgs:
                # Set up area to record the unique test params found
                self.paramsPerPermute[permute] = {}
                for param in self.paramArgs:
                    self.paramsPerPermute[permute][param] = set()
            # Set up copy of the specific test args for selecting these
            self.specificsPerPermute[permute] = copy.deepcopy(self.specificArgs)

    def _init_select(self):
        # Can only perform the selection process once as it alters the permute
        # information set at init
        assert not self.selectionDone

        # Count of non-specific tests added to each permute (not error)
        if not self.selectAll:
            countPerPermute = {permute: 0 for permute in self.permutes}

        # Go through each test looking for permutes, unique params & specifics
        for test in self.tests:
            permute = test.getKey()
            append = False
            possible_append = False

            if test.isError():
                # Error test, choose up to number of tests
                if len(self.testsPerPermute[permute]) < self.numErrorIfs:
                    append = True
            else:
                if self.selectAll:
                    append = True
                else:
                    # See if this is a specific test to add
                    for param, values in self.specificsPerPermute[permute].items():
                        arg = test.getArg(param)
                        # Iterate over a copy of the values, so we can remove them from the original
                        if arg in values.copy():
                            # Found a match, remove it, so we don't look for it later
                            values.remove(arg)
                            # Mark the test as special (and so shouldn't be removed)
                            test.setMark()
                            append = True

                    if self.paramArgs:
                        # See if this test contains any new params we should keep
                        # Perform this check even if we have already selected the test
                        # so we can record the params found
                        for param in self.paramArgs:
                            arg = test.getArg(param)
                            if arg not in self.paramsPerPermute[permute][param]:
                                # We have found a new value for this arg, record it
                                self.paramsPerPermute[permute][param].add(arg)
                                possible_append = True
                    else:
                        # No params set, so possible test to add up to maximum
                        possible_append = True

                    if (not append and possible_append) and (
                        self.maximumPerPermute is None
                        or countPerPermute[permute] < self.maximumPerPermute
                    ):
                        # Not selected but could be added and we have space left if
                        # a maximum is set.
                        append = True
                        countPerPermute[permute] += 1

            # Check for grouping with chosen tests
            if not append:
                # We will keep any tests together than form a group
                key = test.getGroupKey()
                for t in self.testsPerPermute[permute]:
                    if t.getGroupKey() == key:
                        if t.getMark():
                            test.setMark()
                        append = True

            if append:
                self.testsPerPermute[permute].append(test)

        self.selectionDone = True

    def select(self, rng=None):
        # Create selection of tests with optional shuffle
        if not self.selectionDone:
            if rng:
                rng.shuffle(self.tests)

            self._init_select()

        # Now create the full list of selected tests per permute
        selection = []

        for permute, tests in self.testsPerPermute.items():
            selection.extend(tests)

        return selection

    def all(self):
        # Un-selected list of tests - i.e. all of them
        return self.tests


class TestList:
    """List of all tests grouped by operator."""

    def __init__(self, opSelectionInfo, selectionCriteria="default"):
        self.opLists = {}
        self.opSelectionInfo = opSelectionInfo
        self.selectionCriteria = selectionCriteria

    def __len__(self):
        length = 0
        for opName in self.opLists.keys():
            length += len(self.opLists[opName])
        return length

    def add(self, test):
        if test.opName not in self.opLists:
            self.opLists[test.opName] = TestOpList(
                test.opName,
                self.opSelectionInfo,
                self.selectionCriteria,
                test.testOpName,
            )
        self.opLists[test.opName].add(test)

    def _get_tests(self, selectMode, rng):
        selection = []

        for opList in self.opLists.values():
            if selectMode:
                tests = opList.select(rng=rng)
            else:
                tests = opList.all()
            selection.extend(tests)

        if selectMode:
            selection = sorted(selection)
        return selection

    def select(self, rng=None):
        return self._get_tests(True, rng)

    def all(self):
        return self._get_tests(False, None)
