/*
 * Copyright (c) 2017 ARM Limited.
 *
 * SPDX-License-Identifier: MIT
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to
 * deal in the Software without restriction, including without limitation the
 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
 * sell copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in all
 * copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */
#include "Framework.h"

#include "support/ToolchainSupport.h"

#ifdef ARM_COMPUTE_CL
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/runtime/CL/CLScheduler.h"
#endif /* ARM_COMPUTE_CL */

#include <chrono>
#include <iostream>
#include <sstream>
#include <type_traits>

namespace arm_compute
{
namespace test
{
namespace framework
{
Framework::Framework()
{
    _available_instruments.emplace(InstrumentType::WALL_CLOCK_TIMER, Instrument::make_instrument<WallClockTimer>);
#ifdef PMU_ENABLED
    _available_instruments.emplace(InstrumentType::PMU_CYCLE_COUNTER, Instrument::make_instrument<CycleCounter>);
    _available_instruments.emplace(InstrumentType::PMU_INSTRUCTION_COUNTER, Instrument::make_instrument<InstructionCounter>);
#endif /* PMU_ENABLED */
}

std::set<InstrumentType> Framework::available_instruments() const
{
    std::set<InstrumentType> types;

    for(const auto &instrument : _available_instruments)
    {
        types.emplace(instrument.first);
    }

    return types;
}

std::map<TestResult::Status, int> Framework::count_test_results() const
{
    std::map<TestResult::Status, int> counts;

    for(const auto &test : _test_results)
    {
        ++counts[test.second.status];
    }

    return counts;
}

Framework &Framework::get()
{
    static Framework instance;
    return instance;
}

void Framework::init(const std::vector<InstrumentType> &instruments, int num_iterations, DatasetMode mode, const std::string &name_filter, int64_t id_filter, LogLevel log_level)
{
    _test_name_filter = std::regex{ name_filter };
    _test_id_filter   = id_filter;
    _num_iterations   = num_iterations;
    _dataset_mode     = mode;
    _log_level        = log_level;

    _instruments = InstrumentType::NONE;

    for(const auto &instrument : instruments)
    {
        _instruments |= instrument;
    }
}

std::string Framework::current_suite_name() const
{
    return join(_test_suite_name.cbegin(), _test_suite_name.cend(), "/");
}

void Framework::push_suite(std::string name)
{
    _test_suite_name.emplace_back(std::move(name));
}

void Framework::pop_suite()
{
    _test_suite_name.pop_back();
}

void Framework::add_test_info(std::string info)
{
    _test_info.emplace_back(std::move(info));
}

void Framework::clear_test_info()
{
    _test_info.clear();
}

bool Framework::has_test_info() const
{
    return !_test_info.empty();
}

void Framework::print_test_info(std::ostream &os) const
{
    os << "CONTEXT:\n";

    for(const auto &str : _test_info)
    {
        os << "    " << str << "\n";
    }
}

void Framework::log_test_start(const std::string &test_name)
{
    if(_printer != nullptr && _log_level >= LogLevel::TESTS)
    {
        _printer->print_test_header(test_name);
    }
}

void Framework::log_test_skipped(const std::string &test_name)
{
    static_cast<void>(test_name);
}

void Framework::log_test_end(const std::string &test_name)
{
    if(_printer != nullptr)
    {
        if(_log_level >= LogLevel::MEASUREMENTS)
        {
            _printer->print_measurements(_test_results.at(test_name).measurements);
        }

        if(_log_level >= LogLevel::TESTS)
        {
            _printer->print_test_footer();
        }
    }
}

void Framework::log_failed_expectation(const std::string &msg, LogLevel level)
{
    if(_log_level >= level)
    {
        std::cerr << "ERROR: " << msg << "\n";
    }

    if(_current_test_result != nullptr)
    {
        _current_test_result->status = TestResult::Status::FAILED;
    }
}

int Framework::num_iterations() const
{
    return _num_iterations;
}

void Framework::set_num_iterations(int num_iterations)
{
    _num_iterations = num_iterations;
}

void Framework::set_throw_errors(bool throw_errors)
{
    _throw_errors = throw_errors;
}

bool Framework::throw_errors() const
{
    return _throw_errors;
}

bool Framework::is_selected(const TestInfo &info) const
{
    if((info.mode & _dataset_mode) == DatasetMode::DISABLED)
    {
        return false;
    }

    if(_test_id_filter > -1 && _test_id_filter != info.id)
    {
        return false;
    }

    if(!std::regex_search(info.name, _test_name_filter))
    {
        return false;
    }

    return true;
}

void Framework::run_test(TestCaseFactory &test_factory)
{
    const std::string test_case_name = test_factory.name();

    if(test_factory.status() == TestCaseFactory::Status::DISABLED)
    {
        log_test_skipped(test_case_name);
        set_test_result(test_case_name, TestResult(TestResult::Status::DISABLED));
        return;
    }

    log_test_start(test_case_name);

    Profiler   profiler = get_profiler();
    TestResult result(TestResult::Status::SUCCESS);

    _current_test_result = &result;

    try
    {
        std::unique_ptr<TestCase> test_case = test_factory.make();

        try
        {
            test_case->do_setup();

            for(int i = 0; i < _num_iterations; ++i)
            {
                profiler.start();
                test_case->do_run();
#ifdef ARM_COMPUTE_CL
                if(opencl_is_available())
                {
                    CLScheduler::get().sync();
                }
#endif /* ARM_COMPUTE_CL */
                profiler.stop();
            }

            test_case->do_teardown();
        }
        catch(const TestError &error)
        {
            if(_log_level >= error.level())
            {
                std::cerr << "FATAL ERROR: " << error.what() << "\n";
            }

            result.status = TestResult::Status::FAILED;

            if(_throw_errors)
            {
                throw;
            }
        }
#ifdef ARM_COMPUTE_CL
        catch(const ::cl::Error &error)
        {
            if(_log_level >= LogLevel::ERRORS)
            {
                std::cerr << "FATAL CL ERROR: " << error.what() << " with code " << error.err() << "\n";
            }

            result.status = TestResult::Status::FAILED;

            if(_throw_errors)
            {
                throw;
            }
        }
#endif /* ARM_COMPUTE_CL */
        catch(const std::exception &error)
        {
            if(_log_level >= LogLevel::ERRORS)
            {
                std::cerr << "FATAL ERROR: Received unhandled error: '" << error.what() << "'\n";
            }

            result.status = TestResult::Status::CRASHED;

            if(_throw_errors)
            {
                throw;
            }
        }
        catch(...)
        {
            if(_log_level >= LogLevel::ERRORS)
            {
                std::cerr << "FATAL ERROR: Received unhandled exception\n";
            }

            result.status = TestResult::Status::CRASHED;

            if(_throw_errors)
            {
                throw;
            }
        }
    }
    catch(const std::exception &error)
    {
        if(_log_level >= LogLevel::ERRORS)
        {
            std::cerr << "FATAL ERROR: Received unhandled error during fixture creation: '" << error.what() << "'\n";
        }

        if(_throw_errors)
        {
            throw;
        }
    }
    catch(...)
    {
        if(_log_level >= LogLevel::ERRORS)
        {
            std::cerr << "FATAL ERROR: Received unhandled exception during fixture creation\n";
        }

        result.status = TestResult::Status::CRASHED;

        if(_throw_errors)
        {
            throw;
        }
    }

    _current_test_result = nullptr;

    if(test_factory.status() == TestCaseFactory::Status::EXPECTED_FAILURE && result.status == TestResult::Status::FAILED)
    {
        result.status = TestResult::Status::EXPECTED_FAILURE;
    }

    result.measurements = profiler.measurements();

    set_test_result(test_case_name, result);
    log_test_end(test_case_name);
}

bool Framework::run()
{
    // Clear old test results
    _test_results.clear();
    _runtime = std::chrono::seconds{ 0 };

    if(_printer != nullptr && _log_level >= LogLevel::TESTS)
    {
        _printer->print_run_header();
    }

    const auto start = std::chrono::high_resolution_clock::now();

    int id = 0;

    for(auto &test_factory : _test_factories)
    {
        const std::string test_case_name = test_factory->name();
        const TestInfo    test_info{ id, test_case_name, test_factory->mode(), test_factory->status() };

        if(is_selected(test_info))
        {
            run_test(*test_factory);
        }

        ++id;
    }

    const auto end = std::chrono::high_resolution_clock::now();

    if(_printer != nullptr && _log_level >= LogLevel::TESTS)
    {
        _printer->print_run_footer();
    }

    _runtime = std::chrono::duration_cast<std::chrono::seconds>(end - start);

    auto test_results = count_test_results();

    if(_log_level > LogLevel::NONE)
    {
        std::cout << "Executed " << _test_results.size() << " test(s) ("
                  << test_results[TestResult::Status::SUCCESS] << " passed, "
                  << test_results[TestResult::Status::EXPECTED_FAILURE] << " expected failures, "
                  << test_results[TestResult::Status::FAILED] << " failed, "
                  << test_results[TestResult::Status::CRASHED] << " crashed, "
                  << test_results[TestResult::Status::DISABLED] << " disabled) in " << _runtime.count() << " second(s)\n";
    }

    int num_successful_tests = test_results[TestResult::Status::SUCCESS] + test_results[TestResult::Status::EXPECTED_FAILURE];

    return (static_cast<unsigned int>(num_successful_tests) == _test_results.size());
}

void Framework::set_test_result(std::string test_case_name, TestResult result)
{
    _test_results.emplace(std::move(test_case_name), std::move(result));
}

void Framework::print_test_results(Printer &printer) const
{
    printer.print_run_header();

    for(const auto &test : _test_results)
    {
        printer.print_test_header(test.first);
        printer.print_measurements(test.second.measurements);
        printer.print_test_footer();
    }

    printer.print_run_footer();
}

Profiler Framework::get_profiler() const
{
    Profiler profiler;

    for(const auto &instrument : _available_instruments)
    {
        if((instrument.first & _instruments) != InstrumentType::NONE)
        {
            profiler.add(instrument.second());
        }
    }

    return profiler;
}

void Framework::set_printer(Printer *printer)
{
    _printer = printer;
}

std::vector<Framework::TestInfo> Framework::test_infos() const
{
    std::vector<TestInfo> ids;

    int id = 0;

    for(const auto &factory : _test_factories)
    {
        TestInfo test_info{ id, factory->name(), factory->mode(), factory->status() };

        if(is_selected(test_info))
        {
            ids.emplace_back(std::move(test_info));
        }

        ++id;
    }

    return ids;
}
} // namespace framework
} // namespace test
} // namespace arm_compute
