blob: 1de103a4ad21699413c47aa9d4c836a66a552b2e [file] [log] [blame]
Grant Watson64285a12022-11-16 15:32:39 +00001
Jerry Ge9e94af82022-10-27 09:57:00 -07002// Copyright (c) 2022-2023, ARM Limited.
Grant Watson64285a12022-11-16 15:32:39 +00003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16// THIS FILE IS GENERATED. DO NOT EDIT!
17// See scripts/operator_api/generate_api.py
18
19#include "operators.h"
20#include "model_runner_impl.h"
21#include "ops/op_factory.h"
22
23#define TOSA_RETURN_ON_ERROR(status) \
24 do \
25 { \
26 if (status != 0) \
27 { \
28 return tosa_status_error; \
29 } \
30 } while (false)
31
32#define TOSA_RETURN_ON_GRAPH_STATUS_ERROR(status) \
33 do \
34 { \
35 if (status != GraphStatus::TOSA_VALID) \
36 { \
37 auto ustatus = static_cast<std::underlying_type_t<GraphStatus>>(status); \
38 return static_cast<tosa_status_t>(ustatus); \
39 } \
40 } while (false)
41
42namespace {
43
44tosa::DType translate_client_datatype(tosa_datatype_t type)
45{
46 switch (type)
47 {
Grant Watsone70d9312023-08-28 16:34:28 +010048 case tosa_datatype_bf16_t:
49 return tosa::DType::DType_BF16;
50 case tosa_datatype_bool_t:
51 return tosa::DType::DType_BOOL;
Grant Watson64285a12022-11-16 15:32:39 +000052 case tosa_datatype_fp16_t:
53 return tosa::DType::DType_FP16;
54 case tosa_datatype_fp32_t:
55 return tosa::DType::DType_FP32;
Grant Watsone70d9312023-08-28 16:34:28 +010056 case tosa_datatype_int16_t:
57 return tosa::DType::DType_INT16;
58 case tosa_datatype_int32_t:
59 return tosa::DType::DType_INT32;
60 case tosa_datatype_int48_t:
61 return tosa::DType::DType_INT48;
62 case tosa_datatype_int4_t:
63 return tosa::DType::DType_INT4;
64 case tosa_datatype_int8_t:
65 return tosa::DType::DType_INT8;
66 case tosa_datatype_uint16_t:
67 return tosa::DType::DType_UINT16;
68 case tosa_datatype_uint8_t:
69 return tosa::DType::DType_UINT8;
Grant Watsoneff70382023-09-12 10:46:36 +010070 case tosa_datatype_shape_t:
71 return tosa::DType::DType_SHAPE;
Grant Watson64285a12022-11-16 15:32:39 +000072 default:
73 return tosa::DType::DType_UNKNOWN;
74 }
75};
76
77tosa::TosaSerializationTensor* translate_client_tensor(tosa_tensor_t& tensor, const std::string& name)
78{
79 std::vector<int32_t> shape(tensor.shape, tensor.shape + tensor.num_dims);
80 return new tosa::TosaSerializationTensor(name, shape, translate_client_datatype(tensor.data_type), {});
81}
82
83tosa::ResizeMode translate_client_tosa_mode(tosa_mode_t mode) {
84 switch(mode) {
85 case tosa_mode_nearest:
86 return tosa::ResizeMode_NEAREST;
Jerry Ge9e94af82022-10-27 09:57:00 -070087 case tosa_mode_max:
Grant Watson64285a12022-11-16 15:32:39 +000088 case tosa_mode_bilinear:
89 return tosa::ResizeMode_BILINEAR;
90 default:
Jerry Ge9e94af82022-10-27 09:57:00 -070091 return tosa::ResizeMode_UNKNOWN;
Grant Watson64285a12022-11-16 15:32:39 +000092 }
93}
94
Dmitrii Agibovc8fdccf2023-09-21 11:05:58 +010095tosa::DType translate_client_acc_size(tosa_acc_size_t acc_size) {
96 switch(acc_size) {
97 case tosa_acc_size_int32_t:
98 return tosa::DType::DType_INT32;
99 case tosa_acc_size_fp16_t:
100 return tosa::DType::DType_FP16;
101 case tosa_acc_size_fp32_t:
102 return tosa::DType::DType_FP32;
103 default:
104 return tosa::DType::DType_UNKNOWN;
105 }
106}
107
Grant Watson64285a12022-11-16 15:32:39 +0000108} // namespace
109
110extern "C"
111{
112 {% for operator in operators: %}
113 tosa_status_t tosa_run_{{ operator.name }} (
114 {%- for arg in operator.arguments: -%}
115 {% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}}
116 {% if loop.index < operator.arguments|length %},{% endif %}
Grant Watsoneff70382023-09-12 10:46:36 +0100117 {%- endfor -%},const func_ctx_t& func_ctx
Grant Watson64285a12022-11-16 15:32:39 +0000118 )
119 {
120 // Create operator attributes
Grant Watsoneff70382023-09-12 10:46:36 +0100121 {% for att in operator.serialLibAtts: -%}
122 {{att.init}}
Grant Watson64285a12022-11-16 15:32:39 +0000123 {%- endfor -%}
124
125 Tosa{{operator.serializeAttType}}Attribute attr
Grant Watsoneff70382023-09-12 10:46:36 +0100126 {%- if operator.serialLibAtts|length > 0 -%}
Grant Watson64285a12022-11-16 15:32:39 +0000127 (
Grant Watsoneff70382023-09-12 10:46:36 +0100128 {%- for att in operator.serialLibAtts: -%}
129 {%- if att.init == "" -%}
130 client_{{att.name}}
131 {%- else -%}
132 {{att.name}}
133 {%- endif -%}
134 {% if loop.index < operator.serialLibAtts|length %}, {% endif %}
Grant Watson64285a12022-11-16 15:32:39 +0000135 {%- endfor -%}
136 )
137 {%- endif -%};
138
139 // Create tensors
140 {% for input in operator.inputs: -%}
141 tosa::TosaSerializationTensor* {{input}} = translate_client_tensor(client_{{input}}, "{{input}}");
142 {%- endfor -%}
143 {% for output in operator.outputs: %}
144 tosa::TosaSerializationTensor* {{output}} = translate_client_tensor(client_{{output}}, "{{output}}");
145 {%- endfor %}
146
147 // Create operator
148 auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_{{operator.name|upper}},
149 {%- if operator.serializeAttType != "None" -%}
150 tosa::Attribute::Attribute_{{operator.serializeAttType}}Attribute
151 {%- else -%}
152 tosa::Attribute::Attribute_NONE
153 {%- endif -%},
154 &attr, {
155 {%- for input in operator.inputs: -%}
156 {{input}}->GetName()
157 {%- if loop.index < operator.inputs|length -%},{%- endif -%}
158 {%- endfor -%}
159 },
160 {
161 {%- for output in operator.outputs: -%}
162 {{output}}->GetName()
163 {%- if loop.index < operator.outputs|length -%},{%- endif -%}
164 {%- endfor -%}
165 });
166
167 // Create a tosa single-op basic block
Jerry Ge9e94af82022-10-27 09:57:00 -0700168 tosa::TosaSerializationBasicBlock block("{{operator.name}}", "main", { op },
Grant Watson64285a12022-11-16 15:32:39 +0000169 {
170 {%- for input in operator.inputs: -%}
171 {{input}},
172 {%- endfor -%}
173 {%- for output in operator.outputs: -%}
174 {{output}}
175 {%- if loop.index < operator.outputs|length -%},{%- endif -%}
176 {%- endfor -%}
177 },
178 {
179 {%- for input in operator.inputs: -%}
180 {{input}}->GetName()
181 {%- if loop.index < operator.inputs|length -%},{%- endif -%}
182 {%- endfor -%}
183 },
184 {
185 {%- for output in operator.outputs: -%}
186 {{output}}->GetName()
187 {%- if loop.index < operator.outputs|length -%},{%- endif -%}
188 {%- endfor -%}
189 });
190
191 // Setup model
Grant Watsoneff70382023-09-12 10:46:36 +0100192 TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
Grant Watson64285a12022-11-16 15:32:39 +0000193 TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
194 {% for input in operator.inputs: -%}
195 TOSA_RETURN_ON_ERROR(runner.setInput({{input}}->GetName(), client_{{input}}.data, client_{{input}}.size));
196 {%- endfor %}
197
198 // Execute
199 TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run());
200
201 // Extract outputs
202 {% for output in operator.outputs: -%}
203 TOSA_RETURN_ON_ERROR(runner.getOutput({{output}}->GetName(), client_{{output}}.data, client_{{output}}.size));
204 {%- endfor %}
205
206 return tosa_status_valid;
207 }
208 {% endfor %}
209
210} // extern "C"