blob: 1de103a4ad21699413c47aa9d4c836a66a552b2e [file] [log] [blame]
// Copyright (c) 2022-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// THIS FILE IS GENERATED. DO NOT EDIT!
// See scripts/operator_api/generate_api.py
#include "operators.h"
#include "model_runner_impl.h"
#include "ops/op_factory.h"
#define TOSA_RETURN_ON_ERROR(status) \
do \
{ \
if (status != 0) \
{ \
return tosa_status_error; \
} \
} while (false)
#define TOSA_RETURN_ON_GRAPH_STATUS_ERROR(status) \
do \
{ \
if (status != GraphStatus::TOSA_VALID) \
{ \
auto ustatus = static_cast<std::underlying_type_t<GraphStatus>>(status); \
return static_cast<tosa_status_t>(ustatus); \
} \
} while (false)
namespace {
tosa::DType translate_client_datatype(tosa_datatype_t type)
{
switch (type)
{
case tosa_datatype_bf16_t:
return tosa::DType::DType_BF16;
case tosa_datatype_bool_t:
return tosa::DType::DType_BOOL;
case tosa_datatype_fp16_t:
return tosa::DType::DType_FP16;
case tosa_datatype_fp32_t:
return tosa::DType::DType_FP32;
case tosa_datatype_int16_t:
return tosa::DType::DType_INT16;
case tosa_datatype_int32_t:
return tosa::DType::DType_INT32;
case tosa_datatype_int48_t:
return tosa::DType::DType_INT48;
case tosa_datatype_int4_t:
return tosa::DType::DType_INT4;
case tosa_datatype_int8_t:
return tosa::DType::DType_INT8;
case tosa_datatype_uint16_t:
return tosa::DType::DType_UINT16;
case tosa_datatype_uint8_t:
return tosa::DType::DType_UINT8;
case tosa_datatype_shape_t:
return tosa::DType::DType_SHAPE;
default:
return tosa::DType::DType_UNKNOWN;
}
};
tosa::TosaSerializationTensor* translate_client_tensor(tosa_tensor_t& tensor, const std::string& name)
{
std::vector<int32_t> shape(tensor.shape, tensor.shape + tensor.num_dims);
return new tosa::TosaSerializationTensor(name, shape, translate_client_datatype(tensor.data_type), {});
}
tosa::ResizeMode translate_client_tosa_mode(tosa_mode_t mode) {
switch(mode) {
case tosa_mode_nearest:
return tosa::ResizeMode_NEAREST;
case tosa_mode_max:
case tosa_mode_bilinear:
return tosa::ResizeMode_BILINEAR;
default:
return tosa::ResizeMode_UNKNOWN;
}
}
tosa::DType translate_client_acc_size(tosa_acc_size_t acc_size) {
switch(acc_size) {
case tosa_acc_size_int32_t:
return tosa::DType::DType_INT32;
case tosa_acc_size_fp16_t:
return tosa::DType::DType_FP16;
case tosa_acc_size_fp32_t:
return tosa::DType::DType_FP32;
default:
return tosa::DType::DType_UNKNOWN;
}
}
} // namespace
extern "C"
{
{% for operator in operators: %}
tosa_status_t tosa_run_{{ operator.name }} (
{%- for arg in operator.arguments: -%}
{% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}}
{% if loop.index < operator.arguments|length %},{% endif %}
{%- endfor -%},const func_ctx_t& func_ctx
)
{
// Create operator attributes
{% for att in operator.serialLibAtts: -%}
{{att.init}}
{%- endfor -%}
Tosa{{operator.serializeAttType}}Attribute attr
{%- if operator.serialLibAtts|length > 0 -%}
(
{%- for att in operator.serialLibAtts: -%}
{%- if att.init == "" -%}
client_{{att.name}}
{%- else -%}
{{att.name}}
{%- endif -%}
{% if loop.index < operator.serialLibAtts|length %}, {% endif %}
{%- endfor -%}
)
{%- endif -%};
// Create tensors
{% for input in operator.inputs: -%}
tosa::TosaSerializationTensor* {{input}} = translate_client_tensor(client_{{input}}, "{{input}}");
{%- endfor -%}
{% for output in operator.outputs: %}
tosa::TosaSerializationTensor* {{output}} = translate_client_tensor(client_{{output}}, "{{output}}");
{%- endfor %}
// Create operator
auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_{{operator.name|upper}},
{%- if operator.serializeAttType != "None" -%}
tosa::Attribute::Attribute_{{operator.serializeAttType}}Attribute
{%- else -%}
tosa::Attribute::Attribute_NONE
{%- endif -%},
&attr, {
{%- for input in operator.inputs: -%}
{{input}}->GetName()
{%- if loop.index < operator.inputs|length -%},{%- endif -%}
{%- endfor -%}
},
{
{%- for output in operator.outputs: -%}
{{output}}->GetName()
{%- if loop.index < operator.outputs|length -%},{%- endif -%}
{%- endfor -%}
});
// Create a tosa single-op basic block
tosa::TosaSerializationBasicBlock block("{{operator.name}}", "main", { op },
{
{%- for input in operator.inputs: -%}
{{input}},
{%- endfor -%}
{%- for output in operator.outputs: -%}
{{output}}
{%- if loop.index < operator.outputs|length -%},{%- endif -%}
{%- endfor -%}
},
{
{%- for input in operator.inputs: -%}
{{input}}->GetName()
{%- if loop.index < operator.inputs|length -%},{%- endif -%}
{%- endfor -%}
},
{
{%- for output in operator.outputs: -%}
{{output}}->GetName()
{%- if loop.index < operator.outputs|length -%},{%- endif -%}
{%- endfor -%}
});
// Setup model
TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
{% for input in operator.inputs: -%}
TOSA_RETURN_ON_ERROR(runner.setInput({{input}}->GetName(), client_{{input}}.data, client_{{input}}.size));
{%- endfor %}
// Execute
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run());
// Extract outputs
{% for output in operator.outputs: -%}
TOSA_RETURN_ON_ERROR(runner.getOutput({{output}}->GetName(), client_{{output}}.data, client_{{output}}.size));
{%- endfor %}
return tosa_status_valid;
}
{% endfor %}
} // extern "C"