blob: 9117df4721c7c8b4fab09c6d12865b5adfeb845a [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
James Ward8b390432022-08-12 20:48:56 +01002// Copyright (c) 2020-2022, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef OPS_OP_FACTORY_H
17#define OPS_OP_FACTORY_H
18
19#include "attribute.h"
20#include "graph_node.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070021#include "template_types.h"
22#include "tosa_serialization_handler.h"
23
24#define DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) \
25 case RANK: \
Eric Kunzeb5fabec2022-06-07 05:20:44 +000026 return new OP<RANK, DType_##DTYPE>(sgt, attribute, id);
Eric Kunzee5e26762020-10-13 16:11:07 -070027
28#define DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \
29 case RANK: \
Eric Kunzeb5fabec2022-06-07 05:20:44 +000030 return new OP<RANK, DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id);
Eric Kunzee5e26762020-10-13 16:11:07 -070031
32#define DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, RANK1, RANK2, DTYPE) \
33 case RANK2: \
Eric Kunzeb5fabec2022-06-07 05:20:44 +000034 return new OP<RANK1, RANK2, DType_##DTYPE>(sgt, attribute, id);
Eric Kunzee5e26762020-10-13 16:11:07 -070035
36#define DEF_FACTORY_TWO_RANK_TWO_TYPE(OP, RANK1, RANK2, DTYPE1, DTYPE2) \
37 case RANK2: \
Eric Kunzeb5fabec2022-06-07 05:20:44 +000038 return new OP<RANK1, RANK2, DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id);
Eric Kunzee5e26762020-10-13 16:11:07 -070039
40#define DEF_FACTORY_ONE_RANK_0_6(OP) \
41 switch (inputRank) \
42 { \
43 case 0: \
TatWai Chongf7326092022-06-08 12:17:14 -070044 return new OP<0>(sgt, attribute, id); \
Eric Kunzee5e26762020-10-13 16:11:07 -070045 case 1: \
TatWai Chongf7326092022-06-08 12:17:14 -070046 return new OP<1>(sgt, attribute, id); \
Eric Kunzee5e26762020-10-13 16:11:07 -070047 case 2: \
TatWai Chongf7326092022-06-08 12:17:14 -070048 return new OP<2>(sgt, attribute, id); \
Eric Kunzee5e26762020-10-13 16:11:07 -070049 case 3: \
TatWai Chongf7326092022-06-08 12:17:14 -070050 return new OP<3>(sgt, attribute, id); \
Eric Kunzee5e26762020-10-13 16:11:07 -070051 case 4: \
TatWai Chongf7326092022-06-08 12:17:14 -070052 return new OP<4>(sgt, attribute, id); \
Eric Kunzee5e26762020-10-13 16:11:07 -070053 case 5: \
TatWai Chongf7326092022-06-08 12:17:14 -070054 return new OP<5>(sgt, attribute, id); \
Eric Kunzee5e26762020-10-13 16:11:07 -070055 case 6: \
TatWai Chongf7326092022-06-08 12:17:14 -070056 return new OP<6>(sgt, attribute, id); \
Eric Kunzee5e26762020-10-13 16:11:07 -070057 }
58
59#define DEF_FACTORY_ONE_TYPE(OP, DTYPE) \
60 if (inputDType == DType_##DTYPE) \
61 { \
TatWai Chongf7326092022-06-08 12:17:14 -070062 return new OP<DType_##DTYPE>(sgt, attribute, id); \
Eric Kunzee5e26762020-10-13 16:11:07 -070063 }
64
James Ward8b390432022-08-12 20:48:56 +010065#define DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OP, ATTR_NAME, DTYPE, ACCUM_DTYPE) \
66 if (inputDType == DType_##DTYPE && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == DType_##ACCUM_DTYPE) \
67 { \
68 return new OP<DType_##DTYPE, DType_##ACCUM_DTYPE>(sgt, attribute, id); \
69 }
70
Eric Kunzee5e26762020-10-13 16:11:07 -070071#define DEF_FACTORY_TWO_TYPE(OP, DTYPE1, DTYPE2) \
72 if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2) \
73 { \
TatWai Chongf7326092022-06-08 12:17:14 -070074 return new OP<DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id); \
Eric Kunzee5e26762020-10-13 16:11:07 -070075 }
76
James Wardd34b3fc2023-01-18 14:51:25 +000077#define DEF_FACTORY_TWO_TYPE_IN_OUT(OP, DTYPE1, DTYPE2) \
78 if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
79 { \
80 return new OP<DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id); \
81 }
82
James Ward8b390432022-08-12 20:48:56 +010083#define DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OP, ATTR_NAME, DTYPE1, DTYPE2, ACCUM_DTYPE) \
84 if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2 \
85 && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == DType_##ACCUM_DTYPE) \
86 { \
87 return new OP<DType_##DTYPE1, DType_##DTYPE2, DType_##ACCUM_DTYPE>(sgt, attribute, id); \
88 } \
89
James Wardd34b3fc2023-01-18 14:51:25 +000090#define DEF_FACTORY_THREE_TYPE(OP, DTYPE1, DTYPE2, DTYPE3) \
91 if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2 && outputDType == DType_##DTYPE3) \
92 { \
93 return new OP<DType_##DTYPE1, DType_##DTYPE2, DType_##DTYPE3>(sgt, attribute, id); \
94 }
95
James Ward8b390432022-08-12 20:48:56 +010096// Statement-expression to evaluate accumulate attribute in-place
97#define ACCUM_FROM_ATTRIBUTE(ATTRIBUTE_NAME) \
98 ({ \
99 tosa::DType accumDType = tosa::DType_UNKNOWN; \
100 if (auto p = dynamic_cast<tosa::Tosa##ATTRIBUTE_NAME##Attribute*>(attribute)) \
101 { \
102 auto attr = new tosa::Tosa##ATTRIBUTE_NAME##Attribute(p); \
103 ASSERT_MEM(attr); \
104 accumDType = tosa::EnumValuesDType()[attr->accum_dtype()]; \
105 } \
106 else \
107 { \
108 FATAL_ERROR("Can't initialize Tosa" #ATTRIBUTE_NAME "Attribute.\nPre-initialization " \
109 "of this attribute is required in order to determine the accumulate type."); \
110 } \
111 accumDType; \
112 }) \
113
TatWai Chongf7326092022-06-08 12:17:14 -0700114#define DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OP, DTYPE1, DTYPE2) \
Eric Kunzee5e26762020-10-13 16:11:07 -0700115 if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
116 { \
TatWai Chongf7326092022-06-08 12:17:14 -0700117 return new OP<DType_##DTYPE1, DType_##DTYPE2, int16_t>(sgt, attribute, id); \
118 }
119
James Ward8b390432022-08-12 20:48:56 +0100120#define DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OP, DTYPE1, DTYPE2) \
121 if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
122 { \
James Wardee256692022-11-15 11:36:47 +0000123 return new OP<DType_##DTYPE1, DType_##DTYPE2, half_float::half>(sgt, attribute, id); \
James Ward8b390432022-08-12 20:48:56 +0100124 }
125
James Ward24dbc422022-10-19 12:20:31 +0100126#define DEF_FACTORY_TWO_TYPE_RESIZE_BF16(OP, DTYPE1, DTYPE2) \
127 if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
128 { \
129 return new OP<DType_##DTYPE1, DType_##DTYPE2, Eigen::bfloat16>(sgt, attribute, id); \
130 }
131
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100132#define DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OP, DTYPE1, DTYPE2) \
TatWai Chongf7326092022-06-08 12:17:14 -0700133 if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
134 { \
135 return new OP<DType_##DTYPE1, DType_##DTYPE2, float>(sgt, attribute, id); \
Eric Kunzee5e26762020-10-13 16:11:07 -0700136 }
137
138#define DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
139 if (inputDType == DType_##DTYPE) \
140 { \
141 switch (inputRank) \
142 { \
143 DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 0, DTYPE) \
144 DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 1, DTYPE) \
145 DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 2, DTYPE) \
146 DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 3, DTYPE) \
147 DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 4, DTYPE) \
148 DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \
149 DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 6, DTYPE) \
150 } \
151 }
152
153#define DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
154 if (inputDType == DType_##DTYPE) \
155 { \
156 switch (inputRank) \
157 { \
158 DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 1, DTYPE) \
159 DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 2, DTYPE) \
160 DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 3, DTYPE) \
161 DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 4, DTYPE) \
162 DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \
163 DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 6, DTYPE) \
164 } \
165 }
166
167#define DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OP, DTYPE1, DTYPE2) \
168 if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
169 { \
170 switch (inputRank) \
171 { \
172 DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 0, DTYPE1, DTYPE2) \
173 DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 1, DTYPE1, DTYPE2) \
174 DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 2, DTYPE1, DTYPE2) \
175 DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 3, DTYPE1, DTYPE2) \
176 DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 4, DTYPE1, DTYPE2) \
177 DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 5, DTYPE1, DTYPE2) \
178 DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 6, DTYPE1, DTYPE2) \
179 } \
180 }
181
182#define DEF_FACTORY_RESHAPE(OP, DTYPE) \
183 if (inputDType == DType_##DTYPE && outputDType == DType_##DTYPE) \
184 { \
185 switch (inputRank) \
186 { \
187 case 0: \
188 { \
189 switch (outputRank) \
190 { \
191 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 0, DTYPE) \
192 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 1, DTYPE) \
193 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 2, DTYPE) \
194 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 3, DTYPE) \
195 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 4, DTYPE) \
196 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 5, DTYPE) \
197 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 6, DTYPE) \
198 } \
199 } \
200 case 1: \
201 { \
202 switch (outputRank) \
203 { \
204 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 0, DTYPE) \
205 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 1, DTYPE) \
206 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 2, DTYPE) \
207 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 3, DTYPE) \
208 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 4, DTYPE) \
209 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 5, DTYPE) \
210 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 6, DTYPE) \
211 } \
212 } \
213 case 2: \
214 { \
215 switch (outputRank) \
216 { \
217 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 0, DTYPE) \
218 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 1, DTYPE) \
219 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 2, DTYPE) \
220 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 3, DTYPE) \
221 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 4, DTYPE) \
222 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 5, DTYPE) \
223 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 6, DTYPE) \
224 } \
225 } \
226 case 3: \
227 { \
228 switch (outputRank) \
229 { \
230 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 0, DTYPE) \
231 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 1, DTYPE) \
232 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 2, DTYPE) \
233 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 3, DTYPE) \
234 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 4, DTYPE) \
235 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 5, DTYPE) \
236 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 6, DTYPE) \
237 } \
238 } \
239 case 4: \
240 { \
241 switch (outputRank) \
242 { \
243 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 0, DTYPE) \
244 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 1, DTYPE) \
245 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 2, DTYPE) \
246 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 3, DTYPE) \
247 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 4, DTYPE) \
248 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 5, DTYPE) \
249 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 6, DTYPE) \
250 } \
251 } \
252 case 5: \
253 { \
254 switch (outputRank) \
255 { \
256 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 0, DTYPE) \
257 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 1, DTYPE) \
258 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 2, DTYPE) \
259 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 3, DTYPE) \
260 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 4, DTYPE) \
261 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 5, DTYPE) \
262 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 6, DTYPE) \
263 } \
264 } \
265 case 6: \
266 { \
267 switch (outputRank) \
268 { \
269 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 0, DTYPE) \
270 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 1, DTYPE) \
271 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 2, DTYPE) \
272 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 3, DTYPE) \
273 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 4, DTYPE) \
274 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 5, DTYPE) \
275 DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 6, DTYPE) \
276 } \
277 } \
278 } \
279 }
280
Eric Kunzee5e26762020-10-13 16:11:07 -0700281namespace TosaReference
282{
283
Kevin Chengacb550f2021-06-29 15:32:19 -0700284class SubgraphTraverser;
285class GraphNode;
286
Eric Kunzee5e26762020-10-13 16:11:07 -0700287class OpFactory
288{
289public:
Kevin Chengacb550f2021-06-29 15:32:19 -0700290 static GraphNode* newOp(SubgraphTraverser* sgt,
291 tosa::TosaSerializationHandler* tsh,
Eric Kunzee5e26762020-10-13 16:11:07 -0700292 tosa::Op opType,
293 tosa::TosaAttributeBase* attribute,
Eric Kunzee5e26762020-10-13 16:11:07 -0700294 uint64_t id,
295 tosa::DType inputDType,
296 int inputRank,
297 tosa::DType outputDType,
298 int outputRank,
299 tosa::DType weightDType,
300 int weightRank);
301};
302}; // namespace TosaReference
303
304#endif