blob: 37a0af629871126e562812c4ff7ef9dec4c7b2ed [file] [log] [blame]
Grant Watson64285a12022-11-16 15:32:39 +00001
Jerry Ge9e94af82022-10-27 09:57:00 -07002// Copyright (c) 2022-2023, ARM Limited.
Grant Watson64285a12022-11-16 15:32:39 +00003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16// THIS FILE IS GENERATED. DO NOT EDIT!
17// See scripts/operator_api/generate_api.py
18
19#include "operators.h"
20#include "model_runner_impl.h"
21#include "ops/op_factory.h"
22
23#define TOSA_RETURN_ON_ERROR(status) \
24 do \
25 { \
26 if (status != 0) \
27 { \
28 return tosa_status_error; \
29 } \
30 } while (false)
31
32#define TOSA_RETURN_ON_GRAPH_STATUS_ERROR(status) \
33 do \
34 { \
35 if (status != GraphStatus::TOSA_VALID) \
36 { \
37 auto ustatus = static_cast<std::underlying_type_t<GraphStatus>>(status); \
38 return static_cast<tosa_status_t>(ustatus); \
39 } \
40 } while (false)
41
42namespace {
43
44tosa::DType translate_client_datatype(tosa_datatype_t type)
45{
46 switch (type)
47 {
48 case tosa_datatype_fp16_t:
49 return tosa::DType::DType_FP16;
50 case tosa_datatype_fp32_t:
51 return tosa::DType::DType_FP32;
Grant Watsoneb741062023-06-23 16:52:12 +010052 case tosa_datatype_bool_t:
53 return tosa::DType::DType_BOOL;
Grant Watson64285a12022-11-16 15:32:39 +000054 default:
55 return tosa::DType::DType_UNKNOWN;
56 }
57};
58
59tosa::TosaSerializationTensor* translate_client_tensor(tosa_tensor_t& tensor, const std::string& name)
60{
61 std::vector<int32_t> shape(tensor.shape, tensor.shape + tensor.num_dims);
62 return new tosa::TosaSerializationTensor(name, shape, translate_client_datatype(tensor.data_type), {});
63}
64
65tosa::ResizeMode translate_client_tosa_mode(tosa_mode_t mode) {
66 switch(mode) {
67 case tosa_mode_nearest:
68 return tosa::ResizeMode_NEAREST;
Jerry Ge9e94af82022-10-27 09:57:00 -070069 case tosa_mode_max:
Grant Watson64285a12022-11-16 15:32:39 +000070 case tosa_mode_bilinear:
71 return tosa::ResizeMode_BILINEAR;
72 default:
Jerry Ge9e94af82022-10-27 09:57:00 -070073 return tosa::ResizeMode_UNKNOWN;
Grant Watson64285a12022-11-16 15:32:39 +000074 }
75}
76
77} // namespace
78
79extern "C"
80{
81 {% for operator in operators: %}
82 tosa_status_t tosa_run_{{ operator.name }} (
83 {%- for arg in operator.arguments: -%}
84 {% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}}
85 {% if loop.index < operator.arguments|length %},{% endif %}
86 {%- endfor -%}
87 )
88 {
89 // Create operator attributes
90 {% for arg in operator.serializeArgs: %}
91 {%- if arg.SV == "V": -%}
92 const std::vector<{{arg.dType}}> {{arg.name}}{{arg.init}};
93 {%- else: -%}
94 const {{arg.dType}} {{arg.name}}{{arg.init}};
95 {%- endif -%}
96 {%- endfor -%}
97
98 Tosa{{operator.serializeAttType}}Attribute attr
99 {%- if operator.serializeArgs|length > 0 -%}
100 (
101 {%- for arg in operator.serializeArgs: -%}
102 {{arg.name}}{% if loop.index < operator.serializeArgs|length %}, {% endif %}
103 {%- endfor -%}
104 )
105 {%- endif -%};
106
107 // Create tensors
108 {% for input in operator.inputs: -%}
109 tosa::TosaSerializationTensor* {{input}} = translate_client_tensor(client_{{input}}, "{{input}}");
110 {%- endfor -%}
111 {% for output in operator.outputs: %}
112 tosa::TosaSerializationTensor* {{output}} = translate_client_tensor(client_{{output}}, "{{output}}");
113 {%- endfor %}
114
115 // Create operator
116 auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_{{operator.name|upper}},
117 {%- if operator.serializeAttType != "None" -%}
118 tosa::Attribute::Attribute_{{operator.serializeAttType}}Attribute
119 {%- else -%}
120 tosa::Attribute::Attribute_NONE
121 {%- endif -%},
122 &attr, {
123 {%- for input in operator.inputs: -%}
124 {{input}}->GetName()
125 {%- if loop.index < operator.inputs|length -%},{%- endif -%}
126 {%- endfor -%}
127 },
128 {
129 {%- for output in operator.outputs: -%}
130 {{output}}->GetName()
131 {%- if loop.index < operator.outputs|length -%},{%- endif -%}
132 {%- endfor -%}
133 });
134
135 // Create a tosa single-op basic block
Jerry Ge9e94af82022-10-27 09:57:00 -0700136 tosa::TosaSerializationBasicBlock block("{{operator.name}}", "main", { op },
Grant Watson64285a12022-11-16 15:32:39 +0000137 {
138 {%- for input in operator.inputs: -%}
139 {{input}},
140 {%- endfor -%}
141 {%- for output in operator.outputs: -%}
142 {{output}}
143 {%- if loop.index < operator.outputs|length -%},{%- endif -%}
144 {%- endfor -%}
145 },
146 {
147 {%- for input in operator.inputs: -%}
148 {{input}}->GetName()
149 {%- if loop.index < operator.inputs|length -%},{%- endif -%}
150 {%- endfor -%}
151 },
152 {
153 {%- for output in operator.outputs: -%}
154 {{output}}->GetName()
155 {%- if loop.index < operator.outputs|length -%},{%- endif -%}
156 {%- endfor -%}
157 });
158
159 // Setup model
160 TosaReference::ModelRunnerImpl runner;
161 TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
162 {% for input in operator.inputs: -%}
163 TOSA_RETURN_ON_ERROR(runner.setInput({{input}}->GetName(), client_{{input}}.data, client_{{input}}.size));
164 {%- endfor %}
165
166 // Execute
167 TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run());
168
169 // Extract outputs
170 {% for output in operator.outputs: -%}
171 TOSA_RETURN_ON_ERROR(runner.getOutput({{output}}->GetName(), client_{{output}}.data, client_{{output}}.size));
172 {%- endfor %}
173
174 return tosa_status_valid;
175 }
176 {% endfor %}
177
178} // extern "C"