blob: f97d03a26e3a29ab2bd53dbc711b129b1d75a9b7 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
Colm Donelanb4ef1632024-02-01 15:00:43 +00002// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01007
Keith Davis0c2eeac2020-02-11 16:51:50 +00008#include <armnn/TypesUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000010#include <armnn/utility/IgnoreUnused.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010011#include <armnn/utility/NumericCast.hpp>
Cathal Corbett34b429c2021-12-24 12:24:40 +000012#include <armnn/utility/PolymorphicDowncast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
Matteo Martincighe011d202019-11-28 11:35:47 +000014#include <LayerSupportCommon.hpp>
Derek Lambertif674aa02019-08-01 15:56:25 +010015#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000016
Derek Lamberti50db4e82019-03-13 14:16:15 +000017#include <vector>
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <array>
19
telsoa014fcda012018-03-09 14:13:49 +000020namespace armnn
21{
22
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010023namespace
24{
25
26template<typename Float32Func, typename Uint8Func, typename ... Params>
27bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
28 DataType dataType,
29 Float32Func floatFuncPtr,
30 Uint8Func uint8FuncPtr,
31 Params&&... params)
32{
33 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
34 dataType,
35 &FalseFunc<Params...>,
36 floatFuncPtr,
37 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000038 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000039 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010040 std::forward<Params>(params)...);
41}
42
43} // anonymous namespace
44
James Conroy4d1ff582019-06-10 17:06:39 +010045namespace
46{
47
48std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
49 unsigned int actual,
50 std::string& layerStr,
51 std::string& tensorName)
52{
53 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
54 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
55
56 return errorMsg;
57}
58
59} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000060
Cathal Corbett34b429c2021-12-24 12:24:40 +000061bool RefLayerSupport::IsLayerSupported(const LayerType& type,
62 const std::vector<TensorInfo>& infos,
63 const BaseDescriptor& descriptor,
64 const Optional<LstmInputParamsInfo>& lstmParamsInfo,
65 const Optional<QuantizedLstmInputParamsInfo>& quantizedLstmInputParamsInfo,
66 Optional<std::string&> reasonIfUnsupported) const
67{
68 switch (type)
69 {
70 case LayerType::Activation:
71 return IsActivationSupported(infos[0],
72 infos[1],
73 *(PolymorphicDowncast<const ActivationDescriptor*>(&descriptor)),
74 reasonIfUnsupported);
75 case LayerType::Addition:
76 return IsAdditionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
77 case LayerType::ArgMinMax:
78 return IsArgMinMaxSupported(infos[0],
79 infos[1],
80 *(PolymorphicDowncast<const ArgMinMaxDescriptor*>(&descriptor)),
81 reasonIfUnsupported);
Samuel Yap6b478092022-07-06 15:36:03 +010082 case LayerType::BatchMatMul:
83 return IsBatchMatMulSupported(infos[0],
84 infos[1],
85 infos[2],
86 *(PolymorphicDowncast<const BatchMatMulDescriptor*>(&descriptor)),
87 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +000088 case LayerType::BatchNormalization:
89 return IsBatchNormalizationSupported(infos[0],
90 infos[1],
91 infos[2],
92 infos[3],
93 infos[4],
94 infos[5],
95 *(PolymorphicDowncast<const BatchNormalizationDescriptor*>
96 (&descriptor)),
97 reasonIfUnsupported);
98 case LayerType::BatchToSpaceNd:
99 return IsBatchToSpaceNdSupported(infos[0],
100 infos[1],
101 *(PolymorphicDowncast<const BatchToSpaceNdDescriptor*>(&descriptor)),
102 reasonIfUnsupported);
Idriss Chaouch98e383e2023-08-28 14:28:31 +0100103 case LayerType::BroadcastTo:
104 return IsBroadcastToSupported(infos[0],
105 infos[1],
106 *(PolymorphicDowncast<const BroadcastToDescriptor*>(&descriptor)),
107 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000108 case LayerType::Comparison:
109 return IsComparisonSupported(infos[0],
110 infos[1],
111 infos[2],
112 *(PolymorphicDowncast<const ComparisonDescriptor*>(&descriptor)),
113 reasonIfUnsupported);
114 case LayerType::Concat:
115 {
116 std::vector<const TensorInfo*> inputInfos;
117 for (uint32_t i = 0; i < (infos.size() - 1); i++)
118 {
119 inputInfos.push_back(&infos[i]);
120 }
121 return IsConcatSupported(inputInfos,
122 infos[infos.size() - 1],
123 *(PolymorphicDowncast<const OriginsDescriptor*>(&descriptor)),
124 reasonIfUnsupported);
125 }
126 case LayerType::Constant:
127 return IsConstantSupported(infos[0], reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000128 case LayerType::ConvertFp16ToFp32:
129 return IsConvertFp16ToFp32Supported(infos[0], infos[1], reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000130 case LayerType::ConvertFp32ToFp16:
131 return IsConvertFp32ToFp16Supported(infos[0], infos[1], reasonIfUnsupported);
132 case LayerType::Convolution2d:
133 {
134 if (infos.size() != 4)
135 {
136 throw InvalidArgumentException("Invalid number of Convolution2d TensorInfos. "
137 "TensorInfos should be of format: {input, output, weights, biases}.");
138 }
139
140 auto desc = *(PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor));
141 if (infos[3] == TensorInfo())
142 {
143 return IsConvolution2dSupported(infos[0],
144 infos[1],
145 desc,
146 infos[2],
147 EmptyOptional(),
148 reasonIfUnsupported);
149 }
150 else
151 {
152 return IsConvolution2dSupported(infos[0],
153 infos[1],
154 desc,
155 infos[2],
156 infos[3],
157 reasonIfUnsupported);
158 }
159 }
160 case LayerType::DepthToSpace:
161 return IsDepthToSpaceSupported(infos[0],
162 infos[1],
163 *(PolymorphicDowncast<const DepthToSpaceDescriptor*>(&descriptor)),
164 reasonIfUnsupported);
165 case LayerType::DepthwiseConvolution2d:
166 {
167 if (infos.size() != 4)
168 {
169 throw InvalidArgumentException("Invalid number of DepthwiseConvolution2d TensorInfos. "
170 "TensorInfos should be of format: {input, output, weights, biases}.");
171 }
172
173 auto desc = *(PolymorphicDowncast<const DepthwiseConvolution2dDescriptor*>(&descriptor));
174 if (infos[3] == TensorInfo())
175 {
176 return IsDepthwiseConvolutionSupported(infos[0],
177 infos[1],
178 desc,
179 infos[2],
180 EmptyOptional(),
181 reasonIfUnsupported);
182 }
183 else
184 {
185 return IsDepthwiseConvolutionSupported(infos[0],
186 infos[1],
187 desc,
188 infos[2],
189 infos[3],
190 reasonIfUnsupported);
191 }
192 }
193 case LayerType::Dequantize:
194 return IsDequantizeSupported(infos[0], infos[1], reasonIfUnsupported);
195 case LayerType::Division:
196 return IsDivisionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
Mike Kelly3ec30772023-03-08 13:47:17 +0000197 case LayerType::ElementwiseBinary:
198 {
199 std::array<DataType, 7> supportedTypes =
200 {
201 DataType::Float32,
202 DataType::Float16,
203 DataType::QAsymmS8,
204 DataType::QAsymmU8,
205 DataType::QSymmS16,
206 DataType::Signed32
207 };
208
209 bool supported = true;
210 supported &= CheckSupportRule(TypeAnyOf(infos[0], supportedTypes), reasonIfUnsupported,
211 "Reference elementwise unary: input type not supported");
212
213 supported &= CheckSupportRule(TypeAnyOf(infos[1], supportedTypes), reasonIfUnsupported,
214 "Reference elementwise unary: input type not supported");
215
216 supported &= CheckSupportRule(TypeAnyOf(infos[2], supportedTypes), reasonIfUnsupported,
217 "Reference elementwise unary: output type not supported");
218
219 supported &= CheckSupportRule(TypesAreEqual(infos[0], infos[1]), reasonIfUnsupported,
220 "Reference elementwise unary: input types not matching");
221
222 supported &= CheckSupportRule(TypesAreEqual(infos[0], infos[2]), reasonIfUnsupported,
223 "Reference elementwise unary: input and output types not matching");
224
225 return supported;
226 }
Cathal Corbett34b429c2021-12-24 12:24:40 +0000227 case LayerType::ElementwiseUnary:
228 return IsElementwiseUnarySupported(infos[0],
229 infos[1],
230 *(PolymorphicDowncast<const ElementwiseUnaryDescriptor*>(&descriptor)),
231 reasonIfUnsupported);
232 case LayerType::Fill:
233 return IsFillSupported(infos[0],
234 infos[1],
235 *(PolymorphicDowncast<const FillDescriptor*>(&descriptor)),
236 reasonIfUnsupported);
237 case LayerType::Floor:
238 return IsFloorSupported(infos[0], infos[1], reasonIfUnsupported);
239 case LayerType::FullyConnected:
240 return IsFullyConnectedSupported(infos[0],
241 infos[1],
242 infos[2],
243 infos[3],
244 *(PolymorphicDowncast<const FullyConnectedDescriptor*>(&descriptor)),
245 reasonIfUnsupported);
246 case LayerType::Gather:
247 return IsGatherSupported(infos[0],
248 infos[1],
249 infos[2],
250 *(PolymorphicDowncast<const GatherDescriptor*>(&descriptor)),
251 reasonIfUnsupported);
Teresa Charlinb2d3ec52022-04-12 22:07:09 +0100252 case LayerType::GatherNd:
253 return IsGatherNdSupported(infos[0],
254 infos[1],
255 infos[2],
256 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000257 case LayerType::Input:
258 return IsInputSupported(infos[0], reasonIfUnsupported);
259 case LayerType::InstanceNormalization:
260 return IsInstanceNormalizationSupported(infos[0],
261 infos[1],
262 *(PolymorphicDowncast<const InstanceNormalizationDescriptor*>
263 (&descriptor)),
264 reasonIfUnsupported);
265 case LayerType::L2Normalization:
266 return IsL2NormalizationSupported(infos[0],
267 infos[1],
268 *(PolymorphicDowncast<const L2NormalizationDescriptor*>(&descriptor)),
269 reasonIfUnsupported);
270 case LayerType::LogicalBinary:
271 return IsLogicalBinarySupported(infos[0],
272 infos[1],
273 infos[2],
274 *(PolymorphicDowncast<const LogicalBinaryDescriptor*>(&descriptor)),
275 reasonIfUnsupported);
276 case LayerType::LogSoftmax:
277 return IsLogSoftmaxSupported(infos[0],
278 infos[1],
279 *(PolymorphicDowncast<const LogSoftmaxDescriptor*>(&descriptor)),
280 reasonIfUnsupported);
281 case LayerType::Lstm:
282 return IsLstmSupported(infos[0],
283 infos[1],
284 infos[2],
285 infos[3],
286 infos[4],
287 infos[5],
288 infos[6],
289 *(PolymorphicDowncast<const LstmDescriptor*>(&descriptor)),
290 lstmParamsInfo.value(),
291 reasonIfUnsupported);
292 case LayerType::QLstm:
293 return IsQLstmSupported(infos[0],
294 infos[1],
295 infos[2],
296 infos[3],
297 infos[4],
298 infos[5],
299 *(PolymorphicDowncast<const QLstmDescriptor*>(&descriptor)),
300 lstmParamsInfo.value(),
301 reasonIfUnsupported);
302 case LayerType::Maximum:
303 return IsMaximumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
304 case LayerType::Mean:
305 return IsMeanSupported(infos[0],
306 infos[1],
307 *(PolymorphicDowncast<const MeanDescriptor*>(&descriptor)),
308 reasonIfUnsupported);
309 case LayerType::Minimum:
310 return IsMinimumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
311 case LayerType::Multiplication:
312 return IsMultiplicationSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
313 case LayerType::Normalization:
314 return IsNormalizationSupported(infos[0],
315 infos[1],
316 *(PolymorphicDowncast<const NormalizationDescriptor*>(&descriptor)),
317 reasonIfUnsupported);
318 case LayerType::Output:
319 return IsOutputSupported(infos[0], reasonIfUnsupported);
320 case LayerType::Pad:
321 return IsPadSupported(infos[0],
322 infos[1],
323 *(PolymorphicDowncast<const PadDescriptor*>(&descriptor)),
324 reasonIfUnsupported);
325 case LayerType::Permute:
326 return IsPermuteSupported(infos[0],
327 infos[1],
328 *(PolymorphicDowncast<const PermuteDescriptor*>(&descriptor)),
329 reasonIfUnsupported);
330 case LayerType::Pooling2d:
331 return IsPooling2dSupported(infos[0],
332 infos[1],
333 *(PolymorphicDowncast<const Pooling2dDescriptor*>(&descriptor)),
334 reasonIfUnsupported);
335 case LayerType::Prelu:
336 return IsPreluSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
337 case LayerType::Quantize:
338 return IsQuantizeSupported(infos[0], infos[1], reasonIfUnsupported);
339 case LayerType::Reshape:
340 return IsReshapeSupported(infos[0],
341 infos[1],
342 *(PolymorphicDowncast<const ReshapeDescriptor*>(&descriptor)),
343 reasonIfUnsupported);
344 case LayerType::Resize:
345 return IsResizeSupported(infos[0],
346 infos[1],
347 *(PolymorphicDowncast<const ResizeDescriptor*>(&descriptor)),
348 reasonIfUnsupported);
Tianle Cheng988354d2023-06-28 13:20:47 +0100349 case LayerType::ReverseV2:
350 return IsReverseV2Supported(infos[0],
351 infos[1],
Tracy Narinebb8d7592023-07-13 16:50:54 +0100352 infos[2],
Tianle Cheng988354d2023-06-28 13:20:47 +0100353 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000354 case LayerType::Reduce:
355 return IsReduceSupported(infos[0],
356 infos[1],
357 *(PolymorphicDowncast<const ReduceDescriptor*>(&descriptor)),
358 reasonIfUnsupported);
359 case LayerType::Slice:
360 return IsSliceSupported(infos[0],
361 infos[1],
362 *(PolymorphicDowncast<const SliceDescriptor*>(&descriptor)),
363 reasonIfUnsupported);
364 case LayerType::Softmax:
365 return IsSoftmaxSupported(infos[0],
366 infos[1],
367 *(PolymorphicDowncast<const SoftmaxDescriptor*>(&descriptor)),
368 reasonIfUnsupported);
369 case LayerType::SpaceToBatchNd:
370 return IsSpaceToBatchNdSupported(infos[0],
371 infos[1],
372 *(PolymorphicDowncast<const SpaceToBatchNdDescriptor*>(&descriptor)),
373 reasonIfUnsupported);
374 case LayerType::SpaceToDepth:
375 return IsSpaceToDepthSupported(infos[0],
376 infos[1],
377 *(PolymorphicDowncast<const SpaceToDepthDescriptor*>(&descriptor)),
378 reasonIfUnsupported);
379 case LayerType::Splitter:
380 {
381 std::vector<TensorInfo> outputInfos;
382 for (uint32_t i = 1; i < infos.size(); i++)
383 {
384 outputInfos.push_back(infos[i]);
385 }
386 return IsSplitterSupported(infos[0],
387 {outputInfos.begin(), outputInfos.end()},
388 *(PolymorphicDowncast<const ViewsDescriptor*>(&descriptor)),
389 reasonIfUnsupported);
390 }
391 case LayerType::Stack:
392 {
393 std::vector<const TensorInfo*> inputInfos;
394 for (uint32_t i = 0; i < infos.size() - 1; i++)
395 {
396 inputInfos.push_back(&infos[i]);
397 }
398 return IsStackSupported(inputInfos,
399 infos[infos.size() - 1],
400 *(PolymorphicDowncast<const StackDescriptor*>(&descriptor)),
401 reasonIfUnsupported);
402 }
403 case LayerType::StridedSlice:
404 return IsStridedSliceSupported(infos[0],
405 infos[1],
406 *(PolymorphicDowncast<const StridedSliceDescriptor*>(&descriptor)),
407 reasonIfUnsupported);
408 case LayerType::Subtraction:
409 return IsSubtractionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
Teresa Charlin79a06a52023-07-13 17:16:45 +0100410 case LayerType::Tile:
411 return IsTileSupported(infos[0],
412 infos[1],
413 *(PolymorphicDowncast<const TileDescriptor*>(&descriptor)),
414 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000415 case LayerType::Transpose:
416 return IsTransposeSupported(infos[0],
417 infos[1],
418 *(PolymorphicDowncast<const TransposeDescriptor*>(&descriptor)),
419 reasonIfUnsupported);
420 case LayerType::TransposeConvolution2d:
421 {
422 if (infos.size() != 4)
423 {
424 throw InvalidArgumentException("Invalid number of TransposeConvolution2d TensorInfos. "
425 "TensorInfos should be of format: {input, output, weights, biases}.");
426 }
427
428 auto desc = *(PolymorphicDowncast<const TransposeConvolution2dDescriptor*>(&descriptor));
429 if (infos[3] == TensorInfo())
430 {
431 return IsTransposeConvolution2dSupported(infos[0],
432 infos[1],
433 desc,
434 infos[2],
435 EmptyOptional(),
436 reasonIfUnsupported);
437 }
438 else
439 {
440 return IsTransposeConvolution2dSupported(infos[0],
441 infos[1],
442 desc,
443 infos[2],
444 infos[3],
445 reasonIfUnsupported);
446 }
447 }
448 case LayerType::Cast:
449 return IsCastSupported(infos[0], infos[1], reasonIfUnsupported);
450 case LayerType::ChannelShuffle:
451 return IsChannelShuffleSupported(infos[0],
452 infos[1],
453 *(PolymorphicDowncast<const ChannelShuffleDescriptor*>(&descriptor)),
454 reasonIfUnsupported);
455 case LayerType::Convolution3d:
456 {
457 if (infos.size() != 4)
458 {
459 throw InvalidArgumentException("Invalid number of Convolution3d TensorInfos. "
460 "TensorInfos should be of format: {input, output, weights, biases}.");
461 }
462
463 auto desc = *(PolymorphicDowncast<const Convolution3dDescriptor*>(&descriptor));
464 if (infos[3] == TensorInfo())
465 {
466 return IsConvolution3dSupported(infos[0],
467 infos[1],
468 desc,
469 infos[2],
470 EmptyOptional(),
471 reasonIfUnsupported);
472 }
473 else
474 {
475 return IsConvolution3dSupported(infos[0],
476 infos[1],
477 desc,
478 infos[2],
479 infos[3],
480 reasonIfUnsupported);
481 }
482 }
483 case LayerType::Debug:
484 return IsDebugSupported(infos[0], infos[1], reasonIfUnsupported);
485 case LayerType::DetectionPostProcess:
486 return IsDetectionPostProcessSupported(infos[0],
487 infos[1],
488 infos[2],
489 infos[3],
490 infos[4],
491 infos[5],
492 infos[6],
493 *(PolymorphicDowncast<const DetectionPostProcessDescriptor*>
494 (&descriptor)),
495 reasonIfUnsupported);
496 case LayerType::FakeQuantization:
497 return IsFakeQuantizationSupported(infos[0],
498 *(PolymorphicDowncast<const FakeQuantizationDescriptor*>(&descriptor)),
499 reasonIfUnsupported);
500 case LayerType::MemCopy:
501 return IsMemCopySupported(infos[0], infos[1], reasonIfUnsupported);
502 case LayerType::Rank:
503 return IsRankSupported(infos[0], infos[1], reasonIfUnsupported);
504 case LayerType::Shape:
505 return IsShapeSupported(infos[0], infos[1], reasonIfUnsupported);
506 case LayerType::UnidirectionalSequenceLstm:
507 {
508 if (infos.size() != 6)
509 {
510 throw InvalidArgumentException("Invalid number of UnidirectionalSequenceLstm TensorInfos. TensorInfos "
511 "should be of format: {input, outputStateIn, cellStateIn, "
512 "hiddenStateOutputVal, cellStateOutputVal, output}");
513 }
514 auto desc = *(PolymorphicDowncast<const UnidirectionalSequenceLstmDescriptor*>(&descriptor));
Mike Kelly12994962022-04-21 11:57:09 +0100515 return IsUnidirectionalSequenceLstmSupported(infos[0],
516 infos[1],
517 infos[2],
518 infos[3],
519 infos[4],
520 infos[5],
521 desc,
522 lstmParamsInfo.value(),
523 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000524 }
525 case LayerType::Pooling3d:
526 return IsPooling3dSupported(infos[0],
527 infos[1],
528 *(PolymorphicDowncast<const Pooling3dDescriptor*>(&descriptor)),
529 reasonIfUnsupported);
530 case LayerType::Map:
531 return true;
532 case LayerType::Unmap:
533 return true;
534 case LayerType::MemImport:
535 return LayerSupportBase::IsMemImportSupported(infos[0], infos[1], reasonIfUnsupported);
536 case LayerType::Merge:
537 return LayerSupportBase::IsMergeSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
538 case LayerType::QuantizedLstm:
539 return LayerSupportBase::IsQuantizedLstmSupported(infos[0],
540 infos[1],
541 infos[2],
542 infos[3],
543 infos[4],
544 quantizedLstmInputParamsInfo.value(),
545 reasonIfUnsupported);
546 default:
Teresa Charlin9145e382023-08-17 18:44:58 +0100547 // layers not supported in reference by default:
548 // precompiled, standin, switch, fused
Cathal Corbett34b429c2021-12-24 12:24:40 +0000549 return false;
550 }
551}
552
arovir011c7c81b2018-10-08 11:34:28 +0100553bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
554 const TensorInfo& output,
555 const ActivationDescriptor& descriptor,
556 Optional<std::string&> reasonIfUnsupported) const
557{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000558 bool supported = true;
559
560 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +0000561 std::array<DataType,6> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000562 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100563 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000564 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000565 DataType::QAsymmU8,
566 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000567 };
568
569 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
570 "Reference activation: input type not supported.");
571
572 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
573 "Reference activation: output type not supported.");
574
575 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
576 "Reference activation: input and output types mismatched.");
577
578 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
579 "Reference activation: input and output shapes are of different rank.");
580
581
582 struct ActivationFunctionSupported : public Rule
583 {
584 ActivationFunctionSupported(const ActivationDescriptor& desc)
585 {
586 switch(desc.m_Function)
587 {
588 case ActivationFunction::Abs:
589 case ActivationFunction::BoundedReLu:
David Monahan3b3c3812020-02-25 09:03:29 +0000590 case ActivationFunction::Elu:
Teresa Charlin077cddb2023-09-15 15:19:21 +0100591 case ActivationFunction::Gelu:
Colm Donelan03fbeaf2020-02-26 15:39:23 +0000592 case ActivationFunction::HardSwish:
Derek Lamberti50db4e82019-03-13 14:16:15 +0000593 case ActivationFunction::LeakyReLu:
594 case ActivationFunction::Linear:
595 case ActivationFunction::ReLu:
596 case ActivationFunction::Sigmoid:
597 case ActivationFunction::SoftReLu:
598 case ActivationFunction::Sqrt:
599 case ActivationFunction::Square:
600 case ActivationFunction::TanH:
601 {
602 m_Res = true;
603 break;
604 }
605 default:
606 {
607 m_Res = false;
608 break;
609 }
610 }
611 }
612 };
613
614 // Function is supported
615 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
616 "Reference activation: function not supported.");
617
618 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100619}
620
621bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
622 const TensorInfo& input1,
623 const TensorInfo& output,
624 Optional<std::string&> reasonIfUnsupported) const
625{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000626 bool supported = true;
627
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100628 std::array<DataType,7> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000629 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100630 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000631 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000632 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100633 DataType::QSymmS16,
634 DataType::Signed32
Derek Lamberti50db4e82019-03-13 14:16:15 +0000635 };
636
637 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
638 "Reference addition: input 0 is not a supported type.");
639
640 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
641 "Reference addition: input 1 is not a supported type.");
642
643 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
644 "Reference addition: output is not a supported type.");
645
646 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
647 "Reference addition: input 0 and Input 1 types are mismatched");
648
649 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
650 "Reference addition: input and output types are mismatched");
651
652 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
653 "Reference addition: shapes are not suitable for implicit broadcast.");
654
655 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100656}
657
Nikhil Raj68c2c902019-09-19 11:21:11 +0100658bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
659 const armnn::ArgMinMaxDescriptor &descriptor,
660 armnn::Optional<std::string &> reasonIfUnsupported) const
661{
Jan Eilers8eb25602020-03-09 12:13:48 +0000662 IgnoreUnused(descriptor);
Nikhil Raj68c2c902019-09-19 11:21:11 +0100663
Mike Kelly1f140f72021-04-06 12:25:55 +0100664 std::array<DataType, 8> supportedInputTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100665 {
Teresa Charline300b362020-05-25 10:01:03 +0100666 DataType::Float16,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100667 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100668 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000669 DataType::QAsymmU8,
670 DataType::QSymmS16,
Mike Kelly1f140f72021-04-06 12:25:55 +0100671 DataType::Signed32,
672 DataType::Signed64
673 };
674
675 std::array<DataType,2> supportedOutputTypes = {
676 DataType::Signed32,
677 DataType::Signed64
Nikhil Raj68c2c902019-09-19 11:21:11 +0100678 };
679
680 bool supported = true;
681
Mike Kelly1f140f72021-04-06 12:25:55 +0100682 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100683 "Reference ArgMinMax: input is not a supported type.");
Mike Kelly1f140f72021-04-06 12:25:55 +0100684 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100685 "Reference ArgMinMax: output type not supported");
686
687 return supported;
688}
689
Samuel Yap6b478092022-07-06 15:36:03 +0100690bool RefLayerSupport::IsBatchMatMulSupported(const TensorInfo& inputX,
691 const TensorInfo& inputY,
692 const TensorInfo& output,
693 const BatchMatMulDescriptor& descriptor,
694 Optional<std::string &> reasonIfUnsupported) const
695{
696 IgnoreUnused(descriptor);
697
698 std::array<DataType, 6> supportedTypes =
699 {
Samuel Yap6b478092022-07-06 15:36:03 +0100700 DataType::Float16,
701 DataType::Float32,
702 DataType::QAsymmS8,
703 DataType::QAsymmU8,
704 DataType::QSymmS16
705 };
706
707 bool supported = true;
708
709 supported &= CheckSupportRule(TypeAnyOf(inputX, supportedTypes), reasonIfUnsupported,
710 "Reference batch matrix multiplication: input X is not a supported type");
711
712 supported &= CheckSupportRule(TypeAnyOf(inputY, supportedTypes), reasonIfUnsupported,
713 "Reference batch matrix multiplication: input Y is not a supported type");
714
715 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
716 "Reference batch matrix multiplication: output is not a supported type");
717
718 supported &= CheckSupportRule(TypesAreEqual(inputX, inputY), reasonIfUnsupported,
719 "Reference batch matrix multiplication: input X and input Y types are mismatched");
720
721 supported &= CheckSupportRule(TypesAreEqual(inputX, output), reasonIfUnsupported,
722 "Reference batch matrix multiplication: inputs and output types are mismatched");
723
724 supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputX, 2),
725 reasonIfUnsupported,
726 "Reference batch matrix multiplication: input X is not of rank 2 or greater");
727
728 supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputY, 2),
729 reasonIfUnsupported,
730 "Reference batch matrix multiplication: input Y is not of rank 2 or greater");
731
732 return supported;
733}
734
arovir011c7c81b2018-10-08 11:34:28 +0100735bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
736 const TensorInfo& output,
737 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100738 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100739 const TensorInfo& beta,
740 const TensorInfo& gamma,
741 const BatchNormalizationDescriptor& descriptor,
742 Optional<std::string&> reasonIfUnsupported) const
743{
Jan Eilers8eb25602020-03-09 12:13:48 +0000744 IgnoreUnused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100745
Sadik Armagan303980c2020-04-17 12:45:14 +0100746 std::array<DataType, 6> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100747 {
748 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100749 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100750 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000751 DataType::QAsymmU8,
752 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100753 };
754
755 bool supported = true;
756
757 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
758 "Reference batch normalization: input is not a supported type.");
759
760 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
761 "Reference batch normalization: output is not a supported type.");
762
763 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
764 "Reference batch normalization: input and output types are mismatched");
765
766 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
767 "Reference batch normalization: mean is not a supported type.");
768
769 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
770 "Reference batch normalization: variance is not a supported type.");
771
772 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
773 "Reference batch normalization: beta is not a supported type.");
774
775 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
776 "Reference batch normalization: gamma is not a supported type.");
777
778 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100779}
780
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000781bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
782 const TensorInfo& output,
783 const BatchToSpaceNdDescriptor& descriptor,
784 Optional<std::string&> reasonIfUnsupported) const
785{
Jan Eilers8eb25602020-03-09 12:13:48 +0000786 IgnoreUnused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100787
788 bool supported = true;
789
790 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
791 std::string inputTensorStr = "input";
792 std::string outputTensorStr = "output";
793
794 // Define supported types.
Sadik Armagan303980c2020-04-17 12:45:14 +0100795 std::array<DataType,6> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100796 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000797 DataType::Float32,
798 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100799 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000800 DataType::QAsymmU8,
801 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100802 };
803
804 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
805 "Reference BatchToSpaceNd: input type not supported.");
806
807 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
808 "Reference BatchToSpaceNd: output type not supported.");
809
810 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
811 "Reference BatchToSpaceNd: input and output types mismatched.");
812
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100813 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000814}
815
Idriss Chaouch98e383e2023-08-28 14:28:31 +0100816bool RefLayerSupport::IsBroadcastToSupported(const TensorInfo& input,
817 const TensorInfo& output,
818 const BroadcastToDescriptor& descriptor,
819 Optional<std::string&> reasonIfUnsupported) const
820{
821 IgnoreUnused(descriptor);
822
823 bool supported = true;
824
825 std::array<DataType, 8> supportedTypes
826 {
827 DataType::Float32,
828 DataType::Float16,
829 DataType::QAsymmS8,
830 DataType::QAsymmU8,
831 DataType::QSymmS8,
832 DataType::QSymmS16,
833 DataType::Signed32,
834 DataType::Signed64
835 };
836
837 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
838 "BroadcastTo: input type not supported.");
839
840 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
841 "BroadcastTo: output type not supported");
842
843 return supported;
844}
845
mathad01b392e982021-04-07 12:07:30 +0100846bool RefLayerSupport::IsCastSupported(const TensorInfo& input,
847 const TensorInfo& output,
848 Optional<std::string&> reasonIfUnsupported) const
849{
Teresa Charlin5306dc82023-10-30 22:29:58 +0000850 std::array<DataType, 10> supportedInputTypes =
Idriss Chaouch98e383e2023-08-28 14:28:31 +0100851 {
852 DataType::Float32,
853 DataType::Float16,
854 DataType::QSymmS8,
855 DataType::QAsymmS8,
856 DataType::QAsymmU8,
857 DataType::QSymmS16,
Teresa Charlin5306dc82023-10-30 22:29:58 +0000858 DataType::Signed32,
859 DataType::Signed64
Idriss Chaouch98e383e2023-08-28 14:28:31 +0100860 };
mathad01b392e982021-04-07 12:07:30 +0100861
862 bool supported = true;
863 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
864 "Reference cast: input is not a supported type");
865
866
867 supported &= CheckSupportRule(TypeAnyOf(output, supportedInputTypes), reasonIfUnsupported,
868 "Reference cast: output is not a supported type");
869
870 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
871 "Reference cast: input and output shapes have different number of total elements");
872
873 return supported;
874}
875
Simon Obute51f67772021-09-03 15:50:13 +0100876bool RefLayerSupport::IsChannelShuffleSupported(const TensorInfo& input,
877 const TensorInfo& output,
878 const ChannelShuffleDescriptor& descriptor,
879 Optional<std::string&> reasonIfUnsupported) const
880{
881 IgnoreUnused(descriptor);
882 bool supported = true;
883
884 // Define supported output and inputs types.
885 std::array<DataType, 7> supportedTypes =
886 {
Simon Obute51f67772021-09-03 15:50:13 +0100887 DataType::Float32,
888 DataType::Float16,
889 DataType::QAsymmS8,
890 DataType::QAsymmU8,
891 DataType::QSymmS8,
892 DataType::QSymmS16
893 };
894
895 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
896 "Reference ChannelShuffle: input is not a supported type.");
897
898 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
899 "Reference ChannelShuffle: output is not a supported type.");
900
901 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
902 "Reference ChannelShuffle: input and output types are mismatched.");
903
904 return supported;
905}
906
907
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100908bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
909 const TensorInfo& input1,
910 const TensorInfo& output,
911 const ComparisonDescriptor& descriptor,
912 Optional<std::string&> reasonIfUnsupported) const
913{
Jan Eilers8eb25602020-03-09 12:13:48 +0000914 IgnoreUnused(descriptor);
Sadik Armagan303980c2020-04-17 12:45:14 +0100915 std::array<DataType, 8> supportedInputTypes =
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100916 {
Sadik Armaganb60dd242020-03-19 13:53:16 +0000917 DataType::Boolean,
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100918 DataType::Float32,
919 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100920 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000921 DataType::QAsymmU8,
Sadik Armaganb60dd242020-03-19 13:53:16 +0000922 DataType::QSymmS16,
923 DataType::Signed32
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100924 };
925
926 bool supported = true;
927 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
928 "Reference comparison: input 0 is not a supported type");
929
930 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
931 "Reference comparison: input 0 and Input 1 types are mismatched");
932
933 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
934 "Reference comparison: output is not of type Boolean");
935
936 return supported;
937}
938
Jim Flynn906f9462019-05-10 13:55:21 +0100939bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
940 const TensorInfo& output,
Cathal Corbett34b429c2021-12-24 12:24:40 +0000941 const OriginsDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100942 Optional<std::string&> reasonIfUnsupported) const
943{
Jan Eilers8eb25602020-03-09 12:13:48 +0000944 IgnoreUnused(descriptor);
Jim Flynne242f2d2019-05-22 14:24:13 +0100945
946 bool supported = true;
Teresa Charlin6abc7ee2022-02-22 17:32:27 +0000947 std::array<DataType,7> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100948 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000949 DataType::Float32,
950 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000951 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100952 DataType::QAsymmU8,
Teresa Charlin6abc7ee2022-02-22 17:32:27 +0000953 DataType::QSymmS16,
954 DataType::Signed32
Jim Flynne242f2d2019-05-22 14:24:13 +0100955 };
956
957 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
958 "Reference concatenation: output type not supported");
959 for (const TensorInfo* input : inputs)
960 {
961 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
962 "Reference concatenation: input type not supported");
963
964 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
965 "Reference concatenation: input and output types mismatched.");
966 }
967
968 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100969}
970
arovir011c7c81b2018-10-08 11:34:28 +0100971bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
972 Optional<std::string&> reasonIfUnsupported) const
973{
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100974 std::array<DataType,8> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100975 {
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100976 DataType::Float16,
Nina Drozd58ef2c62019-05-16 12:09:18 +0100977 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +0000978 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100979 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000980 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100981 DataType::QSymmS16,
982 DataType::Signed32
Nina Drozd58ef2c62019-05-16 12:09:18 +0100983 };
984
985 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
986 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100987}
988
989bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
990 const TensorInfo& output,
991 Optional<std::string&> reasonIfUnsupported) const
992{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100993 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
994 input.GetDataType(),
995 &TrueFunc<>,
996 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000997 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000998 &FalseFuncI32<>,
999 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001000 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
1001 output.GetDataType(),
1002 &FalseOutputFuncF16<>,
1003 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +00001004 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +00001005 &FalseFuncI32<>,
1006 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +01001007}
1008
1009bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
1010 const TensorInfo& output,
1011 Optional<std::string&> reasonIfUnsupported) const
1012{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001013 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
1014 input.GetDataType(),
1015 &FalseInputFuncF16<>,
1016 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +00001017 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +00001018 &FalseFuncI32<>,
1019 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001020 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
1021 output.GetDataType(),
1022 &TrueFunc<>,
1023 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +00001024 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +00001025 &FalseFuncI32<>,
1026 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +01001027}
1028
1029bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
1030 const TensorInfo& output,
1031 const Convolution2dDescriptor& descriptor,
1032 const TensorInfo& weights,
1033 const Optional<TensorInfo>& biases,
1034 Optional<std::string&> reasonIfUnsupported) const
1035{
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001036 bool supported = true;
1037
1038 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001039 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001040 {
1041 DataType::Float32,
1042 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001043 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001044 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001045 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001046 DataType::QSymmS16
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001047 };
1048
1049 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001050 "Reference Convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001051
1052 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001053 "Reference Convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001054
Ryan OShea31441592022-11-07 16:20:48 +00001055 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1056 "Reference Convolution2d: input and output types mismatched.");
1057
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001058
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001059 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +00001060 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001061 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001062 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001063 {
Sadik Armagan303980c2020-04-17 12:45:14 +01001064 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001065 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01001066 DataType::QSymmS8
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001067 };
1068
1069 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001070 "Reference Convolution2d: weights type not supported for quantized input.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001071 }
1072 else
1073 {
1074 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001075 "Reference Convolution2d: weights is not a supported type.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001076
1077 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001078 "Reference Convolution2d: input and weights types mismatched.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001079 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001080
1081 if (biases.has_value())
1082 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001083 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001084 {
1085 DataType::Float32,
1086 DataType::Float16,
1087 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001088 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001089
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001090 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001091 "Reference Convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001092 }
Jan Eilers8eb25602020-03-09 12:13:48 +00001093 IgnoreUnused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001094
1095 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001096}
1097
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001098bool RefLayerSupport::IsConvolution3dSupported(const TensorInfo& input,
1099 const TensorInfo& output,
1100 const Convolution3dDescriptor& descriptor,
1101 const TensorInfo& weights,
1102 const Optional<TensorInfo>& biases,
1103 Optional<std::string&> reasonIfUnsupported) const
1104{
1105 bool supported = true;
1106
1107 // Define supported types.
1108 std::array<DataType,7> supportedTypes =
1109 {
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001110 DataType::Float32,
1111 DataType::Float16,
1112 DataType::QAsymmS8,
1113 DataType::QAsymmU8,
1114 DataType::QSymmS8,
1115 DataType::QSymmS16
1116 };
1117
1118 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1119 "Reference Convolution3d: input is not a supported type.");
1120
1121 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1122 "Reference Convolution3d: output is not a supported type.");
1123
1124 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1125 "Reference Convolution3d: input and output types mismatched.");
1126
1127 const DataType inputType = input.GetDataType();
1128 if (IsQuantized8BitType(inputType))
1129 {
1130 std::array<DataType, 3> supportedWeightTypes =
1131 {
1132 DataType::QAsymmS8,
1133 DataType::QAsymmU8,
1134 DataType::QSymmS8
1135 };
1136
1137 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1138 "Reference Convolution3d: weights type not supported for quantized input.");
1139 }
1140 else
1141 {
1142 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1143 "Reference Convolution3d: weights is not a supported type.");
1144
1145 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1146 "Reference Convolution3d: input and weights types mismatched.");
1147 }
1148
1149 if (biases.has_value())
1150 {
1151 std::array<DataType,4> biasesSupportedTypes =
1152 {
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001153 DataType::Float32,
1154 DataType::Float16,
1155 DataType::Signed32
1156 };
1157
1158 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1159 "Reference Convolution3d: biases is not a supported type.");
1160 }
1161 IgnoreUnused(descriptor);
1162
1163 return supported;
1164}
1165
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001166bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
1167 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001168 Optional<std::string&> reasonIfUnsupported) const
1169{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001170 bool supported = true;
1171
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001172 std::array<DataType, 8> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001173 {
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001174 DataType::BFloat16,
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +00001175 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001176 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001177 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001178 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001179 DataType::QSymmS8,
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +00001180 DataType::QSymmS16,
1181 DataType::Signed32
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001182 };
1183
1184 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001185 "Reference for Debug layer: input type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001186
1187 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001188 "Reference for Debug layer: output type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001189
1190 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001191 "Reference for Debug layer: input and output types are mismatched");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001192
1193 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001194}
1195
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001196bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
1197 const TensorInfo& output,
1198 const DepthToSpaceDescriptor& descriptor,
1199 Optional<std::string&> reasonIfUnsupported) const
1200{
Jan Eilers8eb25602020-03-09 12:13:48 +00001201 IgnoreUnused(descriptor);
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001202 bool supported = true;
1203
Sadik Armagan303980c2020-04-17 12:45:14 +01001204 std::array<DataType,6> supportedTypes =
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001205 {
1206 DataType::Float32,
1207 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001208 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001209 DataType::QAsymmU8,
1210 DataType::QSymmS16
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001211 };
1212
1213 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1214 "Reference DepthToSpace: input type not supported");
1215
1216 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1217 "Reference DepthToSpace: output type not supported");
1218
1219 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1220 "Reference DepthToSpace: input and output types are mismatched");
1221
1222 return supported;
1223}
1224
arovir011c7c81b2018-10-08 11:34:28 +01001225bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
1226 const TensorInfo& output,
1227 const DepthwiseConvolution2dDescriptor& descriptor,
1228 const TensorInfo& weights,
1229 const Optional<TensorInfo>& biases,
1230 Optional<std::string&> reasonIfUnsupported) const
1231{
Sadik Armagan303980c2020-04-17 12:45:14 +01001232 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001233 bool supported = true;
1234
1235 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001236 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001237 {
1238 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001239 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001240 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001241 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001242 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001243 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001244 };
1245
1246 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1247 "Reference DepthwiseConvolution2d: input is not a supported type.");
1248
1249 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1250 "Reference DepthwiseConvolution2d: output is not a supported type.");
1251
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001252 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1253 "Reference DepthwiseConvolution2d: input and output types mismatched.");
1254
Teresa Charlind8df0262019-11-11 12:28:15 +00001255 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +00001256 if (IsQuantized8BitType(inputType))
Teresa Charlind8df0262019-11-11 12:28:15 +00001257 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001258 std::array<DataType, 3> supportedWeightTypes =
Sadik Armagan303980c2020-04-17 12:45:14 +01001259 {
1260 DataType::QAsymmS8,
1261 DataType::QAsymmU8,
1262 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001263 };
Teresa Charlind8df0262019-11-11 12:28:15 +00001264
1265 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Sadik Armagan303980c2020-04-17 12:45:14 +01001266 "Reference DepthwiseConvolution2d: weights type not supported for "
1267 "quantized input.");
Teresa Charlind8df0262019-11-11 12:28:15 +00001268 }
1269 else
1270 {
1271 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1272 "Reference DepthwiseConvolution2d: weights is not a supported type.");
1273
1274 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1275 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
1276 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001277
1278 if (biases.has_value())
1279 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001280 std::array<DataType,4> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001281 {
1282 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001283 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001284 DataType::Signed32
1285 };
1286 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1287 "Reference DepthwiseConvolution2d: biases is not a supported type.");
1288 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001289
1290 return supported;
1291
arovir011c7c81b2018-10-08 11:34:28 +01001292}
1293
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001294bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
1295 const TensorInfo& output,
1296 Optional<std::string&> reasonIfUnsupported) const
1297{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001298 bool supported = true;
1299
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001300 std::array<DataType,5> supportedInputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00001301 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001302 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00001303 DataType::QSymmS8,
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001304 DataType::QSymmS16,
1305 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001306 };
1307
1308 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001309 "Reference for Dequantize layer: input type not supported.");
1310
Derek Lambertid466a542020-01-22 15:37:29 +00001311 supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
Teresa Charlin1b1950d2021-06-02 20:23:21 +01001312 "Reference for Dequantize layer: per-axis quantized input not supported.");
Derek Lambertid466a542020-01-22 15:37:29 +00001313
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001314 std::array<DataType,3> supportedOutputTypes = {
Jan Eilersf7107932019-11-01 11:09:36 +00001315 DataType::Float32,
1316 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001317 };
1318
1319 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001320 "Reference for Dequantize layer: output type not supported.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001321
1322 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001323 "Reference for Dequantize layer: input/output shapes have different num total "
1324 "elements.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001325
1326 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001327}
1328
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001329bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
1330 const TensorInfo& scores,
1331 const TensorInfo& anchors,
1332 const TensorInfo& detectionBoxes,
1333 const TensorInfo& detectionClasses,
1334 const TensorInfo& detectionScores,
1335 const TensorInfo& numDetections,
1336 const DetectionPostProcessDescriptor& descriptor,
1337 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001338{
Jan Eilers8eb25602020-03-09 12:13:48 +00001339 IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
Derek Lamberti901ea112019-12-10 22:07:09 +00001340
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001341 bool supported = true;
1342
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001343 std::array<DataType,6> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001344 {
1345 DataType::Float32,
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001346 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001347 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001348 DataType::QAsymmU8,
1349 DataType::QSymmS16
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001350 };
1351
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001352 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001353 "Reference DetectionPostProcess: input 0 is not a supported type.");
1354
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001355 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001356 "Reference DetectionPostProcess: input 1 is not a supported type.");
1357
1358 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001359}
1360
Pablo Tellof0bd6832019-04-26 17:58:13 +01001361bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
1362 const TensorInfo& output,
1363 const DepthwiseConvolution2dDescriptor& descriptor,
1364 const TensorInfo& weights,
1365 const Optional<TensorInfo>& biases,
1366 Optional<std::string&> reasonIfUnsupported) const
1367{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001368 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +01001369}
1370
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001371bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +01001372 const TensorInfo& input1,
1373 const TensorInfo& output,
1374 Optional<std::string&> reasonIfUnsupported) const
1375{
Sadik Armagan2999a022019-04-09 14:20:12 +01001376 bool supported = true;
1377
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001378 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001379 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001380 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001381 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001382 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001383 DataType::QSymmS16,
1384 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001385 };
1386
1387 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1388 "Reference division: input 0 is not a supported type.");
1389
1390 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1391 "Reference division: input 1 is not a supported type.");
1392
1393 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1394 "Reference division: output is not a supported type.");
1395
1396 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1397 "Reference division: input 0 and Input 1 types are mismatched");
1398
1399 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1400 "Reference division: input and output types are mismatched");
1401
1402 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1403 "Reference division: shapes are not suitable for implicit broadcast.");
1404
1405 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001406}
1407
josh minor4a3c6102020-01-06 16:40:46 -06001408bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
1409 const TensorInfo& output,
1410 const ElementwiseUnaryDescriptor& descriptor,
1411 Optional<std::string&> reasonIfUnsupported) const
1412{
Jan Eilers8eb25602020-03-09 12:13:48 +00001413 IgnoreUnused(descriptor);
josh minor4a3c6102020-01-06 16:40:46 -06001414
Sadik Armagan303980c2020-04-17 12:45:14 +01001415 std::array<DataType, 7> supportedTypes =
josh minor4a3c6102020-01-06 16:40:46 -06001416 {
1417 DataType::Float32,
1418 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001419 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06001420 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00001421 DataType::QSymmS16,
1422 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06001423 };
1424
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001425 std::array<DataType, 1> logicalSupportedTypes =
1426 {
1427 DataType::Boolean
1428 };
1429
josh minor4a3c6102020-01-06 16:40:46 -06001430 bool supported = true;
1431
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001432 if (descriptor.m_Operation == UnaryOperation::LogicalNot)
1433 {
1434 supported &= CheckSupportRule(TypeAnyOf(input, logicalSupportedTypes), reasonIfUnsupported,
1435 "Reference elementwise unary: input type not supported");
josh minor4a3c6102020-01-06 16:40:46 -06001436
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001437 supported &= CheckSupportRule(TypeAnyOf(output, logicalSupportedTypes), reasonIfUnsupported,
1438 "Reference elementwise unary: output type not supported");
1439 }
1440 else
1441 {
1442 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1443 "Reference elementwise unary: input type not supported");
1444
1445 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1446 "Reference elementwise unary: output type not supported");
1447 }
josh minor4a3c6102020-01-06 16:40:46 -06001448
1449 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1450 "Reference elementwise unary: input and output types not matching");
1451
1452 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1453 "Reference elementwise unary: input and output shapes"
1454 "have different number of total elements");
1455
1456 return supported;
1457}
1458
arovir011c7c81b2018-10-08 11:34:28 +01001459bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
1460 const FakeQuantizationDescriptor& descriptor,
1461 Optional<std::string&> reasonIfUnsupported) const
1462{
Jan Eilers8eb25602020-03-09 12:13:48 +00001463 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001464 bool supported = true;
1465
1466 std::array<DataType,1> supportedTypes =
1467 {
1468 DataType::Float32
1469 };
1470
1471 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1472 "Reference fake quantization: input type not supported.");
1473
1474 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001475}
1476
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001477bool RefLayerSupport::IsFillSupported(const TensorInfo& input,
1478 const TensorInfo& output,
1479 const FillDescriptor& descriptor,
1480 Optional<std::string&> reasonIfUnsupported) const
1481{
1482 IgnoreUnused(descriptor);
1483 IgnoreUnused(output);
1484
1485 bool supported = true;
1486
Sadik Armagana792a052020-06-23 16:22:23 +01001487 std::array<DataType,3> supportedTypes =
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001488 {
1489 DataType::Float32,
Sadik Armagana792a052020-06-23 16:22:23 +01001490 DataType::Float16,
1491 DataType::Signed32
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001492 };
1493
Teresa Charlin4b10fef2020-07-29 09:36:41 +01001494 supported &= CheckSupportRule(TypeIs(input, DataType::Signed32), reasonIfUnsupported,
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001495 "Reference Fill: input type not supported.");
1496
Teresa Charlin44088502020-07-27 11:27:19 +01001497 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1498 "Reference Fill: output type not supported.");
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001499 return supported;
1500}
1501
arovir011c7c81b2018-10-08 11:34:28 +01001502bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
1503 const TensorInfo& output,
1504 Optional<std::string&> reasonIfUnsupported) const
1505{
Jan Eilers8eb25602020-03-09 12:13:48 +00001506 IgnoreUnused(output);
James Conroy83735b12019-05-30 16:36:59 +01001507 bool supported = true;
1508
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01001509 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +01001510 {
James Conroyb40d7102019-06-04 12:32:09 +01001511 DataType::Float32,
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01001512 DataType::Float16
James Conroy83735b12019-05-30 16:36:59 +01001513 };
1514
1515 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1516 "Reference Floor: input type not supported.");
1517
1518 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1519 "Reference Floor: output type not supported.");
1520
1521 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001522}
1523
1524bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
1525 const TensorInfo& output,
1526 const TensorInfo& weights,
1527 const TensorInfo& biases,
1528 const FullyConnectedDescriptor& descriptor,
1529 Optional<std::string&> reasonIfUnsupported) const
1530{
Francis Murtagh46c09d02019-05-28 08:15:28 +01001531 bool supported = true;
1532
1533 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001534 std::array<DataType,6> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +01001535 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001536 DataType::Float32,
1537 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001538 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001539 DataType::QAsymmU8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001540 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001541 };
1542
1543 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1544 "Reference Fully Connected: input type not supported.");
1545
1546 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1547 "Reference Fully Connected: output type not supported.");
1548
Francis Murtagh46c09d02019-05-28 08:15:28 +01001549 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1550 "Reference Fully Connected: weights type not supported.");
1551
Ryan OShea31441592022-11-07 16:20:48 +00001552 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1553 "Reference Fully Connected: input and output types mismatched.");
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001554
Jan Eilers1f45dc32020-06-15 11:43:03 +01001555 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1556 "Reference Fully Connected: weights is not a supported type.");
Francis Murtaghddb1d062020-03-10 13:51:45 +00001557
Jan Eilers1f45dc32020-06-15 11:43:03 +01001558 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1559 "Reference Fully Connected: input and weights types mismatched.");
Francis Murtagh46c09d02019-05-28 08:15:28 +01001560
1561 if (descriptor.m_BiasEnabled)
1562 {
1563 // Defined supported types for bias
Sadik Armagandb73c982020-04-01 17:35:30 +01001564 std::array<DataType, 5>
Francis Murtagh46c09d02019-05-28 08:15:28 +01001565 supportedBiasTypes =
1566 {
1567 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001568 DataType::Float16,
Sadik Armagandb73c982020-04-01 17:35:30 +01001569 DataType::Signed32,
1570 DataType::QAsymmS8
Francis Murtagh46c09d02019-05-28 08:15:28 +01001571 };
1572
1573 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
1574 "Reference Fully Connected: bias type not supported.");
1575
1576 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
1577 "Reference Fully Connected: bias and weight types mismatch.");
1578
1579 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
1580 "Reference Fully Connected: bias type inferred from weights is incompatible.");
1581
Narumol Prangnawarat366d7232020-04-29 12:58:17 +01001582 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(biases, 1U), reasonIfUnsupported,
1583 "Reference Fully Connected: bias must have 1 dimension.");
1584
Francis Murtagh46c09d02019-05-28 08:15:28 +01001585 }
1586
1587 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001588}
1589
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01001590bool RefLayerSupport::IsGatherNdSupported(const armnn::TensorInfo& input0,
1591 const armnn::TensorInfo& input1,
1592 const armnn::TensorInfo& output,
1593 armnn::Optional<std::string&> reasonIfUnsupported) const
1594{
1595 bool supported = true;
1596 std::array<DataType,7> supportedTypes =
1597 {
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01001598 DataType::Float32,
1599 DataType::Float16,
1600 DataType::QAsymmS8,
1601 DataType::QAsymmU8,
1602 DataType::QSymmS16,
1603 DataType::Signed32
1604 };
1605
1606 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1607 "Reference GatherNd: input type not supported");
1608
1609 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1610 "Reference GatherNd: output type not supported");
1611
1612 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1613 "Reference GatherNd: indices (input1) type not supported");
1614
1615 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1616 "Reference GatherNd: input and output types not matching");
1617
1618 return supported;
1619}
1620
narpra014951d842019-01-18 16:53:53 +00001621bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
1622 const armnn::TensorInfo& input1,
1623 const armnn::TensorInfo& output,
Teresa Charlin52664732020-06-29 16:27:03 +01001624 const GatherDescriptor& descriptor,
narpra014951d842019-01-18 16:53:53 +00001625 armnn::Optional<std::string&> reasonIfUnsupported) const
1626{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001627 bool supported = true;
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001628 std::array<DataType,7> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001629 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001630 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001631 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001632 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001633 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001634 DataType::QSymmS16,
1635 DataType::Signed32
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001636 };
1637
Nikhil Raj369d8fc2022-11-24 13:12:36 +00001638 IgnoreUnused(descriptor);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001639 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1640 "Reference Gather: input type not supported");
1641
1642 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1643 "Reference Gather: output type not supported");
1644
1645 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1646 "Reference Gather: indices (input1) type not supported");
1647
1648 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1649 "Reference Gather: input and output types not matching");
1650
1651 return supported;
narpra014951d842019-01-18 16:53:53 +00001652}
1653
Derek Lamberti901ea112019-12-10 22:07:09 +00001654bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
1655 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +01001656{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001657 return true;
arovir011c7c81b2018-10-08 11:34:28 +01001658}
1659
Kevin May09ca49c2019-10-09 12:37:34 +01001660bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
1661 const TensorInfo& output,
1662 const InstanceNormalizationDescriptor& descriptor,
1663 Optional<std::string&> reasonIfUnsupported) const
1664{
Jan Eilers8eb25602020-03-09 12:13:48 +00001665 IgnoreUnused(descriptor);
Kevin May09ca49c2019-10-09 12:37:34 +01001666 // Define supported types
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001667 std::array<DataType, 3> supportedTypes =
Kevin May09ca49c2019-10-09 12:37:34 +01001668 {
1669 DataType::Float32,
1670 DataType::Float16
1671 };
1672
1673 bool supported = true;
1674
1675 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1676 "Reference Instance Normalization: input type not supported.");
1677
1678 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1679 "Reference Instance Normalization: output type not supported.");
1680
1681 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1682 "Reference Instance Normalization: input and output types mismatched.");
1683
1684 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1685 "Reference Instance Normalization: input and output shapes have different "
1686 "num total elements.");
1687
1688 return supported;
1689}
1690
arovir011c7c81b2018-10-08 11:34:28 +01001691bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1692 const TensorInfo& output,
1693 const L2NormalizationDescriptor& descriptor,
1694 Optional<std::string&> reasonIfUnsupported) const
1695{
Jan Eilers8eb25602020-03-09 12:13:48 +00001696 IgnoreUnused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001697 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01001698 std::array<DataType, 6> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001699 {
1700 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001701 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001702 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001703 DataType::QAsymmU8,
1704 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001705 };
1706
1707 bool supported = true;
1708
1709 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1710 "Reference L2normalization: input type not supported.");
1711
1712 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1713 "Reference L2normalization: output type not supported.");
1714
1715 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1716 "Reference L2normalization: input and output types mismatched.");
1717
1718 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1719 "Reference L2normalization: input and output shapes have different "
1720 "num total elements.");
1721
1722 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001723}
1724
James Conroyaba90cd2020-11-06 16:28:18 +00001725bool RefLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0,
1726 const TensorInfo& input1,
1727 const TensorInfo& output,
1728 const LogicalBinaryDescriptor& descriptor,
1729 Optional<std::string&> reasonIfUnsupported) const
1730{
1731 IgnoreUnused(descriptor);
1732
1733 std::array<DataType, 1> supportedTypes =
1734 {
1735 DataType::Boolean
1736 };
1737
1738 bool supported = true;
1739 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1740 "Reference LogicalBinary: input 0 type not supported");
1741 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1742 "Reference LogicalBinary: input 1 type not supported");
1743
1744 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1745 "Reference LogicalBinary: input and output types do not match");
1746
1747 return supported;
1748}
1749
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001750bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1751 const TensorInfo& output,
1752 const LogSoftmaxDescriptor& descriptor,
1753 Optional<std::string&> reasonIfUnsupported) const
1754{
Jan Eilers8eb25602020-03-09 12:13:48 +00001755 IgnoreUnused(descriptor);
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001756
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001757 std::array<DataType, 3> supportedTypes =
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001758 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001759 DataType::Float32,
1760 DataType::Float16
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001761 };
1762
1763 bool supported = true;
1764 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1765 "Reference LogSoftmax: input type not supported");
1766
1767 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1768 "Reference LogSoftmax: output type not supported");
1769
1770 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1771 "Reference LogSoftmax: input and output types do not match");
1772
1773 return supported;
1774}
1775
arovir011c7c81b2018-10-08 11:34:28 +01001776bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1777 const TensorInfo& outputStateIn,
1778 const TensorInfo& cellStateIn,
1779 const TensorInfo& scratchBuffer,
1780 const TensorInfo& outputStateOut,
1781 const TensorInfo& cellStateOut,
1782 const TensorInfo& output,
1783 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001784 const LstmInputParamsInfo& paramsInfo,
1785 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +01001786{
Jan Eilers8eb25602020-03-09 12:13:48 +00001787 IgnoreUnused(descriptor);
1788 IgnoreUnused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001789
1790 bool supported = true;
1791
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001792 std::array<DataType,3> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +01001793 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001794 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001795 };
1796
Jan Eilersd01a83c2019-07-03 18:20:40 +01001797 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001798 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1799 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001800 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1801 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001802 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1803 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001804 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1805 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001806 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1807 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001808 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1809 "Reference Lstm: input and cellStateOut types are mismatched");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01001810
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001811 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1812 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001813 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001814 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001815 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001816 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001817 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001818 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001819 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001820 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001821 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001822 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001823 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001824 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001825 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001826 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001827 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001828 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001829 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001830 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001831 "Reference Lstm: input and OutputGateBias types are mismatched");
1832 if (!descriptor.m_CifgEnabled)
1833 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001834 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001835 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001836 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001837 reasonIfUnsupported,
1838 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001839 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001840 "Reference Lstm: input and InputGateBias types are mismatched");
1841 if (descriptor.m_PeepholeEnabled)
1842 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001843 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001844 reasonIfUnsupported,
1845 "Reference Lstm: input and CellToInputWeights types are mismatched");
1846 }
1847 }
1848 if (descriptor.m_PeepholeEnabled)
1849 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001850 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001851 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001852 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001853 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1854 }
1855 if (descriptor.m_ProjectionEnabled)
1856 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001857 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001858 "Reference Lstm: input and mProjectionWeights types are mismatched");
1859 if (paramsInfo.m_ProjectionBias != nullptr)
1860 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001861 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001862 "Reference Lstm: input and ProjectionBias types are mismatched");
1863 }
1864 }
1865 if (descriptor.m_LayerNormEnabled)
1866 {
1867 if (!descriptor.m_CifgEnabled)
1868 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001869 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001870 reasonIfUnsupported,
1871 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1872 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001873 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001874 reasonIfUnsupported,
1875 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001876 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001877 reasonIfUnsupported,
1878 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001879 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001880 reasonIfUnsupported,
1881 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1882 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001883
1884 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001885}
1886
saoste012df12b32018-11-28 16:57:20 +00001887bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1888 const TensorInfo& input1,
1889 const TensorInfo& output,
1890 Optional<std::string&> reasonIfUnsupported) const
1891{
Sadik Armagan2999a022019-04-09 14:20:12 +01001892 bool supported = true;
1893
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001894 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001895 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001896 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001897 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001898 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001899 DataType::QSymmS16,
1900 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001901 };
1902
1903 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1904 "Reference maximum: input 0 is not a supported type.");
1905
1906 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1907 "Reference maximum: input 1 is not a supported type.");
1908
1909 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1910 "Reference maximum: output is not a supported type.");
1911
1912 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1913 "Reference maximum: input 0 and Input 1 types are mismatched");
1914
1915 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1916 "Reference maximum: input and output types are mismatched");
1917
1918 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1919 "Reference maximum: shapes are not suitable for implicit broadcast.");
1920
1921 return supported;
saoste012df12b32018-11-28 16:57:20 +00001922}
1923
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001924bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1925 const TensorInfo& output,
1926 const MeanDescriptor& descriptor,
1927 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001928{
James Conroy4d1ff582019-06-10 17:06:39 +01001929 bool supported = true;
1930 std::string meanLayerStr = "Mean";
1931 std::string outputTensorStr = "output";
1932
Sadik Armagan303980c2020-04-17 12:45:14 +01001933 std::array<DataType,6> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001934 {
1935 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001936 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001937 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001938 DataType::QAsymmU8,
1939 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01001940 };
1941
1942 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1943 "Reference Mean: input type not supported.");
1944
1945 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1946 "Reference Mean: input and output types are mismatched");
1947
1948 if (descriptor.m_KeepDims)
1949 {
1950 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1951 reasonIfUnsupported,
1952 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1953 output.GetNumDimensions(),
1954 meanLayerStr, outputTensorStr).data());
1955 }
1956 else if (descriptor.m_Axis.empty())
1957 {
1958 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1959 reasonIfUnsupported,
1960 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1961 meanLayerStr, outputTensorStr).data());
1962 }
1963 else
1964 {
Matthew Sloyan171214c2020-09-09 09:07:37 +01001965 auto outputDim = input.GetNumDimensions() - armnn::numeric_cast<unsigned int>(descriptor.m_Axis.size());
James Conroy4d1ff582019-06-10 17:06:39 +01001966
1967 if (outputDim > 0)
1968 {
1969 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1970 reasonIfUnsupported,
1971 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1972 meanLayerStr, outputTensorStr).data());
1973 }
1974 else
1975 {
1976 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1977 reasonIfUnsupported,
1978 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1979 meanLayerStr, outputTensorStr).data());
1980 }
1981 }
1982
1983 return supported;
narpra0132b90462018-09-13 11:07:48 +01001984}
1985
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001986bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1987 const TensorInfo &output,
1988 Optional<std::string &> reasonIfUnsupported) const
1989{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001990 bool supported = true;
1991
Sadik Armagan303980c2020-04-17 12:45:14 +01001992 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001993 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001994 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001995 DataType::Float32,
1996 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001997 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001998 DataType::QAsymmU8,
1999 DataType::QSymmS16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002000 DataType::Boolean
2001 };
2002
2003 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2004 "Reference MemCopy: input type not supported");
2005
2006 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2007 "Reference MemCopy: output type not supported");
2008
2009 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2010 "Reference MemCopy: input and output types are mismatched");
2011
2012 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00002013}
2014
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00002015bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
2016 const TensorInfo& input1,
2017 const TensorInfo& output,
2018 Optional<std::string&> reasonIfUnsupported) const
2019{
Sadik Armagan2999a022019-04-09 14:20:12 +01002020 bool supported = true;
2021
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002022 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01002023 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002024 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002025 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002026 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002027 DataType::QSymmS16,
2028 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002029 };
2030
2031 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2032 "Reference minimum: input 0 is not a supported type.");
2033
2034 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2035 "Reference minimum: input 1 is not a supported type.");
2036
2037 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2038 "Reference minimum: output is not a supported type.");
2039
2040 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2041 "Reference minimum: input 0 and Input 1 types are mismatched");
2042
2043 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2044 "Reference minimum: input and output types are mismatched");
2045
2046 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2047 "Reference minimum: shapes are not suitable for implicit broadcast.");
2048
2049 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00002050}
2051
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002052bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
2053 const TensorInfo& input1,
2054 const TensorInfo& output,
2055 Optional<std::string&> reasonIfUnsupported) const
2056{
Sadik Armagan2999a022019-04-09 14:20:12 +01002057 bool supported = true;
2058
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002059 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01002060 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002061 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002062 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002063 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002064 DataType::QSymmS16,
2065 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002066 };
2067
2068 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2069 "Reference multiplication: input 0 is not a supported type.");
2070
2071 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2072 "Reference multiplication: input 1 is not a supported type.");
2073
2074 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2075 "Reference multiplication: output is not a supported type.");
2076
2077 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2078 "Reference multiplication: input 0 and Input 1 types are mismatched");
2079
2080 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2081 "Reference multiplication: input and output types are mismatched");
2082
2083 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2084 "Reference multiplication: shapes are not suitable for implicit broadcast.");
2085
2086 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002087}
2088
2089bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
2090 const TensorInfo& output,
2091 const NormalizationDescriptor& descriptor,
2092 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01002093{
Jan Eilers8eb25602020-03-09 12:13:48 +00002094 IgnoreUnused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002095
2096 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01002097 std::array<DataType, 6> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002098 {
2099 DataType::Float16,
2100 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002101 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002102 DataType::QAsymmU8,
2103 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002104 };
2105
2106 bool supported = true;
2107
2108 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2109 "Reference normalization: input type not supported.");
2110
2111 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2112 "Reference normalization: output type not supported.");
2113
2114 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2115 "Reference normalization: input and output shapes have different "
2116 "num total elements.");
2117
2118 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002119}
2120
Derek Lamberti901ea112019-12-10 22:07:09 +00002121bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
2122 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002123{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01002124 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002125}
2126
2127bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
2128 const TensorInfo& output,
2129 const PadDescriptor& descriptor,
2130 Optional<std::string&> reasonIfUnsupported) const
2131{
Jan Eilers8eb25602020-03-09 12:13:48 +00002132 IgnoreUnused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002133 bool supported = true;
2134
2135 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002136 std::array<DataType,6> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002137 {
2138 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002139 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002140 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002141 DataType::QAsymmU8,
2142 DataType::QSymmS16
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002143 };
2144
2145 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2146 "Reference pad: input is not a supported type.");
2147
2148 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2149 "Reference pad: output is not a supported type.");
2150
2151 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2152 "Reference pad: input and output types are mismatched.");
2153
2154 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01002155}
2156
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002157bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
2158 const TensorInfo& output,
2159 const PermuteDescriptor& descriptor,
2160 Optional<std::string&> reasonIfUnsupported) const
2161{
Jan Eilers8eb25602020-03-09 12:13:48 +00002162 IgnoreUnused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002163 bool supported = true;
2164
2165 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002166 std::array<DataType, 6> supportedTypes =
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002167 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002168 DataType::BFloat16,
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002169 DataType::Float32,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002170 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002171 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002172 DataType::QAsymmU8,
2173 DataType::QSymmS16
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002174 };
2175
2176 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2177 "Reference permute: input is not a supported type.");
2178
2179 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2180 "Reference permute: output is not a supported type.");
2181
2182 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2183 "Reference permute: input and output types are mismatched.");
2184
2185 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002186}
2187
2188bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
2189 const TensorInfo& output,
2190 const Pooling2dDescriptor& descriptor,
2191 Optional<std::string&> reasonIfUnsupported) const
2192{
Jan Eilers8eb25602020-03-09 12:13:48 +00002193 IgnoreUnused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01002194 bool supported = true;
2195
2196 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002197 std::array<DataType,6> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01002198 {
2199 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002200 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002201 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002202 DataType::QAsymmU8,
2203 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01002204 };
2205
2206 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2207 "Reference poolind2d: input is not a supported type.");
2208
2209 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2210 "Reference poolind2d: output is not a supported type.");
2211
2212 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2213 "Reference poolind2d: input and output types are mismatched.");
2214
2215 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002216}
2217
Tamás Nyíri7b885b32021-10-26 14:47:57 +01002218bool RefLayerSupport::IsPooling3dSupported(const TensorInfo& input,
2219 const TensorInfo& output,
2220 const Pooling3dDescriptor& descriptor,
2221 Optional<std::string&> reasonIfUnsupported) const
2222{
2223 IgnoreUnused(descriptor);
2224 bool supported = true;
2225
2226 // Define supported output and inputs types.
2227 std::array<DataType,6> supportedTypes =
2228 {
Tamás Nyíri7b885b32021-10-26 14:47:57 +01002229 DataType::Float32,
2230 DataType::Float16,
2231 DataType::QAsymmS8,
2232 DataType::QAsymmU8,
2233 DataType::QSymmS16
2234 };
2235
2236 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2237 "Reference poolind3d: input is not a supported type.");
2238
2239 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2240 "Reference poolind3d: output is not a supported type.");
2241
2242 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2243 "Reference poolind3d: input and output types are mismatched.");
2244
2245 return supported;
2246}
2247
2248
James Conroy4f1f8992020-04-29 20:01:10 +01002249bool RefLayerSupport::IsQLstmSupported(const TensorInfo& input,
2250 const TensorInfo& previousOutputIn,
2251 const TensorInfo& previousCellStateIn,
2252 const TensorInfo& outputStateOut,
2253 const TensorInfo& cellStateOut,
2254 const TensorInfo& output,
2255 const QLstmDescriptor& descriptor,
2256 const LstmInputParamsInfo& paramsInfo,
2257 Optional<std::string&> reasonIfUnsupported) const
2258{
2259 IgnoreUnused(input);
2260 IgnoreUnused(previousOutputIn);
2261 IgnoreUnused(previousCellStateIn);
2262 IgnoreUnused(outputStateOut);
2263 IgnoreUnused(cellStateOut);
2264 IgnoreUnused(output);
2265 IgnoreUnused(descriptor);
2266 IgnoreUnused(paramsInfo);
2267
2268 IgnoreUnused(reasonIfUnsupported);
2269
2270 return true;
2271}
2272
Derek Lamberti5f400d62019-03-25 15:41:58 +00002273bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
2274 const TensorInfo& output,
2275 Optional<std::string&> reasonIfUnsupported) const
2276{
2277 bool supported = true;
2278
Finn Williamsfd271062019-12-04 14:27:27 +00002279 // Define supported input types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002280 std::array<DataType,7> supportedInputTypes = {
Keith Davis5e51cd82020-01-29 16:52:59 +00002281 DataType::Float32,
Keith Davis3d8bc972020-02-04 09:31:47 +00002282 DataType::Float16,
Ryan OShea9add1202020-02-07 10:06:33 +00002283 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002284 DataType::QAsymmU8,
2285 DataType::QSymmS8,
2286 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002287 };
2288
2289 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
2290 "Reference quantize: input type not supported.");
2291
2292 // Define supported output types.
Ryan OShea9add1202020-02-07 10:06:33 +00002293 std::array<DataType,4> supportedOutputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00002294 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002295 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00002296 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002297 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002298 };
2299 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2300 "Reference quantize: output type not supported.");
2301
2302 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2303 "Reference quantize: input and output shapes have different num total elements.");
2304
2305 return supported;
2306}
2307
Finn Williams2605b232020-06-10 15:53:46 +01002308bool RefLayerSupport::IsRankSupported(const TensorInfo& input,
2309 const TensorInfo& output,
2310 Optional<std::string&> reasonIfUnsupported) const
2311{
2312 IgnoreUnused(input);
2313 // Define supported output types.
2314 std::array<DataType,1> supportedOutputTypes =
2315 {
2316 DataType::Signed32,
2317 };
2318
2319 return CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2320 "Reference rank: input type not supported.");
2321}
2322
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00002323bool RefLayerSupport::IsReduceSupported(const TensorInfo& input,
2324 const TensorInfo& output,
2325 const ReduceDescriptor& descriptor,
2326 Optional<std::string&> reasonIfUnsupported) const
2327{
2328 IgnoreUnused(descriptor);
2329 bool supported = true;
2330 std::array<DataType,7> supportedTypes =
2331 {
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00002332 DataType::Float32,
2333 DataType::Float16,
2334 DataType::QAsymmS8,
2335 DataType::QAsymmU8,
2336 DataType::QSymmS16,
2337 DataType::Signed32
2338 };
2339
2340 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2341 "Reference Reduce: input type not supported");
2342
2343 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2344 "Reference Reduce: output type not supported");
2345
2346 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2347 "Reference Reduce: input and output types not matching");
2348
2349 return supported;
2350}
2351
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002352bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Kevin Maya023c402019-12-12 17:28:05 +00002353 const TensorInfo& output,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00002354 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002355 Optional<std::string&> reasonIfUnsupported) const
2356{
Jan Eilers8eb25602020-03-09 12:13:48 +00002357 IgnoreUnused(output);
2358 IgnoreUnused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01002359 // Define supported output types.
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002360 std::array<DataType,8> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01002361 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002362 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01002363 DataType::Float32,
2364 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01002365 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002366 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002367 DataType::QAsymmU8,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002368 DataType::QSymmS16,
2369 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01002370 };
Keith Davis0c2eeac2020-02-11 16:51:50 +00002371
Nina Drozd2f2778f2019-05-27 10:37:05 +01002372 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
2373 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002374}
2375
Teresa Charlin970f43b2019-07-01 13:51:07 +01002376bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
2377 const TensorInfo& output,
2378 const ResizeDescriptor& descriptor,
2379 Optional<std::string&> reasonIfUnsupported) const
2380{
Jan Eilers8eb25602020-03-09 12:13:48 +00002381 IgnoreUnused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01002382 bool supported = true;
Teresa Charlince655882023-11-21 15:44:13 +00002383 std::array<DataType,7> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01002384 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002385 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01002386 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002387 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002388 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002389 DataType::QAsymmU8,
Teresa Charlince655882023-11-21 15:44:13 +00002390 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002391 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01002392 };
2393
2394 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2395 "Reference Resize: input type not supported");
2396
2397 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2398 "Reference Resize: output type not supported");
2399
2400 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2401 "Reference Resize: input and output types not matching");
2402
2403 return supported;
2404}
2405
Tracy Narinebb8d7592023-07-13 16:50:54 +01002406bool RefLayerSupport::IsReverseV2Supported(const TensorInfo& input0,
2407 const TensorInfo& input1,
Tianle Cheng988354d2023-06-28 13:20:47 +01002408 const TensorInfo& output,
Tianle Cheng988354d2023-06-28 13:20:47 +01002409 Optional<std::string&> reasonIfUnsupported) const
2410{
Tianle Cheng988354d2023-06-28 13:20:47 +01002411 bool supported = true;
2412 // ReverseV2 is data type agnostic so it can support all the types in the Reference backend
Declan-ARM1bf56cd2023-07-20 17:32:57 +01002413 std::array<DataType,8> supportedTypes =
Tianle Cheng988354d2023-06-28 13:20:47 +01002414 {
2415 DataType::BFloat16,
2416 DataType::Float32,
2417 DataType::Float16,
2418 DataType::QAsymmS8,
2419 DataType::QAsymmU8,
Declan-ARM1bf56cd2023-07-20 17:32:57 +01002420 DataType::QSymmS8,
2421 DataType::QSymmS16,
2422 DataType::Signed32
Tianle Cheng988354d2023-06-28 13:20:47 +01002423 };
2424
Tracy Narinebb8d7592023-07-13 16:50:54 +01002425 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2426 "Reference ReverseV2: input0 type not supported");
Tianle Cheng988354d2023-06-28 13:20:47 +01002427
2428 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2429 "Reference ReverseV2: output type not supported");
2430
Tracy Narinebb8d7592023-07-13 16:50:54 +01002431 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2432 "Reference ReverseV2: input0 and output types not matching");
2433
2434 std::array<DataType,6> input2SupportedTypes =
2435 {
2436 DataType::Signed32
2437 };
2438
2439 supported &= CheckSupportRule(TypeAnyOf(input1, input2SupportedTypes), reasonIfUnsupported,
2440 "Reference ReverseV2: input1 type not supported");
Tianle Cheng988354d2023-06-28 13:20:47 +01002441
2442 return supported;
2443}
2444
Keith Davis3ae3f972021-05-21 16:33:48 +01002445bool RefLayerSupport::IsShapeSupported(const TensorInfo& input,
2446 const TensorInfo& output,
2447 Optional<std::string&> reasonIfUnsupported) const
2448{
2449 IgnoreUnused(input);
2450 bool supported = true;
2451
2452 std::array<DataType, 1> supportedTypes =
2453 {
2454 DataType::Signed32
2455 };
2456
2457 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2458 "Reference Shape: output type not supported");
2459
2460 return supported;
2461}
2462
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002463bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
2464 const TensorInfo& output,
2465 const SliceDescriptor& descriptor,
2466 Optional<std::string&> reasonIfUnsupported) const
2467{
Jan Eilers8eb25602020-03-09 12:13:48 +00002468 IgnoreUnused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002469 bool supported = true;
2470
Sadik Armagan303980c2020-04-17 12:45:14 +01002471 std::array<DataType, 5> supportedTypes =
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002472 {
2473 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002474 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002475 DataType::QAsymmU8,
Ryan OShea980446b2023-06-08 16:23:28 +01002476 DataType::QSymmS16,
2477 DataType::Signed32
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002478 };
2479
2480 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2481 "Reference Slice: input type not supported");
2482
2483 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2484 "Reference Slice: output type not supported");
2485
2486 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2487 "Reference Slice: input and output types are mismatched");
2488
2489 return supported;
2490}
2491
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002492bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
2493 const TensorInfo& output,
2494 const SoftmaxDescriptor& descriptor,
2495 Optional<std::string&> reasonIfUnsupported) const
2496{
Jan Eilers8eb25602020-03-09 12:13:48 +00002497 IgnoreUnused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01002498 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002499 std::array<DataType,7> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01002500 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002501 DataType::Float32,
2502 DataType::Float16,
2503 DataType::QSymmS8,
2504 DataType::QAsymmS8,
2505 DataType::QAsymmU8,
2506 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +01002507 };
2508
2509 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002510 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002511
2512 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002513 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002514
2515 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002516 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002517
2518 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002519}
2520
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002521bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
2522 const TensorInfo& output,
2523 const SpaceToBatchNdDescriptor& descriptor,
2524 Optional<std::string&> reasonIfUnsupported) const
2525{
Jan Eilers8eb25602020-03-09 12:13:48 +00002526 IgnoreUnused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01002527 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002528 std::array<DataType,6> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01002529 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002530 DataType::Float32,
2531 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002532 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002533 DataType::QAsymmU8,
2534 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01002535 };
2536
2537 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2538 "Reference SpaceToBatchNd: input type not supported");
2539
2540 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2541 "Reference SpaceToBatchNd: output type not supported");
2542
2543 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2544 "Reference SpaceToBatchNd: input and output types are mismatched");
2545
2546 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002547}
2548
Keith Davisa57eccb2019-06-14 17:33:22 +01002549bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01002550 const TensorInfo& output,
2551 const SpaceToDepthDescriptor& descriptor,
2552 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01002553{
2554
Jan Eilers8eb25602020-03-09 12:13:48 +00002555 IgnoreUnused(descriptor);
Keith Davisa57eccb2019-06-14 17:33:22 +01002556 bool supported = true;
2557
Sadik Armagan303980c2020-04-17 12:45:14 +01002558 std::array<DataType,6> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01002559 {
2560 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002561 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002562 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002563 DataType::QAsymmU8,
2564 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01002565 };
2566
2567 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2568 "Reference SpaceToDepth: input type not supported");
2569
2570 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2571 "Reference SpaceToDepth: output type not supported");
2572
2573 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2574 "Reference SpaceToDepth: input and output types are mismatched");
2575
2576 return supported;
2577}
2578
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002579bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002580 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
2581 const ViewsDescriptor& descriptor,
2582 Optional<std::string&> reasonIfUnsupported) const
2583{
Jan Eilers8eb25602020-03-09 12:13:48 +00002584 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002585 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002586 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002587 {
2588 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002589 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002590 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002591 DataType::QAsymmU8,
2592 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002593 };
2594
2595 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2596 "Reference splitter: output type not supported");
Derek Lambertieac4adb2020-08-25 13:05:59 +01002597 for (const TensorInfo& output : outputs)
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002598 {
2599 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2600 "Reference splitter: input type not supported");
2601
2602 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2603 "Reference splitter: input and output types mismatched.");
2604 }
2605
2606 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002607}
2608
Matthew Jackson81e601c2019-07-11 12:07:09 +01002609bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
2610 const TensorInfo& output,
2611 const StackDescriptor& descriptor,
2612 Optional<std::string&> reasonIfUnsupported) const
2613{
Jan Eilers8eb25602020-03-09 12:13:48 +00002614 IgnoreUnused(descriptor);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002615
2616 bool supported = true;
Sadik Armagan529195f2022-01-14 12:56:35 +00002617 std::array<DataType,7> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01002618 {
2619 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01002620 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002621 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002622 DataType::QAsymmU8,
Sadik Armagan529195f2022-01-14 12:56:35 +00002623 DataType::QSymmS16,
2624 DataType::Signed32
Matthew Jackson81e601c2019-07-11 12:07:09 +01002625 };
2626
2627 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2628 "Reference stack: output type not supported");
2629 for (const TensorInfo* input : inputs)
2630 {
Matthew Jackson81e601c2019-07-11 12:07:09 +01002631 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
2632 "Reference stack: input type not supported");
2633
2634 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
2635 "Reference stack: input and output types mismatched.");
2636 }
2637
2638 return supported;
2639}
2640
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002641bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
2642 const TensorInfo& output,
2643 const StridedSliceDescriptor& descriptor,
2644 Optional<std::string&> reasonIfUnsupported) const
2645{
Jan Eilers8eb25602020-03-09 12:13:48 +00002646 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002647 bool supported = true;
2648
Sadik Armagan303980c2020-04-17 12:45:14 +01002649 std::array<DataType,5> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002650 {
2651 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002652 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002653 DataType::QAsymmU8,
2654 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002655 };
2656
2657 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2658 "Reference StridedSlice: input type not supported");
2659
2660 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2661 "Reference StridedSlice: output type not supported");
2662
2663 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2664 "Reference StridedSlice: input and output types are mismatched");
2665
2666 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002667}
2668
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002669bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
2670 const TensorInfo& input1,
2671 const TensorInfo& output,
2672 Optional<std::string&> reasonIfUnsupported) const
2673{
Sadik Armagan2999a022019-04-09 14:20:12 +01002674 bool supported = true;
2675
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002676 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01002677 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002678 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002679 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002680 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002681 DataType::QSymmS16,
2682 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002683 };
2684
2685 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2686 "Reference subtraction: input 0 is not a supported type.");
2687
2688 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2689 "Reference subtraction: input 1 is not a supported type.");
2690
2691 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2692 "Reference subtraction: output is not a supported type.");
2693
2694 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2695 "Reference subtraction: input 0 and Input 1 types are mismatched");
2696
2697 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2698 "Reference subtraction: input and output types are mismatched");
2699
2700 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2701 "Reference subtraction: shapes are not suitable for implicit broadcast.");
2702
2703 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002704}
2705
Matteo Martincighab9e5252019-06-13 17:27:46 +01002706bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
2707 const TensorInfo& alpha,
2708 const TensorInfo& output,
2709 Optional<std::string&> reasonIfUnsupported) const
2710{
2711 bool supported = true;
2712
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002713 std::array<DataType, 6> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01002714 {
2715 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002716 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002717 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002718 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002719 DataType::QSymmS16
Matteo Martincighab9e5252019-06-13 17:27:46 +01002720 };
2721
2722 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2723 "PReLU: input is not a supported type.");
2724
2725 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
2726 "PReLU: alpha is not a supported type.");
2727
2728 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2729 "PReLU: output is not a supported type.");
2730
2731 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
2732 "PReLU: input, alpha and output types are mismatched");
2733
2734 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
2735 "PReLU: shapes are not suitable for implicit broadcast");
2736
2737 return supported;
2738}
2739
Teresa Charlin79a06a52023-07-13 17:16:45 +01002740bool RefLayerSupport::IsTileSupported(const TensorInfo& input,
2741 const TensorInfo& output,
2742 const TileDescriptor& descriptor,
2743 Optional<std::string&> reasonIfUnsupported) const
2744{
2745 IgnoreUnused(descriptor);
2746
2747 bool supported = true;
2748
2749 std::array<DataType, 7> supportedTypes
2750 {
2751 DataType::Float32,
2752 DataType::Float16,
2753 DataType::QAsymmS8,
2754 DataType::QAsymmU8,
2755 DataType::QSymmS8,
2756 DataType::QSymmS16,
2757 DataType::Signed32
2758 };
2759
2760 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2761 "Tile: input type not supported.");
2762
2763 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2764 "Tile: output type not supported");
2765
2766 return supported;
2767}
2768
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002769bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
2770 const TensorInfo& output,
2771 const TransposeConvolution2dDescriptor& descriptor,
2772 const TensorInfo& weights,
2773 const Optional<TensorInfo>& biases,
2774 Optional<std::string&> reasonIfUnsupported) const
2775{
Jan Eilers8eb25602020-03-09 12:13:48 +00002776 IgnoreUnused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002777 bool supported = true;
2778
Sadik Armagan303980c2020-04-17 12:45:14 +01002779 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002780 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002781 DataType::Float32,
2782 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002783 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002784 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002785 DataType::QSymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002786 DataType::QSymmS16
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002787 };
2788
2789 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2790 "Reference TransposeConvolution2d: input is not a supported type.");
2791
2792 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2793 "Reference TransposeConvolution2d: output is not a supported type.");
2794
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002795 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2796 "Reference TransposeConvolution2d: input and output types mismatched.");
2797
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002798
2799 const DataType inputType = input.GetDataType();
Sadik Armagan303980c2020-04-17 12:45:14 +01002800 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002801 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01002802 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002803 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002804 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002805 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01002806 DataType::QSymmS8
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002807 };
2808
2809 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
2810 "Reference TransposeConvolution2d: weights type not supported for "
2811 "quantized input.");
2812 }
2813 else
2814 {
2815 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
2816 "Reference TransposeConvolution2d: weights is not a supported type.");
2817
2818 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
2819 "Reference TransposeConvolution2d: input and weights types mismatched.");
2820 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002821
2822 if (biases.has_value())
2823 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002824 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01002825 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002826 DataType::Float32,
2827 DataType::Float16,
2828 DataType::Signed32
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002829 };
2830 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
2831 "Reference TransposeConvolution2d: biases is not a supported type.");
2832 }
2833
2834 return supported;
2835}
2836
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002837bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
2838 const TensorInfo& output,
2839 const TransposeDescriptor& descriptor,
2840 Optional<std::string&> reasonIfUnsupported) const
2841{
Jan Eilers8eb25602020-03-09 12:13:48 +00002842 IgnoreUnused(descriptor);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002843 bool supported = true;
2844
2845 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002846 std::array<DataType, 6> supportedTypes =
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002847 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002848 DataType::BFloat16,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002849 DataType::Float32,
2850 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002851 DataType::QAsymmS8,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002852 DataType::QAsymmU8,
2853 DataType::QSymmS16
2854 };
2855
2856 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2857 "Reference transpose: input is not a supported type.");
2858
2859 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2860 "Reference transpose: output is not a supported type.");
2861
2862 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2863 "Reference transpose: input and output types are mismatched.");
2864
2865 return supported;
2866}
2867
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002868bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
2869 const TensorInfo& input,
2870 const TensorInfo& outputStateIn,
2871 const TensorInfo& cellStateIn,
Mike Kelly12994962022-04-21 11:57:09 +01002872 const TensorInfo& outputStateOut,
2873 const TensorInfo& cellStateOut,
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002874 const TensorInfo& output,
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002875 const UnidirectionalSequenceLstmDescriptor& descriptor,
2876 const LstmInputParamsInfo& paramsInfo,
2877 Optional<std::string&> reasonIfUnsupported) const
2878{
2879 IgnoreUnused(descriptor);
2880 IgnoreUnused(paramsInfo);
2881 IgnoreUnused(outputStateIn);
2882 IgnoreUnused(cellStateIn);
Mike Kelly12994962022-04-21 11:57:09 +01002883 IgnoreUnused(outputStateOut);
2884 IgnoreUnused(cellStateOut);
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002885 bool supported = true;
2886
Mike Kelly12994962022-04-21 11:57:09 +01002887 std::array<DataType, 2> supportedTypes =
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002888 {
Mike Kelly12994962022-04-21 11:57:09 +01002889 DataType::Float32,
2890 DataType::QAsymmS8
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002891 };
2892
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002893 std::array<DataType, 2> supportedWeightTypes =
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002894 {
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002895 DataType::Float32,
2896 DataType::QAsymmS8
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002897 };
2898
Mike Kelly12994962022-04-21 11:57:09 +01002899 std::array<DataType, 3> supportedBiasTypes =
2900 {
2901 DataType::Float32,
2902 DataType::QAsymmS8,
2903 DataType::Signed32
2904 };
2905
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002906 // check inputs and outputs
2907 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2908 "Reference UnidirectionalSequenceLstm: input is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002909 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2910 "Reference UnidirectionalSequenceLstm: output is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002911
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002912 // check layer parameters
2913 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes),
2914 reasonIfUnsupported,
2915 "Reference UnidirectionalSequenceLstm: InputToForgetWeights "
2916 "is not a supported type.");
2917 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToCellWeights(), supportedWeightTypes),
2918 reasonIfUnsupported,
2919 "Reference UnidirectionalSequenceLstm: InputToCellWeights is not a supported type.");
2920 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToOutputWeights(), supportedWeightTypes),
2921 reasonIfUnsupported,
2922 "Reference UnidirectionalSequenceLstm: InputToOutputWeights "
2923 "is not a supported type.");
2924 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToForgetWeights(), supportedWeightTypes),
2925 reasonIfUnsupported,
2926 "Reference UnidirectionalSequenceLstm: RecurrentToForgetWeights "
2927 "is not a supported type.");
2928 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToCellWeights(), supportedWeightTypes),
2929 reasonIfUnsupported,
2930 "Reference UnidirectionalSequenceLstm: RecurrentToCellWeights "
2931 "is not a supported type.");
2932 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToOutputWeights(), supportedWeightTypes),
2933 reasonIfUnsupported,
2934 "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights "
2935 "is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002936
2937 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetGateBias(), supportedBiasTypes), reasonIfUnsupported,
2938 "Reference UnidirectionalSequenceLstm: ForgetGateBias is not a supported type.");
2939 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellBias(), supportedBiasTypes), reasonIfUnsupported,
2940 "Reference UnidirectionalSequenceLstm: CellBias is not a supported type.");
2941 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputGateBias(), supportedBiasTypes), reasonIfUnsupported,
2942 "Reference UnidirectionalSequenceLstm: OutputGateBias is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002943 if (!descriptor.m_CifgEnabled)
2944 {
2945 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes),
2946 reasonIfUnsupported,
2947 "Reference UnidirectionalSequenceLstm: InputToInputWeights "
2948 "is not a supported type.");
2949 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToInputWeights(), supportedWeightTypes),
2950 reasonIfUnsupported,
2951 "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights "
2952 "is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002953 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputGateBias(), supportedBiasTypes), reasonIfUnsupported,
2954 "Reference UnidirectionalSequenceLstm: InputGateBias is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002955 if (descriptor.m_PeepholeEnabled)
2956 {
2957 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes),
2958 reasonIfUnsupported,
2959 "Reference UnidirectionalSequenceLstm: CellToInputWeights "
2960 "is not a supported type.");
2961 }
2962 }
2963 if (descriptor.m_PeepholeEnabled)
2964 {
2965 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToForgetWeights(), supportedWeightTypes),
2966 reasonIfUnsupported,
2967 "Reference UnidirectionalSequenceLstm: CellToForgetWeights "
2968 "is not a supported type.");
2969 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToOutputWeights(), supportedWeightTypes),
2970 reasonIfUnsupported,
2971 "Reference UnidirectionalSequenceLstm: CellToOutputWeights "
2972 "is not a supported type.");
2973 }
2974 if (descriptor.m_ProjectionEnabled)
2975 {
2976 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetProjectionWeights(), supportedWeightTypes),
2977 reasonIfUnsupported,
2978 "Reference UnidirectionalSequenceLstm: ProjectionWeights "
2979 "is not a supported type.");
2980 if (paramsInfo.m_ProjectionBias != nullptr)
2981 {
2982 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
2983 "Reference UnidirectionalSequenceLstm: input and ProjectionBias types "
2984 "are mismatched");
2985 }
2986 }
2987 if (descriptor.m_LayerNormEnabled)
2988 {
2989 if (!descriptor.m_CifgEnabled)
2990 {
2991 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputLayerNormWeights(), supportedWeightTypes),
2992 reasonIfUnsupported,
2993 "Reference UnidirectionalSequenceLstm: InputLayerNormWeights "
2994 "is not a supported type.");
2995 }
2996 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetLayerNormWeights(), supportedWeightTypes),
2997 reasonIfUnsupported,
2998 "Reference UnidirectionalSequenceLstm: ForgetLayerNormWeights "
2999 "is not a supported type.");
3000 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellLayerNormWeights(), supportedWeightTypes),
3001 reasonIfUnsupported,
3002 "Reference UnidirectionalSequenceLstm: CellLayerNormWeights "
3003 "is not a supported type.");
3004 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputLayerNormWeights(), supportedWeightTypes),
3005 reasonIfUnsupported,
3006 "Reference UnidirectionalSequenceLstm: OutputLayerNormWeights "
3007 "is not a supported type.");
3008 }
3009
3010 return supported;
3011}
3012
arovir011c7c81b2018-10-08 11:34:28 +01003013} // namespace armnn