blob: 6b0ed6eea75bb63393f8d8a1e20b8be29bd7e62d [file] [log] [blame]
Grant Watson64285a12022-11-16 15:32:39 +00001
2// Copyright (c) 2022, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16// THIS FILE IS GENERATED. DO NOT EDIT!
17// See scripts/operator_api/generate_api.py
18
19#include "operators.h"
20#include "model_runner_impl.h"
21#include "ops/op_factory.h"
22
23#define TOSA_RETURN_ON_ERROR(status) \
24 do \
25 { \
26 if (status != 0) \
27 { \
28 return tosa_status_error; \
29 } \
30 } while (false)
31
32#define TOSA_RETURN_ON_GRAPH_STATUS_ERROR(status) \
33 do \
34 { \
35 if (status != GraphStatus::TOSA_VALID) \
36 { \
37 auto ustatus = static_cast<std::underlying_type_t<GraphStatus>>(status); \
38 return static_cast<tosa_status_t>(ustatus); \
39 } \
40 } while (false)
41
42namespace {
43
44tosa::DType translate_client_datatype(tosa_datatype_t type)
45{
46 switch (type)
47 {
48 case tosa_datatype_fp16_t:
49 return tosa::DType::DType_FP16;
50 case tosa_datatype_fp32_t:
51 return tosa::DType::DType_FP32;
52 default:
53 return tosa::DType::DType_UNKNOWN;
54 }
55};
56
57tosa::TosaSerializationTensor* translate_client_tensor(tosa_tensor_t& tensor, const std::string& name)
58{
59 std::vector<int32_t> shape(tensor.shape, tensor.shape + tensor.num_dims);
60 return new tosa::TosaSerializationTensor(name, shape, translate_client_datatype(tensor.data_type), {});
61}
62
63tosa::ResizeMode translate_client_tosa_mode(tosa_mode_t mode) {
64 switch(mode) {
65 case tosa_mode_nearest:
66 return tosa::ResizeMode_NEAREST;
67 case tosa_mode_max:
68 case tosa_mode_bilinear:
69 return tosa::ResizeMode_BILINEAR;
70 default:
71 return tosa::ResizeMode_UNKNOWN;
72 }
73}
74
75} // namespace
76
77extern "C"
78{
79 {% for operator in operators: %}
80 tosa_status_t tosa_run_{{ operator.name }} (
81 {%- for arg in operator.arguments: -%}
82 {% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}}
83 {% if loop.index < operator.arguments|length %},{% endif %}
84 {%- endfor -%}
85 )
86 {
87 // Create operator attributes
88 {% for arg in operator.serializeArgs: %}
89 {%- if arg.SV == "V": -%}
90 const std::vector<{{arg.dType}}> {{arg.name}}{{arg.init}};
91 {%- else: -%}
92 const {{arg.dType}} {{arg.name}}{{arg.init}};
93 {%- endif -%}
94 {%- endfor -%}
95
96 Tosa{{operator.serializeAttType}}Attribute attr
97 {%- if operator.serializeArgs|length > 0 -%}
98 (
99 {%- for arg in operator.serializeArgs: -%}
100 {{arg.name}}{% if loop.index < operator.serializeArgs|length %}, {% endif %}
101 {%- endfor -%}
102 )
103 {%- endif -%};
104
105 // Create tensors
106 {% for input in operator.inputs: -%}
107 tosa::TosaSerializationTensor* {{input}} = translate_client_tensor(client_{{input}}, "{{input}}");
108 {%- endfor -%}
109 {% for output in operator.outputs: %}
110 tosa::TosaSerializationTensor* {{output}} = translate_client_tensor(client_{{output}}, "{{output}}");
111 {%- endfor %}
112
113 // Create operator
114 auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_{{operator.name|upper}},
115 {%- if operator.serializeAttType != "None" -%}
116 tosa::Attribute::Attribute_{{operator.serializeAttType}}Attribute
117 {%- else -%}
118 tosa::Attribute::Attribute_NONE
119 {%- endif -%},
120 &attr, {
121 {%- for input in operator.inputs: -%}
122 {{input}}->GetName()
123 {%- if loop.index < operator.inputs|length -%},{%- endif -%}
124 {%- endfor -%}
125 },
126 {
127 {%- for output in operator.outputs: -%}
128 {{output}}->GetName()
129 {%- if loop.index < operator.outputs|length -%},{%- endif -%}
130 {%- endfor -%}
131 });
132
133 // Create a tosa single-op basic block
134 tosa::TosaSerializationBasicBlock block("{{operator.name}}", { op },
135 {
136 {%- for input in operator.inputs: -%}
137 {{input}},
138 {%- endfor -%}
139 {%- for output in operator.outputs: -%}
140 {{output}}
141 {%- if loop.index < operator.outputs|length -%},{%- endif -%}
142 {%- endfor -%}
143 },
144 {
145 {%- for input in operator.inputs: -%}
146 {{input}}->GetName()
147 {%- if loop.index < operator.inputs|length -%},{%- endif -%}
148 {%- endfor -%}
149 },
150 {
151 {%- for output in operator.outputs: -%}
152 {{output}}->GetName()
153 {%- if loop.index < operator.outputs|length -%},{%- endif -%}
154 {%- endfor -%}
155 });
156
157 // Setup model
158 TosaReference::ModelRunnerImpl runner;
159 TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
160 {% for input in operator.inputs: -%}
161 TOSA_RETURN_ON_ERROR(runner.setInput({{input}}->GetName(), client_{{input}}.data, client_{{input}}.size));
162 {%- endfor %}
163
164 // Execute
165 TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run());
166
167 // Extract outputs
168 {% for output in operator.outputs: -%}
169 TOSA_RETURN_ON_ERROR(runner.getOutput({{output}}->GetName(), client_{{output}}.data, client_{{output}}.size));
170 {%- endfor %}
171
172 return tosa_status_valid;
173 }
174 {% endfor %}
175
176} // extern "C"