
// Copyright (c) 2022-2023, ARM Limited.
//
//    Licensed under the Apache License, Version 2.0 (the "License");
//    you may not use this file except in compliance with the License.
//    You may obtain a copy of the License at
//
//         http://www.apache.org/licenses/LICENSE-2.0
//
//    Unless required by applicable law or agreed to in writing, software
//    distributed under the License is distributed on an "AS IS" BASIS,
//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//    See the License for the specific language governing permissions and
//    limitations under the License.

// THIS FILE IS GENERATED. DO NOT EDIT!
// See scripts/operator_api/generate_api.py

#include "operators.h"
#include "model_runner_impl.h"
#include "ops/op_factory.h"

#define TOSA_RETURN_ON_ERROR(status)                                                                                   \
    do                                                                                                                 \
    {                                                                                                                  \
        if (status != 0)                                                                                               \
        {                                                                                                              \
            return tosa_status_error;                                                                                  \
        }                                                                                                              \
    } while (false)

#define TOSA_RETURN_ON_GRAPH_STATUS_ERROR(status)                                                                      \
    do                                                                                                                 \
    {                                                                                                                  \
        if (status != GraphStatus::TOSA_VALID)                                                                         \
        {                                                                                                              \
            auto ustatus = static_cast<std::underlying_type_t<GraphStatus>>(status);                                   \
            return static_cast<tosa_status_t>(ustatus);                                                                \
        }                                                                                                              \
    } while (false)

namespace {

tosa::DType translate_client_datatype(tosa_datatype_t type)
{
    switch (type)
    {
        case tosa_datatype_bf16_t:
            return tosa::DType::DType_BF16;
        case tosa_datatype_bool_t:
            return tosa::DType::DType_BOOL;
        case tosa_datatype_fp16_t:
            return tosa::DType::DType_FP16;
        case tosa_datatype_fp32_t:
            return tosa::DType::DType_FP32;
        case tosa_datatype_int16_t:
            return tosa::DType::DType_INT16;
        case tosa_datatype_int32_t:
            return tosa::DType::DType_INT32;
        case tosa_datatype_int48_t:
            return tosa::DType::DType_INT48;
        case tosa_datatype_int4_t:
            return tosa::DType::DType_INT4;
        case tosa_datatype_int8_t:
            return tosa::DType::DType_INT8;
        case tosa_datatype_uint16_t:
            return tosa::DType::DType_UINT16;
        case tosa_datatype_uint8_t:
            return tosa::DType::DType_UINT8;
        case tosa_datatype_shape_t:
            return tosa::DType::DType_SHAPE;
        default:
            return tosa::DType::DType_UNKNOWN;
    }
};

tosa::TosaSerializationTensor* translate_client_tensor(tosa_tensor_t& tensor, const std::string& name)
{
    std::vector<int32_t> shape(tensor.shape, tensor.shape + tensor.num_dims);
    return new tosa::TosaSerializationTensor(name, shape, translate_client_datatype(tensor.data_type), {});
}

tosa::ResizeMode translate_client_tosa_mode(tosa_mode_t mode) {
    switch(mode) {
        case tosa_mode_nearest:
            return tosa::ResizeMode_NEAREST;
        case tosa_mode_max:
        case tosa_mode_bilinear:
            return tosa::ResizeMode_BILINEAR;
        default:
            return tosa::ResizeMode_UNKNOWN;
    }
}

tosa::DType translate_client_acc_size(tosa_acc_size_t acc_size) {
    switch(acc_size) {
        case tosa_acc_size_int32_t:
            return tosa::DType::DType_INT32;
        case tosa_acc_size_fp16_t:
            return tosa::DType::DType_FP16;
        case tosa_acc_size_fp32_t:
            return tosa::DType::DType_FP32;
        default:
            return tosa::DType::DType_UNKNOWN;
    }
}

}    // namespace

extern "C"
{
    {% for operator in operators: %}
    tosa_status_t tosa_run_{{ operator.name }} (
        {%- for arg in operator.arguments: -%}
            {% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}}
            {% if loop.index < operator.arguments|length %},{% endif %}
        {%- endfor -%},const func_ctx_t& func_ctx
    )
    {
        // Create operator attributes
        {% for att in operator.serialLibAtts: -%}
            {{att.init}}
        {%- endfor -%}

        Tosa{{operator.serializeAttType}}Attribute attr
        {%- if operator.serialLibAtts|length > 0 -%}
        (
            {%- for att in operator.serialLibAtts: -%}
                {%- if att.init == "" -%}
                    client_{{att.name}}
                {%- else -%}
                    {{att.name}}
                {%- endif -%}
                {% if loop.index < operator.serialLibAtts|length %}, {% endif %}
            {%- endfor -%}
        )
        {%- endif -%};

        // Create tensors
        {% for input in operator.inputs: -%}
            tosa::TosaSerializationTensor* {{input}}  = translate_client_tensor(client_{{input}}, "{{input}}");
        {%- endfor -%}
        {% for output in operator.outputs: %}
            tosa::TosaSerializationTensor* {{output}}  = translate_client_tensor(client_{{output}}, "{{output}}");
        {%- endfor %}

        // Create operator
        auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_{{operator.name|upper}},
                                                      {%- if operator.serializeAttType != "None" -%}
                                                        tosa::Attribute::Attribute_{{operator.serializeAttType}}Attribute
                                                      {%- else -%}
                                                        tosa::Attribute::Attribute_NONE
                                                      {%- endif -%},
                                                      &attr, {
                                                                {%- for input in operator.inputs: -%}
                                                                    {{input}}->GetName()
                                                                    {%- if loop.index < operator.inputs|length -%},{%- endif -%}
                                                                {%- endfor -%}
                                                             },
                                                             {
                                                                {%- for output in operator.outputs: -%}
                                                                    {{output}}->GetName()
                                                                    {%- if loop.index < operator.outputs|length -%},{%- endif -%}
                                                                {%- endfor -%}
                                                             });

        // Create a tosa single-op basic block
        tosa::TosaSerializationBasicBlock block("{{operator.name}}", "main", { op },
                                                 {
                                                    {%- for input in operator.inputs: -%}
                                                        {{input}},
                                                    {%- endfor -%}
                                                    {%- for output in operator.outputs: -%}
                                                        {{output}}
                                                        {%- if loop.index < operator.outputs|length -%},{%- endif -%}
                                                    {%- endfor -%}
                                                 },
                                                 {
                                                    {%- for input in operator.inputs: -%}
                                                        {{input}}->GetName()
                                                        {%- if loop.index < operator.inputs|length -%},{%- endif -%}
                                                    {%- endfor -%}
                                                 },
                                                 {
                                                    {%- for output in operator.outputs: -%}
                                                        {{output}}->GetName()
                                                        {%- if loop.index < operator.outputs|length -%},{%- endif -%}
                                                    {%- endfor -%}
                                                 });

        // Setup model
        TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
        TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
        {% for input in operator.inputs: -%}
            TOSA_RETURN_ON_ERROR(runner.setInput({{input}}->GetName(), client_{{input}}.data, client_{{input}}.size));
        {%- endfor %}

        // Execute
        TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run());

        // Extract outputs
        {% for output in operator.outputs: -%}
            TOSA_RETURN_ON_ERROR(runner.getOutput({{output}}->GetName(), client_{{output}}.data, client_{{output}}.size));
        {%- endfor %}

        return tosa_status_valid;
    }
    {% endfor %}

}    // extern "C"