blob: 3e04a19df4b14cbaf16817262b76802c78eef505 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
Colm Donelanb4ef1632024-02-01 15:00:43 +00002// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01007
Keith Davis0c2eeac2020-02-11 16:51:50 +00008#include <armnn/TypesUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000010#include <armnn/utility/IgnoreUnused.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010011#include <armnn/utility/NumericCast.hpp>
Cathal Corbett34b429c2021-12-24 12:24:40 +000012#include <armnn/utility/PolymorphicDowncast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
Matteo Martincighe011d202019-11-28 11:35:47 +000014#include <LayerSupportCommon.hpp>
Derek Lambertif674aa02019-08-01 15:56:25 +010015#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000016
Derek Lamberti50db4e82019-03-13 14:16:15 +000017#include <array>
Colm Donelan02300aa2024-04-04 11:20:29 +010018#include <vector>
Derek Lamberti50db4e82019-03-13 14:16:15 +000019
telsoa014fcda012018-03-09 14:13:49 +000020namespace armnn
21{
22
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010023namespace
24{
25
26template<typename Float32Func, typename Uint8Func, typename ... Params>
27bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
28 DataType dataType,
29 Float32Func floatFuncPtr,
30 Uint8Func uint8FuncPtr,
31 Params&&... params)
32{
33 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
34 dataType,
35 &FalseFunc<Params...>,
36 floatFuncPtr,
37 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000038 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000039 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010040 std::forward<Params>(params)...);
41}
42
43} // anonymous namespace
44
James Conroy4d1ff582019-06-10 17:06:39 +010045namespace
46{
47
48std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
49 unsigned int actual,
50 std::string& layerStr,
51 std::string& tensorName)
52{
53 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
54 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
55
56 return errorMsg;
57}
58
59} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000060
Cathal Corbett34b429c2021-12-24 12:24:40 +000061bool RefLayerSupport::IsLayerSupported(const LayerType& type,
62 const std::vector<TensorInfo>& infos,
63 const BaseDescriptor& descriptor,
64 const Optional<LstmInputParamsInfo>& lstmParamsInfo,
65 const Optional<QuantizedLstmInputParamsInfo>& quantizedLstmInputParamsInfo,
66 Optional<std::string&> reasonIfUnsupported) const
67{
68 switch (type)
69 {
70 case LayerType::Activation:
71 return IsActivationSupported(infos[0],
72 infos[1],
73 *(PolymorphicDowncast<const ActivationDescriptor*>(&descriptor)),
74 reasonIfUnsupported);
75 case LayerType::Addition:
76 return IsAdditionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
77 case LayerType::ArgMinMax:
78 return IsArgMinMaxSupported(infos[0],
79 infos[1],
80 *(PolymorphicDowncast<const ArgMinMaxDescriptor*>(&descriptor)),
81 reasonIfUnsupported);
Samuel Yap6b478092022-07-06 15:36:03 +010082 case LayerType::BatchMatMul:
83 return IsBatchMatMulSupported(infos[0],
84 infos[1],
85 infos[2],
86 *(PolymorphicDowncast<const BatchMatMulDescriptor*>(&descriptor)),
87 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +000088 case LayerType::BatchNormalization:
89 return IsBatchNormalizationSupported(infos[0],
90 infos[1],
91 infos[2],
92 infos[3],
93 infos[4],
94 infos[5],
95 *(PolymorphicDowncast<const BatchNormalizationDescriptor*>
96 (&descriptor)),
97 reasonIfUnsupported);
98 case LayerType::BatchToSpaceNd:
99 return IsBatchToSpaceNdSupported(infos[0],
100 infos[1],
101 *(PolymorphicDowncast<const BatchToSpaceNdDescriptor*>(&descriptor)),
102 reasonIfUnsupported);
Idriss Chaouch98e383e2023-08-28 14:28:31 +0100103 case LayerType::BroadcastTo:
104 return IsBroadcastToSupported(infos[0],
105 infos[1],
106 *(PolymorphicDowncast<const BroadcastToDescriptor*>(&descriptor)),
107 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000108 case LayerType::Comparison:
109 return IsComparisonSupported(infos[0],
110 infos[1],
111 infos[2],
112 *(PolymorphicDowncast<const ComparisonDescriptor*>(&descriptor)),
113 reasonIfUnsupported);
114 case LayerType::Concat:
115 {
116 std::vector<const TensorInfo*> inputInfos;
117 for (uint32_t i = 0; i < (infos.size() - 1); i++)
118 {
119 inputInfos.push_back(&infos[i]);
120 }
121 return IsConcatSupported(inputInfos,
122 infos[infos.size() - 1],
123 *(PolymorphicDowncast<const OriginsDescriptor*>(&descriptor)),
124 reasonIfUnsupported);
125 }
126 case LayerType::Constant:
127 return IsConstantSupported(infos[0], reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000128 case LayerType::ConvertFp16ToFp32:
129 return IsConvertFp16ToFp32Supported(infos[0], infos[1], reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000130 case LayerType::ConvertFp32ToFp16:
131 return IsConvertFp32ToFp16Supported(infos[0], infos[1], reasonIfUnsupported);
132 case LayerType::Convolution2d:
133 {
134 if (infos.size() != 4)
135 {
136 throw InvalidArgumentException("Invalid number of Convolution2d TensorInfos. "
137 "TensorInfos should be of format: {input, output, weights, biases}.");
138 }
139
140 auto desc = *(PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor));
141 if (infos[3] == TensorInfo())
142 {
143 return IsConvolution2dSupported(infos[0],
144 infos[1],
145 desc,
146 infos[2],
147 EmptyOptional(),
148 reasonIfUnsupported);
149 }
150 else
151 {
152 return IsConvolution2dSupported(infos[0],
153 infos[1],
154 desc,
155 infos[2],
156 infos[3],
157 reasonIfUnsupported);
158 }
159 }
160 case LayerType::DepthToSpace:
161 return IsDepthToSpaceSupported(infos[0],
162 infos[1],
163 *(PolymorphicDowncast<const DepthToSpaceDescriptor*>(&descriptor)),
164 reasonIfUnsupported);
165 case LayerType::DepthwiseConvolution2d:
166 {
167 if (infos.size() != 4)
168 {
169 throw InvalidArgumentException("Invalid number of DepthwiseConvolution2d TensorInfos. "
170 "TensorInfos should be of format: {input, output, weights, biases}.");
171 }
172
173 auto desc = *(PolymorphicDowncast<const DepthwiseConvolution2dDescriptor*>(&descriptor));
174 if (infos[3] == TensorInfo())
175 {
176 return IsDepthwiseConvolutionSupported(infos[0],
177 infos[1],
178 desc,
179 infos[2],
180 EmptyOptional(),
181 reasonIfUnsupported);
182 }
183 else
184 {
185 return IsDepthwiseConvolutionSupported(infos[0],
186 infos[1],
187 desc,
188 infos[2],
189 infos[3],
190 reasonIfUnsupported);
191 }
192 }
193 case LayerType::Dequantize:
194 return IsDequantizeSupported(infos[0], infos[1], reasonIfUnsupported);
195 case LayerType::Division:
196 return IsDivisionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
Mike Kelly3ec30772023-03-08 13:47:17 +0000197 case LayerType::ElementwiseBinary:
198 {
199 std::array<DataType, 7> supportedTypes =
200 {
201 DataType::Float32,
202 DataType::Float16,
203 DataType::QAsymmS8,
204 DataType::QAsymmU8,
205 DataType::QSymmS16,
206 DataType::Signed32
207 };
208
209 bool supported = true;
210 supported &= CheckSupportRule(TypeAnyOf(infos[0], supportedTypes), reasonIfUnsupported,
211 "Reference elementwise unary: input type not supported");
212
213 supported &= CheckSupportRule(TypeAnyOf(infos[1], supportedTypes), reasonIfUnsupported,
214 "Reference elementwise unary: input type not supported");
215
216 supported &= CheckSupportRule(TypeAnyOf(infos[2], supportedTypes), reasonIfUnsupported,
217 "Reference elementwise unary: output type not supported");
218
219 supported &= CheckSupportRule(TypesAreEqual(infos[0], infos[1]), reasonIfUnsupported,
220 "Reference elementwise unary: input types not matching");
221
222 supported &= CheckSupportRule(TypesAreEqual(infos[0], infos[2]), reasonIfUnsupported,
223 "Reference elementwise unary: input and output types not matching");
224
225 return supported;
226 }
Cathal Corbett34b429c2021-12-24 12:24:40 +0000227 case LayerType::ElementwiseUnary:
228 return IsElementwiseUnarySupported(infos[0],
229 infos[1],
230 *(PolymorphicDowncast<const ElementwiseUnaryDescriptor*>(&descriptor)),
231 reasonIfUnsupported);
232 case LayerType::Fill:
233 return IsFillSupported(infos[0],
234 infos[1],
235 *(PolymorphicDowncast<const FillDescriptor*>(&descriptor)),
236 reasonIfUnsupported);
237 case LayerType::Floor:
238 return IsFloorSupported(infos[0], infos[1], reasonIfUnsupported);
239 case LayerType::FullyConnected:
240 return IsFullyConnectedSupported(infos[0],
241 infos[1],
242 infos[2],
243 infos[3],
244 *(PolymorphicDowncast<const FullyConnectedDescriptor*>(&descriptor)),
245 reasonIfUnsupported);
246 case LayerType::Gather:
247 return IsGatherSupported(infos[0],
248 infos[1],
249 infos[2],
250 *(PolymorphicDowncast<const GatherDescriptor*>(&descriptor)),
251 reasonIfUnsupported);
Teresa Charlinb2d3ec52022-04-12 22:07:09 +0100252 case LayerType::GatherNd:
253 return IsGatherNdSupported(infos[0],
254 infos[1],
255 infos[2],
256 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000257 case LayerType::Input:
258 return IsInputSupported(infos[0], reasonIfUnsupported);
259 case LayerType::InstanceNormalization:
260 return IsInstanceNormalizationSupported(infos[0],
261 infos[1],
262 *(PolymorphicDowncast<const InstanceNormalizationDescriptor*>
263 (&descriptor)),
264 reasonIfUnsupported);
265 case LayerType::L2Normalization:
266 return IsL2NormalizationSupported(infos[0],
267 infos[1],
268 *(PolymorphicDowncast<const L2NormalizationDescriptor*>(&descriptor)),
269 reasonIfUnsupported);
270 case LayerType::LogicalBinary:
271 return IsLogicalBinarySupported(infos[0],
272 infos[1],
273 infos[2],
274 *(PolymorphicDowncast<const LogicalBinaryDescriptor*>(&descriptor)),
275 reasonIfUnsupported);
276 case LayerType::LogSoftmax:
277 return IsLogSoftmaxSupported(infos[0],
278 infos[1],
279 *(PolymorphicDowncast<const LogSoftmaxDescriptor*>(&descriptor)),
280 reasonIfUnsupported);
281 case LayerType::Lstm:
282 return IsLstmSupported(infos[0],
283 infos[1],
284 infos[2],
285 infos[3],
286 infos[4],
287 infos[5],
288 infos[6],
289 *(PolymorphicDowncast<const LstmDescriptor*>(&descriptor)),
290 lstmParamsInfo.value(),
291 reasonIfUnsupported);
292 case LayerType::QLstm:
293 return IsQLstmSupported(infos[0],
294 infos[1],
295 infos[2],
296 infos[3],
297 infos[4],
298 infos[5],
299 *(PolymorphicDowncast<const QLstmDescriptor*>(&descriptor)),
300 lstmParamsInfo.value(),
301 reasonIfUnsupported);
302 case LayerType::Maximum:
303 return IsMaximumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
304 case LayerType::Mean:
305 return IsMeanSupported(infos[0],
306 infos[1],
307 *(PolymorphicDowncast<const MeanDescriptor*>(&descriptor)),
308 reasonIfUnsupported);
309 case LayerType::Minimum:
310 return IsMinimumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
311 case LayerType::Multiplication:
312 return IsMultiplicationSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
313 case LayerType::Normalization:
314 return IsNormalizationSupported(infos[0],
315 infos[1],
316 *(PolymorphicDowncast<const NormalizationDescriptor*>(&descriptor)),
317 reasonIfUnsupported);
318 case LayerType::Output:
319 return IsOutputSupported(infos[0], reasonIfUnsupported);
320 case LayerType::Pad:
321 return IsPadSupported(infos[0],
322 infos[1],
323 *(PolymorphicDowncast<const PadDescriptor*>(&descriptor)),
324 reasonIfUnsupported);
325 case LayerType::Permute:
326 return IsPermuteSupported(infos[0],
327 infos[1],
328 *(PolymorphicDowncast<const PermuteDescriptor*>(&descriptor)),
329 reasonIfUnsupported);
330 case LayerType::Pooling2d:
331 return IsPooling2dSupported(infos[0],
332 infos[1],
333 *(PolymorphicDowncast<const Pooling2dDescriptor*>(&descriptor)),
334 reasonIfUnsupported);
335 case LayerType::Prelu:
336 return IsPreluSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
337 case LayerType::Quantize:
338 return IsQuantizeSupported(infos[0], infos[1], reasonIfUnsupported);
339 case LayerType::Reshape:
340 return IsReshapeSupported(infos[0],
341 infos[1],
342 *(PolymorphicDowncast<const ReshapeDescriptor*>(&descriptor)),
343 reasonIfUnsupported);
344 case LayerType::Resize:
345 return IsResizeSupported(infos[0],
346 infos[1],
347 *(PolymorphicDowncast<const ResizeDescriptor*>(&descriptor)),
348 reasonIfUnsupported);
Tianle Cheng988354d2023-06-28 13:20:47 +0100349 case LayerType::ReverseV2:
350 return IsReverseV2Supported(infos[0],
351 infos[1],
Tracy Narinebb8d7592023-07-13 16:50:54 +0100352 infos[2],
Tianle Cheng988354d2023-06-28 13:20:47 +0100353 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000354 case LayerType::Reduce:
355 return IsReduceSupported(infos[0],
356 infos[1],
357 *(PolymorphicDowncast<const ReduceDescriptor*>(&descriptor)),
358 reasonIfUnsupported);
Tianle Cheng28288182024-02-23 17:56:54 +0000359 case LayerType::ScatterNd:
360 return IsScatterNdSupported(infos[0],
361 infos[1],
362 infos[2],
363 infos[3],
364 *(PolymorphicDowncast<const ScatterNdDescriptor*>(&descriptor)),
365 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000366 case LayerType::Slice:
367 return IsSliceSupported(infos[0],
368 infos[1],
369 *(PolymorphicDowncast<const SliceDescriptor*>(&descriptor)),
370 reasonIfUnsupported);
371 case LayerType::Softmax:
372 return IsSoftmaxSupported(infos[0],
373 infos[1],
374 *(PolymorphicDowncast<const SoftmaxDescriptor*>(&descriptor)),
375 reasonIfUnsupported);
376 case LayerType::SpaceToBatchNd:
377 return IsSpaceToBatchNdSupported(infos[0],
378 infos[1],
379 *(PolymorphicDowncast<const SpaceToBatchNdDescriptor*>(&descriptor)),
380 reasonIfUnsupported);
381 case LayerType::SpaceToDepth:
382 return IsSpaceToDepthSupported(infos[0],
383 infos[1],
384 *(PolymorphicDowncast<const SpaceToDepthDescriptor*>(&descriptor)),
385 reasonIfUnsupported);
386 case LayerType::Splitter:
387 {
388 std::vector<TensorInfo> outputInfos;
389 for (uint32_t i = 1; i < infos.size(); i++)
390 {
391 outputInfos.push_back(infos[i]);
392 }
393 return IsSplitterSupported(infos[0],
394 {outputInfos.begin(), outputInfos.end()},
395 *(PolymorphicDowncast<const ViewsDescriptor*>(&descriptor)),
396 reasonIfUnsupported);
397 }
398 case LayerType::Stack:
399 {
400 std::vector<const TensorInfo*> inputInfos;
401 for (uint32_t i = 0; i < infos.size() - 1; i++)
402 {
403 inputInfos.push_back(&infos[i]);
404 }
405 return IsStackSupported(inputInfos,
406 infos[infos.size() - 1],
407 *(PolymorphicDowncast<const StackDescriptor*>(&descriptor)),
408 reasonIfUnsupported);
409 }
410 case LayerType::StridedSlice:
411 return IsStridedSliceSupported(infos[0],
412 infos[1],
413 *(PolymorphicDowncast<const StridedSliceDescriptor*>(&descriptor)),
414 reasonIfUnsupported);
415 case LayerType::Subtraction:
416 return IsSubtractionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
Teresa Charlin79a06a52023-07-13 17:16:45 +0100417 case LayerType::Tile:
418 return IsTileSupported(infos[0],
419 infos[1],
420 *(PolymorphicDowncast<const TileDescriptor*>(&descriptor)),
421 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000422 case LayerType::Transpose:
423 return IsTransposeSupported(infos[0],
424 infos[1],
425 *(PolymorphicDowncast<const TransposeDescriptor*>(&descriptor)),
426 reasonIfUnsupported);
427 case LayerType::TransposeConvolution2d:
428 {
429 if (infos.size() != 4)
430 {
431 throw InvalidArgumentException("Invalid number of TransposeConvolution2d TensorInfos. "
432 "TensorInfos should be of format: {input, output, weights, biases}.");
433 }
434
435 auto desc = *(PolymorphicDowncast<const TransposeConvolution2dDescriptor*>(&descriptor));
436 if (infos[3] == TensorInfo())
437 {
438 return IsTransposeConvolution2dSupported(infos[0],
439 infos[1],
440 desc,
441 infos[2],
442 EmptyOptional(),
443 reasonIfUnsupported);
444 }
445 else
446 {
447 return IsTransposeConvolution2dSupported(infos[0],
448 infos[1],
449 desc,
450 infos[2],
451 infos[3],
452 reasonIfUnsupported);
453 }
454 }
455 case LayerType::Cast:
456 return IsCastSupported(infos[0], infos[1], reasonIfUnsupported);
457 case LayerType::ChannelShuffle:
458 return IsChannelShuffleSupported(infos[0],
459 infos[1],
460 *(PolymorphicDowncast<const ChannelShuffleDescriptor*>(&descriptor)),
461 reasonIfUnsupported);
462 case LayerType::Convolution3d:
463 {
464 if (infos.size() != 4)
465 {
466 throw InvalidArgumentException("Invalid number of Convolution3d TensorInfos. "
467 "TensorInfos should be of format: {input, output, weights, biases}.");
468 }
469
470 auto desc = *(PolymorphicDowncast<const Convolution3dDescriptor*>(&descriptor));
471 if (infos[3] == TensorInfo())
472 {
473 return IsConvolution3dSupported(infos[0],
474 infos[1],
475 desc,
476 infos[2],
477 EmptyOptional(),
478 reasonIfUnsupported);
479 }
480 else
481 {
482 return IsConvolution3dSupported(infos[0],
483 infos[1],
484 desc,
485 infos[2],
486 infos[3],
487 reasonIfUnsupported);
488 }
489 }
490 case LayerType::Debug:
491 return IsDebugSupported(infos[0], infos[1], reasonIfUnsupported);
492 case LayerType::DetectionPostProcess:
493 return IsDetectionPostProcessSupported(infos[0],
494 infos[1],
495 infos[2],
496 infos[3],
497 infos[4],
498 infos[5],
499 infos[6],
500 *(PolymorphicDowncast<const DetectionPostProcessDescriptor*>
501 (&descriptor)),
502 reasonIfUnsupported);
503 case LayerType::FakeQuantization:
504 return IsFakeQuantizationSupported(infos[0],
505 *(PolymorphicDowncast<const FakeQuantizationDescriptor*>(&descriptor)),
506 reasonIfUnsupported);
507 case LayerType::MemCopy:
508 return IsMemCopySupported(infos[0], infos[1], reasonIfUnsupported);
509 case LayerType::Rank:
510 return IsRankSupported(infos[0], infos[1], reasonIfUnsupported);
511 case LayerType::Shape:
512 return IsShapeSupported(infos[0], infos[1], reasonIfUnsupported);
513 case LayerType::UnidirectionalSequenceLstm:
514 {
515 if (infos.size() != 6)
516 {
517 throw InvalidArgumentException("Invalid number of UnidirectionalSequenceLstm TensorInfos. TensorInfos "
518 "should be of format: {input, outputStateIn, cellStateIn, "
519 "hiddenStateOutputVal, cellStateOutputVal, output}");
520 }
521 auto desc = *(PolymorphicDowncast<const UnidirectionalSequenceLstmDescriptor*>(&descriptor));
Mike Kelly12994962022-04-21 11:57:09 +0100522 return IsUnidirectionalSequenceLstmSupported(infos[0],
523 infos[1],
524 infos[2],
525 infos[3],
526 infos[4],
527 infos[5],
528 desc,
529 lstmParamsInfo.value(),
530 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000531 }
532 case LayerType::Pooling3d:
533 return IsPooling3dSupported(infos[0],
534 infos[1],
535 *(PolymorphicDowncast<const Pooling3dDescriptor*>(&descriptor)),
536 reasonIfUnsupported);
537 case LayerType::Map:
538 return true;
539 case LayerType::Unmap:
540 return true;
541 case LayerType::MemImport:
542 return LayerSupportBase::IsMemImportSupported(infos[0], infos[1], reasonIfUnsupported);
543 case LayerType::Merge:
544 return LayerSupportBase::IsMergeSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
545 case LayerType::QuantizedLstm:
546 return LayerSupportBase::IsQuantizedLstmSupported(infos[0],
547 infos[1],
548 infos[2],
549 infos[3],
550 infos[4],
551 quantizedLstmInputParamsInfo.value(),
552 reasonIfUnsupported);
553 default:
Teresa Charlin9145e382023-08-17 18:44:58 +0100554 // layers not supported in reference by default:
555 // precompiled, standin, switch, fused
Cathal Corbett34b429c2021-12-24 12:24:40 +0000556 return false;
557 }
558}
559
arovir011c7c81b2018-10-08 11:34:28 +0100560bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
561 const TensorInfo& output,
562 const ActivationDescriptor& descriptor,
563 Optional<std::string&> reasonIfUnsupported) const
564{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000565 bool supported = true;
566
567 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +0000568 std::array<DataType,6> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000569 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100570 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000571 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000572 DataType::QAsymmU8,
573 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000574 };
575
576 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
577 "Reference activation: input type not supported.");
578
579 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
580 "Reference activation: output type not supported.");
581
582 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
583 "Reference activation: input and output types mismatched.");
584
585 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
586 "Reference activation: input and output shapes are of different rank.");
587
588
589 struct ActivationFunctionSupported : public Rule
590 {
591 ActivationFunctionSupported(const ActivationDescriptor& desc)
592 {
593 switch(desc.m_Function)
594 {
595 case ActivationFunction::Abs:
596 case ActivationFunction::BoundedReLu:
David Monahan3b3c3812020-02-25 09:03:29 +0000597 case ActivationFunction::Elu:
Teresa Charlin077cddb2023-09-15 15:19:21 +0100598 case ActivationFunction::Gelu:
Colm Donelan03fbeaf2020-02-26 15:39:23 +0000599 case ActivationFunction::HardSwish:
Derek Lamberti50db4e82019-03-13 14:16:15 +0000600 case ActivationFunction::LeakyReLu:
601 case ActivationFunction::Linear:
602 case ActivationFunction::ReLu:
603 case ActivationFunction::Sigmoid:
604 case ActivationFunction::SoftReLu:
605 case ActivationFunction::Sqrt:
606 case ActivationFunction::Square:
607 case ActivationFunction::TanH:
608 {
609 m_Res = true;
610 break;
611 }
612 default:
613 {
614 m_Res = false;
615 break;
616 }
617 }
618 }
619 };
620
621 // Function is supported
622 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
623 "Reference activation: function not supported.");
624
625 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100626}
627
628bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
629 const TensorInfo& input1,
630 const TensorInfo& output,
631 Optional<std::string&> reasonIfUnsupported) const
632{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000633 bool supported = true;
634
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100635 std::array<DataType,7> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000636 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100637 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000638 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000639 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100640 DataType::QSymmS16,
641 DataType::Signed32
Derek Lamberti50db4e82019-03-13 14:16:15 +0000642 };
643
644 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
645 "Reference addition: input 0 is not a supported type.");
646
647 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
648 "Reference addition: input 1 is not a supported type.");
649
650 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
651 "Reference addition: output is not a supported type.");
652
653 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
654 "Reference addition: input 0 and Input 1 types are mismatched");
655
656 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
657 "Reference addition: input and output types are mismatched");
658
659 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
660 "Reference addition: shapes are not suitable for implicit broadcast.");
661
662 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100663}
664
Nikhil Raj68c2c902019-09-19 11:21:11 +0100665bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
666 const armnn::ArgMinMaxDescriptor &descriptor,
667 armnn::Optional<std::string &> reasonIfUnsupported) const
668{
Jan Eilers8eb25602020-03-09 12:13:48 +0000669 IgnoreUnused(descriptor);
Nikhil Raj68c2c902019-09-19 11:21:11 +0100670
Mike Kelly1f140f72021-04-06 12:25:55 +0100671 std::array<DataType, 8> supportedInputTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100672 {
Teresa Charline300b362020-05-25 10:01:03 +0100673 DataType::Float16,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100674 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100675 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000676 DataType::QAsymmU8,
677 DataType::QSymmS16,
Mike Kelly1f140f72021-04-06 12:25:55 +0100678 DataType::Signed32,
679 DataType::Signed64
680 };
681
682 std::array<DataType,2> supportedOutputTypes = {
683 DataType::Signed32,
684 DataType::Signed64
Nikhil Raj68c2c902019-09-19 11:21:11 +0100685 };
686
687 bool supported = true;
688
Mike Kelly1f140f72021-04-06 12:25:55 +0100689 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100690 "Reference ArgMinMax: input is not a supported type.");
Mike Kelly1f140f72021-04-06 12:25:55 +0100691 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100692 "Reference ArgMinMax: output type not supported");
693
694 return supported;
695}
696
Samuel Yap6b478092022-07-06 15:36:03 +0100697bool RefLayerSupport::IsBatchMatMulSupported(const TensorInfo& inputX,
698 const TensorInfo& inputY,
699 const TensorInfo& output,
700 const BatchMatMulDescriptor& descriptor,
701 Optional<std::string &> reasonIfUnsupported) const
702{
703 IgnoreUnused(descriptor);
704
705 std::array<DataType, 6> supportedTypes =
706 {
Samuel Yap6b478092022-07-06 15:36:03 +0100707 DataType::Float16,
708 DataType::Float32,
709 DataType::QAsymmS8,
710 DataType::QAsymmU8,
711 DataType::QSymmS16
712 };
713
714 bool supported = true;
715
716 supported &= CheckSupportRule(TypeAnyOf(inputX, supportedTypes), reasonIfUnsupported,
717 "Reference batch matrix multiplication: input X is not a supported type");
718
719 supported &= CheckSupportRule(TypeAnyOf(inputY, supportedTypes), reasonIfUnsupported,
720 "Reference batch matrix multiplication: input Y is not a supported type");
721
722 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
723 "Reference batch matrix multiplication: output is not a supported type");
724
725 supported &= CheckSupportRule(TypesAreEqual(inputX, inputY), reasonIfUnsupported,
726 "Reference batch matrix multiplication: input X and input Y types are mismatched");
727
728 supported &= CheckSupportRule(TypesAreEqual(inputX, output), reasonIfUnsupported,
729 "Reference batch matrix multiplication: inputs and output types are mismatched");
730
731 supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputX, 2),
732 reasonIfUnsupported,
733 "Reference batch matrix multiplication: input X is not of rank 2 or greater");
734
735 supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputY, 2),
736 reasonIfUnsupported,
737 "Reference batch matrix multiplication: input Y is not of rank 2 or greater");
738
739 return supported;
740}
741
arovir011c7c81b2018-10-08 11:34:28 +0100742bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
743 const TensorInfo& output,
744 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100745 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100746 const TensorInfo& beta,
747 const TensorInfo& gamma,
748 const BatchNormalizationDescriptor& descriptor,
749 Optional<std::string&> reasonIfUnsupported) const
750{
Jan Eilers8eb25602020-03-09 12:13:48 +0000751 IgnoreUnused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100752
Sadik Armagan303980c2020-04-17 12:45:14 +0100753 std::array<DataType, 6> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100754 {
755 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100756 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100757 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000758 DataType::QAsymmU8,
759 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100760 };
761
762 bool supported = true;
763
764 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
765 "Reference batch normalization: input is not a supported type.");
766
767 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
768 "Reference batch normalization: output is not a supported type.");
769
770 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
771 "Reference batch normalization: input and output types are mismatched");
772
773 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
774 "Reference batch normalization: mean is not a supported type.");
775
776 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
777 "Reference batch normalization: variance is not a supported type.");
778
779 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
780 "Reference batch normalization: beta is not a supported type.");
781
782 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
783 "Reference batch normalization: gamma is not a supported type.");
784
785 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100786}
787
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000788bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
789 const TensorInfo& output,
790 const BatchToSpaceNdDescriptor& descriptor,
791 Optional<std::string&> reasonIfUnsupported) const
792{
Jan Eilers8eb25602020-03-09 12:13:48 +0000793 IgnoreUnused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100794
795 bool supported = true;
796
797 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
798 std::string inputTensorStr = "input";
799 std::string outputTensorStr = "output";
800
801 // Define supported types.
Sadik Armagan303980c2020-04-17 12:45:14 +0100802 std::array<DataType,6> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100803 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000804 DataType::Float32,
805 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100806 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000807 DataType::QAsymmU8,
808 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100809 };
810
811 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
812 "Reference BatchToSpaceNd: input type not supported.");
813
814 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
815 "Reference BatchToSpaceNd: output type not supported.");
816
817 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
818 "Reference BatchToSpaceNd: input and output types mismatched.");
819
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100820 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000821}
822
Idriss Chaouch98e383e2023-08-28 14:28:31 +0100823bool RefLayerSupport::IsBroadcastToSupported(const TensorInfo& input,
824 const TensorInfo& output,
825 const BroadcastToDescriptor& descriptor,
826 Optional<std::string&> reasonIfUnsupported) const
827{
828 IgnoreUnused(descriptor);
829
830 bool supported = true;
831
832 std::array<DataType, 8> supportedTypes
833 {
834 DataType::Float32,
835 DataType::Float16,
836 DataType::QAsymmS8,
837 DataType::QAsymmU8,
838 DataType::QSymmS8,
839 DataType::QSymmS16,
840 DataType::Signed32,
841 DataType::Signed64
842 };
843
844 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
845 "BroadcastTo: input type not supported.");
846
847 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
848 "BroadcastTo: output type not supported");
849
850 return supported;
851}
852
mathad01b392e982021-04-07 12:07:30 +0100853bool RefLayerSupport::IsCastSupported(const TensorInfo& input,
854 const TensorInfo& output,
855 Optional<std::string&> reasonIfUnsupported) const
856{
Teresa Charlin5306dc82023-10-30 22:29:58 +0000857 std::array<DataType, 10> supportedInputTypes =
Idriss Chaouch98e383e2023-08-28 14:28:31 +0100858 {
859 DataType::Float32,
860 DataType::Float16,
861 DataType::QSymmS8,
862 DataType::QAsymmS8,
863 DataType::QAsymmU8,
864 DataType::QSymmS16,
Teresa Charlin5306dc82023-10-30 22:29:58 +0000865 DataType::Signed32,
866 DataType::Signed64
Idriss Chaouch98e383e2023-08-28 14:28:31 +0100867 };
mathad01b392e982021-04-07 12:07:30 +0100868
869 bool supported = true;
870 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
871 "Reference cast: input is not a supported type");
872
873
874 supported &= CheckSupportRule(TypeAnyOf(output, supportedInputTypes), reasonIfUnsupported,
875 "Reference cast: output is not a supported type");
876
877 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
878 "Reference cast: input and output shapes have different number of total elements");
879
880 return supported;
881}
882
Simon Obute51f67772021-09-03 15:50:13 +0100883bool RefLayerSupport::IsChannelShuffleSupported(const TensorInfo& input,
884 const TensorInfo& output,
885 const ChannelShuffleDescriptor& descriptor,
886 Optional<std::string&> reasonIfUnsupported) const
887{
888 IgnoreUnused(descriptor);
889 bool supported = true;
890
891 // Define supported output and inputs types.
892 std::array<DataType, 7> supportedTypes =
893 {
Simon Obute51f67772021-09-03 15:50:13 +0100894 DataType::Float32,
895 DataType::Float16,
896 DataType::QAsymmS8,
897 DataType::QAsymmU8,
898 DataType::QSymmS8,
899 DataType::QSymmS16
900 };
901
902 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
903 "Reference ChannelShuffle: input is not a supported type.");
904
905 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
906 "Reference ChannelShuffle: output is not a supported type.");
907
908 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
909 "Reference ChannelShuffle: input and output types are mismatched.");
910
911 return supported;
912}
913
914
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100915bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
916 const TensorInfo& input1,
917 const TensorInfo& output,
918 const ComparisonDescriptor& descriptor,
919 Optional<std::string&> reasonIfUnsupported) const
920{
Jan Eilers8eb25602020-03-09 12:13:48 +0000921 IgnoreUnused(descriptor);
Sadik Armagan303980c2020-04-17 12:45:14 +0100922 std::array<DataType, 8> supportedInputTypes =
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100923 {
Sadik Armaganb60dd242020-03-19 13:53:16 +0000924 DataType::Boolean,
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100925 DataType::Float32,
926 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100927 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000928 DataType::QAsymmU8,
Sadik Armaganb60dd242020-03-19 13:53:16 +0000929 DataType::QSymmS16,
930 DataType::Signed32
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100931 };
932
933 bool supported = true;
934 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
935 "Reference comparison: input 0 is not a supported type");
936
937 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
938 "Reference comparison: input 0 and Input 1 types are mismatched");
939
940 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
941 "Reference comparison: output is not of type Boolean");
942
Colm Donelan02300aa2024-04-04 11:20:29 +0100943 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
944 "Reference comparison: shapes are not suitable for implicit broadcast.");
945
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100946 return supported;
947}
948
Jim Flynn906f9462019-05-10 13:55:21 +0100949bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
950 const TensorInfo& output,
Cathal Corbett34b429c2021-12-24 12:24:40 +0000951 const OriginsDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100952 Optional<std::string&> reasonIfUnsupported) const
953{
Jan Eilers8eb25602020-03-09 12:13:48 +0000954 IgnoreUnused(descriptor);
Jim Flynne242f2d2019-05-22 14:24:13 +0100955
956 bool supported = true;
Teresa Charlin6abc7ee2022-02-22 17:32:27 +0000957 std::array<DataType,7> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100958 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000959 DataType::Float32,
960 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000961 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100962 DataType::QAsymmU8,
Teresa Charlin6abc7ee2022-02-22 17:32:27 +0000963 DataType::QSymmS16,
964 DataType::Signed32
Jim Flynne242f2d2019-05-22 14:24:13 +0100965 };
966
967 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
968 "Reference concatenation: output type not supported");
969 for (const TensorInfo* input : inputs)
970 {
971 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
972 "Reference concatenation: input type not supported");
973
974 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
975 "Reference concatenation: input and output types mismatched.");
976 }
977
978 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100979}
980
arovir011c7c81b2018-10-08 11:34:28 +0100981bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
982 Optional<std::string&> reasonIfUnsupported) const
983{
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100984 std::array<DataType,8> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100985 {
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100986 DataType::Float16,
Nina Drozd58ef2c62019-05-16 12:09:18 +0100987 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +0000988 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100989 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000990 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100991 DataType::QSymmS16,
992 DataType::Signed32
Nina Drozd58ef2c62019-05-16 12:09:18 +0100993 };
994
995 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
996 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100997}
998
999bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
1000 const TensorInfo& output,
1001 Optional<std::string&> reasonIfUnsupported) const
1002{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001003 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
1004 input.GetDataType(),
1005 &TrueFunc<>,
1006 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +00001007 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +00001008 &FalseFuncI32<>,
1009 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001010 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
1011 output.GetDataType(),
1012 &FalseOutputFuncF16<>,
1013 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +00001014 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +00001015 &FalseFuncI32<>,
1016 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +01001017}
1018
1019bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
1020 const TensorInfo& output,
1021 Optional<std::string&> reasonIfUnsupported) const
1022{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001023 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
1024 input.GetDataType(),
1025 &FalseInputFuncF16<>,
1026 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +00001027 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +00001028 &FalseFuncI32<>,
1029 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001030 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
1031 output.GetDataType(),
1032 &TrueFunc<>,
1033 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +00001034 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +00001035 &FalseFuncI32<>,
1036 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +01001037}
1038
1039bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
1040 const TensorInfo& output,
1041 const Convolution2dDescriptor& descriptor,
1042 const TensorInfo& weights,
1043 const Optional<TensorInfo>& biases,
1044 Optional<std::string&> reasonIfUnsupported) const
1045{
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001046 bool supported = true;
1047
1048 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001049 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001050 {
1051 DataType::Float32,
1052 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001053 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001054 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001055 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001056 DataType::QSymmS16
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001057 };
1058
1059 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001060 "Reference Convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001061
1062 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001063 "Reference Convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001064
Ryan OShea31441592022-11-07 16:20:48 +00001065 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1066 "Reference Convolution2d: input and output types mismatched.");
1067
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001068
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001069 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +00001070 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001071 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001072 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001073 {
Sadik Armagan303980c2020-04-17 12:45:14 +01001074 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001075 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01001076 DataType::QSymmS8
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001077 };
1078
1079 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001080 "Reference Convolution2d: weights type not supported for quantized input.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001081 }
1082 else
1083 {
1084 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001085 "Reference Convolution2d: weights is not a supported type.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001086
1087 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001088 "Reference Convolution2d: input and weights types mismatched.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001089 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001090
1091 if (biases.has_value())
1092 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001093 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001094 {
1095 DataType::Float32,
1096 DataType::Float16,
1097 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001098 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001099
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001100 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001101 "Reference Convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001102 }
Jan Eilers8eb25602020-03-09 12:13:48 +00001103 IgnoreUnused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001104
1105 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001106}
1107
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001108bool RefLayerSupport::IsConvolution3dSupported(const TensorInfo& input,
1109 const TensorInfo& output,
1110 const Convolution3dDescriptor& descriptor,
1111 const TensorInfo& weights,
1112 const Optional<TensorInfo>& biases,
1113 Optional<std::string&> reasonIfUnsupported) const
1114{
1115 bool supported = true;
1116
1117 // Define supported types.
1118 std::array<DataType,7> supportedTypes =
1119 {
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001120 DataType::Float32,
1121 DataType::Float16,
1122 DataType::QAsymmS8,
1123 DataType::QAsymmU8,
1124 DataType::QSymmS8,
1125 DataType::QSymmS16
1126 };
1127
1128 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1129 "Reference Convolution3d: input is not a supported type.");
1130
1131 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1132 "Reference Convolution3d: output is not a supported type.");
1133
1134 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1135 "Reference Convolution3d: input and output types mismatched.");
1136
1137 const DataType inputType = input.GetDataType();
1138 if (IsQuantized8BitType(inputType))
1139 {
1140 std::array<DataType, 3> supportedWeightTypes =
1141 {
1142 DataType::QAsymmS8,
1143 DataType::QAsymmU8,
1144 DataType::QSymmS8
1145 };
1146
1147 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1148 "Reference Convolution3d: weights type not supported for quantized input.");
1149 }
1150 else
1151 {
1152 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1153 "Reference Convolution3d: weights is not a supported type.");
1154
1155 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1156 "Reference Convolution3d: input and weights types mismatched.");
1157 }
1158
1159 if (biases.has_value())
1160 {
1161 std::array<DataType,4> biasesSupportedTypes =
1162 {
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001163 DataType::Float32,
1164 DataType::Float16,
1165 DataType::Signed32
1166 };
1167
1168 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1169 "Reference Convolution3d: biases is not a supported type.");
1170 }
1171 IgnoreUnused(descriptor);
1172
1173 return supported;
1174}
1175
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001176bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
1177 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001178 Optional<std::string&> reasonIfUnsupported) const
1179{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001180 bool supported = true;
1181
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001182 std::array<DataType, 8> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001183 {
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001184 DataType::BFloat16,
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +00001185 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001186 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001187 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001188 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001189 DataType::QSymmS8,
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +00001190 DataType::QSymmS16,
1191 DataType::Signed32
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001192 };
1193
1194 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001195 "Reference for Debug layer: input type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001196
1197 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001198 "Reference for Debug layer: output type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001199
1200 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001201 "Reference for Debug layer: input and output types are mismatched");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001202
1203 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001204}
1205
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001206bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
1207 const TensorInfo& output,
1208 const DepthToSpaceDescriptor& descriptor,
1209 Optional<std::string&> reasonIfUnsupported) const
1210{
Jan Eilers8eb25602020-03-09 12:13:48 +00001211 IgnoreUnused(descriptor);
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001212 bool supported = true;
1213
Sadik Armagan303980c2020-04-17 12:45:14 +01001214 std::array<DataType,6> supportedTypes =
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001215 {
1216 DataType::Float32,
1217 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001218 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001219 DataType::QAsymmU8,
1220 DataType::QSymmS16
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001221 };
1222
1223 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1224 "Reference DepthToSpace: input type not supported");
1225
1226 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1227 "Reference DepthToSpace: output type not supported");
1228
1229 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1230 "Reference DepthToSpace: input and output types are mismatched");
1231
1232 return supported;
1233}
1234
arovir011c7c81b2018-10-08 11:34:28 +01001235bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
1236 const TensorInfo& output,
1237 const DepthwiseConvolution2dDescriptor& descriptor,
1238 const TensorInfo& weights,
1239 const Optional<TensorInfo>& biases,
1240 Optional<std::string&> reasonIfUnsupported) const
1241{
Sadik Armagan303980c2020-04-17 12:45:14 +01001242 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001243 bool supported = true;
1244
1245 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001246 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001247 {
1248 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001249 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001250 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001251 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001252 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001253 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001254 };
1255
1256 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1257 "Reference DepthwiseConvolution2d: input is not a supported type.");
1258
1259 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1260 "Reference DepthwiseConvolution2d: output is not a supported type.");
1261
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001262 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1263 "Reference DepthwiseConvolution2d: input and output types mismatched.");
1264
Teresa Charlind8df0262019-11-11 12:28:15 +00001265 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +00001266 if (IsQuantized8BitType(inputType))
Teresa Charlind8df0262019-11-11 12:28:15 +00001267 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001268 std::array<DataType, 3> supportedWeightTypes =
Sadik Armagan303980c2020-04-17 12:45:14 +01001269 {
1270 DataType::QAsymmS8,
1271 DataType::QAsymmU8,
1272 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001273 };
Teresa Charlind8df0262019-11-11 12:28:15 +00001274
1275 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Sadik Armagan303980c2020-04-17 12:45:14 +01001276 "Reference DepthwiseConvolution2d: weights type not supported for "
1277 "quantized input.");
Teresa Charlind8df0262019-11-11 12:28:15 +00001278 }
1279 else
1280 {
1281 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1282 "Reference DepthwiseConvolution2d: weights is not a supported type.");
1283
1284 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1285 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
1286 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001287
1288 if (biases.has_value())
1289 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001290 std::array<DataType,4> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001291 {
1292 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001293 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001294 DataType::Signed32
1295 };
1296 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1297 "Reference DepthwiseConvolution2d: biases is not a supported type.");
1298 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001299
1300 return supported;
1301
arovir011c7c81b2018-10-08 11:34:28 +01001302}
1303
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001304bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
1305 const TensorInfo& output,
1306 Optional<std::string&> reasonIfUnsupported) const
1307{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001308 bool supported = true;
1309
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001310 std::array<DataType,5> supportedInputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00001311 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001312 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00001313 DataType::QSymmS8,
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001314 DataType::QSymmS16,
1315 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001316 };
1317
1318 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001319 "Reference for Dequantize layer: input type not supported.");
1320
Derek Lambertid466a542020-01-22 15:37:29 +00001321 supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
Teresa Charlin1b1950d2021-06-02 20:23:21 +01001322 "Reference for Dequantize layer: per-axis quantized input not supported.");
Derek Lambertid466a542020-01-22 15:37:29 +00001323
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001324 std::array<DataType,3> supportedOutputTypes = {
Jan Eilersf7107932019-11-01 11:09:36 +00001325 DataType::Float32,
1326 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001327 };
1328
1329 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001330 "Reference for Dequantize layer: output type not supported.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001331
1332 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001333 "Reference for Dequantize layer: input/output shapes have different num total "
1334 "elements.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001335
1336 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001337}
1338
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001339bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
1340 const TensorInfo& scores,
1341 const TensorInfo& anchors,
1342 const TensorInfo& detectionBoxes,
1343 const TensorInfo& detectionClasses,
1344 const TensorInfo& detectionScores,
1345 const TensorInfo& numDetections,
1346 const DetectionPostProcessDescriptor& descriptor,
1347 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001348{
Jan Eilers8eb25602020-03-09 12:13:48 +00001349 IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
Derek Lamberti901ea112019-12-10 22:07:09 +00001350
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001351 bool supported = true;
1352
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001353 std::array<DataType,6> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001354 {
1355 DataType::Float32,
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001356 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001357 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001358 DataType::QAsymmU8,
1359 DataType::QSymmS16
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001360 };
1361
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001362 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001363 "Reference DetectionPostProcess: input 0 is not a supported type.");
1364
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001365 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001366 "Reference DetectionPostProcess: input 1 is not a supported type.");
1367
1368 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001369}
1370
Pablo Tellof0bd6832019-04-26 17:58:13 +01001371bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
1372 const TensorInfo& output,
1373 const DepthwiseConvolution2dDescriptor& descriptor,
1374 const TensorInfo& weights,
1375 const Optional<TensorInfo>& biases,
1376 Optional<std::string&> reasonIfUnsupported) const
1377{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001378 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +01001379}
1380
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001381bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +01001382 const TensorInfo& input1,
1383 const TensorInfo& output,
1384 Optional<std::string&> reasonIfUnsupported) const
1385{
Sadik Armagan2999a022019-04-09 14:20:12 +01001386 bool supported = true;
1387
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001388 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001389 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001390 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001391 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001392 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001393 DataType::QSymmS16,
1394 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001395 };
1396
1397 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1398 "Reference division: input 0 is not a supported type.");
1399
1400 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1401 "Reference division: input 1 is not a supported type.");
1402
1403 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1404 "Reference division: output is not a supported type.");
1405
1406 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1407 "Reference division: input 0 and Input 1 types are mismatched");
1408
1409 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1410 "Reference division: input and output types are mismatched");
1411
1412 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1413 "Reference division: shapes are not suitable for implicit broadcast.");
1414
1415 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001416}
1417
josh minor4a3c6102020-01-06 16:40:46 -06001418bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
1419 const TensorInfo& output,
1420 const ElementwiseUnaryDescriptor& descriptor,
1421 Optional<std::string&> reasonIfUnsupported) const
1422{
Jan Eilers8eb25602020-03-09 12:13:48 +00001423 IgnoreUnused(descriptor);
josh minor4a3c6102020-01-06 16:40:46 -06001424
Sadik Armagan303980c2020-04-17 12:45:14 +01001425 std::array<DataType, 7> supportedTypes =
josh minor4a3c6102020-01-06 16:40:46 -06001426 {
1427 DataType::Float32,
1428 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001429 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06001430 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00001431 DataType::QSymmS16,
1432 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06001433 };
1434
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001435 std::array<DataType, 1> logicalSupportedTypes =
1436 {
1437 DataType::Boolean
1438 };
1439
josh minor4a3c6102020-01-06 16:40:46 -06001440 bool supported = true;
1441
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001442 if (descriptor.m_Operation == UnaryOperation::LogicalNot)
1443 {
1444 supported &= CheckSupportRule(TypeAnyOf(input, logicalSupportedTypes), reasonIfUnsupported,
1445 "Reference elementwise unary: input type not supported");
josh minor4a3c6102020-01-06 16:40:46 -06001446
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001447 supported &= CheckSupportRule(TypeAnyOf(output, logicalSupportedTypes), reasonIfUnsupported,
1448 "Reference elementwise unary: output type not supported");
1449 }
1450 else
1451 {
1452 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1453 "Reference elementwise unary: input type not supported");
1454
1455 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1456 "Reference elementwise unary: output type not supported");
1457 }
josh minor4a3c6102020-01-06 16:40:46 -06001458
1459 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1460 "Reference elementwise unary: input and output types not matching");
1461
1462 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1463 "Reference elementwise unary: input and output shapes"
1464 "have different number of total elements");
1465
1466 return supported;
1467}
1468
arovir011c7c81b2018-10-08 11:34:28 +01001469bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
1470 const FakeQuantizationDescriptor& descriptor,
1471 Optional<std::string&> reasonIfUnsupported) const
1472{
Jan Eilers8eb25602020-03-09 12:13:48 +00001473 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001474 bool supported = true;
1475
1476 std::array<DataType,1> supportedTypes =
1477 {
1478 DataType::Float32
1479 };
1480
1481 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1482 "Reference fake quantization: input type not supported.");
1483
1484 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001485}
1486
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001487bool RefLayerSupport::IsFillSupported(const TensorInfo& input,
1488 const TensorInfo& output,
1489 const FillDescriptor& descriptor,
1490 Optional<std::string&> reasonIfUnsupported) const
1491{
1492 IgnoreUnused(descriptor);
1493 IgnoreUnused(output);
1494
1495 bool supported = true;
1496
Sadik Armagana792a052020-06-23 16:22:23 +01001497 std::array<DataType,3> supportedTypes =
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001498 {
1499 DataType::Float32,
Sadik Armagana792a052020-06-23 16:22:23 +01001500 DataType::Float16,
1501 DataType::Signed32
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001502 };
1503
Teresa Charlin4b10fef2020-07-29 09:36:41 +01001504 supported &= CheckSupportRule(TypeIs(input, DataType::Signed32), reasonIfUnsupported,
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001505 "Reference Fill: input type not supported.");
1506
Teresa Charlin44088502020-07-27 11:27:19 +01001507 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1508 "Reference Fill: output type not supported.");
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001509 return supported;
1510}
1511
arovir011c7c81b2018-10-08 11:34:28 +01001512bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
1513 const TensorInfo& output,
1514 Optional<std::string&> reasonIfUnsupported) const
1515{
Jan Eilers8eb25602020-03-09 12:13:48 +00001516 IgnoreUnused(output);
James Conroy83735b12019-05-30 16:36:59 +01001517 bool supported = true;
1518
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01001519 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +01001520 {
James Conroyb40d7102019-06-04 12:32:09 +01001521 DataType::Float32,
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01001522 DataType::Float16
James Conroy83735b12019-05-30 16:36:59 +01001523 };
1524
1525 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1526 "Reference Floor: input type not supported.");
1527
1528 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1529 "Reference Floor: output type not supported.");
1530
1531 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001532}
1533
1534bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
1535 const TensorInfo& output,
1536 const TensorInfo& weights,
1537 const TensorInfo& biases,
1538 const FullyConnectedDescriptor& descriptor,
1539 Optional<std::string&> reasonIfUnsupported) const
1540{
Francis Murtagh46c09d02019-05-28 08:15:28 +01001541 bool supported = true;
1542
1543 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001544 std::array<DataType,6> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +01001545 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001546 DataType::Float32,
1547 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001548 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001549 DataType::QAsymmU8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001550 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001551 };
1552
1553 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1554 "Reference Fully Connected: input type not supported.");
1555
1556 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1557 "Reference Fully Connected: output type not supported.");
1558
Francis Murtagh46c09d02019-05-28 08:15:28 +01001559 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1560 "Reference Fully Connected: weights type not supported.");
1561
Ryan OShea31441592022-11-07 16:20:48 +00001562 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1563 "Reference Fully Connected: input and output types mismatched.");
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001564
Jan Eilers1f45dc32020-06-15 11:43:03 +01001565 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1566 "Reference Fully Connected: weights is not a supported type.");
Francis Murtaghddb1d062020-03-10 13:51:45 +00001567
Jan Eilers1f45dc32020-06-15 11:43:03 +01001568 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1569 "Reference Fully Connected: input and weights types mismatched.");
Francis Murtagh46c09d02019-05-28 08:15:28 +01001570
1571 if (descriptor.m_BiasEnabled)
1572 {
1573 // Defined supported types for bias
Sadik Armagandb73c982020-04-01 17:35:30 +01001574 std::array<DataType, 5>
Francis Murtagh46c09d02019-05-28 08:15:28 +01001575 supportedBiasTypes =
1576 {
1577 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001578 DataType::Float16,
Sadik Armagandb73c982020-04-01 17:35:30 +01001579 DataType::Signed32,
1580 DataType::QAsymmS8
Francis Murtagh46c09d02019-05-28 08:15:28 +01001581 };
1582
1583 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
1584 "Reference Fully Connected: bias type not supported.");
1585
1586 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
1587 "Reference Fully Connected: bias and weight types mismatch.");
1588
1589 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
1590 "Reference Fully Connected: bias type inferred from weights is incompatible.");
1591
Narumol Prangnawarat366d7232020-04-29 12:58:17 +01001592 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(biases, 1U), reasonIfUnsupported,
1593 "Reference Fully Connected: bias must have 1 dimension.");
1594
Francis Murtagh46c09d02019-05-28 08:15:28 +01001595 }
1596
1597 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001598}
1599
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01001600bool RefLayerSupport::IsGatherNdSupported(const armnn::TensorInfo& input0,
1601 const armnn::TensorInfo& input1,
1602 const armnn::TensorInfo& output,
1603 armnn::Optional<std::string&> reasonIfUnsupported) const
1604{
1605 bool supported = true;
1606 std::array<DataType,7> supportedTypes =
1607 {
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01001608 DataType::Float32,
1609 DataType::Float16,
1610 DataType::QAsymmS8,
1611 DataType::QAsymmU8,
1612 DataType::QSymmS16,
1613 DataType::Signed32
1614 };
1615
1616 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1617 "Reference GatherNd: input type not supported");
1618
1619 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1620 "Reference GatherNd: output type not supported");
1621
1622 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1623 "Reference GatherNd: indices (input1) type not supported");
1624
1625 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1626 "Reference GatherNd: input and output types not matching");
1627
1628 return supported;
1629}
1630
narpra014951d842019-01-18 16:53:53 +00001631bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
1632 const armnn::TensorInfo& input1,
1633 const armnn::TensorInfo& output,
Teresa Charlin52664732020-06-29 16:27:03 +01001634 const GatherDescriptor& descriptor,
narpra014951d842019-01-18 16:53:53 +00001635 armnn::Optional<std::string&> reasonIfUnsupported) const
1636{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001637 bool supported = true;
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001638 std::array<DataType,7> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001639 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001640 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001641 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001642 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001643 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001644 DataType::QSymmS16,
1645 DataType::Signed32
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001646 };
1647
Nikhil Raj369d8fc2022-11-24 13:12:36 +00001648 IgnoreUnused(descriptor);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001649 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1650 "Reference Gather: input type not supported");
1651
1652 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1653 "Reference Gather: output type not supported");
1654
1655 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1656 "Reference Gather: indices (input1) type not supported");
1657
1658 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1659 "Reference Gather: input and output types not matching");
1660
1661 return supported;
narpra014951d842019-01-18 16:53:53 +00001662}
1663
Derek Lamberti901ea112019-12-10 22:07:09 +00001664bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
1665 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +01001666{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001667 return true;
arovir011c7c81b2018-10-08 11:34:28 +01001668}
1669
Kevin May09ca49c2019-10-09 12:37:34 +01001670bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
1671 const TensorInfo& output,
1672 const InstanceNormalizationDescriptor& descriptor,
1673 Optional<std::string&> reasonIfUnsupported) const
1674{
Jan Eilers8eb25602020-03-09 12:13:48 +00001675 IgnoreUnused(descriptor);
Kevin May09ca49c2019-10-09 12:37:34 +01001676 // Define supported types
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001677 std::array<DataType, 3> supportedTypes =
Kevin May09ca49c2019-10-09 12:37:34 +01001678 {
1679 DataType::Float32,
1680 DataType::Float16
1681 };
1682
1683 bool supported = true;
1684
1685 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1686 "Reference Instance Normalization: input type not supported.");
1687
1688 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1689 "Reference Instance Normalization: output type not supported.");
1690
1691 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1692 "Reference Instance Normalization: input and output types mismatched.");
1693
1694 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1695 "Reference Instance Normalization: input and output shapes have different "
1696 "num total elements.");
1697
1698 return supported;
1699}
1700
arovir011c7c81b2018-10-08 11:34:28 +01001701bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1702 const TensorInfo& output,
1703 const L2NormalizationDescriptor& descriptor,
1704 Optional<std::string&> reasonIfUnsupported) const
1705{
Jan Eilers8eb25602020-03-09 12:13:48 +00001706 IgnoreUnused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001707 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01001708 std::array<DataType, 6> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001709 {
1710 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001711 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001712 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001713 DataType::QAsymmU8,
1714 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001715 };
1716
1717 bool supported = true;
1718
1719 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1720 "Reference L2normalization: input type not supported.");
1721
1722 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1723 "Reference L2normalization: output type not supported.");
1724
1725 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1726 "Reference L2normalization: input and output types mismatched.");
1727
1728 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1729 "Reference L2normalization: input and output shapes have different "
1730 "num total elements.");
1731
1732 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001733}
1734
James Conroyaba90cd2020-11-06 16:28:18 +00001735bool RefLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0,
1736 const TensorInfo& input1,
1737 const TensorInfo& output,
1738 const LogicalBinaryDescriptor& descriptor,
1739 Optional<std::string&> reasonIfUnsupported) const
1740{
1741 IgnoreUnused(descriptor);
1742
1743 std::array<DataType, 1> supportedTypes =
1744 {
1745 DataType::Boolean
1746 };
1747
1748 bool supported = true;
1749 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1750 "Reference LogicalBinary: input 0 type not supported");
1751 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1752 "Reference LogicalBinary: input 1 type not supported");
1753
1754 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1755 "Reference LogicalBinary: input and output types do not match");
1756
Colm Donelan02300aa2024-04-04 11:20:29 +01001757 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1758 "Reference LogicalBinary: shapes are not suitable for implicit broadcast.");
1759
James Conroyaba90cd2020-11-06 16:28:18 +00001760 return supported;
1761}
1762
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001763bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1764 const TensorInfo& output,
1765 const LogSoftmaxDescriptor& descriptor,
1766 Optional<std::string&> reasonIfUnsupported) const
1767{
Jan Eilers8eb25602020-03-09 12:13:48 +00001768 IgnoreUnused(descriptor);
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001769
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001770 std::array<DataType, 3> supportedTypes =
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001771 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001772 DataType::Float32,
1773 DataType::Float16
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001774 };
1775
1776 bool supported = true;
1777 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1778 "Reference LogSoftmax: input type not supported");
1779
1780 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1781 "Reference LogSoftmax: output type not supported");
1782
1783 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1784 "Reference LogSoftmax: input and output types do not match");
1785
1786 return supported;
1787}
1788
arovir011c7c81b2018-10-08 11:34:28 +01001789bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1790 const TensorInfo& outputStateIn,
1791 const TensorInfo& cellStateIn,
1792 const TensorInfo& scratchBuffer,
1793 const TensorInfo& outputStateOut,
1794 const TensorInfo& cellStateOut,
1795 const TensorInfo& output,
1796 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001797 const LstmInputParamsInfo& paramsInfo,
1798 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +01001799{
Jan Eilers8eb25602020-03-09 12:13:48 +00001800 IgnoreUnused(descriptor);
1801 IgnoreUnused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001802
1803 bool supported = true;
1804
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001805 std::array<DataType,3> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +01001806 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001807 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001808 };
1809
Jan Eilersd01a83c2019-07-03 18:20:40 +01001810 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001811 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1812 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001813 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1814 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001815 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1816 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001817 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1818 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001819 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1820 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001821 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1822 "Reference Lstm: input and cellStateOut types are mismatched");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01001823
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001824 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1825 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001826 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001827 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001828 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001829 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001830 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001831 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001832 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001833 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001834 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001835 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001836 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001837 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001838 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001839 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001840 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001841 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001842 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001843 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001844 "Reference Lstm: input and OutputGateBias types are mismatched");
1845 if (!descriptor.m_CifgEnabled)
1846 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001847 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001848 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001849 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001850 reasonIfUnsupported,
1851 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001852 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001853 "Reference Lstm: input and InputGateBias types are mismatched");
1854 if (descriptor.m_PeepholeEnabled)
1855 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001856 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001857 reasonIfUnsupported,
1858 "Reference Lstm: input and CellToInputWeights types are mismatched");
1859 }
1860 }
1861 if (descriptor.m_PeepholeEnabled)
1862 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001863 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001864 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001865 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001866 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1867 }
1868 if (descriptor.m_ProjectionEnabled)
1869 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001870 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001871 "Reference Lstm: input and mProjectionWeights types are mismatched");
1872 if (paramsInfo.m_ProjectionBias != nullptr)
1873 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001874 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001875 "Reference Lstm: input and ProjectionBias types are mismatched");
1876 }
1877 }
1878 if (descriptor.m_LayerNormEnabled)
1879 {
1880 if (!descriptor.m_CifgEnabled)
1881 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001882 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001883 reasonIfUnsupported,
1884 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1885 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001886 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001887 reasonIfUnsupported,
1888 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001889 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001890 reasonIfUnsupported,
1891 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001892 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001893 reasonIfUnsupported,
1894 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1895 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001896
1897 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001898}
1899
saoste012df12b32018-11-28 16:57:20 +00001900bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1901 const TensorInfo& input1,
1902 const TensorInfo& output,
1903 Optional<std::string&> reasonIfUnsupported) const
1904{
Sadik Armagan2999a022019-04-09 14:20:12 +01001905 bool supported = true;
1906
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001907 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001908 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001909 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001910 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001911 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001912 DataType::QSymmS16,
1913 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001914 };
1915
1916 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1917 "Reference maximum: input 0 is not a supported type.");
1918
1919 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1920 "Reference maximum: input 1 is not a supported type.");
1921
1922 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1923 "Reference maximum: output is not a supported type.");
1924
1925 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1926 "Reference maximum: input 0 and Input 1 types are mismatched");
1927
1928 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1929 "Reference maximum: input and output types are mismatched");
1930
1931 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1932 "Reference maximum: shapes are not suitable for implicit broadcast.");
1933
1934 return supported;
saoste012df12b32018-11-28 16:57:20 +00001935}
1936
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001937bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1938 const TensorInfo& output,
1939 const MeanDescriptor& descriptor,
1940 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001941{
James Conroy4d1ff582019-06-10 17:06:39 +01001942 bool supported = true;
1943 std::string meanLayerStr = "Mean";
1944 std::string outputTensorStr = "output";
1945
Sadik Armagan303980c2020-04-17 12:45:14 +01001946 std::array<DataType,6> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001947 {
1948 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001949 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001950 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001951 DataType::QAsymmU8,
1952 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01001953 };
1954
1955 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1956 "Reference Mean: input type not supported.");
1957
1958 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1959 "Reference Mean: input and output types are mismatched");
1960
1961 if (descriptor.m_KeepDims)
1962 {
1963 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1964 reasonIfUnsupported,
1965 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1966 output.GetNumDimensions(),
1967 meanLayerStr, outputTensorStr).data());
1968 }
1969 else if (descriptor.m_Axis.empty())
1970 {
1971 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1972 reasonIfUnsupported,
1973 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1974 meanLayerStr, outputTensorStr).data());
1975 }
1976 else
1977 {
Matthew Sloyan171214c2020-09-09 09:07:37 +01001978 auto outputDim = input.GetNumDimensions() - armnn::numeric_cast<unsigned int>(descriptor.m_Axis.size());
James Conroy4d1ff582019-06-10 17:06:39 +01001979
1980 if (outputDim > 0)
1981 {
1982 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1983 reasonIfUnsupported,
1984 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1985 meanLayerStr, outputTensorStr).data());
1986 }
1987 else
1988 {
1989 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1990 reasonIfUnsupported,
1991 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1992 meanLayerStr, outputTensorStr).data());
1993 }
1994 }
1995
1996 return supported;
narpra0132b90462018-09-13 11:07:48 +01001997}
1998
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001999bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
2000 const TensorInfo &output,
2001 Optional<std::string &> reasonIfUnsupported) const
2002{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002003 bool supported = true;
2004
Sadik Armagan303980c2020-04-17 12:45:14 +01002005 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002006 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002007 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002008 DataType::Float32,
2009 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002010 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002011 DataType::QAsymmU8,
2012 DataType::QSymmS16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002013 DataType::Boolean
2014 };
2015
2016 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2017 "Reference MemCopy: input type not supported");
2018
2019 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2020 "Reference MemCopy: output type not supported");
2021
2022 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2023 "Reference MemCopy: input and output types are mismatched");
2024
2025 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00002026}
2027
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00002028bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
2029 const TensorInfo& input1,
2030 const TensorInfo& output,
2031 Optional<std::string&> reasonIfUnsupported) const
2032{
Sadik Armagan2999a022019-04-09 14:20:12 +01002033 bool supported = true;
2034
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002035 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01002036 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002037 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002038 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002039 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002040 DataType::QSymmS16,
2041 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002042 };
2043
2044 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2045 "Reference minimum: input 0 is not a supported type.");
2046
2047 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2048 "Reference minimum: input 1 is not a supported type.");
2049
2050 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2051 "Reference minimum: output is not a supported type.");
2052
2053 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2054 "Reference minimum: input 0 and Input 1 types are mismatched");
2055
2056 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2057 "Reference minimum: input and output types are mismatched");
2058
2059 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2060 "Reference minimum: shapes are not suitable for implicit broadcast.");
2061
2062 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00002063}
2064
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002065bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
2066 const TensorInfo& input1,
2067 const TensorInfo& output,
2068 Optional<std::string&> reasonIfUnsupported) const
2069{
Sadik Armagan2999a022019-04-09 14:20:12 +01002070 bool supported = true;
2071
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002072 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01002073 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002074 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002075 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002076 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002077 DataType::QSymmS16,
2078 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002079 };
2080
2081 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2082 "Reference multiplication: input 0 is not a supported type.");
2083
2084 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2085 "Reference multiplication: input 1 is not a supported type.");
2086
2087 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2088 "Reference multiplication: output is not a supported type.");
2089
2090 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2091 "Reference multiplication: input 0 and Input 1 types are mismatched");
2092
2093 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2094 "Reference multiplication: input and output types are mismatched");
2095
2096 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2097 "Reference multiplication: shapes are not suitable for implicit broadcast.");
2098
2099 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002100}
2101
2102bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
2103 const TensorInfo& output,
2104 const NormalizationDescriptor& descriptor,
2105 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01002106{
Jan Eilers8eb25602020-03-09 12:13:48 +00002107 IgnoreUnused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002108
2109 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01002110 std::array<DataType, 6> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002111 {
2112 DataType::Float16,
2113 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002114 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002115 DataType::QAsymmU8,
2116 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002117 };
2118
2119 bool supported = true;
2120
2121 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2122 "Reference normalization: input type not supported.");
2123
2124 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2125 "Reference normalization: output type not supported.");
2126
2127 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2128 "Reference normalization: input and output shapes have different "
2129 "num total elements.");
2130
2131 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002132}
2133
Derek Lamberti901ea112019-12-10 22:07:09 +00002134bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
2135 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002136{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01002137 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002138}
2139
2140bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
2141 const TensorInfo& output,
2142 const PadDescriptor& descriptor,
2143 Optional<std::string&> reasonIfUnsupported) const
2144{
Jan Eilers8eb25602020-03-09 12:13:48 +00002145 IgnoreUnused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002146 bool supported = true;
2147
2148 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002149 std::array<DataType,6> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002150 {
2151 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002152 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002153 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002154 DataType::QAsymmU8,
2155 DataType::QSymmS16
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002156 };
2157
2158 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2159 "Reference pad: input is not a supported type.");
2160
2161 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2162 "Reference pad: output is not a supported type.");
2163
2164 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2165 "Reference pad: input and output types are mismatched.");
2166
2167 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01002168}
2169
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002170bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
2171 const TensorInfo& output,
2172 const PermuteDescriptor& descriptor,
2173 Optional<std::string&> reasonIfUnsupported) const
2174{
Jan Eilers8eb25602020-03-09 12:13:48 +00002175 IgnoreUnused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002176 bool supported = true;
2177
2178 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002179 std::array<DataType, 6> supportedTypes =
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002180 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002181 DataType::BFloat16,
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002182 DataType::Float32,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002183 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002184 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002185 DataType::QAsymmU8,
2186 DataType::QSymmS16
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002187 };
2188
2189 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2190 "Reference permute: input is not a supported type.");
2191
2192 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2193 "Reference permute: output is not a supported type.");
2194
2195 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2196 "Reference permute: input and output types are mismatched.");
2197
2198 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002199}
2200
2201bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
2202 const TensorInfo& output,
2203 const Pooling2dDescriptor& descriptor,
2204 Optional<std::string&> reasonIfUnsupported) const
2205{
Jan Eilers8eb25602020-03-09 12:13:48 +00002206 IgnoreUnused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01002207 bool supported = true;
2208
2209 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002210 std::array<DataType,6> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01002211 {
2212 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002213 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002214 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002215 DataType::QAsymmU8,
2216 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01002217 };
2218
2219 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2220 "Reference poolind2d: input is not a supported type.");
2221
2222 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2223 "Reference poolind2d: output is not a supported type.");
2224
2225 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2226 "Reference poolind2d: input and output types are mismatched.");
2227
2228 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002229}
2230
Tamás Nyíri7b885b32021-10-26 14:47:57 +01002231bool RefLayerSupport::IsPooling3dSupported(const TensorInfo& input,
2232 const TensorInfo& output,
2233 const Pooling3dDescriptor& descriptor,
2234 Optional<std::string&> reasonIfUnsupported) const
2235{
2236 IgnoreUnused(descriptor);
2237 bool supported = true;
2238
2239 // Define supported output and inputs types.
2240 std::array<DataType,6> supportedTypes =
2241 {
Tamás Nyíri7b885b32021-10-26 14:47:57 +01002242 DataType::Float32,
2243 DataType::Float16,
2244 DataType::QAsymmS8,
2245 DataType::QAsymmU8,
2246 DataType::QSymmS16
2247 };
2248
2249 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2250 "Reference poolind3d: input is not a supported type.");
2251
2252 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2253 "Reference poolind3d: output is not a supported type.");
2254
2255 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2256 "Reference poolind3d: input and output types are mismatched.");
2257
2258 return supported;
2259}
2260
2261
James Conroy4f1f8992020-04-29 20:01:10 +01002262bool RefLayerSupport::IsQLstmSupported(const TensorInfo& input,
2263 const TensorInfo& previousOutputIn,
2264 const TensorInfo& previousCellStateIn,
2265 const TensorInfo& outputStateOut,
2266 const TensorInfo& cellStateOut,
2267 const TensorInfo& output,
2268 const QLstmDescriptor& descriptor,
2269 const LstmInputParamsInfo& paramsInfo,
2270 Optional<std::string&> reasonIfUnsupported) const
2271{
2272 IgnoreUnused(input);
2273 IgnoreUnused(previousOutputIn);
2274 IgnoreUnused(previousCellStateIn);
2275 IgnoreUnused(outputStateOut);
2276 IgnoreUnused(cellStateOut);
2277 IgnoreUnused(output);
2278 IgnoreUnused(descriptor);
2279 IgnoreUnused(paramsInfo);
2280
2281 IgnoreUnused(reasonIfUnsupported);
2282
2283 return true;
2284}
2285
Derek Lamberti5f400d62019-03-25 15:41:58 +00002286bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
2287 const TensorInfo& output,
2288 Optional<std::string&> reasonIfUnsupported) const
2289{
2290 bool supported = true;
2291
Finn Williamsfd271062019-12-04 14:27:27 +00002292 // Define supported input types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002293 std::array<DataType,7> supportedInputTypes = {
Keith Davis5e51cd82020-01-29 16:52:59 +00002294 DataType::Float32,
Keith Davis3d8bc972020-02-04 09:31:47 +00002295 DataType::Float16,
Ryan OShea9add1202020-02-07 10:06:33 +00002296 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002297 DataType::QAsymmU8,
2298 DataType::QSymmS8,
2299 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002300 };
2301
2302 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
2303 "Reference quantize: input type not supported.");
2304
2305 // Define supported output types.
Ryan OShea9add1202020-02-07 10:06:33 +00002306 std::array<DataType,4> supportedOutputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00002307 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002308 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00002309 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002310 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002311 };
2312 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2313 "Reference quantize: output type not supported.");
2314
2315 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2316 "Reference quantize: input and output shapes have different num total elements.");
2317
2318 return supported;
2319}
2320
Finn Williams2605b232020-06-10 15:53:46 +01002321bool RefLayerSupport::IsRankSupported(const TensorInfo& input,
2322 const TensorInfo& output,
2323 Optional<std::string&> reasonIfUnsupported) const
2324{
2325 IgnoreUnused(input);
2326 // Define supported output types.
2327 std::array<DataType,1> supportedOutputTypes =
2328 {
2329 DataType::Signed32,
2330 };
2331
2332 return CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2333 "Reference rank: input type not supported.");
2334}
2335
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00002336bool RefLayerSupport::IsReduceSupported(const TensorInfo& input,
2337 const TensorInfo& output,
2338 const ReduceDescriptor& descriptor,
2339 Optional<std::string&> reasonIfUnsupported) const
2340{
2341 IgnoreUnused(descriptor);
2342 bool supported = true;
2343 std::array<DataType,7> supportedTypes =
2344 {
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00002345 DataType::Float32,
2346 DataType::Float16,
2347 DataType::QAsymmS8,
2348 DataType::QAsymmU8,
2349 DataType::QSymmS16,
2350 DataType::Signed32
2351 };
2352
2353 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2354 "Reference Reduce: input type not supported");
2355
2356 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2357 "Reference Reduce: output type not supported");
2358
2359 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2360 "Reference Reduce: input and output types not matching");
2361
2362 return supported;
2363}
2364
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002365bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Kevin Maya023c402019-12-12 17:28:05 +00002366 const TensorInfo& output,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00002367 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002368 Optional<std::string&> reasonIfUnsupported) const
2369{
Jan Eilers8eb25602020-03-09 12:13:48 +00002370 IgnoreUnused(output);
2371 IgnoreUnused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01002372 // Define supported output types.
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002373 std::array<DataType,8> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01002374 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002375 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01002376 DataType::Float32,
2377 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01002378 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002379 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002380 DataType::QAsymmU8,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002381 DataType::QSymmS16,
2382 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01002383 };
Keith Davis0c2eeac2020-02-11 16:51:50 +00002384
Nina Drozd2f2778f2019-05-27 10:37:05 +01002385 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
2386 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002387}
2388
Teresa Charlin970f43b2019-07-01 13:51:07 +01002389bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
2390 const TensorInfo& output,
2391 const ResizeDescriptor& descriptor,
2392 Optional<std::string&> reasonIfUnsupported) const
2393{
Jan Eilers8eb25602020-03-09 12:13:48 +00002394 IgnoreUnused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01002395 bool supported = true;
Teresa Charlince655882023-11-21 15:44:13 +00002396 std::array<DataType,7> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01002397 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002398 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01002399 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002400 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002401 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002402 DataType::QAsymmU8,
Teresa Charlince655882023-11-21 15:44:13 +00002403 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002404 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01002405 };
2406
2407 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2408 "Reference Resize: input type not supported");
2409
2410 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2411 "Reference Resize: output type not supported");
2412
2413 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2414 "Reference Resize: input and output types not matching");
2415
2416 return supported;
2417}
2418
Tracy Narinebb8d7592023-07-13 16:50:54 +01002419bool RefLayerSupport::IsReverseV2Supported(const TensorInfo& input0,
2420 const TensorInfo& input1,
Tianle Cheng988354d2023-06-28 13:20:47 +01002421 const TensorInfo& output,
Tianle Cheng988354d2023-06-28 13:20:47 +01002422 Optional<std::string&> reasonIfUnsupported) const
2423{
Tianle Cheng988354d2023-06-28 13:20:47 +01002424 bool supported = true;
2425 // ReverseV2 is data type agnostic so it can support all the types in the Reference backend
Declan-ARM1bf56cd2023-07-20 17:32:57 +01002426 std::array<DataType,8> supportedTypes =
Tianle Cheng988354d2023-06-28 13:20:47 +01002427 {
2428 DataType::BFloat16,
2429 DataType::Float32,
2430 DataType::Float16,
2431 DataType::QAsymmS8,
2432 DataType::QAsymmU8,
Declan-ARM1bf56cd2023-07-20 17:32:57 +01002433 DataType::QSymmS8,
2434 DataType::QSymmS16,
2435 DataType::Signed32
Tianle Cheng988354d2023-06-28 13:20:47 +01002436 };
2437
Tracy Narinebb8d7592023-07-13 16:50:54 +01002438 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2439 "Reference ReverseV2: input0 type not supported");
Tianle Cheng988354d2023-06-28 13:20:47 +01002440
2441 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2442 "Reference ReverseV2: output type not supported");
2443
Tracy Narinebb8d7592023-07-13 16:50:54 +01002444 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2445 "Reference ReverseV2: input0 and output types not matching");
2446
2447 std::array<DataType,6> input2SupportedTypes =
2448 {
2449 DataType::Signed32
2450 };
2451
2452 supported &= CheckSupportRule(TypeAnyOf(input1, input2SupportedTypes), reasonIfUnsupported,
2453 "Reference ReverseV2: input1 type not supported");
Tianle Cheng988354d2023-06-28 13:20:47 +01002454
2455 return supported;
2456}
2457
Tianle Cheng28288182024-02-23 17:56:54 +00002458bool RefLayerSupport::IsScatterNdSupported(const TensorInfo& input,
2459 const TensorInfo& indices,
2460 const TensorInfo& updates,
2461 const TensorInfo& output,
2462 const ScatterNdDescriptor& descriptor,
2463 Optional<std::string&> reasonIfUnsupported) const
2464{
2465 IgnoreUnused(descriptor);
2466
2467 bool supported = true;
2468
2469 std::array<DataType, 7> supportedTypes
2470 {
2471 DataType::Float32,
2472 DataType::Float16,
2473 DataType::QAsymmS8,
2474 DataType::QAsymmU8,
2475 DataType::QSymmS8,
2476 DataType::QSymmS16,
2477 DataType::Signed32
2478 };
2479
2480 std::array<DataType, 1> indicesSupportedTypes =
2481 {
2482 DataType::Signed32
2483 };
2484
2485 supported &= CheckSupportRule(TypeAnyOf(indices, indicesSupportedTypes), reasonIfUnsupported,
2486 "ScatterNd: indices type not supported.");
2487
2488 supported &= CheckSupportRule(TypeAnyOf(updates, supportedTypes), reasonIfUnsupported,
2489 "ScatterNd: updates type not supported.");
2490
2491 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2492 "ScatterNd: output type not supported");
2493
2494 supported &= CheckSupportRule(TypesAreEqual(updates, output), reasonIfUnsupported,
2495 "ScatterNd: input and updates types are mismatched");
2496
2497 if (descriptor.m_InputEnabled)
2498 {
2499 // If the input slot is enabled, we have the input tensor in this slot
2500 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2501 "ScatterNd: input type not supported.");
2502
2503 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2504 "ScatterNd: input and output types are mismatched");
2505 }
2506 else
2507 {
2508 // If the input slot is not enabled, we have the shape tensor in this slot
2509 supported &= CheckSupportRule(TypeAnyOf(input, indicesSupportedTypes), reasonIfUnsupported,
2510 "ScatterNd: shape type not supported.");
2511 }
2512
2513 return supported;
2514}
2515
Keith Davis3ae3f972021-05-21 16:33:48 +01002516bool RefLayerSupport::IsShapeSupported(const TensorInfo& input,
2517 const TensorInfo& output,
2518 Optional<std::string&> reasonIfUnsupported) const
2519{
2520 IgnoreUnused(input);
2521 bool supported = true;
2522
2523 std::array<DataType, 1> supportedTypes =
2524 {
2525 DataType::Signed32
2526 };
2527
2528 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2529 "Reference Shape: output type not supported");
2530
2531 return supported;
2532}
2533
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002534bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
2535 const TensorInfo& output,
2536 const SliceDescriptor& descriptor,
2537 Optional<std::string&> reasonIfUnsupported) const
2538{
Jan Eilers8eb25602020-03-09 12:13:48 +00002539 IgnoreUnused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002540 bool supported = true;
2541
Sadik Armagan303980c2020-04-17 12:45:14 +01002542 std::array<DataType, 5> supportedTypes =
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002543 {
2544 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002545 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002546 DataType::QAsymmU8,
Ryan OShea980446b2023-06-08 16:23:28 +01002547 DataType::QSymmS16,
2548 DataType::Signed32
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002549 };
2550
2551 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2552 "Reference Slice: input type not supported");
2553
2554 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2555 "Reference Slice: output type not supported");
2556
2557 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2558 "Reference Slice: input and output types are mismatched");
2559
2560 return supported;
2561}
2562
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002563bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
2564 const TensorInfo& output,
2565 const SoftmaxDescriptor& descriptor,
2566 Optional<std::string&> reasonIfUnsupported) const
2567{
Jan Eilers8eb25602020-03-09 12:13:48 +00002568 IgnoreUnused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01002569 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002570 std::array<DataType,7> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01002571 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002572 DataType::Float32,
2573 DataType::Float16,
2574 DataType::QSymmS8,
2575 DataType::QAsymmS8,
2576 DataType::QAsymmU8,
2577 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +01002578 };
2579
2580 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002581 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002582
2583 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002584 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002585
2586 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002587 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002588
2589 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002590}
2591
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002592bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
2593 const TensorInfo& output,
2594 const SpaceToBatchNdDescriptor& descriptor,
2595 Optional<std::string&> reasonIfUnsupported) const
2596{
Jan Eilers8eb25602020-03-09 12:13:48 +00002597 IgnoreUnused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01002598 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002599 std::array<DataType,6> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01002600 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002601 DataType::Float32,
2602 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002603 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002604 DataType::QAsymmU8,
2605 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01002606 };
2607
2608 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2609 "Reference SpaceToBatchNd: input type not supported");
2610
2611 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2612 "Reference SpaceToBatchNd: output type not supported");
2613
2614 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2615 "Reference SpaceToBatchNd: input and output types are mismatched");
2616
2617 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002618}
2619
Keith Davisa57eccb2019-06-14 17:33:22 +01002620bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01002621 const TensorInfo& output,
2622 const SpaceToDepthDescriptor& descriptor,
2623 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01002624{
2625
Jan Eilers8eb25602020-03-09 12:13:48 +00002626 IgnoreUnused(descriptor);
Keith Davisa57eccb2019-06-14 17:33:22 +01002627 bool supported = true;
2628
Sadik Armagan303980c2020-04-17 12:45:14 +01002629 std::array<DataType,6> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01002630 {
2631 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002632 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002633 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002634 DataType::QAsymmU8,
2635 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01002636 };
2637
2638 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2639 "Reference SpaceToDepth: input type not supported");
2640
2641 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2642 "Reference SpaceToDepth: output type not supported");
2643
2644 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2645 "Reference SpaceToDepth: input and output types are mismatched");
2646
2647 return supported;
2648}
2649
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002650bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002651 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
2652 const ViewsDescriptor& descriptor,
2653 Optional<std::string&> reasonIfUnsupported) const
2654{
Jan Eilers8eb25602020-03-09 12:13:48 +00002655 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002656 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002657 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002658 {
2659 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002660 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002661 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002662 DataType::QAsymmU8,
2663 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002664 };
2665
2666 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2667 "Reference splitter: output type not supported");
Derek Lambertieac4adb2020-08-25 13:05:59 +01002668 for (const TensorInfo& output : outputs)
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002669 {
2670 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2671 "Reference splitter: input type not supported");
2672
2673 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2674 "Reference splitter: input and output types mismatched.");
2675 }
2676
2677 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002678}
2679
Matthew Jackson81e601c2019-07-11 12:07:09 +01002680bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
2681 const TensorInfo& output,
2682 const StackDescriptor& descriptor,
2683 Optional<std::string&> reasonIfUnsupported) const
2684{
Jan Eilers8eb25602020-03-09 12:13:48 +00002685 IgnoreUnused(descriptor);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002686
2687 bool supported = true;
Sadik Armagan529195f2022-01-14 12:56:35 +00002688 std::array<DataType,7> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01002689 {
2690 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01002691 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002692 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002693 DataType::QAsymmU8,
Sadik Armagan529195f2022-01-14 12:56:35 +00002694 DataType::QSymmS16,
2695 DataType::Signed32
Matthew Jackson81e601c2019-07-11 12:07:09 +01002696 };
2697
2698 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2699 "Reference stack: output type not supported");
2700 for (const TensorInfo* input : inputs)
2701 {
Matthew Jackson81e601c2019-07-11 12:07:09 +01002702 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
2703 "Reference stack: input type not supported");
2704
2705 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
2706 "Reference stack: input and output types mismatched.");
2707 }
2708
2709 return supported;
2710}
2711
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002712bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
2713 const TensorInfo& output,
2714 const StridedSliceDescriptor& descriptor,
2715 Optional<std::string&> reasonIfUnsupported) const
2716{
Jan Eilers8eb25602020-03-09 12:13:48 +00002717 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002718 bool supported = true;
2719
Sadik Armagan303980c2020-04-17 12:45:14 +01002720 std::array<DataType,5> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002721 {
2722 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002723 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002724 DataType::QAsymmU8,
2725 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002726 };
2727
2728 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2729 "Reference StridedSlice: input type not supported");
2730
2731 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2732 "Reference StridedSlice: output type not supported");
2733
2734 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2735 "Reference StridedSlice: input and output types are mismatched");
2736
2737 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002738}
2739
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002740bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
2741 const TensorInfo& input1,
2742 const TensorInfo& output,
2743 Optional<std::string&> reasonIfUnsupported) const
2744{
Sadik Armagan2999a022019-04-09 14:20:12 +01002745 bool supported = true;
2746
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002747 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01002748 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002749 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002750 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002751 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002752 DataType::QSymmS16,
2753 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002754 };
2755
2756 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2757 "Reference subtraction: input 0 is not a supported type.");
2758
2759 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2760 "Reference subtraction: input 1 is not a supported type.");
2761
2762 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2763 "Reference subtraction: output is not a supported type.");
2764
2765 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2766 "Reference subtraction: input 0 and Input 1 types are mismatched");
2767
2768 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2769 "Reference subtraction: input and output types are mismatched");
2770
2771 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2772 "Reference subtraction: shapes are not suitable for implicit broadcast.");
2773
2774 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002775}
2776
Matteo Martincighab9e5252019-06-13 17:27:46 +01002777bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
2778 const TensorInfo& alpha,
2779 const TensorInfo& output,
2780 Optional<std::string&> reasonIfUnsupported) const
2781{
2782 bool supported = true;
2783
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002784 std::array<DataType, 6> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01002785 {
2786 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002787 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002788 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002789 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002790 DataType::QSymmS16
Matteo Martincighab9e5252019-06-13 17:27:46 +01002791 };
2792
2793 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2794 "PReLU: input is not a supported type.");
2795
2796 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
2797 "PReLU: alpha is not a supported type.");
2798
2799 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2800 "PReLU: output is not a supported type.");
2801
2802 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
2803 "PReLU: input, alpha and output types are mismatched");
2804
2805 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
2806 "PReLU: shapes are not suitable for implicit broadcast");
2807
2808 return supported;
2809}
2810
Teresa Charlin79a06a52023-07-13 17:16:45 +01002811bool RefLayerSupport::IsTileSupported(const TensorInfo& input,
2812 const TensorInfo& output,
2813 const TileDescriptor& descriptor,
2814 Optional<std::string&> reasonIfUnsupported) const
2815{
2816 IgnoreUnused(descriptor);
2817
2818 bool supported = true;
2819
2820 std::array<DataType, 7> supportedTypes
2821 {
2822 DataType::Float32,
2823 DataType::Float16,
2824 DataType::QAsymmS8,
2825 DataType::QAsymmU8,
2826 DataType::QSymmS8,
2827 DataType::QSymmS16,
2828 DataType::Signed32
2829 };
2830
2831 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2832 "Tile: input type not supported.");
2833
2834 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2835 "Tile: output type not supported");
2836
2837 return supported;
2838}
2839
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002840bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
2841 const TensorInfo& output,
2842 const TransposeConvolution2dDescriptor& descriptor,
2843 const TensorInfo& weights,
2844 const Optional<TensorInfo>& biases,
2845 Optional<std::string&> reasonIfUnsupported) const
2846{
Jan Eilers8eb25602020-03-09 12:13:48 +00002847 IgnoreUnused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002848 bool supported = true;
2849
Sadik Armagan303980c2020-04-17 12:45:14 +01002850 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002851 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002852 DataType::Float32,
2853 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002854 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002855 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002856 DataType::QSymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002857 DataType::QSymmS16
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002858 };
2859
2860 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2861 "Reference TransposeConvolution2d: input is not a supported type.");
2862
2863 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2864 "Reference TransposeConvolution2d: output is not a supported type.");
2865
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002866 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2867 "Reference TransposeConvolution2d: input and output types mismatched.");
2868
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002869
2870 const DataType inputType = input.GetDataType();
Sadik Armagan303980c2020-04-17 12:45:14 +01002871 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002872 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01002873 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002874 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002875 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002876 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01002877 DataType::QSymmS8
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002878 };
2879
2880 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
2881 "Reference TransposeConvolution2d: weights type not supported for "
2882 "quantized input.");
2883 }
2884 else
2885 {
2886 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
2887 "Reference TransposeConvolution2d: weights is not a supported type.");
2888
2889 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
2890 "Reference TransposeConvolution2d: input and weights types mismatched.");
2891 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002892
2893 if (biases.has_value())
2894 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002895 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01002896 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002897 DataType::Float32,
2898 DataType::Float16,
2899 DataType::Signed32
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002900 };
2901 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
2902 "Reference TransposeConvolution2d: biases is not a supported type.");
2903 }
2904
2905 return supported;
2906}
2907
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002908bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
2909 const TensorInfo& output,
2910 const TransposeDescriptor& descriptor,
2911 Optional<std::string&> reasonIfUnsupported) const
2912{
Jan Eilers8eb25602020-03-09 12:13:48 +00002913 IgnoreUnused(descriptor);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002914 bool supported = true;
2915
2916 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002917 std::array<DataType, 6> supportedTypes =
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002918 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002919 DataType::BFloat16,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002920 DataType::Float32,
2921 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002922 DataType::QAsymmS8,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002923 DataType::QAsymmU8,
2924 DataType::QSymmS16
2925 };
2926
2927 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2928 "Reference transpose: input is not a supported type.");
2929
2930 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2931 "Reference transpose: output is not a supported type.");
2932
2933 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2934 "Reference transpose: input and output types are mismatched.");
2935
2936 return supported;
2937}
2938
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002939bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
2940 const TensorInfo& input,
2941 const TensorInfo& outputStateIn,
2942 const TensorInfo& cellStateIn,
Mike Kelly12994962022-04-21 11:57:09 +01002943 const TensorInfo& outputStateOut,
2944 const TensorInfo& cellStateOut,
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002945 const TensorInfo& output,
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002946 const UnidirectionalSequenceLstmDescriptor& descriptor,
2947 const LstmInputParamsInfo& paramsInfo,
2948 Optional<std::string&> reasonIfUnsupported) const
2949{
2950 IgnoreUnused(descriptor);
2951 IgnoreUnused(paramsInfo);
2952 IgnoreUnused(outputStateIn);
2953 IgnoreUnused(cellStateIn);
Mike Kelly12994962022-04-21 11:57:09 +01002954 IgnoreUnused(outputStateOut);
2955 IgnoreUnused(cellStateOut);
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002956 bool supported = true;
2957
Mike Kelly12994962022-04-21 11:57:09 +01002958 std::array<DataType, 2> supportedTypes =
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002959 {
Mike Kelly12994962022-04-21 11:57:09 +01002960 DataType::Float32,
2961 DataType::QAsymmS8
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002962 };
2963
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002964 std::array<DataType, 2> supportedWeightTypes =
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002965 {
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002966 DataType::Float32,
2967 DataType::QAsymmS8
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002968 };
2969
Mike Kelly12994962022-04-21 11:57:09 +01002970 std::array<DataType, 3> supportedBiasTypes =
2971 {
2972 DataType::Float32,
2973 DataType::QAsymmS8,
2974 DataType::Signed32
2975 };
2976
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002977 // check inputs and outputs
2978 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2979 "Reference UnidirectionalSequenceLstm: input is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002980 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2981 "Reference UnidirectionalSequenceLstm: output is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002982
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002983 // check layer parameters
2984 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes),
2985 reasonIfUnsupported,
2986 "Reference UnidirectionalSequenceLstm: InputToForgetWeights "
2987 "is not a supported type.");
2988 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToCellWeights(), supportedWeightTypes),
2989 reasonIfUnsupported,
2990 "Reference UnidirectionalSequenceLstm: InputToCellWeights is not a supported type.");
2991 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToOutputWeights(), supportedWeightTypes),
2992 reasonIfUnsupported,
2993 "Reference UnidirectionalSequenceLstm: InputToOutputWeights "
2994 "is not a supported type.");
2995 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToForgetWeights(), supportedWeightTypes),
2996 reasonIfUnsupported,
2997 "Reference UnidirectionalSequenceLstm: RecurrentToForgetWeights "
2998 "is not a supported type.");
2999 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToCellWeights(), supportedWeightTypes),
3000 reasonIfUnsupported,
3001 "Reference UnidirectionalSequenceLstm: RecurrentToCellWeights "
3002 "is not a supported type.");
3003 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToOutputWeights(), supportedWeightTypes),
3004 reasonIfUnsupported,
3005 "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights "
3006 "is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01003007
3008 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetGateBias(), supportedBiasTypes), reasonIfUnsupported,
3009 "Reference UnidirectionalSequenceLstm: ForgetGateBias is not a supported type.");
3010 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellBias(), supportedBiasTypes), reasonIfUnsupported,
3011 "Reference UnidirectionalSequenceLstm: CellBias is not a supported type.");
3012 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputGateBias(), supportedBiasTypes), reasonIfUnsupported,
3013 "Reference UnidirectionalSequenceLstm: OutputGateBias is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01003014 if (!descriptor.m_CifgEnabled)
3015 {
3016 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes),
3017 reasonIfUnsupported,
3018 "Reference UnidirectionalSequenceLstm: InputToInputWeights "
3019 "is not a supported type.");
3020 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToInputWeights(), supportedWeightTypes),
3021 reasonIfUnsupported,
3022 "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights "
3023 "is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01003024 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputGateBias(), supportedBiasTypes), reasonIfUnsupported,
3025 "Reference UnidirectionalSequenceLstm: InputGateBias is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01003026 if (descriptor.m_PeepholeEnabled)
3027 {
3028 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes),
3029 reasonIfUnsupported,
3030 "Reference UnidirectionalSequenceLstm: CellToInputWeights "
3031 "is not a supported type.");
3032 }
3033 }
3034 if (descriptor.m_PeepholeEnabled)
3035 {
3036 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToForgetWeights(), supportedWeightTypes),
3037 reasonIfUnsupported,
3038 "Reference UnidirectionalSequenceLstm: CellToForgetWeights "
3039 "is not a supported type.");
3040 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToOutputWeights(), supportedWeightTypes),
3041 reasonIfUnsupported,
3042 "Reference UnidirectionalSequenceLstm: CellToOutputWeights "
3043 "is not a supported type.");
3044 }
3045 if (descriptor.m_ProjectionEnabled)
3046 {
3047 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetProjectionWeights(), supportedWeightTypes),
3048 reasonIfUnsupported,
3049 "Reference UnidirectionalSequenceLstm: ProjectionWeights "
3050 "is not a supported type.");
3051 if (paramsInfo.m_ProjectionBias != nullptr)
3052 {
3053 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
3054 "Reference UnidirectionalSequenceLstm: input and ProjectionBias types "
3055 "are mismatched");
3056 }
3057 }
3058 if (descriptor.m_LayerNormEnabled)
3059 {
3060 if (!descriptor.m_CifgEnabled)
3061 {
3062 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputLayerNormWeights(), supportedWeightTypes),
3063 reasonIfUnsupported,
3064 "Reference UnidirectionalSequenceLstm: InputLayerNormWeights "
3065 "is not a supported type.");
3066 }
3067 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetLayerNormWeights(), supportedWeightTypes),
3068 reasonIfUnsupported,
3069 "Reference UnidirectionalSequenceLstm: ForgetLayerNormWeights "
3070 "is not a supported type.");
3071 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellLayerNormWeights(), supportedWeightTypes),
3072 reasonIfUnsupported,
3073 "Reference UnidirectionalSequenceLstm: CellLayerNormWeights "
3074 "is not a supported type.");
3075 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputLayerNormWeights(), supportedWeightTypes),
3076 reasonIfUnsupported,
3077 "Reference UnidirectionalSequenceLstm: OutputLayerNormWeights "
3078 "is not a supported type.");
3079 }
3080
3081 return supported;
3082}
3083
arovir011c7c81b2018-10-08 11:34:28 +01003084} // namespace armnn