| |
| // Copyright (c) 2022-2023, ARM Limited. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| // THIS FILE IS GENERATED. DO NOT EDIT! |
| // See scripts/operator_api/generate_api.py |
| |
| #include "operators.h" |
| #include "model_runner_impl.h" |
| #include "ops/op_factory.h" |
| |
| #define TOSA_RETURN_ON_ERROR(status) \ |
| do \ |
| { \ |
| if (status != 0) \ |
| { \ |
| return tosa_status_error; \ |
| } \ |
| } while (false) |
| |
| #define TOSA_RETURN_ON_GRAPH_STATUS_ERROR(status) \ |
| do \ |
| { \ |
| if (status != GraphStatus::TOSA_VALID) \ |
| { \ |
| auto ustatus = static_cast<std::underlying_type_t<GraphStatus>>(status); \ |
| return static_cast<tosa_status_t>(ustatus); \ |
| } \ |
| } while (false) |
| |
| namespace { |
| |
| tosa::DType translate_client_datatype(tosa_datatype_t type) |
| { |
| switch (type) |
| { |
| case tosa_datatype_bf16_t: |
| return tosa::DType::DType_BF16; |
| case tosa_datatype_bool_t: |
| return tosa::DType::DType_BOOL; |
| case tosa_datatype_fp16_t: |
| return tosa::DType::DType_FP16; |
| case tosa_datatype_fp32_t: |
| return tosa::DType::DType_FP32; |
| case tosa_datatype_int16_t: |
| return tosa::DType::DType_INT16; |
| case tosa_datatype_int32_t: |
| return tosa::DType::DType_INT32; |
| case tosa_datatype_int48_t: |
| return tosa::DType::DType_INT48; |
| case tosa_datatype_int4_t: |
| return tosa::DType::DType_INT4; |
| case tosa_datatype_int8_t: |
| return tosa::DType::DType_INT8; |
| case tosa_datatype_uint16_t: |
| return tosa::DType::DType_UINT16; |
| case tosa_datatype_uint8_t: |
| return tosa::DType::DType_UINT8; |
| case tosa_datatype_shape_t: |
| return tosa::DType::DType_SHAPE; |
| default: |
| return tosa::DType::DType_UNKNOWN; |
| } |
| }; |
| |
| tosa::TosaSerializationTensor* translate_client_tensor(tosa_tensor_t& tensor, const std::string& name) |
| { |
| std::vector<int32_t> shape(tensor.shape, tensor.shape + tensor.num_dims); |
| return new tosa::TosaSerializationTensor(name, shape, translate_client_datatype(tensor.data_type), {}); |
| } |
| |
| tosa::ResizeMode translate_client_tosa_mode(tosa_mode_t mode) { |
| switch(mode) { |
| case tosa_mode_nearest: |
| return tosa::ResizeMode_NEAREST; |
| case tosa_mode_max: |
| case tosa_mode_bilinear: |
| return tosa::ResizeMode_BILINEAR; |
| default: |
| return tosa::ResizeMode_UNKNOWN; |
| } |
| } |
| |
| tosa::DType translate_client_acc_size(tosa_acc_size_t acc_size) { |
| switch(acc_size) { |
| case tosa_acc_size_int32_t: |
| return tosa::DType::DType_INT32; |
| case tosa_acc_size_fp16_t: |
| return tosa::DType::DType_FP16; |
| case tosa_acc_size_fp32_t: |
| return tosa::DType::DType_FP32; |
| default: |
| return tosa::DType::DType_UNKNOWN; |
| } |
| } |
| |
| } // namespace |
| |
| extern "C" |
| { |
| {% for operator in operators: %} |
| tosa_status_t tosa_run_{{ operator.name }} ( |
| {%- for arg in operator.arguments: -%} |
| {% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}} |
| {% if loop.index < operator.arguments|length %},{% endif %} |
| {%- endfor -%},const func_ctx_t& func_ctx |
| ) |
| { |
| // Create operator attributes |
| {% for att in operator.serialLibAtts: -%} |
| {{att.init}} |
| {%- endfor -%} |
| |
| Tosa{{operator.serializeAttType}}Attribute attr |
| {%- if operator.serialLibAtts|length > 0 -%} |
| ( |
| {%- for att in operator.serialLibAtts: -%} |
| {%- if att.init == "" -%} |
| client_{{att.name}} |
| {%- else -%} |
| {{att.name}} |
| {%- endif -%} |
| {% if loop.index < operator.serialLibAtts|length %}, {% endif %} |
| {%- endfor -%} |
| ) |
| {%- endif -%}; |
| |
| // Create tensors |
| {% for input in operator.inputs: -%} |
| tosa::TosaSerializationTensor* {{input}} = translate_client_tensor(client_{{input}}, "{{input}}"); |
| {%- endfor -%} |
| {% for output in operator.outputs: %} |
| tosa::TosaSerializationTensor* {{output}} = translate_client_tensor(client_{{output}}, "{{output}}"); |
| {%- endfor %} |
| |
| // Create operator |
| auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_{{operator.name|upper}}, |
| {%- if operator.serializeAttType != "None" -%} |
| tosa::Attribute::Attribute_{{operator.serializeAttType}}Attribute |
| {%- else -%} |
| tosa::Attribute::Attribute_NONE |
| {%- endif -%}, |
| &attr, { |
| {%- for input in operator.inputs: -%} |
| {{input}}->GetName() |
| {%- if loop.index < operator.inputs|length -%},{%- endif -%} |
| {%- endfor -%} |
| }, |
| { |
| {%- for output in operator.outputs: -%} |
| {{output}}->GetName() |
| {%- if loop.index < operator.outputs|length -%},{%- endif -%} |
| {%- endfor -%} |
| }); |
| |
| // Create a tosa single-op basic block |
| tosa::TosaSerializationBasicBlock block("{{operator.name}}", "main", { op }, |
| { |
| {%- for input in operator.inputs: -%} |
| {{input}}, |
| {%- endfor -%} |
| {%- for output in operator.outputs: -%} |
| {{output}} |
| {%- if loop.index < operator.outputs|length -%},{%- endif -%} |
| {%- endfor -%} |
| }, |
| { |
| {%- for input in operator.inputs: -%} |
| {{input}}->GetName() |
| {%- if loop.index < operator.inputs|length -%},{%- endif -%} |
| {%- endfor -%} |
| }, |
| { |
| {%- for output in operator.outputs: -%} |
| {{output}}->GetName() |
| {%- if loop.index < operator.outputs|length -%},{%- endif -%} |
| {%- endfor -%} |
| }); |
| |
| // Setup model |
| TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug); |
| TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block)); |
| {% for input in operator.inputs: -%} |
| TOSA_RETURN_ON_ERROR(runner.setInput({{input}}->GetName(), client_{{input}}.data, client_{{input}}.size)); |
| {%- endfor %} |
| |
| // Execute |
| TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run()); |
| |
| // Extract outputs |
| {% for output in operator.outputs: -%} |
| TOSA_RETURN_ON_ERROR(runner.getOutput({{output}}->GetName(), client_{{output}}.data, client_{{output}}.size)); |
| {%- endfor %} |
| |
| return tosa_status_valid; |
| } |
| {% endfor %} |
| |
| } // extern "C" |