blob: 654aeb55dcf908f679c6a3c0d1eb9b30a591b3c2 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
Colm Donelanb4ef1632024-02-01 15:00:43 +00002// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01007
Keith Davis0c2eeac2020-02-11 16:51:50 +00008#include <armnn/TypesUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000010#include <armnn/utility/IgnoreUnused.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010011#include <armnn/utility/NumericCast.hpp>
Cathal Corbett34b429c2021-12-24 12:24:40 +000012#include <armnn/utility/PolymorphicDowncast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
Matteo Martincighe011d202019-11-28 11:35:47 +000014#include <LayerSupportCommon.hpp>
Derek Lambertif674aa02019-08-01 15:56:25 +010015#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000016
Derek Lamberti50db4e82019-03-13 14:16:15 +000017#include <vector>
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <array>
19
telsoa014fcda012018-03-09 14:13:49 +000020namespace armnn
21{
22
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010023namespace
24{
25
26template<typename Float32Func, typename Uint8Func, typename ... Params>
27bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
28 DataType dataType,
29 Float32Func floatFuncPtr,
30 Uint8Func uint8FuncPtr,
31 Params&&... params)
32{
33 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
34 dataType,
35 &FalseFunc<Params...>,
36 floatFuncPtr,
37 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000038 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000039 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010040 std::forward<Params>(params)...);
41}
42
43} // anonymous namespace
44
James Conroy4d1ff582019-06-10 17:06:39 +010045namespace
46{
47
48std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
49 unsigned int actual,
50 std::string& layerStr,
51 std::string& tensorName)
52{
53 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
54 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
55
56 return errorMsg;
57}
58
59} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000060
Cathal Corbett34b429c2021-12-24 12:24:40 +000061bool RefLayerSupport::IsLayerSupported(const LayerType& type,
62 const std::vector<TensorInfo>& infos,
63 const BaseDescriptor& descriptor,
64 const Optional<LstmInputParamsInfo>& lstmParamsInfo,
65 const Optional<QuantizedLstmInputParamsInfo>& quantizedLstmInputParamsInfo,
66 Optional<std::string&> reasonIfUnsupported) const
67{
68 switch (type)
69 {
70 case LayerType::Activation:
71 return IsActivationSupported(infos[0],
72 infos[1],
73 *(PolymorphicDowncast<const ActivationDescriptor*>(&descriptor)),
74 reasonIfUnsupported);
75 case LayerType::Addition:
76 return IsAdditionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
77 case LayerType::ArgMinMax:
78 return IsArgMinMaxSupported(infos[0],
79 infos[1],
80 *(PolymorphicDowncast<const ArgMinMaxDescriptor*>(&descriptor)),
81 reasonIfUnsupported);
Samuel Yap6b478092022-07-06 15:36:03 +010082 case LayerType::BatchMatMul:
83 return IsBatchMatMulSupported(infos[0],
84 infos[1],
85 infos[2],
86 *(PolymorphicDowncast<const BatchMatMulDescriptor*>(&descriptor)),
87 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +000088 case LayerType::BatchNormalization:
89 return IsBatchNormalizationSupported(infos[0],
90 infos[1],
91 infos[2],
92 infos[3],
93 infos[4],
94 infos[5],
95 *(PolymorphicDowncast<const BatchNormalizationDescriptor*>
96 (&descriptor)),
97 reasonIfUnsupported);
98 case LayerType::BatchToSpaceNd:
99 return IsBatchToSpaceNdSupported(infos[0],
100 infos[1],
101 *(PolymorphicDowncast<const BatchToSpaceNdDescriptor*>(&descriptor)),
102 reasonIfUnsupported);
Idriss Chaouch98e383e2023-08-28 14:28:31 +0100103 case LayerType::BroadcastTo:
104 return IsBroadcastToSupported(infos[0],
105 infos[1],
106 *(PolymorphicDowncast<const BroadcastToDescriptor*>(&descriptor)),
107 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000108 case LayerType::Comparison:
109 return IsComparisonSupported(infos[0],
110 infos[1],
111 infos[2],
112 *(PolymorphicDowncast<const ComparisonDescriptor*>(&descriptor)),
113 reasonIfUnsupported);
114 case LayerType::Concat:
115 {
116 std::vector<const TensorInfo*> inputInfos;
117 for (uint32_t i = 0; i < (infos.size() - 1); i++)
118 {
119 inputInfos.push_back(&infos[i]);
120 }
121 return IsConcatSupported(inputInfos,
122 infos[infos.size() - 1],
123 *(PolymorphicDowncast<const OriginsDescriptor*>(&descriptor)),
124 reasonIfUnsupported);
125 }
126 case LayerType::Constant:
127 return IsConstantSupported(infos[0], reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000128 case LayerType::ConvertFp16ToFp32:
129 return IsConvertFp16ToFp32Supported(infos[0], infos[1], reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000130 case LayerType::ConvertFp32ToFp16:
131 return IsConvertFp32ToFp16Supported(infos[0], infos[1], reasonIfUnsupported);
132 case LayerType::Convolution2d:
133 {
134 if (infos.size() != 4)
135 {
136 throw InvalidArgumentException("Invalid number of Convolution2d TensorInfos. "
137 "TensorInfos should be of format: {input, output, weights, biases}.");
138 }
139
140 auto desc = *(PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor));
141 if (infos[3] == TensorInfo())
142 {
143 return IsConvolution2dSupported(infos[0],
144 infos[1],
145 desc,
146 infos[2],
147 EmptyOptional(),
148 reasonIfUnsupported);
149 }
150 else
151 {
152 return IsConvolution2dSupported(infos[0],
153 infos[1],
154 desc,
155 infos[2],
156 infos[3],
157 reasonIfUnsupported);
158 }
159 }
160 case LayerType::DepthToSpace:
161 return IsDepthToSpaceSupported(infos[0],
162 infos[1],
163 *(PolymorphicDowncast<const DepthToSpaceDescriptor*>(&descriptor)),
164 reasonIfUnsupported);
165 case LayerType::DepthwiseConvolution2d:
166 {
167 if (infos.size() != 4)
168 {
169 throw InvalidArgumentException("Invalid number of DepthwiseConvolution2d TensorInfos. "
170 "TensorInfos should be of format: {input, output, weights, biases}.");
171 }
172
173 auto desc = *(PolymorphicDowncast<const DepthwiseConvolution2dDescriptor*>(&descriptor));
174 if (infos[3] == TensorInfo())
175 {
176 return IsDepthwiseConvolutionSupported(infos[0],
177 infos[1],
178 desc,
179 infos[2],
180 EmptyOptional(),
181 reasonIfUnsupported);
182 }
183 else
184 {
185 return IsDepthwiseConvolutionSupported(infos[0],
186 infos[1],
187 desc,
188 infos[2],
189 infos[3],
190 reasonIfUnsupported);
191 }
192 }
193 case LayerType::Dequantize:
194 return IsDequantizeSupported(infos[0], infos[1], reasonIfUnsupported);
195 case LayerType::Division:
196 return IsDivisionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
Mike Kelly3ec30772023-03-08 13:47:17 +0000197 case LayerType::ElementwiseBinary:
198 {
199 std::array<DataType, 7> supportedTypes =
200 {
201 DataType::Float32,
202 DataType::Float16,
203 DataType::QAsymmS8,
204 DataType::QAsymmU8,
205 DataType::QSymmS16,
206 DataType::Signed32
207 };
208
209 bool supported = true;
210 supported &= CheckSupportRule(TypeAnyOf(infos[0], supportedTypes), reasonIfUnsupported,
211 "Reference elementwise unary: input type not supported");
212
213 supported &= CheckSupportRule(TypeAnyOf(infos[1], supportedTypes), reasonIfUnsupported,
214 "Reference elementwise unary: input type not supported");
215
216 supported &= CheckSupportRule(TypeAnyOf(infos[2], supportedTypes), reasonIfUnsupported,
217 "Reference elementwise unary: output type not supported");
218
219 supported &= CheckSupportRule(TypesAreEqual(infos[0], infos[1]), reasonIfUnsupported,
220 "Reference elementwise unary: input types not matching");
221
222 supported &= CheckSupportRule(TypesAreEqual(infos[0], infos[2]), reasonIfUnsupported,
223 "Reference elementwise unary: input and output types not matching");
224
225 return supported;
226 }
Cathal Corbett34b429c2021-12-24 12:24:40 +0000227 case LayerType::ElementwiseUnary:
228 return IsElementwiseUnarySupported(infos[0],
229 infos[1],
230 *(PolymorphicDowncast<const ElementwiseUnaryDescriptor*>(&descriptor)),
231 reasonIfUnsupported);
232 case LayerType::Fill:
233 return IsFillSupported(infos[0],
234 infos[1],
235 *(PolymorphicDowncast<const FillDescriptor*>(&descriptor)),
236 reasonIfUnsupported);
237 case LayerType::Floor:
238 return IsFloorSupported(infos[0], infos[1], reasonIfUnsupported);
239 case LayerType::FullyConnected:
240 return IsFullyConnectedSupported(infos[0],
241 infos[1],
242 infos[2],
243 infos[3],
244 *(PolymorphicDowncast<const FullyConnectedDescriptor*>(&descriptor)),
245 reasonIfUnsupported);
246 case LayerType::Gather:
247 return IsGatherSupported(infos[0],
248 infos[1],
249 infos[2],
250 *(PolymorphicDowncast<const GatherDescriptor*>(&descriptor)),
251 reasonIfUnsupported);
Teresa Charlinb2d3ec52022-04-12 22:07:09 +0100252 case LayerType::GatherNd:
253 return IsGatherNdSupported(infos[0],
254 infos[1],
255 infos[2],
256 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000257 case LayerType::Input:
258 return IsInputSupported(infos[0], reasonIfUnsupported);
259 case LayerType::InstanceNormalization:
260 return IsInstanceNormalizationSupported(infos[0],
261 infos[1],
262 *(PolymorphicDowncast<const InstanceNormalizationDescriptor*>
263 (&descriptor)),
264 reasonIfUnsupported);
265 case LayerType::L2Normalization:
266 return IsL2NormalizationSupported(infos[0],
267 infos[1],
268 *(PolymorphicDowncast<const L2NormalizationDescriptor*>(&descriptor)),
269 reasonIfUnsupported);
270 case LayerType::LogicalBinary:
271 return IsLogicalBinarySupported(infos[0],
272 infos[1],
273 infos[2],
274 *(PolymorphicDowncast<const LogicalBinaryDescriptor*>(&descriptor)),
275 reasonIfUnsupported);
276 case LayerType::LogSoftmax:
277 return IsLogSoftmaxSupported(infos[0],
278 infos[1],
279 *(PolymorphicDowncast<const LogSoftmaxDescriptor*>(&descriptor)),
280 reasonIfUnsupported);
281 case LayerType::Lstm:
282 return IsLstmSupported(infos[0],
283 infos[1],
284 infos[2],
285 infos[3],
286 infos[4],
287 infos[5],
288 infos[6],
289 *(PolymorphicDowncast<const LstmDescriptor*>(&descriptor)),
290 lstmParamsInfo.value(),
291 reasonIfUnsupported);
292 case LayerType::QLstm:
293 return IsQLstmSupported(infos[0],
294 infos[1],
295 infos[2],
296 infos[3],
297 infos[4],
298 infos[5],
299 *(PolymorphicDowncast<const QLstmDescriptor*>(&descriptor)),
300 lstmParamsInfo.value(),
301 reasonIfUnsupported);
302 case LayerType::Maximum:
303 return IsMaximumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
304 case LayerType::Mean:
305 return IsMeanSupported(infos[0],
306 infos[1],
307 *(PolymorphicDowncast<const MeanDescriptor*>(&descriptor)),
308 reasonIfUnsupported);
309 case LayerType::Minimum:
310 return IsMinimumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
311 case LayerType::Multiplication:
312 return IsMultiplicationSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
313 case LayerType::Normalization:
314 return IsNormalizationSupported(infos[0],
315 infos[1],
316 *(PolymorphicDowncast<const NormalizationDescriptor*>(&descriptor)),
317 reasonIfUnsupported);
318 case LayerType::Output:
319 return IsOutputSupported(infos[0], reasonIfUnsupported);
320 case LayerType::Pad:
321 return IsPadSupported(infos[0],
322 infos[1],
323 *(PolymorphicDowncast<const PadDescriptor*>(&descriptor)),
324 reasonIfUnsupported);
325 case LayerType::Permute:
326 return IsPermuteSupported(infos[0],
327 infos[1],
328 *(PolymorphicDowncast<const PermuteDescriptor*>(&descriptor)),
329 reasonIfUnsupported);
330 case LayerType::Pooling2d:
331 return IsPooling2dSupported(infos[0],
332 infos[1],
333 *(PolymorphicDowncast<const Pooling2dDescriptor*>(&descriptor)),
334 reasonIfUnsupported);
335 case LayerType::Prelu:
336 return IsPreluSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
337 case LayerType::Quantize:
338 return IsQuantizeSupported(infos[0], infos[1], reasonIfUnsupported);
339 case LayerType::Reshape:
340 return IsReshapeSupported(infos[0],
341 infos[1],
342 *(PolymorphicDowncast<const ReshapeDescriptor*>(&descriptor)),
343 reasonIfUnsupported);
344 case LayerType::Resize:
345 return IsResizeSupported(infos[0],
346 infos[1],
347 *(PolymorphicDowncast<const ResizeDescriptor*>(&descriptor)),
348 reasonIfUnsupported);
Tianle Cheng988354d2023-06-28 13:20:47 +0100349 case LayerType::ReverseV2:
350 return IsReverseV2Supported(infos[0],
351 infos[1],
Tracy Narinebb8d7592023-07-13 16:50:54 +0100352 infos[2],
Tianle Cheng988354d2023-06-28 13:20:47 +0100353 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000354 case LayerType::Reduce:
355 return IsReduceSupported(infos[0],
356 infos[1],
357 *(PolymorphicDowncast<const ReduceDescriptor*>(&descriptor)),
358 reasonIfUnsupported);
Tianle Cheng28288182024-02-23 17:56:54 +0000359 case LayerType::ScatterNd:
360 return IsScatterNdSupported(infos[0],
361 infos[1],
362 infos[2],
363 infos[3],
364 *(PolymorphicDowncast<const ScatterNdDescriptor*>(&descriptor)),
365 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000366 case LayerType::Slice:
367 return IsSliceSupported(infos[0],
368 infos[1],
369 *(PolymorphicDowncast<const SliceDescriptor*>(&descriptor)),
370 reasonIfUnsupported);
371 case LayerType::Softmax:
372 return IsSoftmaxSupported(infos[0],
373 infos[1],
374 *(PolymorphicDowncast<const SoftmaxDescriptor*>(&descriptor)),
375 reasonIfUnsupported);
376 case LayerType::SpaceToBatchNd:
377 return IsSpaceToBatchNdSupported(infos[0],
378 infos[1],
379 *(PolymorphicDowncast<const SpaceToBatchNdDescriptor*>(&descriptor)),
380 reasonIfUnsupported);
381 case LayerType::SpaceToDepth:
382 return IsSpaceToDepthSupported(infos[0],
383 infos[1],
384 *(PolymorphicDowncast<const SpaceToDepthDescriptor*>(&descriptor)),
385 reasonIfUnsupported);
386 case LayerType::Splitter:
387 {
388 std::vector<TensorInfo> outputInfos;
389 for (uint32_t i = 1; i < infos.size(); i++)
390 {
391 outputInfos.push_back(infos[i]);
392 }
393 return IsSplitterSupported(infos[0],
394 {outputInfos.begin(), outputInfos.end()},
395 *(PolymorphicDowncast<const ViewsDescriptor*>(&descriptor)),
396 reasonIfUnsupported);
397 }
398 case LayerType::Stack:
399 {
400 std::vector<const TensorInfo*> inputInfos;
401 for (uint32_t i = 0; i < infos.size() - 1; i++)
402 {
403 inputInfos.push_back(&infos[i]);
404 }
405 return IsStackSupported(inputInfos,
406 infos[infos.size() - 1],
407 *(PolymorphicDowncast<const StackDescriptor*>(&descriptor)),
408 reasonIfUnsupported);
409 }
410 case LayerType::StridedSlice:
411 return IsStridedSliceSupported(infos[0],
412 infos[1],
413 *(PolymorphicDowncast<const StridedSliceDescriptor*>(&descriptor)),
414 reasonIfUnsupported);
415 case LayerType::Subtraction:
416 return IsSubtractionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
Teresa Charlin79a06a52023-07-13 17:16:45 +0100417 case LayerType::Tile:
418 return IsTileSupported(infos[0],
419 infos[1],
420 *(PolymorphicDowncast<const TileDescriptor*>(&descriptor)),
421 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000422 case LayerType::Transpose:
423 return IsTransposeSupported(infos[0],
424 infos[1],
425 *(PolymorphicDowncast<const TransposeDescriptor*>(&descriptor)),
426 reasonIfUnsupported);
427 case LayerType::TransposeConvolution2d:
428 {
429 if (infos.size() != 4)
430 {
431 throw InvalidArgumentException("Invalid number of TransposeConvolution2d TensorInfos. "
432 "TensorInfos should be of format: {input, output, weights, biases}.");
433 }
434
435 auto desc = *(PolymorphicDowncast<const TransposeConvolution2dDescriptor*>(&descriptor));
436 if (infos[3] == TensorInfo())
437 {
438 return IsTransposeConvolution2dSupported(infos[0],
439 infos[1],
440 desc,
441 infos[2],
442 EmptyOptional(),
443 reasonIfUnsupported);
444 }
445 else
446 {
447 return IsTransposeConvolution2dSupported(infos[0],
448 infos[1],
449 desc,
450 infos[2],
451 infos[3],
452 reasonIfUnsupported);
453 }
454 }
455 case LayerType::Cast:
456 return IsCastSupported(infos[0], infos[1], reasonIfUnsupported);
457 case LayerType::ChannelShuffle:
458 return IsChannelShuffleSupported(infos[0],
459 infos[1],
460 *(PolymorphicDowncast<const ChannelShuffleDescriptor*>(&descriptor)),
461 reasonIfUnsupported);
462 case LayerType::Convolution3d:
463 {
464 if (infos.size() != 4)
465 {
466 throw InvalidArgumentException("Invalid number of Convolution3d TensorInfos. "
467 "TensorInfos should be of format: {input, output, weights, biases}.");
468 }
469
470 auto desc = *(PolymorphicDowncast<const Convolution3dDescriptor*>(&descriptor));
471 if (infos[3] == TensorInfo())
472 {
473 return IsConvolution3dSupported(infos[0],
474 infos[1],
475 desc,
476 infos[2],
477 EmptyOptional(),
478 reasonIfUnsupported);
479 }
480 else
481 {
482 return IsConvolution3dSupported(infos[0],
483 infos[1],
484 desc,
485 infos[2],
486 infos[3],
487 reasonIfUnsupported);
488 }
489 }
490 case LayerType::Debug:
491 return IsDebugSupported(infos[0], infos[1], reasonIfUnsupported);
492 case LayerType::DetectionPostProcess:
493 return IsDetectionPostProcessSupported(infos[0],
494 infos[1],
495 infos[2],
496 infos[3],
497 infos[4],
498 infos[5],
499 infos[6],
500 *(PolymorphicDowncast<const DetectionPostProcessDescriptor*>
501 (&descriptor)),
502 reasonIfUnsupported);
503 case LayerType::FakeQuantization:
504 return IsFakeQuantizationSupported(infos[0],
505 *(PolymorphicDowncast<const FakeQuantizationDescriptor*>(&descriptor)),
506 reasonIfUnsupported);
507 case LayerType::MemCopy:
508 return IsMemCopySupported(infos[0], infos[1], reasonIfUnsupported);
509 case LayerType::Rank:
510 return IsRankSupported(infos[0], infos[1], reasonIfUnsupported);
511 case LayerType::Shape:
512 return IsShapeSupported(infos[0], infos[1], reasonIfUnsupported);
513 case LayerType::UnidirectionalSequenceLstm:
514 {
515 if (infos.size() != 6)
516 {
517 throw InvalidArgumentException("Invalid number of UnidirectionalSequenceLstm TensorInfos. TensorInfos "
518 "should be of format: {input, outputStateIn, cellStateIn, "
519 "hiddenStateOutputVal, cellStateOutputVal, output}");
520 }
521 auto desc = *(PolymorphicDowncast<const UnidirectionalSequenceLstmDescriptor*>(&descriptor));
Mike Kelly12994962022-04-21 11:57:09 +0100522 return IsUnidirectionalSequenceLstmSupported(infos[0],
523 infos[1],
524 infos[2],
525 infos[3],
526 infos[4],
527 infos[5],
528 desc,
529 lstmParamsInfo.value(),
530 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000531 }
532 case LayerType::Pooling3d:
533 return IsPooling3dSupported(infos[0],
534 infos[1],
535 *(PolymorphicDowncast<const Pooling3dDescriptor*>(&descriptor)),
536 reasonIfUnsupported);
537 case LayerType::Map:
538 return true;
539 case LayerType::Unmap:
540 return true;
541 case LayerType::MemImport:
542 return LayerSupportBase::IsMemImportSupported(infos[0], infos[1], reasonIfUnsupported);
543 case LayerType::Merge:
544 return LayerSupportBase::IsMergeSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
545 case LayerType::QuantizedLstm:
546 return LayerSupportBase::IsQuantizedLstmSupported(infos[0],
547 infos[1],
548 infos[2],
549 infos[3],
550 infos[4],
551 quantizedLstmInputParamsInfo.value(),
552 reasonIfUnsupported);
553 default:
Teresa Charlin9145e382023-08-17 18:44:58 +0100554 // layers not supported in reference by default:
555 // precompiled, standin, switch, fused
Cathal Corbett34b429c2021-12-24 12:24:40 +0000556 return false;
557 }
558}
559
arovir011c7c81b2018-10-08 11:34:28 +0100560bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
561 const TensorInfo& output,
562 const ActivationDescriptor& descriptor,
563 Optional<std::string&> reasonIfUnsupported) const
564{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000565 bool supported = true;
566
567 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +0000568 std::array<DataType,6> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000569 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100570 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000571 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000572 DataType::QAsymmU8,
573 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000574 };
575
576 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
577 "Reference activation: input type not supported.");
578
579 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
580 "Reference activation: output type not supported.");
581
582 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
583 "Reference activation: input and output types mismatched.");
584
585 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
586 "Reference activation: input and output shapes are of different rank.");
587
588
589 struct ActivationFunctionSupported : public Rule
590 {
591 ActivationFunctionSupported(const ActivationDescriptor& desc)
592 {
593 switch(desc.m_Function)
594 {
595 case ActivationFunction::Abs:
596 case ActivationFunction::BoundedReLu:
David Monahan3b3c3812020-02-25 09:03:29 +0000597 case ActivationFunction::Elu:
Teresa Charlin077cddb2023-09-15 15:19:21 +0100598 case ActivationFunction::Gelu:
Colm Donelan03fbeaf2020-02-26 15:39:23 +0000599 case ActivationFunction::HardSwish:
Derek Lamberti50db4e82019-03-13 14:16:15 +0000600 case ActivationFunction::LeakyReLu:
601 case ActivationFunction::Linear:
602 case ActivationFunction::ReLu:
603 case ActivationFunction::Sigmoid:
604 case ActivationFunction::SoftReLu:
605 case ActivationFunction::Sqrt:
606 case ActivationFunction::Square:
607 case ActivationFunction::TanH:
608 {
609 m_Res = true;
610 break;
611 }
612 default:
613 {
614 m_Res = false;
615 break;
616 }
617 }
618 }
619 };
620
621 // Function is supported
622 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
623 "Reference activation: function not supported.");
624
625 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100626}
627
628bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
629 const TensorInfo& input1,
630 const TensorInfo& output,
631 Optional<std::string&> reasonIfUnsupported) const
632{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000633 bool supported = true;
634
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100635 std::array<DataType,7> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000636 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100637 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000638 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000639 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100640 DataType::QSymmS16,
641 DataType::Signed32
Derek Lamberti50db4e82019-03-13 14:16:15 +0000642 };
643
644 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
645 "Reference addition: input 0 is not a supported type.");
646
647 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
648 "Reference addition: input 1 is not a supported type.");
649
650 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
651 "Reference addition: output is not a supported type.");
652
653 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
654 "Reference addition: input 0 and Input 1 types are mismatched");
655
656 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
657 "Reference addition: input and output types are mismatched");
658
659 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
660 "Reference addition: shapes are not suitable for implicit broadcast.");
661
662 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100663}
664
Nikhil Raj68c2c902019-09-19 11:21:11 +0100665bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
666 const armnn::ArgMinMaxDescriptor &descriptor,
667 armnn::Optional<std::string &> reasonIfUnsupported) const
668{
Jan Eilers8eb25602020-03-09 12:13:48 +0000669 IgnoreUnused(descriptor);
Nikhil Raj68c2c902019-09-19 11:21:11 +0100670
Mike Kelly1f140f72021-04-06 12:25:55 +0100671 std::array<DataType, 8> supportedInputTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100672 {
Teresa Charline300b362020-05-25 10:01:03 +0100673 DataType::Float16,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100674 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100675 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000676 DataType::QAsymmU8,
677 DataType::QSymmS16,
Mike Kelly1f140f72021-04-06 12:25:55 +0100678 DataType::Signed32,
679 DataType::Signed64
680 };
681
682 std::array<DataType,2> supportedOutputTypes = {
683 DataType::Signed32,
684 DataType::Signed64
Nikhil Raj68c2c902019-09-19 11:21:11 +0100685 };
686
687 bool supported = true;
688
Mike Kelly1f140f72021-04-06 12:25:55 +0100689 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100690 "Reference ArgMinMax: input is not a supported type.");
Mike Kelly1f140f72021-04-06 12:25:55 +0100691 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100692 "Reference ArgMinMax: output type not supported");
693
694 return supported;
695}
696
Samuel Yap6b478092022-07-06 15:36:03 +0100697bool RefLayerSupport::IsBatchMatMulSupported(const TensorInfo& inputX,
698 const TensorInfo& inputY,
699 const TensorInfo& output,
700 const BatchMatMulDescriptor& descriptor,
701 Optional<std::string &> reasonIfUnsupported) const
702{
703 IgnoreUnused(descriptor);
704
705 std::array<DataType, 6> supportedTypes =
706 {
Samuel Yap6b478092022-07-06 15:36:03 +0100707 DataType::Float16,
708 DataType::Float32,
709 DataType::QAsymmS8,
710 DataType::QAsymmU8,
711 DataType::QSymmS16
712 };
713
714 bool supported = true;
715
716 supported &= CheckSupportRule(TypeAnyOf(inputX, supportedTypes), reasonIfUnsupported,
717 "Reference batch matrix multiplication: input X is not a supported type");
718
719 supported &= CheckSupportRule(TypeAnyOf(inputY, supportedTypes), reasonIfUnsupported,
720 "Reference batch matrix multiplication: input Y is not a supported type");
721
722 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
723 "Reference batch matrix multiplication: output is not a supported type");
724
725 supported &= CheckSupportRule(TypesAreEqual(inputX, inputY), reasonIfUnsupported,
726 "Reference batch matrix multiplication: input X and input Y types are mismatched");
727
728 supported &= CheckSupportRule(TypesAreEqual(inputX, output), reasonIfUnsupported,
729 "Reference batch matrix multiplication: inputs and output types are mismatched");
730
731 supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputX, 2),
732 reasonIfUnsupported,
733 "Reference batch matrix multiplication: input X is not of rank 2 or greater");
734
735 supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputY, 2),
736 reasonIfUnsupported,
737 "Reference batch matrix multiplication: input Y is not of rank 2 or greater");
738
739 return supported;
740}
741
arovir011c7c81b2018-10-08 11:34:28 +0100742bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
743 const TensorInfo& output,
744 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100745 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100746 const TensorInfo& beta,
747 const TensorInfo& gamma,
748 const BatchNormalizationDescriptor& descriptor,
749 Optional<std::string&> reasonIfUnsupported) const
750{
Jan Eilers8eb25602020-03-09 12:13:48 +0000751 IgnoreUnused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100752
Sadik Armagan303980c2020-04-17 12:45:14 +0100753 std::array<DataType, 6> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100754 {
755 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100756 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100757 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000758 DataType::QAsymmU8,
759 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100760 };
761
762 bool supported = true;
763
764 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
765 "Reference batch normalization: input is not a supported type.");
766
767 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
768 "Reference batch normalization: output is not a supported type.");
769
770 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
771 "Reference batch normalization: input and output types are mismatched");
772
773 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
774 "Reference batch normalization: mean is not a supported type.");
775
776 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
777 "Reference batch normalization: variance is not a supported type.");
778
779 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
780 "Reference batch normalization: beta is not a supported type.");
781
782 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
783 "Reference batch normalization: gamma is not a supported type.");
784
785 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100786}
787
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000788bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
789 const TensorInfo& output,
790 const BatchToSpaceNdDescriptor& descriptor,
791 Optional<std::string&> reasonIfUnsupported) const
792{
Jan Eilers8eb25602020-03-09 12:13:48 +0000793 IgnoreUnused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100794
795 bool supported = true;
796
797 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
798 std::string inputTensorStr = "input";
799 std::string outputTensorStr = "output";
800
801 // Define supported types.
Sadik Armagan303980c2020-04-17 12:45:14 +0100802 std::array<DataType,6> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100803 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000804 DataType::Float32,
805 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100806 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000807 DataType::QAsymmU8,
808 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100809 };
810
811 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
812 "Reference BatchToSpaceNd: input type not supported.");
813
814 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
815 "Reference BatchToSpaceNd: output type not supported.");
816
817 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
818 "Reference BatchToSpaceNd: input and output types mismatched.");
819
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100820 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000821}
822
Idriss Chaouch98e383e2023-08-28 14:28:31 +0100823bool RefLayerSupport::IsBroadcastToSupported(const TensorInfo& input,
824 const TensorInfo& output,
825 const BroadcastToDescriptor& descriptor,
826 Optional<std::string&> reasonIfUnsupported) const
827{
828 IgnoreUnused(descriptor);
829
830 bool supported = true;
831
832 std::array<DataType, 8> supportedTypes
833 {
834 DataType::Float32,
835 DataType::Float16,
836 DataType::QAsymmS8,
837 DataType::QAsymmU8,
838 DataType::QSymmS8,
839 DataType::QSymmS16,
840 DataType::Signed32,
841 DataType::Signed64
842 };
843
844 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
845 "BroadcastTo: input type not supported.");
846
847 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
848 "BroadcastTo: output type not supported");
849
850 return supported;
851}
852
mathad01b392e982021-04-07 12:07:30 +0100853bool RefLayerSupport::IsCastSupported(const TensorInfo& input,
854 const TensorInfo& output,
855 Optional<std::string&> reasonIfUnsupported) const
856{
Teresa Charlin5306dc82023-10-30 22:29:58 +0000857 std::array<DataType, 10> supportedInputTypes =
Idriss Chaouch98e383e2023-08-28 14:28:31 +0100858 {
859 DataType::Float32,
860 DataType::Float16,
861 DataType::QSymmS8,
862 DataType::QAsymmS8,
863 DataType::QAsymmU8,
864 DataType::QSymmS16,
Teresa Charlin5306dc82023-10-30 22:29:58 +0000865 DataType::Signed32,
866 DataType::Signed64
Idriss Chaouch98e383e2023-08-28 14:28:31 +0100867 };
mathad01b392e982021-04-07 12:07:30 +0100868
869 bool supported = true;
870 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
871 "Reference cast: input is not a supported type");
872
873
874 supported &= CheckSupportRule(TypeAnyOf(output, supportedInputTypes), reasonIfUnsupported,
875 "Reference cast: output is not a supported type");
876
877 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
878 "Reference cast: input and output shapes have different number of total elements");
879
880 return supported;
881}
882
Simon Obute51f67772021-09-03 15:50:13 +0100883bool RefLayerSupport::IsChannelShuffleSupported(const TensorInfo& input,
884 const TensorInfo& output,
885 const ChannelShuffleDescriptor& descriptor,
886 Optional<std::string&> reasonIfUnsupported) const
887{
888 IgnoreUnused(descriptor);
889 bool supported = true;
890
891 // Define supported output and inputs types.
892 std::array<DataType, 7> supportedTypes =
893 {
Simon Obute51f67772021-09-03 15:50:13 +0100894 DataType::Float32,
895 DataType::Float16,
896 DataType::QAsymmS8,
897 DataType::QAsymmU8,
898 DataType::QSymmS8,
899 DataType::QSymmS16
900 };
901
902 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
903 "Reference ChannelShuffle: input is not a supported type.");
904
905 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
906 "Reference ChannelShuffle: output is not a supported type.");
907
908 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
909 "Reference ChannelShuffle: input and output types are mismatched.");
910
911 return supported;
912}
913
914
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100915bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
916 const TensorInfo& input1,
917 const TensorInfo& output,
918 const ComparisonDescriptor& descriptor,
919 Optional<std::string&> reasonIfUnsupported) const
920{
Jan Eilers8eb25602020-03-09 12:13:48 +0000921 IgnoreUnused(descriptor);
Sadik Armagan303980c2020-04-17 12:45:14 +0100922 std::array<DataType, 8> supportedInputTypes =
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100923 {
Sadik Armaganb60dd242020-03-19 13:53:16 +0000924 DataType::Boolean,
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100925 DataType::Float32,
926 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100927 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000928 DataType::QAsymmU8,
Sadik Armaganb60dd242020-03-19 13:53:16 +0000929 DataType::QSymmS16,
930 DataType::Signed32
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100931 };
932
933 bool supported = true;
934 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
935 "Reference comparison: input 0 is not a supported type");
936
937 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
938 "Reference comparison: input 0 and Input 1 types are mismatched");
939
940 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
941 "Reference comparison: output is not of type Boolean");
942
943 return supported;
944}
945
Jim Flynn906f9462019-05-10 13:55:21 +0100946bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
947 const TensorInfo& output,
Cathal Corbett34b429c2021-12-24 12:24:40 +0000948 const OriginsDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100949 Optional<std::string&> reasonIfUnsupported) const
950{
Jan Eilers8eb25602020-03-09 12:13:48 +0000951 IgnoreUnused(descriptor);
Jim Flynne242f2d2019-05-22 14:24:13 +0100952
953 bool supported = true;
Teresa Charlin6abc7ee2022-02-22 17:32:27 +0000954 std::array<DataType,7> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100955 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000956 DataType::Float32,
957 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000958 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100959 DataType::QAsymmU8,
Teresa Charlin6abc7ee2022-02-22 17:32:27 +0000960 DataType::QSymmS16,
961 DataType::Signed32
Jim Flynne242f2d2019-05-22 14:24:13 +0100962 };
963
964 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
965 "Reference concatenation: output type not supported");
966 for (const TensorInfo* input : inputs)
967 {
968 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
969 "Reference concatenation: input type not supported");
970
971 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
972 "Reference concatenation: input and output types mismatched.");
973 }
974
975 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100976}
977
arovir011c7c81b2018-10-08 11:34:28 +0100978bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
979 Optional<std::string&> reasonIfUnsupported) const
980{
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100981 std::array<DataType,8> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100982 {
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100983 DataType::Float16,
Nina Drozd58ef2c62019-05-16 12:09:18 +0100984 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +0000985 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100986 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000987 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100988 DataType::QSymmS16,
989 DataType::Signed32
Nina Drozd58ef2c62019-05-16 12:09:18 +0100990 };
991
992 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
993 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100994}
995
996bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
997 const TensorInfo& output,
998 Optional<std::string&> reasonIfUnsupported) const
999{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001000 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
1001 input.GetDataType(),
1002 &TrueFunc<>,
1003 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +00001004 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +00001005 &FalseFuncI32<>,
1006 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001007 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
1008 output.GetDataType(),
1009 &FalseOutputFuncF16<>,
1010 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +00001011 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +00001012 &FalseFuncI32<>,
1013 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +01001014}
1015
1016bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
1017 const TensorInfo& output,
1018 Optional<std::string&> reasonIfUnsupported) const
1019{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001020 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
1021 input.GetDataType(),
1022 &FalseInputFuncF16<>,
1023 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +00001024 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +00001025 &FalseFuncI32<>,
1026 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001027 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
1028 output.GetDataType(),
1029 &TrueFunc<>,
1030 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +00001031 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +00001032 &FalseFuncI32<>,
1033 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +01001034}
1035
1036bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
1037 const TensorInfo& output,
1038 const Convolution2dDescriptor& descriptor,
1039 const TensorInfo& weights,
1040 const Optional<TensorInfo>& biases,
1041 Optional<std::string&> reasonIfUnsupported) const
1042{
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001043 bool supported = true;
1044
1045 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001046 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001047 {
1048 DataType::Float32,
1049 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001050 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001051 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001052 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001053 DataType::QSymmS16
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001054 };
1055
1056 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001057 "Reference Convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001058
1059 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001060 "Reference Convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001061
Ryan OShea31441592022-11-07 16:20:48 +00001062 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1063 "Reference Convolution2d: input and output types mismatched.");
1064
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001065
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001066 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +00001067 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001068 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001069 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001070 {
Sadik Armagan303980c2020-04-17 12:45:14 +01001071 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001072 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01001073 DataType::QSymmS8
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001074 };
1075
1076 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001077 "Reference Convolution2d: weights type not supported for quantized input.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001078 }
1079 else
1080 {
1081 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001082 "Reference Convolution2d: weights is not a supported type.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001083
1084 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001085 "Reference Convolution2d: input and weights types mismatched.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001086 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001087
1088 if (biases.has_value())
1089 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001090 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001091 {
1092 DataType::Float32,
1093 DataType::Float16,
1094 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001095 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001096
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001097 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001098 "Reference Convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001099 }
Jan Eilers8eb25602020-03-09 12:13:48 +00001100 IgnoreUnused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001101
1102 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001103}
1104
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001105bool RefLayerSupport::IsConvolution3dSupported(const TensorInfo& input,
1106 const TensorInfo& output,
1107 const Convolution3dDescriptor& descriptor,
1108 const TensorInfo& weights,
1109 const Optional<TensorInfo>& biases,
1110 Optional<std::string&> reasonIfUnsupported) const
1111{
1112 bool supported = true;
1113
1114 // Define supported types.
1115 std::array<DataType,7> supportedTypes =
1116 {
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001117 DataType::Float32,
1118 DataType::Float16,
1119 DataType::QAsymmS8,
1120 DataType::QAsymmU8,
1121 DataType::QSymmS8,
1122 DataType::QSymmS16
1123 };
1124
1125 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1126 "Reference Convolution3d: input is not a supported type.");
1127
1128 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1129 "Reference Convolution3d: output is not a supported type.");
1130
1131 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1132 "Reference Convolution3d: input and output types mismatched.");
1133
1134 const DataType inputType = input.GetDataType();
1135 if (IsQuantized8BitType(inputType))
1136 {
1137 std::array<DataType, 3> supportedWeightTypes =
1138 {
1139 DataType::QAsymmS8,
1140 DataType::QAsymmU8,
1141 DataType::QSymmS8
1142 };
1143
1144 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1145 "Reference Convolution3d: weights type not supported for quantized input.");
1146 }
1147 else
1148 {
1149 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1150 "Reference Convolution3d: weights is not a supported type.");
1151
1152 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1153 "Reference Convolution3d: input and weights types mismatched.");
1154 }
1155
1156 if (biases.has_value())
1157 {
1158 std::array<DataType,4> biasesSupportedTypes =
1159 {
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001160 DataType::Float32,
1161 DataType::Float16,
1162 DataType::Signed32
1163 };
1164
1165 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1166 "Reference Convolution3d: biases is not a supported type.");
1167 }
1168 IgnoreUnused(descriptor);
1169
1170 return supported;
1171}
1172
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001173bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
1174 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001175 Optional<std::string&> reasonIfUnsupported) const
1176{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001177 bool supported = true;
1178
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001179 std::array<DataType, 8> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001180 {
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001181 DataType::BFloat16,
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +00001182 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001183 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001184 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001185 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001186 DataType::QSymmS8,
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +00001187 DataType::QSymmS16,
1188 DataType::Signed32
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001189 };
1190
1191 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001192 "Reference for Debug layer: input type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001193
1194 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001195 "Reference for Debug layer: output type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001196
1197 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001198 "Reference for Debug layer: input and output types are mismatched");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001199
1200 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001201}
1202
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001203bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
1204 const TensorInfo& output,
1205 const DepthToSpaceDescriptor& descriptor,
1206 Optional<std::string&> reasonIfUnsupported) const
1207{
Jan Eilers8eb25602020-03-09 12:13:48 +00001208 IgnoreUnused(descriptor);
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001209 bool supported = true;
1210
Sadik Armagan303980c2020-04-17 12:45:14 +01001211 std::array<DataType,6> supportedTypes =
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001212 {
1213 DataType::Float32,
1214 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001215 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001216 DataType::QAsymmU8,
1217 DataType::QSymmS16
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001218 };
1219
1220 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1221 "Reference DepthToSpace: input type not supported");
1222
1223 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1224 "Reference DepthToSpace: output type not supported");
1225
1226 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1227 "Reference DepthToSpace: input and output types are mismatched");
1228
1229 return supported;
1230}
1231
arovir011c7c81b2018-10-08 11:34:28 +01001232bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
1233 const TensorInfo& output,
1234 const DepthwiseConvolution2dDescriptor& descriptor,
1235 const TensorInfo& weights,
1236 const Optional<TensorInfo>& biases,
1237 Optional<std::string&> reasonIfUnsupported) const
1238{
Sadik Armagan303980c2020-04-17 12:45:14 +01001239 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001240 bool supported = true;
1241
1242 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001243 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001244 {
1245 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001246 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001247 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001248 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001249 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001250 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001251 };
1252
1253 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1254 "Reference DepthwiseConvolution2d: input is not a supported type.");
1255
1256 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1257 "Reference DepthwiseConvolution2d: output is not a supported type.");
1258
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001259 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1260 "Reference DepthwiseConvolution2d: input and output types mismatched.");
1261
Teresa Charlind8df0262019-11-11 12:28:15 +00001262 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +00001263 if (IsQuantized8BitType(inputType))
Teresa Charlind8df0262019-11-11 12:28:15 +00001264 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001265 std::array<DataType, 3> supportedWeightTypes =
Sadik Armagan303980c2020-04-17 12:45:14 +01001266 {
1267 DataType::QAsymmS8,
1268 DataType::QAsymmU8,
1269 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001270 };
Teresa Charlind8df0262019-11-11 12:28:15 +00001271
1272 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Sadik Armagan303980c2020-04-17 12:45:14 +01001273 "Reference DepthwiseConvolution2d: weights type not supported for "
1274 "quantized input.");
Teresa Charlind8df0262019-11-11 12:28:15 +00001275 }
1276 else
1277 {
1278 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1279 "Reference DepthwiseConvolution2d: weights is not a supported type.");
1280
1281 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1282 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
1283 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001284
1285 if (biases.has_value())
1286 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001287 std::array<DataType,4> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001288 {
1289 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001290 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001291 DataType::Signed32
1292 };
1293 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1294 "Reference DepthwiseConvolution2d: biases is not a supported type.");
1295 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001296
1297 return supported;
1298
arovir011c7c81b2018-10-08 11:34:28 +01001299}
1300
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001301bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
1302 const TensorInfo& output,
1303 Optional<std::string&> reasonIfUnsupported) const
1304{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001305 bool supported = true;
1306
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001307 std::array<DataType,5> supportedInputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00001308 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001309 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00001310 DataType::QSymmS8,
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001311 DataType::QSymmS16,
1312 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001313 };
1314
1315 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001316 "Reference for Dequantize layer: input type not supported.");
1317
Derek Lambertid466a542020-01-22 15:37:29 +00001318 supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
Teresa Charlin1b1950d2021-06-02 20:23:21 +01001319 "Reference for Dequantize layer: per-axis quantized input not supported.");
Derek Lambertid466a542020-01-22 15:37:29 +00001320
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001321 std::array<DataType,3> supportedOutputTypes = {
Jan Eilersf7107932019-11-01 11:09:36 +00001322 DataType::Float32,
1323 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001324 };
1325
1326 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001327 "Reference for Dequantize layer: output type not supported.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001328
1329 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001330 "Reference for Dequantize layer: input/output shapes have different num total "
1331 "elements.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001332
1333 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001334}
1335
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001336bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
1337 const TensorInfo& scores,
1338 const TensorInfo& anchors,
1339 const TensorInfo& detectionBoxes,
1340 const TensorInfo& detectionClasses,
1341 const TensorInfo& detectionScores,
1342 const TensorInfo& numDetections,
1343 const DetectionPostProcessDescriptor& descriptor,
1344 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001345{
Jan Eilers8eb25602020-03-09 12:13:48 +00001346 IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
Derek Lamberti901ea112019-12-10 22:07:09 +00001347
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001348 bool supported = true;
1349
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001350 std::array<DataType,6> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001351 {
1352 DataType::Float32,
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001353 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001354 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001355 DataType::QAsymmU8,
1356 DataType::QSymmS16
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001357 };
1358
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001359 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001360 "Reference DetectionPostProcess: input 0 is not a supported type.");
1361
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001362 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001363 "Reference DetectionPostProcess: input 1 is not a supported type.");
1364
1365 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001366}
1367
Pablo Tellof0bd6832019-04-26 17:58:13 +01001368bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
1369 const TensorInfo& output,
1370 const DepthwiseConvolution2dDescriptor& descriptor,
1371 const TensorInfo& weights,
1372 const Optional<TensorInfo>& biases,
1373 Optional<std::string&> reasonIfUnsupported) const
1374{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001375 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +01001376}
1377
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001378bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +01001379 const TensorInfo& input1,
1380 const TensorInfo& output,
1381 Optional<std::string&> reasonIfUnsupported) const
1382{
Sadik Armagan2999a022019-04-09 14:20:12 +01001383 bool supported = true;
1384
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001385 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001386 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001387 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001388 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001389 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001390 DataType::QSymmS16,
1391 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001392 };
1393
1394 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1395 "Reference division: input 0 is not a supported type.");
1396
1397 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1398 "Reference division: input 1 is not a supported type.");
1399
1400 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1401 "Reference division: output is not a supported type.");
1402
1403 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1404 "Reference division: input 0 and Input 1 types are mismatched");
1405
1406 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1407 "Reference division: input and output types are mismatched");
1408
1409 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1410 "Reference division: shapes are not suitable for implicit broadcast.");
1411
1412 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001413}
1414
josh minor4a3c6102020-01-06 16:40:46 -06001415bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
1416 const TensorInfo& output,
1417 const ElementwiseUnaryDescriptor& descriptor,
1418 Optional<std::string&> reasonIfUnsupported) const
1419{
Jan Eilers8eb25602020-03-09 12:13:48 +00001420 IgnoreUnused(descriptor);
josh minor4a3c6102020-01-06 16:40:46 -06001421
Sadik Armagan303980c2020-04-17 12:45:14 +01001422 std::array<DataType, 7> supportedTypes =
josh minor4a3c6102020-01-06 16:40:46 -06001423 {
1424 DataType::Float32,
1425 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001426 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06001427 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00001428 DataType::QSymmS16,
1429 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06001430 };
1431
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001432 std::array<DataType, 1> logicalSupportedTypes =
1433 {
1434 DataType::Boolean
1435 };
1436
josh minor4a3c6102020-01-06 16:40:46 -06001437 bool supported = true;
1438
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001439 if (descriptor.m_Operation == UnaryOperation::LogicalNot)
1440 {
1441 supported &= CheckSupportRule(TypeAnyOf(input, logicalSupportedTypes), reasonIfUnsupported,
1442 "Reference elementwise unary: input type not supported");
josh minor4a3c6102020-01-06 16:40:46 -06001443
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001444 supported &= CheckSupportRule(TypeAnyOf(output, logicalSupportedTypes), reasonIfUnsupported,
1445 "Reference elementwise unary: output type not supported");
1446 }
1447 else
1448 {
1449 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1450 "Reference elementwise unary: input type not supported");
1451
1452 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1453 "Reference elementwise unary: output type not supported");
1454 }
josh minor4a3c6102020-01-06 16:40:46 -06001455
1456 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1457 "Reference elementwise unary: input and output types not matching");
1458
1459 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1460 "Reference elementwise unary: input and output shapes"
1461 "have different number of total elements");
1462
1463 return supported;
1464}
1465
arovir011c7c81b2018-10-08 11:34:28 +01001466bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
1467 const FakeQuantizationDescriptor& descriptor,
1468 Optional<std::string&> reasonIfUnsupported) const
1469{
Jan Eilers8eb25602020-03-09 12:13:48 +00001470 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001471 bool supported = true;
1472
1473 std::array<DataType,1> supportedTypes =
1474 {
1475 DataType::Float32
1476 };
1477
1478 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1479 "Reference fake quantization: input type not supported.");
1480
1481 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001482}
1483
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001484bool RefLayerSupport::IsFillSupported(const TensorInfo& input,
1485 const TensorInfo& output,
1486 const FillDescriptor& descriptor,
1487 Optional<std::string&> reasonIfUnsupported) const
1488{
1489 IgnoreUnused(descriptor);
1490 IgnoreUnused(output);
1491
1492 bool supported = true;
1493
Sadik Armagana792a052020-06-23 16:22:23 +01001494 std::array<DataType,3> supportedTypes =
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001495 {
1496 DataType::Float32,
Sadik Armagana792a052020-06-23 16:22:23 +01001497 DataType::Float16,
1498 DataType::Signed32
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001499 };
1500
Teresa Charlin4b10fef2020-07-29 09:36:41 +01001501 supported &= CheckSupportRule(TypeIs(input, DataType::Signed32), reasonIfUnsupported,
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001502 "Reference Fill: input type not supported.");
1503
Teresa Charlin44088502020-07-27 11:27:19 +01001504 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1505 "Reference Fill: output type not supported.");
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001506 return supported;
1507}
1508
arovir011c7c81b2018-10-08 11:34:28 +01001509bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
1510 const TensorInfo& output,
1511 Optional<std::string&> reasonIfUnsupported) const
1512{
Jan Eilers8eb25602020-03-09 12:13:48 +00001513 IgnoreUnused(output);
James Conroy83735b12019-05-30 16:36:59 +01001514 bool supported = true;
1515
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01001516 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +01001517 {
James Conroyb40d7102019-06-04 12:32:09 +01001518 DataType::Float32,
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01001519 DataType::Float16
James Conroy83735b12019-05-30 16:36:59 +01001520 };
1521
1522 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1523 "Reference Floor: input type not supported.");
1524
1525 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1526 "Reference Floor: output type not supported.");
1527
1528 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001529}
1530
1531bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
1532 const TensorInfo& output,
1533 const TensorInfo& weights,
1534 const TensorInfo& biases,
1535 const FullyConnectedDescriptor& descriptor,
1536 Optional<std::string&> reasonIfUnsupported) const
1537{
Francis Murtagh46c09d02019-05-28 08:15:28 +01001538 bool supported = true;
1539
1540 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001541 std::array<DataType,6> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +01001542 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001543 DataType::Float32,
1544 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001545 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001546 DataType::QAsymmU8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001547 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001548 };
1549
1550 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1551 "Reference Fully Connected: input type not supported.");
1552
1553 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1554 "Reference Fully Connected: output type not supported.");
1555
Francis Murtagh46c09d02019-05-28 08:15:28 +01001556 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1557 "Reference Fully Connected: weights type not supported.");
1558
Ryan OShea31441592022-11-07 16:20:48 +00001559 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1560 "Reference Fully Connected: input and output types mismatched.");
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001561
Jan Eilers1f45dc32020-06-15 11:43:03 +01001562 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1563 "Reference Fully Connected: weights is not a supported type.");
Francis Murtaghddb1d062020-03-10 13:51:45 +00001564
Jan Eilers1f45dc32020-06-15 11:43:03 +01001565 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1566 "Reference Fully Connected: input and weights types mismatched.");
Francis Murtagh46c09d02019-05-28 08:15:28 +01001567
1568 if (descriptor.m_BiasEnabled)
1569 {
1570 // Defined supported types for bias
Sadik Armagandb73c982020-04-01 17:35:30 +01001571 std::array<DataType, 5>
Francis Murtagh46c09d02019-05-28 08:15:28 +01001572 supportedBiasTypes =
1573 {
1574 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001575 DataType::Float16,
Sadik Armagandb73c982020-04-01 17:35:30 +01001576 DataType::Signed32,
1577 DataType::QAsymmS8
Francis Murtagh46c09d02019-05-28 08:15:28 +01001578 };
1579
1580 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
1581 "Reference Fully Connected: bias type not supported.");
1582
1583 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
1584 "Reference Fully Connected: bias and weight types mismatch.");
1585
1586 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
1587 "Reference Fully Connected: bias type inferred from weights is incompatible.");
1588
Narumol Prangnawarat366d7232020-04-29 12:58:17 +01001589 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(biases, 1U), reasonIfUnsupported,
1590 "Reference Fully Connected: bias must have 1 dimension.");
1591
Francis Murtagh46c09d02019-05-28 08:15:28 +01001592 }
1593
1594 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001595}
1596
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01001597bool RefLayerSupport::IsGatherNdSupported(const armnn::TensorInfo& input0,
1598 const armnn::TensorInfo& input1,
1599 const armnn::TensorInfo& output,
1600 armnn::Optional<std::string&> reasonIfUnsupported) const
1601{
1602 bool supported = true;
1603 std::array<DataType,7> supportedTypes =
1604 {
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01001605 DataType::Float32,
1606 DataType::Float16,
1607 DataType::QAsymmS8,
1608 DataType::QAsymmU8,
1609 DataType::QSymmS16,
1610 DataType::Signed32
1611 };
1612
1613 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1614 "Reference GatherNd: input type not supported");
1615
1616 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1617 "Reference GatherNd: output type not supported");
1618
1619 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1620 "Reference GatherNd: indices (input1) type not supported");
1621
1622 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1623 "Reference GatherNd: input and output types not matching");
1624
1625 return supported;
1626}
1627
narpra014951d842019-01-18 16:53:53 +00001628bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
1629 const armnn::TensorInfo& input1,
1630 const armnn::TensorInfo& output,
Teresa Charlin52664732020-06-29 16:27:03 +01001631 const GatherDescriptor& descriptor,
narpra014951d842019-01-18 16:53:53 +00001632 armnn::Optional<std::string&> reasonIfUnsupported) const
1633{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001634 bool supported = true;
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001635 std::array<DataType,7> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001636 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001637 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001638 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001639 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001640 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001641 DataType::QSymmS16,
1642 DataType::Signed32
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001643 };
1644
Nikhil Raj369d8fc2022-11-24 13:12:36 +00001645 IgnoreUnused(descriptor);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001646 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1647 "Reference Gather: input type not supported");
1648
1649 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1650 "Reference Gather: output type not supported");
1651
1652 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1653 "Reference Gather: indices (input1) type not supported");
1654
1655 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1656 "Reference Gather: input and output types not matching");
1657
1658 return supported;
narpra014951d842019-01-18 16:53:53 +00001659}
1660
Derek Lamberti901ea112019-12-10 22:07:09 +00001661bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
1662 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +01001663{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001664 return true;
arovir011c7c81b2018-10-08 11:34:28 +01001665}
1666
Kevin May09ca49c2019-10-09 12:37:34 +01001667bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
1668 const TensorInfo& output,
1669 const InstanceNormalizationDescriptor& descriptor,
1670 Optional<std::string&> reasonIfUnsupported) const
1671{
Jan Eilers8eb25602020-03-09 12:13:48 +00001672 IgnoreUnused(descriptor);
Kevin May09ca49c2019-10-09 12:37:34 +01001673 // Define supported types
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001674 std::array<DataType, 3> supportedTypes =
Kevin May09ca49c2019-10-09 12:37:34 +01001675 {
1676 DataType::Float32,
1677 DataType::Float16
1678 };
1679
1680 bool supported = true;
1681
1682 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1683 "Reference Instance Normalization: input type not supported.");
1684
1685 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1686 "Reference Instance Normalization: output type not supported.");
1687
1688 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1689 "Reference Instance Normalization: input and output types mismatched.");
1690
1691 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1692 "Reference Instance Normalization: input and output shapes have different "
1693 "num total elements.");
1694
1695 return supported;
1696}
1697
arovir011c7c81b2018-10-08 11:34:28 +01001698bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1699 const TensorInfo& output,
1700 const L2NormalizationDescriptor& descriptor,
1701 Optional<std::string&> reasonIfUnsupported) const
1702{
Jan Eilers8eb25602020-03-09 12:13:48 +00001703 IgnoreUnused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001704 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01001705 std::array<DataType, 6> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001706 {
1707 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001708 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001709 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001710 DataType::QAsymmU8,
1711 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001712 };
1713
1714 bool supported = true;
1715
1716 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1717 "Reference L2normalization: input type not supported.");
1718
1719 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1720 "Reference L2normalization: output type not supported.");
1721
1722 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1723 "Reference L2normalization: input and output types mismatched.");
1724
1725 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1726 "Reference L2normalization: input and output shapes have different "
1727 "num total elements.");
1728
1729 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001730}
1731
James Conroyaba90cd2020-11-06 16:28:18 +00001732bool RefLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0,
1733 const TensorInfo& input1,
1734 const TensorInfo& output,
1735 const LogicalBinaryDescriptor& descriptor,
1736 Optional<std::string&> reasonIfUnsupported) const
1737{
1738 IgnoreUnused(descriptor);
1739
1740 std::array<DataType, 1> supportedTypes =
1741 {
1742 DataType::Boolean
1743 };
1744
1745 bool supported = true;
1746 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1747 "Reference LogicalBinary: input 0 type not supported");
1748 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1749 "Reference LogicalBinary: input 1 type not supported");
1750
1751 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1752 "Reference LogicalBinary: input and output types do not match");
1753
1754 return supported;
1755}
1756
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001757bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1758 const TensorInfo& output,
1759 const LogSoftmaxDescriptor& descriptor,
1760 Optional<std::string&> reasonIfUnsupported) const
1761{
Jan Eilers8eb25602020-03-09 12:13:48 +00001762 IgnoreUnused(descriptor);
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001763
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001764 std::array<DataType, 3> supportedTypes =
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001765 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001766 DataType::Float32,
1767 DataType::Float16
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001768 };
1769
1770 bool supported = true;
1771 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1772 "Reference LogSoftmax: input type not supported");
1773
1774 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1775 "Reference LogSoftmax: output type not supported");
1776
1777 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1778 "Reference LogSoftmax: input and output types do not match");
1779
1780 return supported;
1781}
1782
arovir011c7c81b2018-10-08 11:34:28 +01001783bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1784 const TensorInfo& outputStateIn,
1785 const TensorInfo& cellStateIn,
1786 const TensorInfo& scratchBuffer,
1787 const TensorInfo& outputStateOut,
1788 const TensorInfo& cellStateOut,
1789 const TensorInfo& output,
1790 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001791 const LstmInputParamsInfo& paramsInfo,
1792 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +01001793{
Jan Eilers8eb25602020-03-09 12:13:48 +00001794 IgnoreUnused(descriptor);
1795 IgnoreUnused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001796
1797 bool supported = true;
1798
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001799 std::array<DataType,3> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +01001800 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001801 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001802 };
1803
Jan Eilersd01a83c2019-07-03 18:20:40 +01001804 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001805 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1806 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001807 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1808 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001809 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1810 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001811 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1812 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001813 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1814 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001815 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1816 "Reference Lstm: input and cellStateOut types are mismatched");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01001817
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001818 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1819 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001820 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001821 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001822 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001823 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001824 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001825 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001826 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001827 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001828 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001829 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001830 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001831 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001832 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001833 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001834 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001835 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001836 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001837 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001838 "Reference Lstm: input and OutputGateBias types are mismatched");
1839 if (!descriptor.m_CifgEnabled)
1840 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001841 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001842 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001843 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001844 reasonIfUnsupported,
1845 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001846 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001847 "Reference Lstm: input and InputGateBias types are mismatched");
1848 if (descriptor.m_PeepholeEnabled)
1849 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001850 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001851 reasonIfUnsupported,
1852 "Reference Lstm: input and CellToInputWeights types are mismatched");
1853 }
1854 }
1855 if (descriptor.m_PeepholeEnabled)
1856 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001857 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001858 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001859 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001860 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1861 }
1862 if (descriptor.m_ProjectionEnabled)
1863 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001864 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001865 "Reference Lstm: input and mProjectionWeights types are mismatched");
1866 if (paramsInfo.m_ProjectionBias != nullptr)
1867 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001868 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001869 "Reference Lstm: input and ProjectionBias types are mismatched");
1870 }
1871 }
1872 if (descriptor.m_LayerNormEnabled)
1873 {
1874 if (!descriptor.m_CifgEnabled)
1875 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001876 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001877 reasonIfUnsupported,
1878 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1879 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001880 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001881 reasonIfUnsupported,
1882 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001883 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001884 reasonIfUnsupported,
1885 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001886 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001887 reasonIfUnsupported,
1888 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1889 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001890
1891 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001892}
1893
saoste012df12b32018-11-28 16:57:20 +00001894bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1895 const TensorInfo& input1,
1896 const TensorInfo& output,
1897 Optional<std::string&> reasonIfUnsupported) const
1898{
Sadik Armagan2999a022019-04-09 14:20:12 +01001899 bool supported = true;
1900
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001901 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001902 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001903 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001904 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001905 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001906 DataType::QSymmS16,
1907 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001908 };
1909
1910 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1911 "Reference maximum: input 0 is not a supported type.");
1912
1913 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1914 "Reference maximum: input 1 is not a supported type.");
1915
1916 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1917 "Reference maximum: output is not a supported type.");
1918
1919 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1920 "Reference maximum: input 0 and Input 1 types are mismatched");
1921
1922 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1923 "Reference maximum: input and output types are mismatched");
1924
1925 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1926 "Reference maximum: shapes are not suitable for implicit broadcast.");
1927
1928 return supported;
saoste012df12b32018-11-28 16:57:20 +00001929}
1930
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001931bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1932 const TensorInfo& output,
1933 const MeanDescriptor& descriptor,
1934 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001935{
James Conroy4d1ff582019-06-10 17:06:39 +01001936 bool supported = true;
1937 std::string meanLayerStr = "Mean";
1938 std::string outputTensorStr = "output";
1939
Sadik Armagan303980c2020-04-17 12:45:14 +01001940 std::array<DataType,6> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001941 {
1942 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001943 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001944 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001945 DataType::QAsymmU8,
1946 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01001947 };
1948
1949 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1950 "Reference Mean: input type not supported.");
1951
1952 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1953 "Reference Mean: input and output types are mismatched");
1954
1955 if (descriptor.m_KeepDims)
1956 {
1957 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1958 reasonIfUnsupported,
1959 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1960 output.GetNumDimensions(),
1961 meanLayerStr, outputTensorStr).data());
1962 }
1963 else if (descriptor.m_Axis.empty())
1964 {
1965 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1966 reasonIfUnsupported,
1967 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1968 meanLayerStr, outputTensorStr).data());
1969 }
1970 else
1971 {
Matthew Sloyan171214c2020-09-09 09:07:37 +01001972 auto outputDim = input.GetNumDimensions() - armnn::numeric_cast<unsigned int>(descriptor.m_Axis.size());
James Conroy4d1ff582019-06-10 17:06:39 +01001973
1974 if (outputDim > 0)
1975 {
1976 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1977 reasonIfUnsupported,
1978 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1979 meanLayerStr, outputTensorStr).data());
1980 }
1981 else
1982 {
1983 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1984 reasonIfUnsupported,
1985 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1986 meanLayerStr, outputTensorStr).data());
1987 }
1988 }
1989
1990 return supported;
narpra0132b90462018-09-13 11:07:48 +01001991}
1992
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001993bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1994 const TensorInfo &output,
1995 Optional<std::string &> reasonIfUnsupported) const
1996{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001997 bool supported = true;
1998
Sadik Armagan303980c2020-04-17 12:45:14 +01001999 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002000 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002001 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002002 DataType::Float32,
2003 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002004 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002005 DataType::QAsymmU8,
2006 DataType::QSymmS16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002007 DataType::Boolean
2008 };
2009
2010 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2011 "Reference MemCopy: input type not supported");
2012
2013 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2014 "Reference MemCopy: output type not supported");
2015
2016 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2017 "Reference MemCopy: input and output types are mismatched");
2018
2019 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00002020}
2021
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00002022bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
2023 const TensorInfo& input1,
2024 const TensorInfo& output,
2025 Optional<std::string&> reasonIfUnsupported) const
2026{
Sadik Armagan2999a022019-04-09 14:20:12 +01002027 bool supported = true;
2028
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002029 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01002030 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002031 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002032 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002033 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002034 DataType::QSymmS16,
2035 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002036 };
2037
2038 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2039 "Reference minimum: input 0 is not a supported type.");
2040
2041 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2042 "Reference minimum: input 1 is not a supported type.");
2043
2044 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2045 "Reference minimum: output is not a supported type.");
2046
2047 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2048 "Reference minimum: input 0 and Input 1 types are mismatched");
2049
2050 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2051 "Reference minimum: input and output types are mismatched");
2052
2053 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2054 "Reference minimum: shapes are not suitable for implicit broadcast.");
2055
2056 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00002057}
2058
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002059bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
2060 const TensorInfo& input1,
2061 const TensorInfo& output,
2062 Optional<std::string&> reasonIfUnsupported) const
2063{
Sadik Armagan2999a022019-04-09 14:20:12 +01002064 bool supported = true;
2065
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002066 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01002067 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002068 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002069 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002070 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002071 DataType::QSymmS16,
2072 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002073 };
2074
2075 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2076 "Reference multiplication: input 0 is not a supported type.");
2077
2078 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2079 "Reference multiplication: input 1 is not a supported type.");
2080
2081 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2082 "Reference multiplication: output is not a supported type.");
2083
2084 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2085 "Reference multiplication: input 0 and Input 1 types are mismatched");
2086
2087 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2088 "Reference multiplication: input and output types are mismatched");
2089
2090 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2091 "Reference multiplication: shapes are not suitable for implicit broadcast.");
2092
2093 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002094}
2095
2096bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
2097 const TensorInfo& output,
2098 const NormalizationDescriptor& descriptor,
2099 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01002100{
Jan Eilers8eb25602020-03-09 12:13:48 +00002101 IgnoreUnused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002102
2103 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01002104 std::array<DataType, 6> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002105 {
2106 DataType::Float16,
2107 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002108 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002109 DataType::QAsymmU8,
2110 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002111 };
2112
2113 bool supported = true;
2114
2115 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2116 "Reference normalization: input type not supported.");
2117
2118 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2119 "Reference normalization: output type not supported.");
2120
2121 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2122 "Reference normalization: input and output shapes have different "
2123 "num total elements.");
2124
2125 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002126}
2127
Derek Lamberti901ea112019-12-10 22:07:09 +00002128bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
2129 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002130{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01002131 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002132}
2133
2134bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
2135 const TensorInfo& output,
2136 const PadDescriptor& descriptor,
2137 Optional<std::string&> reasonIfUnsupported) const
2138{
Jan Eilers8eb25602020-03-09 12:13:48 +00002139 IgnoreUnused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002140 bool supported = true;
2141
2142 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002143 std::array<DataType,6> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002144 {
2145 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002146 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002147 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002148 DataType::QAsymmU8,
2149 DataType::QSymmS16
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002150 };
2151
2152 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2153 "Reference pad: input is not a supported type.");
2154
2155 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2156 "Reference pad: output is not a supported type.");
2157
2158 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2159 "Reference pad: input and output types are mismatched.");
2160
2161 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01002162}
2163
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002164bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
2165 const TensorInfo& output,
2166 const PermuteDescriptor& descriptor,
2167 Optional<std::string&> reasonIfUnsupported) const
2168{
Jan Eilers8eb25602020-03-09 12:13:48 +00002169 IgnoreUnused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002170 bool supported = true;
2171
2172 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002173 std::array<DataType, 6> supportedTypes =
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002174 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002175 DataType::BFloat16,
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002176 DataType::Float32,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002177 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002178 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002179 DataType::QAsymmU8,
2180 DataType::QSymmS16
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002181 };
2182
2183 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2184 "Reference permute: input is not a supported type.");
2185
2186 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2187 "Reference permute: output is not a supported type.");
2188
2189 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2190 "Reference permute: input and output types are mismatched.");
2191
2192 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002193}
2194
2195bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
2196 const TensorInfo& output,
2197 const Pooling2dDescriptor& descriptor,
2198 Optional<std::string&> reasonIfUnsupported) const
2199{
Jan Eilers8eb25602020-03-09 12:13:48 +00002200 IgnoreUnused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01002201 bool supported = true;
2202
2203 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002204 std::array<DataType,6> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01002205 {
2206 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002207 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002208 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002209 DataType::QAsymmU8,
2210 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01002211 };
2212
2213 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2214 "Reference poolind2d: input is not a supported type.");
2215
2216 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2217 "Reference poolind2d: output is not a supported type.");
2218
2219 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2220 "Reference poolind2d: input and output types are mismatched.");
2221
2222 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002223}
2224
Tamás Nyíri7b885b32021-10-26 14:47:57 +01002225bool RefLayerSupport::IsPooling3dSupported(const TensorInfo& input,
2226 const TensorInfo& output,
2227 const Pooling3dDescriptor& descriptor,
2228 Optional<std::string&> reasonIfUnsupported) const
2229{
2230 IgnoreUnused(descriptor);
2231 bool supported = true;
2232
2233 // Define supported output and inputs types.
2234 std::array<DataType,6> supportedTypes =
2235 {
Tamás Nyíri7b885b32021-10-26 14:47:57 +01002236 DataType::Float32,
2237 DataType::Float16,
2238 DataType::QAsymmS8,
2239 DataType::QAsymmU8,
2240 DataType::QSymmS16
2241 };
2242
2243 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2244 "Reference poolind3d: input is not a supported type.");
2245
2246 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2247 "Reference poolind3d: output is not a supported type.");
2248
2249 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2250 "Reference poolind3d: input and output types are mismatched.");
2251
2252 return supported;
2253}
2254
2255
James Conroy4f1f8992020-04-29 20:01:10 +01002256bool RefLayerSupport::IsQLstmSupported(const TensorInfo& input,
2257 const TensorInfo& previousOutputIn,
2258 const TensorInfo& previousCellStateIn,
2259 const TensorInfo& outputStateOut,
2260 const TensorInfo& cellStateOut,
2261 const TensorInfo& output,
2262 const QLstmDescriptor& descriptor,
2263 const LstmInputParamsInfo& paramsInfo,
2264 Optional<std::string&> reasonIfUnsupported) const
2265{
2266 IgnoreUnused(input);
2267 IgnoreUnused(previousOutputIn);
2268 IgnoreUnused(previousCellStateIn);
2269 IgnoreUnused(outputStateOut);
2270 IgnoreUnused(cellStateOut);
2271 IgnoreUnused(output);
2272 IgnoreUnused(descriptor);
2273 IgnoreUnused(paramsInfo);
2274
2275 IgnoreUnused(reasonIfUnsupported);
2276
2277 return true;
2278}
2279
Derek Lamberti5f400d62019-03-25 15:41:58 +00002280bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
2281 const TensorInfo& output,
2282 Optional<std::string&> reasonIfUnsupported) const
2283{
2284 bool supported = true;
2285
Finn Williamsfd271062019-12-04 14:27:27 +00002286 // Define supported input types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002287 std::array<DataType,7> supportedInputTypes = {
Keith Davis5e51cd82020-01-29 16:52:59 +00002288 DataType::Float32,
Keith Davis3d8bc972020-02-04 09:31:47 +00002289 DataType::Float16,
Ryan OShea9add1202020-02-07 10:06:33 +00002290 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002291 DataType::QAsymmU8,
2292 DataType::QSymmS8,
2293 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002294 };
2295
2296 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
2297 "Reference quantize: input type not supported.");
2298
2299 // Define supported output types.
Ryan OShea9add1202020-02-07 10:06:33 +00002300 std::array<DataType,4> supportedOutputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00002301 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002302 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00002303 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002304 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002305 };
2306 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2307 "Reference quantize: output type not supported.");
2308
2309 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2310 "Reference quantize: input and output shapes have different num total elements.");
2311
2312 return supported;
2313}
2314
Finn Williams2605b232020-06-10 15:53:46 +01002315bool RefLayerSupport::IsRankSupported(const TensorInfo& input,
2316 const TensorInfo& output,
2317 Optional<std::string&> reasonIfUnsupported) const
2318{
2319 IgnoreUnused(input);
2320 // Define supported output types.
2321 std::array<DataType,1> supportedOutputTypes =
2322 {
2323 DataType::Signed32,
2324 };
2325
2326 return CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2327 "Reference rank: input type not supported.");
2328}
2329
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00002330bool RefLayerSupport::IsReduceSupported(const TensorInfo& input,
2331 const TensorInfo& output,
2332 const ReduceDescriptor& descriptor,
2333 Optional<std::string&> reasonIfUnsupported) const
2334{
2335 IgnoreUnused(descriptor);
2336 bool supported = true;
2337 std::array<DataType,7> supportedTypes =
2338 {
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00002339 DataType::Float32,
2340 DataType::Float16,
2341 DataType::QAsymmS8,
2342 DataType::QAsymmU8,
2343 DataType::QSymmS16,
2344 DataType::Signed32
2345 };
2346
2347 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2348 "Reference Reduce: input type not supported");
2349
2350 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2351 "Reference Reduce: output type not supported");
2352
2353 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2354 "Reference Reduce: input and output types not matching");
2355
2356 return supported;
2357}
2358
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002359bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Kevin Maya023c402019-12-12 17:28:05 +00002360 const TensorInfo& output,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00002361 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002362 Optional<std::string&> reasonIfUnsupported) const
2363{
Jan Eilers8eb25602020-03-09 12:13:48 +00002364 IgnoreUnused(output);
2365 IgnoreUnused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01002366 // Define supported output types.
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002367 std::array<DataType,8> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01002368 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002369 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01002370 DataType::Float32,
2371 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01002372 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002373 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002374 DataType::QAsymmU8,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002375 DataType::QSymmS16,
2376 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01002377 };
Keith Davis0c2eeac2020-02-11 16:51:50 +00002378
Nina Drozd2f2778f2019-05-27 10:37:05 +01002379 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
2380 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002381}
2382
Teresa Charlin970f43b2019-07-01 13:51:07 +01002383bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
2384 const TensorInfo& output,
2385 const ResizeDescriptor& descriptor,
2386 Optional<std::string&> reasonIfUnsupported) const
2387{
Jan Eilers8eb25602020-03-09 12:13:48 +00002388 IgnoreUnused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01002389 bool supported = true;
Teresa Charlince655882023-11-21 15:44:13 +00002390 std::array<DataType,7> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01002391 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002392 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01002393 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002394 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002395 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002396 DataType::QAsymmU8,
Teresa Charlince655882023-11-21 15:44:13 +00002397 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002398 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01002399 };
2400
2401 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2402 "Reference Resize: input type not supported");
2403
2404 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2405 "Reference Resize: output type not supported");
2406
2407 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2408 "Reference Resize: input and output types not matching");
2409
2410 return supported;
2411}
2412
Tracy Narinebb8d7592023-07-13 16:50:54 +01002413bool RefLayerSupport::IsReverseV2Supported(const TensorInfo& input0,
2414 const TensorInfo& input1,
Tianle Cheng988354d2023-06-28 13:20:47 +01002415 const TensorInfo& output,
Tianle Cheng988354d2023-06-28 13:20:47 +01002416 Optional<std::string&> reasonIfUnsupported) const
2417{
Tianle Cheng988354d2023-06-28 13:20:47 +01002418 bool supported = true;
2419 // ReverseV2 is data type agnostic so it can support all the types in the Reference backend
Declan-ARM1bf56cd2023-07-20 17:32:57 +01002420 std::array<DataType,8> supportedTypes =
Tianle Cheng988354d2023-06-28 13:20:47 +01002421 {
2422 DataType::BFloat16,
2423 DataType::Float32,
2424 DataType::Float16,
2425 DataType::QAsymmS8,
2426 DataType::QAsymmU8,
Declan-ARM1bf56cd2023-07-20 17:32:57 +01002427 DataType::QSymmS8,
2428 DataType::QSymmS16,
2429 DataType::Signed32
Tianle Cheng988354d2023-06-28 13:20:47 +01002430 };
2431
Tracy Narinebb8d7592023-07-13 16:50:54 +01002432 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2433 "Reference ReverseV2: input0 type not supported");
Tianle Cheng988354d2023-06-28 13:20:47 +01002434
2435 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2436 "Reference ReverseV2: output type not supported");
2437
Tracy Narinebb8d7592023-07-13 16:50:54 +01002438 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2439 "Reference ReverseV2: input0 and output types not matching");
2440
2441 std::array<DataType,6> input2SupportedTypes =
2442 {
2443 DataType::Signed32
2444 };
2445
2446 supported &= CheckSupportRule(TypeAnyOf(input1, input2SupportedTypes), reasonIfUnsupported,
2447 "Reference ReverseV2: input1 type not supported");
Tianle Cheng988354d2023-06-28 13:20:47 +01002448
2449 return supported;
2450}
2451
Tianle Cheng28288182024-02-23 17:56:54 +00002452bool RefLayerSupport::IsScatterNdSupported(const TensorInfo& input,
2453 const TensorInfo& indices,
2454 const TensorInfo& updates,
2455 const TensorInfo& output,
2456 const ScatterNdDescriptor& descriptor,
2457 Optional<std::string&> reasonIfUnsupported) const
2458{
2459 IgnoreUnused(descriptor);
2460
2461 bool supported = true;
2462
2463 std::array<DataType, 7> supportedTypes
2464 {
2465 DataType::Float32,
2466 DataType::Float16,
2467 DataType::QAsymmS8,
2468 DataType::QAsymmU8,
2469 DataType::QSymmS8,
2470 DataType::QSymmS16,
2471 DataType::Signed32
2472 };
2473
2474 std::array<DataType, 1> indicesSupportedTypes =
2475 {
2476 DataType::Signed32
2477 };
2478
2479 supported &= CheckSupportRule(TypeAnyOf(indices, indicesSupportedTypes), reasonIfUnsupported,
2480 "ScatterNd: indices type not supported.");
2481
2482 supported &= CheckSupportRule(TypeAnyOf(updates, supportedTypes), reasonIfUnsupported,
2483 "ScatterNd: updates type not supported.");
2484
2485 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2486 "ScatterNd: output type not supported");
2487
2488 supported &= CheckSupportRule(TypesAreEqual(updates, output), reasonIfUnsupported,
2489 "ScatterNd: input and updates types are mismatched");
2490
2491 if (descriptor.m_InputEnabled)
2492 {
2493 // If the input slot is enabled, we have the input tensor in this slot
2494 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2495 "ScatterNd: input type not supported.");
2496
2497 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2498 "ScatterNd: input and output types are mismatched");
2499 }
2500 else
2501 {
2502 // If the input slot is not enabled, we have the shape tensor in this slot
2503 supported &= CheckSupportRule(TypeAnyOf(input, indicesSupportedTypes), reasonIfUnsupported,
2504 "ScatterNd: shape type not supported.");
2505 }
2506
2507 return supported;
2508}
2509
Keith Davis3ae3f972021-05-21 16:33:48 +01002510bool RefLayerSupport::IsShapeSupported(const TensorInfo& input,
2511 const TensorInfo& output,
2512 Optional<std::string&> reasonIfUnsupported) const
2513{
2514 IgnoreUnused(input);
2515 bool supported = true;
2516
2517 std::array<DataType, 1> supportedTypes =
2518 {
2519 DataType::Signed32
2520 };
2521
2522 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2523 "Reference Shape: output type not supported");
2524
2525 return supported;
2526}
2527
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002528bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
2529 const TensorInfo& output,
2530 const SliceDescriptor& descriptor,
2531 Optional<std::string&> reasonIfUnsupported) const
2532{
Jan Eilers8eb25602020-03-09 12:13:48 +00002533 IgnoreUnused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002534 bool supported = true;
2535
Sadik Armagan303980c2020-04-17 12:45:14 +01002536 std::array<DataType, 5> supportedTypes =
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002537 {
2538 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002539 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002540 DataType::QAsymmU8,
Ryan OShea980446b2023-06-08 16:23:28 +01002541 DataType::QSymmS16,
2542 DataType::Signed32
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002543 };
2544
2545 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2546 "Reference Slice: input type not supported");
2547
2548 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2549 "Reference Slice: output type not supported");
2550
2551 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2552 "Reference Slice: input and output types are mismatched");
2553
2554 return supported;
2555}
2556
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002557bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
2558 const TensorInfo& output,
2559 const SoftmaxDescriptor& descriptor,
2560 Optional<std::string&> reasonIfUnsupported) const
2561{
Jan Eilers8eb25602020-03-09 12:13:48 +00002562 IgnoreUnused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01002563 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002564 std::array<DataType,7> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01002565 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002566 DataType::Float32,
2567 DataType::Float16,
2568 DataType::QSymmS8,
2569 DataType::QAsymmS8,
2570 DataType::QAsymmU8,
2571 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +01002572 };
2573
2574 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002575 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002576
2577 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002578 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002579
2580 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002581 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002582
2583 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002584}
2585
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002586bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
2587 const TensorInfo& output,
2588 const SpaceToBatchNdDescriptor& descriptor,
2589 Optional<std::string&> reasonIfUnsupported) const
2590{
Jan Eilers8eb25602020-03-09 12:13:48 +00002591 IgnoreUnused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01002592 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002593 std::array<DataType,6> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01002594 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002595 DataType::Float32,
2596 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002597 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002598 DataType::QAsymmU8,
2599 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01002600 };
2601
2602 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2603 "Reference SpaceToBatchNd: input type not supported");
2604
2605 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2606 "Reference SpaceToBatchNd: output type not supported");
2607
2608 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2609 "Reference SpaceToBatchNd: input and output types are mismatched");
2610
2611 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002612}
2613
Keith Davisa57eccb2019-06-14 17:33:22 +01002614bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01002615 const TensorInfo& output,
2616 const SpaceToDepthDescriptor& descriptor,
2617 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01002618{
2619
Jan Eilers8eb25602020-03-09 12:13:48 +00002620 IgnoreUnused(descriptor);
Keith Davisa57eccb2019-06-14 17:33:22 +01002621 bool supported = true;
2622
Sadik Armagan303980c2020-04-17 12:45:14 +01002623 std::array<DataType,6> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01002624 {
2625 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002626 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002627 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002628 DataType::QAsymmU8,
2629 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01002630 };
2631
2632 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2633 "Reference SpaceToDepth: input type not supported");
2634
2635 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2636 "Reference SpaceToDepth: output type not supported");
2637
2638 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2639 "Reference SpaceToDepth: input and output types are mismatched");
2640
2641 return supported;
2642}
2643
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002644bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002645 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
2646 const ViewsDescriptor& descriptor,
2647 Optional<std::string&> reasonIfUnsupported) const
2648{
Jan Eilers8eb25602020-03-09 12:13:48 +00002649 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002650 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002651 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002652 {
2653 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002654 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002655 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002656 DataType::QAsymmU8,
2657 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002658 };
2659
2660 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2661 "Reference splitter: output type not supported");
Derek Lambertieac4adb2020-08-25 13:05:59 +01002662 for (const TensorInfo& output : outputs)
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002663 {
2664 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2665 "Reference splitter: input type not supported");
2666
2667 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2668 "Reference splitter: input and output types mismatched.");
2669 }
2670
2671 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002672}
2673
Matthew Jackson81e601c2019-07-11 12:07:09 +01002674bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
2675 const TensorInfo& output,
2676 const StackDescriptor& descriptor,
2677 Optional<std::string&> reasonIfUnsupported) const
2678{
Jan Eilers8eb25602020-03-09 12:13:48 +00002679 IgnoreUnused(descriptor);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002680
2681 bool supported = true;
Sadik Armagan529195f2022-01-14 12:56:35 +00002682 std::array<DataType,7> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01002683 {
2684 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01002685 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002686 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002687 DataType::QAsymmU8,
Sadik Armagan529195f2022-01-14 12:56:35 +00002688 DataType::QSymmS16,
2689 DataType::Signed32
Matthew Jackson81e601c2019-07-11 12:07:09 +01002690 };
2691
2692 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2693 "Reference stack: output type not supported");
2694 for (const TensorInfo* input : inputs)
2695 {
Matthew Jackson81e601c2019-07-11 12:07:09 +01002696 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
2697 "Reference stack: input type not supported");
2698
2699 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
2700 "Reference stack: input and output types mismatched.");
2701 }
2702
2703 return supported;
2704}
2705
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002706bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
2707 const TensorInfo& output,
2708 const StridedSliceDescriptor& descriptor,
2709 Optional<std::string&> reasonIfUnsupported) const
2710{
Jan Eilers8eb25602020-03-09 12:13:48 +00002711 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002712 bool supported = true;
2713
Sadik Armagan303980c2020-04-17 12:45:14 +01002714 std::array<DataType,5> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002715 {
2716 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002717 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002718 DataType::QAsymmU8,
2719 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002720 };
2721
2722 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2723 "Reference StridedSlice: input type not supported");
2724
2725 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2726 "Reference StridedSlice: output type not supported");
2727
2728 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2729 "Reference StridedSlice: input and output types are mismatched");
2730
2731 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002732}
2733
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002734bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
2735 const TensorInfo& input1,
2736 const TensorInfo& output,
2737 Optional<std::string&> reasonIfUnsupported) const
2738{
Sadik Armagan2999a022019-04-09 14:20:12 +01002739 bool supported = true;
2740
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002741 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01002742 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002743 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002744 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002745 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002746 DataType::QSymmS16,
2747 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002748 };
2749
2750 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2751 "Reference subtraction: input 0 is not a supported type.");
2752
2753 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2754 "Reference subtraction: input 1 is not a supported type.");
2755
2756 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2757 "Reference subtraction: output is not a supported type.");
2758
2759 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2760 "Reference subtraction: input 0 and Input 1 types are mismatched");
2761
2762 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2763 "Reference subtraction: input and output types are mismatched");
2764
2765 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2766 "Reference subtraction: shapes are not suitable for implicit broadcast.");
2767
2768 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002769}
2770
Matteo Martincighab9e5252019-06-13 17:27:46 +01002771bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
2772 const TensorInfo& alpha,
2773 const TensorInfo& output,
2774 Optional<std::string&> reasonIfUnsupported) const
2775{
2776 bool supported = true;
2777
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002778 std::array<DataType, 6> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01002779 {
2780 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002781 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002782 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002783 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002784 DataType::QSymmS16
Matteo Martincighab9e5252019-06-13 17:27:46 +01002785 };
2786
2787 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2788 "PReLU: input is not a supported type.");
2789
2790 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
2791 "PReLU: alpha is not a supported type.");
2792
2793 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2794 "PReLU: output is not a supported type.");
2795
2796 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
2797 "PReLU: input, alpha and output types are mismatched");
2798
2799 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
2800 "PReLU: shapes are not suitable for implicit broadcast");
2801
2802 return supported;
2803}
2804
Teresa Charlin79a06a52023-07-13 17:16:45 +01002805bool RefLayerSupport::IsTileSupported(const TensorInfo& input,
2806 const TensorInfo& output,
2807 const TileDescriptor& descriptor,
2808 Optional<std::string&> reasonIfUnsupported) const
2809{
2810 IgnoreUnused(descriptor);
2811
2812 bool supported = true;
2813
2814 std::array<DataType, 7> supportedTypes
2815 {
2816 DataType::Float32,
2817 DataType::Float16,
2818 DataType::QAsymmS8,
2819 DataType::QAsymmU8,
2820 DataType::QSymmS8,
2821 DataType::QSymmS16,
2822 DataType::Signed32
2823 };
2824
2825 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2826 "Tile: input type not supported.");
2827
2828 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2829 "Tile: output type not supported");
2830
2831 return supported;
2832}
2833
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002834bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
2835 const TensorInfo& output,
2836 const TransposeConvolution2dDescriptor& descriptor,
2837 const TensorInfo& weights,
2838 const Optional<TensorInfo>& biases,
2839 Optional<std::string&> reasonIfUnsupported) const
2840{
Jan Eilers8eb25602020-03-09 12:13:48 +00002841 IgnoreUnused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002842 bool supported = true;
2843
Sadik Armagan303980c2020-04-17 12:45:14 +01002844 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002845 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002846 DataType::Float32,
2847 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002848 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002849 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002850 DataType::QSymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002851 DataType::QSymmS16
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002852 };
2853
2854 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2855 "Reference TransposeConvolution2d: input is not a supported type.");
2856
2857 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2858 "Reference TransposeConvolution2d: output is not a supported type.");
2859
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002860 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2861 "Reference TransposeConvolution2d: input and output types mismatched.");
2862
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002863
2864 const DataType inputType = input.GetDataType();
Sadik Armagan303980c2020-04-17 12:45:14 +01002865 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002866 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01002867 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002868 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002869 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002870 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01002871 DataType::QSymmS8
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002872 };
2873
2874 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
2875 "Reference TransposeConvolution2d: weights type not supported for "
2876 "quantized input.");
2877 }
2878 else
2879 {
2880 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
2881 "Reference TransposeConvolution2d: weights is not a supported type.");
2882
2883 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
2884 "Reference TransposeConvolution2d: input and weights types mismatched.");
2885 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002886
2887 if (biases.has_value())
2888 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002889 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01002890 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002891 DataType::Float32,
2892 DataType::Float16,
2893 DataType::Signed32
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002894 };
2895 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
2896 "Reference TransposeConvolution2d: biases is not a supported type.");
2897 }
2898
2899 return supported;
2900}
2901
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002902bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
2903 const TensorInfo& output,
2904 const TransposeDescriptor& descriptor,
2905 Optional<std::string&> reasonIfUnsupported) const
2906{
Jan Eilers8eb25602020-03-09 12:13:48 +00002907 IgnoreUnused(descriptor);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002908 bool supported = true;
2909
2910 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002911 std::array<DataType, 6> supportedTypes =
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002912 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002913 DataType::BFloat16,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002914 DataType::Float32,
2915 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002916 DataType::QAsymmS8,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002917 DataType::QAsymmU8,
2918 DataType::QSymmS16
2919 };
2920
2921 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2922 "Reference transpose: input is not a supported type.");
2923
2924 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2925 "Reference transpose: output is not a supported type.");
2926
2927 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2928 "Reference transpose: input and output types are mismatched.");
2929
2930 return supported;
2931}
2932
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002933bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
2934 const TensorInfo& input,
2935 const TensorInfo& outputStateIn,
2936 const TensorInfo& cellStateIn,
Mike Kelly12994962022-04-21 11:57:09 +01002937 const TensorInfo& outputStateOut,
2938 const TensorInfo& cellStateOut,
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002939 const TensorInfo& output,
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002940 const UnidirectionalSequenceLstmDescriptor& descriptor,
2941 const LstmInputParamsInfo& paramsInfo,
2942 Optional<std::string&> reasonIfUnsupported) const
2943{
2944 IgnoreUnused(descriptor);
2945 IgnoreUnused(paramsInfo);
2946 IgnoreUnused(outputStateIn);
2947 IgnoreUnused(cellStateIn);
Mike Kelly12994962022-04-21 11:57:09 +01002948 IgnoreUnused(outputStateOut);
2949 IgnoreUnused(cellStateOut);
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002950 bool supported = true;
2951
Mike Kelly12994962022-04-21 11:57:09 +01002952 std::array<DataType, 2> supportedTypes =
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002953 {
Mike Kelly12994962022-04-21 11:57:09 +01002954 DataType::Float32,
2955 DataType::QAsymmS8
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002956 };
2957
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002958 std::array<DataType, 2> supportedWeightTypes =
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002959 {
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002960 DataType::Float32,
2961 DataType::QAsymmS8
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002962 };
2963
Mike Kelly12994962022-04-21 11:57:09 +01002964 std::array<DataType, 3> supportedBiasTypes =
2965 {
2966 DataType::Float32,
2967 DataType::QAsymmS8,
2968 DataType::Signed32
2969 };
2970
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002971 // check inputs and outputs
2972 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2973 "Reference UnidirectionalSequenceLstm: input is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002974 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2975 "Reference UnidirectionalSequenceLstm: output is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002976
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002977 // check layer parameters
2978 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes),
2979 reasonIfUnsupported,
2980 "Reference UnidirectionalSequenceLstm: InputToForgetWeights "
2981 "is not a supported type.");
2982 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToCellWeights(), supportedWeightTypes),
2983 reasonIfUnsupported,
2984 "Reference UnidirectionalSequenceLstm: InputToCellWeights is not a supported type.");
2985 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToOutputWeights(), supportedWeightTypes),
2986 reasonIfUnsupported,
2987 "Reference UnidirectionalSequenceLstm: InputToOutputWeights "
2988 "is not a supported type.");
2989 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToForgetWeights(), supportedWeightTypes),
2990 reasonIfUnsupported,
2991 "Reference UnidirectionalSequenceLstm: RecurrentToForgetWeights "
2992 "is not a supported type.");
2993 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToCellWeights(), supportedWeightTypes),
2994 reasonIfUnsupported,
2995 "Reference UnidirectionalSequenceLstm: RecurrentToCellWeights "
2996 "is not a supported type.");
2997 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToOutputWeights(), supportedWeightTypes),
2998 reasonIfUnsupported,
2999 "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights "
3000 "is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01003001
3002 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetGateBias(), supportedBiasTypes), reasonIfUnsupported,
3003 "Reference UnidirectionalSequenceLstm: ForgetGateBias is not a supported type.");
3004 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellBias(), supportedBiasTypes), reasonIfUnsupported,
3005 "Reference UnidirectionalSequenceLstm: CellBias is not a supported type.");
3006 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputGateBias(), supportedBiasTypes), reasonIfUnsupported,
3007 "Reference UnidirectionalSequenceLstm: OutputGateBias is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01003008 if (!descriptor.m_CifgEnabled)
3009 {
3010 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes),
3011 reasonIfUnsupported,
3012 "Reference UnidirectionalSequenceLstm: InputToInputWeights "
3013 "is not a supported type.");
3014 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToInputWeights(), supportedWeightTypes),
3015 reasonIfUnsupported,
3016 "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights "
3017 "is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01003018 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputGateBias(), supportedBiasTypes), reasonIfUnsupported,
3019 "Reference UnidirectionalSequenceLstm: InputGateBias is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01003020 if (descriptor.m_PeepholeEnabled)
3021 {
3022 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes),
3023 reasonIfUnsupported,
3024 "Reference UnidirectionalSequenceLstm: CellToInputWeights "
3025 "is not a supported type.");
3026 }
3027 }
3028 if (descriptor.m_PeepholeEnabled)
3029 {
3030 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToForgetWeights(), supportedWeightTypes),
3031 reasonIfUnsupported,
3032 "Reference UnidirectionalSequenceLstm: CellToForgetWeights "
3033 "is not a supported type.");
3034 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToOutputWeights(), supportedWeightTypes),
3035 reasonIfUnsupported,
3036 "Reference UnidirectionalSequenceLstm: CellToOutputWeights "
3037 "is not a supported type.");
3038 }
3039 if (descriptor.m_ProjectionEnabled)
3040 {
3041 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetProjectionWeights(), supportedWeightTypes),
3042 reasonIfUnsupported,
3043 "Reference UnidirectionalSequenceLstm: ProjectionWeights "
3044 "is not a supported type.");
3045 if (paramsInfo.m_ProjectionBias != nullptr)
3046 {
3047 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
3048 "Reference UnidirectionalSequenceLstm: input and ProjectionBias types "
3049 "are mismatched");
3050 }
3051 }
3052 if (descriptor.m_LayerNormEnabled)
3053 {
3054 if (!descriptor.m_CifgEnabled)
3055 {
3056 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputLayerNormWeights(), supportedWeightTypes),
3057 reasonIfUnsupported,
3058 "Reference UnidirectionalSequenceLstm: InputLayerNormWeights "
3059 "is not a supported type.");
3060 }
3061 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetLayerNormWeights(), supportedWeightTypes),
3062 reasonIfUnsupported,
3063 "Reference UnidirectionalSequenceLstm: ForgetLayerNormWeights "
3064 "is not a supported type.");
3065 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellLayerNormWeights(), supportedWeightTypes),
3066 reasonIfUnsupported,
3067 "Reference UnidirectionalSequenceLstm: CellLayerNormWeights "
3068 "is not a supported type.");
3069 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputLayerNormWeights(), supportedWeightTypes),
3070 reasonIfUnsupported,
3071 "Reference UnidirectionalSequenceLstm: OutputLayerNormWeights "
3072 "is not a supported type.");
3073 }
3074
3075 return supported;
3076}
3077
arovir011c7c81b2018-10-08 11:34:28 +01003078} // namespace armnn