blob: 3bc4affb28d8f0620146f1edafe09facb62c8c4c [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
Teresa Charlin52664732020-06-29 16:27:03 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01007
Keith Davis0c2eeac2020-02-11 16:51:50 +00008#include <armnn/TypesUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000010#include <armnn/utility/IgnoreUnused.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010011#include <armnn/utility/NumericCast.hpp>
Cathal Corbett34b429c2021-12-24 12:24:40 +000012#include <armnn/utility/PolymorphicDowncast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
Matteo Martincighe011d202019-11-28 11:35:47 +000014#include <LayerSupportCommon.hpp>
Derek Lambertif674aa02019-08-01 15:56:25 +010015#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000016
Derek Lamberti50db4e82019-03-13 14:16:15 +000017#include <vector>
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <array>
19
telsoa014fcda012018-03-09 14:13:49 +000020namespace armnn
21{
22
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010023namespace
24{
25
26template<typename Float32Func, typename Uint8Func, typename ... Params>
27bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
28 DataType dataType,
29 Float32Func floatFuncPtr,
30 Uint8Func uint8FuncPtr,
31 Params&&... params)
32{
33 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
34 dataType,
35 &FalseFunc<Params...>,
36 floatFuncPtr,
37 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000038 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000039 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010040 std::forward<Params>(params)...);
41}
42
43} // anonymous namespace
44
James Conroy4d1ff582019-06-10 17:06:39 +010045namespace
46{
47
48std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
49 unsigned int actual,
50 std::string& layerStr,
51 std::string& tensorName)
52{
53 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
54 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
55
56 return errorMsg;
57}
58
59} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000060
Cathal Corbett34b429c2021-12-24 12:24:40 +000061bool RefLayerSupport::IsLayerSupported(const LayerType& type,
62 const std::vector<TensorInfo>& infos,
63 const BaseDescriptor& descriptor,
64 const Optional<LstmInputParamsInfo>& lstmParamsInfo,
65 const Optional<QuantizedLstmInputParamsInfo>& quantizedLstmInputParamsInfo,
66 Optional<std::string&> reasonIfUnsupported) const
67{
68 switch (type)
69 {
70 case LayerType::Activation:
71 return IsActivationSupported(infos[0],
72 infos[1],
73 *(PolymorphicDowncast<const ActivationDescriptor*>(&descriptor)),
74 reasonIfUnsupported);
75 case LayerType::Addition:
76 return IsAdditionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
77 case LayerType::ArgMinMax:
78 return IsArgMinMaxSupported(infos[0],
79 infos[1],
80 *(PolymorphicDowncast<const ArgMinMaxDescriptor*>(&descriptor)),
81 reasonIfUnsupported);
82 case LayerType::BatchNormalization:
83 return IsBatchNormalizationSupported(infos[0],
84 infos[1],
85 infos[2],
86 infos[3],
87 infos[4],
88 infos[5],
89 *(PolymorphicDowncast<const BatchNormalizationDescriptor*>
90 (&descriptor)),
91 reasonIfUnsupported);
92 case LayerType::BatchToSpaceNd:
93 return IsBatchToSpaceNdSupported(infos[0],
94 infos[1],
95 *(PolymorphicDowncast<const BatchToSpaceNdDescriptor*>(&descriptor)),
96 reasonIfUnsupported);
97 case LayerType::Comparison:
98 return IsComparisonSupported(infos[0],
99 infos[1],
100 infos[2],
101 *(PolymorphicDowncast<const ComparisonDescriptor*>(&descriptor)),
102 reasonIfUnsupported);
103 case LayerType::Concat:
104 {
105 std::vector<const TensorInfo*> inputInfos;
106 for (uint32_t i = 0; i < (infos.size() - 1); i++)
107 {
108 inputInfos.push_back(&infos[i]);
109 }
110 return IsConcatSupported(inputInfos,
111 infos[infos.size() - 1],
112 *(PolymorphicDowncast<const OriginsDescriptor*>(&descriptor)),
113 reasonIfUnsupported);
114 }
115 case LayerType::Constant:
116 return IsConstantSupported(infos[0], reasonIfUnsupported);
117 case LayerType::ConvertBf16ToFp32:
118 return IsConvertBf16ToFp32Supported(infos[0], infos[1], reasonIfUnsupported);
119 case LayerType::ConvertFp16ToFp32:
120 return IsConvertFp16ToFp32Supported(infos[0], infos[1], reasonIfUnsupported);
121 case LayerType::ConvertFp32ToBf16:
122 return IsConvertFp32ToBf16Supported(infos[0], infos[1], reasonIfUnsupported);
123 case LayerType::ConvertFp32ToFp16:
124 return IsConvertFp32ToFp16Supported(infos[0], infos[1], reasonIfUnsupported);
125 case LayerType::Convolution2d:
126 {
127 if (infos.size() != 4)
128 {
129 throw InvalidArgumentException("Invalid number of Convolution2d TensorInfos. "
130 "TensorInfos should be of format: {input, output, weights, biases}.");
131 }
132
133 auto desc = *(PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor));
134 if (infos[3] == TensorInfo())
135 {
136 return IsConvolution2dSupported(infos[0],
137 infos[1],
138 desc,
139 infos[2],
140 EmptyOptional(),
141 reasonIfUnsupported);
142 }
143 else
144 {
145 return IsConvolution2dSupported(infos[0],
146 infos[1],
147 desc,
148 infos[2],
149 infos[3],
150 reasonIfUnsupported);
151 }
152 }
153 case LayerType::DepthToSpace:
154 return IsDepthToSpaceSupported(infos[0],
155 infos[1],
156 *(PolymorphicDowncast<const DepthToSpaceDescriptor*>(&descriptor)),
157 reasonIfUnsupported);
158 case LayerType::DepthwiseConvolution2d:
159 {
160 if (infos.size() != 4)
161 {
162 throw InvalidArgumentException("Invalid number of DepthwiseConvolution2d TensorInfos. "
163 "TensorInfos should be of format: {input, output, weights, biases}.");
164 }
165
166 auto desc = *(PolymorphicDowncast<const DepthwiseConvolution2dDescriptor*>(&descriptor));
167 if (infos[3] == TensorInfo())
168 {
169 return IsDepthwiseConvolutionSupported(infos[0],
170 infos[1],
171 desc,
172 infos[2],
173 EmptyOptional(),
174 reasonIfUnsupported);
175 }
176 else
177 {
178 return IsDepthwiseConvolutionSupported(infos[0],
179 infos[1],
180 desc,
181 infos[2],
182 infos[3],
183 reasonIfUnsupported);
184 }
185 }
186 case LayerType::Dequantize:
187 return IsDequantizeSupported(infos[0], infos[1], reasonIfUnsupported);
188 case LayerType::Division:
189 return IsDivisionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
190 case LayerType::ElementwiseUnary:
191 return IsElementwiseUnarySupported(infos[0],
192 infos[1],
193 *(PolymorphicDowncast<const ElementwiseUnaryDescriptor*>(&descriptor)),
194 reasonIfUnsupported);
195 case LayerType::Fill:
196 return IsFillSupported(infos[0],
197 infos[1],
198 *(PolymorphicDowncast<const FillDescriptor*>(&descriptor)),
199 reasonIfUnsupported);
200 case LayerType::Floor:
201 return IsFloorSupported(infos[0], infos[1], reasonIfUnsupported);
202 case LayerType::FullyConnected:
203 return IsFullyConnectedSupported(infos[0],
204 infos[1],
205 infos[2],
206 infos[3],
207 *(PolymorphicDowncast<const FullyConnectedDescriptor*>(&descriptor)),
208 reasonIfUnsupported);
209 case LayerType::Gather:
210 return IsGatherSupported(infos[0],
211 infos[1],
212 infos[2],
213 *(PolymorphicDowncast<const GatherDescriptor*>(&descriptor)),
214 reasonIfUnsupported);
Teresa Charlinb2d3ec52022-04-12 22:07:09 +0100215 case LayerType::GatherNd:
216 return IsGatherNdSupported(infos[0],
217 infos[1],
218 infos[2],
219 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000220 case LayerType::Input:
221 return IsInputSupported(infos[0], reasonIfUnsupported);
222 case LayerType::InstanceNormalization:
223 return IsInstanceNormalizationSupported(infos[0],
224 infos[1],
225 *(PolymorphicDowncast<const InstanceNormalizationDescriptor*>
226 (&descriptor)),
227 reasonIfUnsupported);
228 case LayerType::L2Normalization:
229 return IsL2NormalizationSupported(infos[0],
230 infos[1],
231 *(PolymorphicDowncast<const L2NormalizationDescriptor*>(&descriptor)),
232 reasonIfUnsupported);
233 case LayerType::LogicalBinary:
234 return IsLogicalBinarySupported(infos[0],
235 infos[1],
236 infos[2],
237 *(PolymorphicDowncast<const LogicalBinaryDescriptor*>(&descriptor)),
238 reasonIfUnsupported);
239 case LayerType::LogSoftmax:
240 return IsLogSoftmaxSupported(infos[0],
241 infos[1],
242 *(PolymorphicDowncast<const LogSoftmaxDescriptor*>(&descriptor)),
243 reasonIfUnsupported);
244 case LayerType::Lstm:
245 return IsLstmSupported(infos[0],
246 infos[1],
247 infos[2],
248 infos[3],
249 infos[4],
250 infos[5],
251 infos[6],
252 *(PolymorphicDowncast<const LstmDescriptor*>(&descriptor)),
253 lstmParamsInfo.value(),
254 reasonIfUnsupported);
255 case LayerType::QLstm:
256 return IsQLstmSupported(infos[0],
257 infos[1],
258 infos[2],
259 infos[3],
260 infos[4],
261 infos[5],
262 *(PolymorphicDowncast<const QLstmDescriptor*>(&descriptor)),
263 lstmParamsInfo.value(),
264 reasonIfUnsupported);
265 case LayerType::Maximum:
266 return IsMaximumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
267 case LayerType::Mean:
268 return IsMeanSupported(infos[0],
269 infos[1],
270 *(PolymorphicDowncast<const MeanDescriptor*>(&descriptor)),
271 reasonIfUnsupported);
272 case LayerType::Minimum:
273 return IsMinimumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
274 case LayerType::Multiplication:
275 return IsMultiplicationSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
276 case LayerType::Normalization:
277 return IsNormalizationSupported(infos[0],
278 infos[1],
279 *(PolymorphicDowncast<const NormalizationDescriptor*>(&descriptor)),
280 reasonIfUnsupported);
281 case LayerType::Output:
282 return IsOutputSupported(infos[0], reasonIfUnsupported);
283 case LayerType::Pad:
284 return IsPadSupported(infos[0],
285 infos[1],
286 *(PolymorphicDowncast<const PadDescriptor*>(&descriptor)),
287 reasonIfUnsupported);
288 case LayerType::Permute:
289 return IsPermuteSupported(infos[0],
290 infos[1],
291 *(PolymorphicDowncast<const PermuteDescriptor*>(&descriptor)),
292 reasonIfUnsupported);
293 case LayerType::Pooling2d:
294 return IsPooling2dSupported(infos[0],
295 infos[1],
296 *(PolymorphicDowncast<const Pooling2dDescriptor*>(&descriptor)),
297 reasonIfUnsupported);
298 case LayerType::Prelu:
299 return IsPreluSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
300 case LayerType::Quantize:
301 return IsQuantizeSupported(infos[0], infos[1], reasonIfUnsupported);
302 case LayerType::Reshape:
303 return IsReshapeSupported(infos[0],
304 infos[1],
305 *(PolymorphicDowncast<const ReshapeDescriptor*>(&descriptor)),
306 reasonIfUnsupported);
307 case LayerType::Resize:
308 return IsResizeSupported(infos[0],
309 infos[1],
310 *(PolymorphicDowncast<const ResizeDescriptor*>(&descriptor)),
311 reasonIfUnsupported);
312 case LayerType::Reduce:
313 return IsReduceSupported(infos[0],
314 infos[1],
315 *(PolymorphicDowncast<const ReduceDescriptor*>(&descriptor)),
316 reasonIfUnsupported);
317 case LayerType::Slice:
318 return IsSliceSupported(infos[0],
319 infos[1],
320 *(PolymorphicDowncast<const SliceDescriptor*>(&descriptor)),
321 reasonIfUnsupported);
322 case LayerType::Softmax:
323 return IsSoftmaxSupported(infos[0],
324 infos[1],
325 *(PolymorphicDowncast<const SoftmaxDescriptor*>(&descriptor)),
326 reasonIfUnsupported);
327 case LayerType::SpaceToBatchNd:
328 return IsSpaceToBatchNdSupported(infos[0],
329 infos[1],
330 *(PolymorphicDowncast<const SpaceToBatchNdDescriptor*>(&descriptor)),
331 reasonIfUnsupported);
332 case LayerType::SpaceToDepth:
333 return IsSpaceToDepthSupported(infos[0],
334 infos[1],
335 *(PolymorphicDowncast<const SpaceToDepthDescriptor*>(&descriptor)),
336 reasonIfUnsupported);
337 case LayerType::Splitter:
338 {
339 std::vector<TensorInfo> outputInfos;
340 for (uint32_t i = 1; i < infos.size(); i++)
341 {
342 outputInfos.push_back(infos[i]);
343 }
344 return IsSplitterSupported(infos[0],
345 {outputInfos.begin(), outputInfos.end()},
346 *(PolymorphicDowncast<const ViewsDescriptor*>(&descriptor)),
347 reasonIfUnsupported);
348 }
349 case LayerType::Stack:
350 {
351 std::vector<const TensorInfo*> inputInfos;
352 for (uint32_t i = 0; i < infos.size() - 1; i++)
353 {
354 inputInfos.push_back(&infos[i]);
355 }
356 return IsStackSupported(inputInfos,
357 infos[infos.size() - 1],
358 *(PolymorphicDowncast<const StackDescriptor*>(&descriptor)),
359 reasonIfUnsupported);
360 }
361 case LayerType::StridedSlice:
362 return IsStridedSliceSupported(infos[0],
363 infos[1],
364 *(PolymorphicDowncast<const StridedSliceDescriptor*>(&descriptor)),
365 reasonIfUnsupported);
366 case LayerType::Subtraction:
367 return IsSubtractionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
368 case LayerType::Transpose:
369 return IsTransposeSupported(infos[0],
370 infos[1],
371 *(PolymorphicDowncast<const TransposeDescriptor*>(&descriptor)),
372 reasonIfUnsupported);
373 case LayerType::TransposeConvolution2d:
374 {
375 if (infos.size() != 4)
376 {
377 throw InvalidArgumentException("Invalid number of TransposeConvolution2d TensorInfos. "
378 "TensorInfos should be of format: {input, output, weights, biases}.");
379 }
380
381 auto desc = *(PolymorphicDowncast<const TransposeConvolution2dDescriptor*>(&descriptor));
382 if (infos[3] == TensorInfo())
383 {
384 return IsTransposeConvolution2dSupported(infos[0],
385 infos[1],
386 desc,
387 infos[2],
388 EmptyOptional(),
389 reasonIfUnsupported);
390 }
391 else
392 {
393 return IsTransposeConvolution2dSupported(infos[0],
394 infos[1],
395 desc,
396 infos[2],
397 infos[3],
398 reasonIfUnsupported);
399 }
400 }
401 case LayerType::Cast:
402 return IsCastSupported(infos[0], infos[1], reasonIfUnsupported);
403 case LayerType::ChannelShuffle:
404 return IsChannelShuffleSupported(infos[0],
405 infos[1],
406 *(PolymorphicDowncast<const ChannelShuffleDescriptor*>(&descriptor)),
407 reasonIfUnsupported);
408 case LayerType::Convolution3d:
409 {
410 if (infos.size() != 4)
411 {
412 throw InvalidArgumentException("Invalid number of Convolution3d TensorInfos. "
413 "TensorInfos should be of format: {input, output, weights, biases}.");
414 }
415
416 auto desc = *(PolymorphicDowncast<const Convolution3dDescriptor*>(&descriptor));
417 if (infos[3] == TensorInfo())
418 {
419 return IsConvolution3dSupported(infos[0],
420 infos[1],
421 desc,
422 infos[2],
423 EmptyOptional(),
424 reasonIfUnsupported);
425 }
426 else
427 {
428 return IsConvolution3dSupported(infos[0],
429 infos[1],
430 desc,
431 infos[2],
432 infos[3],
433 reasonIfUnsupported);
434 }
435 }
436 case LayerType::Debug:
437 return IsDebugSupported(infos[0], infos[1], reasonIfUnsupported);
438 case LayerType::DetectionPostProcess:
439 return IsDetectionPostProcessSupported(infos[0],
440 infos[1],
441 infos[2],
442 infos[3],
443 infos[4],
444 infos[5],
445 infos[6],
446 *(PolymorphicDowncast<const DetectionPostProcessDescriptor*>
447 (&descriptor)),
448 reasonIfUnsupported);
449 case LayerType::FakeQuantization:
450 return IsFakeQuantizationSupported(infos[0],
451 *(PolymorphicDowncast<const FakeQuantizationDescriptor*>(&descriptor)),
452 reasonIfUnsupported);
453 case LayerType::MemCopy:
454 return IsMemCopySupported(infos[0], infos[1], reasonIfUnsupported);
455 case LayerType::Rank:
456 return IsRankSupported(infos[0], infos[1], reasonIfUnsupported);
457 case LayerType::Shape:
458 return IsShapeSupported(infos[0], infos[1], reasonIfUnsupported);
459 case LayerType::UnidirectionalSequenceLstm:
460 {
461 if (infos.size() != 6)
462 {
463 throw InvalidArgumentException("Invalid number of UnidirectionalSequenceLstm TensorInfos. TensorInfos "
464 "should be of format: {input, outputStateIn, cellStateIn, "
465 "hiddenStateOutputVal, cellStateOutputVal, output}");
466 }
467 auto desc = *(PolymorphicDowncast<const UnidirectionalSequenceLstmDescriptor*>(&descriptor));
468
469 bool isHiddenStateOutputOptional = (infos[4] == TensorInfo());
470 bool isCellStateOutput = (infos[5] == TensorInfo());
471 if (isHiddenStateOutputOptional && isCellStateOutput)
472 {
473 return IsUnidirectionalSequenceLstmSupported(infos[0],
474 infos[1],
475 infos[2],
476 infos[3],
477 EmptyOptional(),
478 EmptyOptional(),
479 desc,
480 lstmParamsInfo.value(),
481 reasonIfUnsupported);
482 }
483 else if (isHiddenStateOutputOptional)
484 {
485 return IsUnidirectionalSequenceLstmSupported(infos[0],
486 infos[1],
487 infos[2],
488 infos[3],
489 EmptyOptional(),
490 infos[5],
491 desc,
492 lstmParamsInfo.value(),
493 reasonIfUnsupported);
494 }
495 else if (isCellStateOutput)
496 {
497 return IsUnidirectionalSequenceLstmSupported(infos[0],
498 infos[1],
499 infos[2],
500 infos[3],
501 infos[4],
502 EmptyOptional(),
503 desc,
504 lstmParamsInfo.value(),
505 reasonIfUnsupported);
506 }
507 else
508 {
509 return IsUnidirectionalSequenceLstmSupported(infos[0],
510 infos[1],
511 infos[2],
512 infos[3],
513 infos[4],
514 infos[5],
515 desc,
516 lstmParamsInfo.value(),
517 reasonIfUnsupported);
518 }
519 }
520 case LayerType::Pooling3d:
521 return IsPooling3dSupported(infos[0],
522 infos[1],
523 *(PolymorphicDowncast<const Pooling3dDescriptor*>(&descriptor)),
524 reasonIfUnsupported);
525 case LayerType::Map:
526 return true;
527 case LayerType::Unmap:
528 return true;
529 case LayerType::MemImport:
530 return LayerSupportBase::IsMemImportSupported(infos[0], infos[1], reasonIfUnsupported);
531 case LayerType::Merge:
532 return LayerSupportBase::IsMergeSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
533 case LayerType::QuantizedLstm:
534 return LayerSupportBase::IsQuantizedLstmSupported(infos[0],
535 infos[1],
536 infos[2],
537 infos[3],
538 infos[4],
539 quantizedLstmInputParamsInfo.value(),
540 reasonIfUnsupported);
541 default:
542 // layers not supported in neon by default:
543 // precompiled, standin, switch
544 return false;
545 }
546}
547
arovir011c7c81b2018-10-08 11:34:28 +0100548bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
549 const TensorInfo& output,
550 const ActivationDescriptor& descriptor,
551 Optional<std::string&> reasonIfUnsupported) const
552{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000553 bool supported = true;
554
555 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +0000556 std::array<DataType,6> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000557 DataType::BFloat16,
Derek Lamberti50db4e82019-03-13 14:16:15 +0000558 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100559 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000560 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000561 DataType::QAsymmU8,
562 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000563 };
564
565 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
566 "Reference activation: input type not supported.");
567
568 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
569 "Reference activation: output type not supported.");
570
571 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
572 "Reference activation: input and output types mismatched.");
573
574 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
575 "Reference activation: input and output shapes are of different rank.");
576
577
578 struct ActivationFunctionSupported : public Rule
579 {
580 ActivationFunctionSupported(const ActivationDescriptor& desc)
581 {
582 switch(desc.m_Function)
583 {
584 case ActivationFunction::Abs:
585 case ActivationFunction::BoundedReLu:
David Monahan3b3c3812020-02-25 09:03:29 +0000586 case ActivationFunction::Elu:
Colm Donelan03fbeaf2020-02-26 15:39:23 +0000587 case ActivationFunction::HardSwish:
Derek Lamberti50db4e82019-03-13 14:16:15 +0000588 case ActivationFunction::LeakyReLu:
589 case ActivationFunction::Linear:
590 case ActivationFunction::ReLu:
591 case ActivationFunction::Sigmoid:
592 case ActivationFunction::SoftReLu:
593 case ActivationFunction::Sqrt:
594 case ActivationFunction::Square:
595 case ActivationFunction::TanH:
596 {
597 m_Res = true;
598 break;
599 }
600 default:
601 {
602 m_Res = false;
603 break;
604 }
605 }
606 }
607 };
608
609 // Function is supported
610 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
611 "Reference activation: function not supported.");
612
613 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100614}
615
616bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
617 const TensorInfo& input1,
618 const TensorInfo& output,
619 Optional<std::string&> reasonIfUnsupported) const
620{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000621 bool supported = true;
622
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100623 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000624 DataType::BFloat16,
Derek Lamberti50db4e82019-03-13 14:16:15 +0000625 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100626 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000627 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000628 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100629 DataType::QSymmS16,
630 DataType::Signed32
Derek Lamberti50db4e82019-03-13 14:16:15 +0000631 };
632
633 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
634 "Reference addition: input 0 is not a supported type.");
635
636 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
637 "Reference addition: input 1 is not a supported type.");
638
639 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
640 "Reference addition: output is not a supported type.");
641
642 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
643 "Reference addition: input 0 and Input 1 types are mismatched");
644
645 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
646 "Reference addition: input and output types are mismatched");
647
648 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
649 "Reference addition: shapes are not suitable for implicit broadcast.");
650
651 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100652}
653
Nikhil Raj68c2c902019-09-19 11:21:11 +0100654bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
655 const armnn::ArgMinMaxDescriptor &descriptor,
656 armnn::Optional<std::string &> reasonIfUnsupported) const
657{
Jan Eilers8eb25602020-03-09 12:13:48 +0000658 IgnoreUnused(descriptor);
Nikhil Raj68c2c902019-09-19 11:21:11 +0100659
Mike Kelly1f140f72021-04-06 12:25:55 +0100660 std::array<DataType, 8> supportedInputTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100661 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000662 DataType::BFloat16,
Teresa Charline300b362020-05-25 10:01:03 +0100663 DataType::Float16,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100664 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100665 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000666 DataType::QAsymmU8,
667 DataType::QSymmS16,
Mike Kelly1f140f72021-04-06 12:25:55 +0100668 DataType::Signed32,
669 DataType::Signed64
670 };
671
672 std::array<DataType,2> supportedOutputTypes = {
673 DataType::Signed32,
674 DataType::Signed64
Nikhil Raj68c2c902019-09-19 11:21:11 +0100675 };
676
677 bool supported = true;
678
Mike Kelly1f140f72021-04-06 12:25:55 +0100679 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100680 "Reference ArgMinMax: input is not a supported type.");
Mike Kelly1f140f72021-04-06 12:25:55 +0100681 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100682 "Reference ArgMinMax: output type not supported");
683
684 return supported;
685}
686
arovir011c7c81b2018-10-08 11:34:28 +0100687bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
688 const TensorInfo& output,
689 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100690 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100691 const TensorInfo& beta,
692 const TensorInfo& gamma,
693 const BatchNormalizationDescriptor& descriptor,
694 Optional<std::string&> reasonIfUnsupported) const
695{
Jan Eilers8eb25602020-03-09 12:13:48 +0000696 IgnoreUnused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100697
Sadik Armagan303980c2020-04-17 12:45:14 +0100698 std::array<DataType, 6> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100699 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000700 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100701 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100702 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100703 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000704 DataType::QAsymmU8,
705 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100706 };
707
708 bool supported = true;
709
710 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
711 "Reference batch normalization: input is not a supported type.");
712
713 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
714 "Reference batch normalization: output is not a supported type.");
715
716 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
717 "Reference batch normalization: input and output types are mismatched");
718
719 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
720 "Reference batch normalization: mean is not a supported type.");
721
722 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
723 "Reference batch normalization: variance is not a supported type.");
724
725 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
726 "Reference batch normalization: beta is not a supported type.");
727
728 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
729 "Reference batch normalization: gamma is not a supported type.");
730
731 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100732}
733
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000734bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
735 const TensorInfo& output,
736 const BatchToSpaceNdDescriptor& descriptor,
737 Optional<std::string&> reasonIfUnsupported) const
738{
Jan Eilers8eb25602020-03-09 12:13:48 +0000739 IgnoreUnused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100740
741 bool supported = true;
742
743 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
744 std::string inputTensorStr = "input";
745 std::string outputTensorStr = "output";
746
747 // Define supported types.
Sadik Armagan303980c2020-04-17 12:45:14 +0100748 std::array<DataType,6> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100749 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000750 DataType::BFloat16,
751 DataType::Float32,
752 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100753 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000754 DataType::QAsymmU8,
755 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100756 };
757
758 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
759 "Reference BatchToSpaceNd: input type not supported.");
760
761 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
762 "Reference BatchToSpaceNd: output type not supported.");
763
764 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
765 "Reference BatchToSpaceNd: input and output types mismatched.");
766
767 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
768 reasonIfUnsupported,
769 CreateIncorrectDimensionsErrorMsg(4,
770 output.GetNumDimensions(),
771 batchToSpaceNdLayerStr,
772 outputTensorStr).data());
773
774 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
775 reasonIfUnsupported,
776 CreateIncorrectDimensionsErrorMsg(4,
777 input.GetNumDimensions(),
778 batchToSpaceNdLayerStr,
779 inputTensorStr).data());
780
781 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000782}
783
mathad01b392e982021-04-07 12:07:30 +0100784bool RefLayerSupport::IsCastSupported(const TensorInfo& input,
785 const TensorInfo& output,
786 Optional<std::string&> reasonIfUnsupported) const
787{
788 std::array<DataType, 9> supportedInputTypes =
789 {
790 DataType::BFloat16,
791 DataType::Float32,
792 DataType::Float16,
793 DataType::QSymmS8,
794 DataType::QAsymmS8,
795 DataType::QAsymmU8,
796 DataType::QSymmS16,
797 DataType::Signed32
798 };
799
800 bool supported = true;
801 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
802 "Reference cast: input is not a supported type");
803
804
805 supported &= CheckSupportRule(TypeAnyOf(output, supportedInputTypes), reasonIfUnsupported,
806 "Reference cast: output is not a supported type");
807
808 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
809 "Reference cast: input and output shapes have different number of total elements");
810
811 return supported;
812}
813
Simon Obute51f67772021-09-03 15:50:13 +0100814bool RefLayerSupport::IsChannelShuffleSupported(const TensorInfo& input,
815 const TensorInfo& output,
816 const ChannelShuffleDescriptor& descriptor,
817 Optional<std::string&> reasonIfUnsupported) const
818{
819 IgnoreUnused(descriptor);
820 bool supported = true;
821
822 // Define supported output and inputs types.
823 std::array<DataType, 7> supportedTypes =
824 {
825 DataType::BFloat16,
826 DataType::Float32,
827 DataType::Float16,
828 DataType::QAsymmS8,
829 DataType::QAsymmU8,
830 DataType::QSymmS8,
831 DataType::QSymmS16
832 };
833
834 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
835 "Reference ChannelShuffle: input is not a supported type.");
836
837 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
838 "Reference ChannelShuffle: output is not a supported type.");
839
840 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
841 "Reference ChannelShuffle: input and output types are mismatched.");
842
843 return supported;
844}
845
846
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100847bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
848 const TensorInfo& input1,
849 const TensorInfo& output,
850 const ComparisonDescriptor& descriptor,
851 Optional<std::string&> reasonIfUnsupported) const
852{
Jan Eilers8eb25602020-03-09 12:13:48 +0000853 IgnoreUnused(descriptor);
Sadik Armagan303980c2020-04-17 12:45:14 +0100854 std::array<DataType, 8> supportedInputTypes =
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100855 {
Sadik Armaganb60dd242020-03-19 13:53:16 +0000856 DataType::Boolean,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000857 DataType::BFloat16,
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100858 DataType::Float32,
859 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100860 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000861 DataType::QAsymmU8,
Sadik Armaganb60dd242020-03-19 13:53:16 +0000862 DataType::QSymmS16,
863 DataType::Signed32
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100864 };
865
866 bool supported = true;
867 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
868 "Reference comparison: input 0 is not a supported type");
869
870 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
871 "Reference comparison: input 0 and Input 1 types are mismatched");
872
873 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
874 "Reference comparison: output is not of type Boolean");
875
876 return supported;
877}
878
Jim Flynn906f9462019-05-10 13:55:21 +0100879bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
880 const TensorInfo& output,
Cathal Corbett34b429c2021-12-24 12:24:40 +0000881 const OriginsDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100882 Optional<std::string&> reasonIfUnsupported) const
883{
Jan Eilers8eb25602020-03-09 12:13:48 +0000884 IgnoreUnused(descriptor);
Jim Flynne242f2d2019-05-22 14:24:13 +0100885
886 bool supported = true;
Teresa Charlin6abc7ee2022-02-22 17:32:27 +0000887 std::array<DataType,7> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100888 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000889 DataType::BFloat16,
890 DataType::Float32,
891 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000892 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100893 DataType::QAsymmU8,
Teresa Charlin6abc7ee2022-02-22 17:32:27 +0000894 DataType::QSymmS16,
895 DataType::Signed32
Jim Flynne242f2d2019-05-22 14:24:13 +0100896 };
897
898 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
899 "Reference concatenation: output type not supported");
900 for (const TensorInfo* input : inputs)
901 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100902 ARMNN_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100903 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
904 "Reference concatenation: input type not supported");
905
906 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
907 "Reference concatenation: input and output types mismatched.");
908 }
909
910 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100911}
912
arovir011c7c81b2018-10-08 11:34:28 +0100913bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
914 Optional<std::string&> reasonIfUnsupported) const
915{
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100916 std::array<DataType,8> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100917 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000918 DataType::BFloat16,
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100919 DataType::Float16,
Nina Drozd58ef2c62019-05-16 12:09:18 +0100920 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +0000921 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100922 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000923 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100924 DataType::QSymmS16,
925 DataType::Signed32
Nina Drozd58ef2c62019-05-16 12:09:18 +0100926 };
927
928 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
929 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100930}
931
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000932bool RefLayerSupport::IsConvertBf16ToFp32Supported(const TensorInfo& input,
933 const TensorInfo& output,
934 Optional<std::string&> reasonIfUnsupported) const
935{
936 bool supported = true;
937
938 supported &= CheckSupportRule(TypeIs(input, DataType::BFloat16), reasonIfUnsupported,
939 "Reference for ConvertBf16ToFp32 layer: input type not supported");
940
941 supported &= CheckSupportRule(TypeIs(output, DataType::Float32), reasonIfUnsupported,
942 "Reference for ConvertBf16ToFp32 layer: output type not supported");
943
944 return supported;
945}
946
arovir011c7c81b2018-10-08 11:34:28 +0100947bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
948 const TensorInfo& output,
949 Optional<std::string&> reasonIfUnsupported) const
950{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100951 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
952 input.GetDataType(),
953 &TrueFunc<>,
954 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000955 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000956 &FalseFuncI32<>,
957 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100958 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
959 output.GetDataType(),
960 &FalseOutputFuncF16<>,
961 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000962 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000963 &FalseFuncI32<>,
964 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100965}
966
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000967bool RefLayerSupport::IsConvertFp32ToBf16Supported(const TensorInfo& input,
968 const TensorInfo& output,
969 Optional<std::string&> reasonIfUnsupported) const
970{
971 bool supported = true;
972
973 supported &= CheckSupportRule(TypeIs(input, DataType::Float32), reasonIfUnsupported,
974 "Reference for ConvertFp32ToBf16 layer: input type not supported");
975
976 supported &= CheckSupportRule(TypeIs(output, DataType::BFloat16), reasonIfUnsupported,
977 "Reference for ConvertFp32ToBf16 layer: output type not supported");
978
979 return supported;
980}
981
arovir011c7c81b2018-10-08 11:34:28 +0100982bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
983 const TensorInfo& output,
984 Optional<std::string&> reasonIfUnsupported) const
985{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100986 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
987 input.GetDataType(),
988 &FalseInputFuncF16<>,
989 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000990 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000991 &FalseFuncI32<>,
992 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100993 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
994 output.GetDataType(),
995 &TrueFunc<>,
996 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000997 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000998 &FalseFuncI32<>,
999 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +01001000}
1001
1002bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
1003 const TensorInfo& output,
1004 const Convolution2dDescriptor& descriptor,
1005 const TensorInfo& weights,
1006 const Optional<TensorInfo>& biases,
1007 Optional<std::string&> reasonIfUnsupported) const
1008{
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001009 bool supported = true;
1010
1011 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001012 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001013 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001014 DataType::BFloat16,
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001015 DataType::Float32,
1016 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001017 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001018 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001019 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001020 DataType::QSymmS16
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001021 };
1022
1023 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001024 "Reference Convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001025
1026 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001027 "Reference Convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001028
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001029 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
1030 if (input.GetDataType() == DataType::BFloat16)
1031 {
1032 if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
1033 {
1034 reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
1035 supported = false;
1036 }
1037 }
1038 else
1039 {
1040 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001041 "Reference Convolution2d: input and output types mismatched.");
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001042 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001043
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001044 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +00001045 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001046 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001047 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001048 {
Sadik Armagan303980c2020-04-17 12:45:14 +01001049 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001050 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01001051 DataType::QSymmS8
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001052 };
1053
1054 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001055 "Reference Convolution2d: weights type not supported for quantized input.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001056 }
1057 else
1058 {
1059 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001060 "Reference Convolution2d: weights is not a supported type.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001061
1062 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001063 "Reference Convolution2d: input and weights types mismatched.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001064 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001065
1066 if (biases.has_value())
1067 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001068 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001069 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001070 DataType::BFloat16,
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001071 DataType::Float32,
1072 DataType::Float16,
1073 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001074 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001075
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001076 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001077 "Reference Convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001078 }
Jan Eilers8eb25602020-03-09 12:13:48 +00001079 IgnoreUnused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001080
1081 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001082}
1083
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001084bool RefLayerSupport::IsConvolution3dSupported(const TensorInfo& input,
1085 const TensorInfo& output,
1086 const Convolution3dDescriptor& descriptor,
1087 const TensorInfo& weights,
1088 const Optional<TensorInfo>& biases,
1089 Optional<std::string&> reasonIfUnsupported) const
1090{
1091 bool supported = true;
1092
1093 // Define supported types.
1094 std::array<DataType,7> supportedTypes =
1095 {
1096 DataType::BFloat16,
1097 DataType::Float32,
1098 DataType::Float16,
1099 DataType::QAsymmS8,
1100 DataType::QAsymmU8,
1101 DataType::QSymmS8,
1102 DataType::QSymmS16
1103 };
1104
1105 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1106 "Reference Convolution3d: input is not a supported type.");
1107
1108 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1109 "Reference Convolution3d: output is not a supported type.");
1110
1111 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1112 "Reference Convolution3d: input and output types mismatched.");
1113
1114 const DataType inputType = input.GetDataType();
1115 if (IsQuantized8BitType(inputType))
1116 {
1117 std::array<DataType, 3> supportedWeightTypes =
1118 {
1119 DataType::QAsymmS8,
1120 DataType::QAsymmU8,
1121 DataType::QSymmS8
1122 };
1123
1124 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1125 "Reference Convolution3d: weights type not supported for quantized input.");
1126 }
1127 else
1128 {
1129 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1130 "Reference Convolution3d: weights is not a supported type.");
1131
1132 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1133 "Reference Convolution3d: input and weights types mismatched.");
1134 }
1135
1136 if (biases.has_value())
1137 {
1138 std::array<DataType,4> biasesSupportedTypes =
1139 {
1140 DataType::BFloat16,
1141 DataType::Float32,
1142 DataType::Float16,
1143 DataType::Signed32
1144 };
1145
1146 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1147 "Reference Convolution3d: biases is not a supported type.");
1148 }
1149 IgnoreUnused(descriptor);
1150
1151 return supported;
1152}
1153
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001154bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
1155 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001156 Optional<std::string&> reasonIfUnsupported) const
1157{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001158 bool supported = true;
1159
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001160 std::array<DataType, 8> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001161 {
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001162 DataType::BFloat16,
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +00001163 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001164 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001165 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001166 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001167 DataType::QSymmS8,
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +00001168 DataType::QSymmS16,
1169 DataType::Signed32
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001170 };
1171
1172 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001173 "Reference for Debug layer: input type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001174
1175 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001176 "Reference for Debug layer: output type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001177
1178 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001179 "Reference for Debug layer: input and output types are mismatched");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001180
1181 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001182}
1183
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001184bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
1185 const TensorInfo& output,
1186 const DepthToSpaceDescriptor& descriptor,
1187 Optional<std::string&> reasonIfUnsupported) const
1188{
Jan Eilers8eb25602020-03-09 12:13:48 +00001189 IgnoreUnused(descriptor);
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001190 bool supported = true;
1191
Sadik Armagan303980c2020-04-17 12:45:14 +01001192 std::array<DataType,6> supportedTypes =
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001193 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001194 DataType::BFloat16,
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001195 DataType::Float32,
1196 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001197 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001198 DataType::QAsymmU8,
1199 DataType::QSymmS16
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001200 };
1201
1202 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1203 "Reference DepthToSpace: input type not supported");
1204
1205 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1206 "Reference DepthToSpace: output type not supported");
1207
1208 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1209 "Reference DepthToSpace: input and output types are mismatched");
1210
1211 return supported;
1212}
1213
arovir011c7c81b2018-10-08 11:34:28 +01001214bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
1215 const TensorInfo& output,
1216 const DepthwiseConvolution2dDescriptor& descriptor,
1217 const TensorInfo& weights,
1218 const Optional<TensorInfo>& biases,
1219 Optional<std::string&> reasonIfUnsupported) const
1220{
Sadik Armagan303980c2020-04-17 12:45:14 +01001221 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001222 bool supported = true;
1223
1224 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001225 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001226 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001227 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001228 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001229 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001230 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001231 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001232 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001233 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001234 };
1235
1236 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1237 "Reference DepthwiseConvolution2d: input is not a supported type.");
1238
1239 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1240 "Reference DepthwiseConvolution2d: output is not a supported type.");
1241
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001242 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1243 "Reference DepthwiseConvolution2d: input and output types mismatched.");
1244
Teresa Charlind8df0262019-11-11 12:28:15 +00001245 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +00001246 if (IsQuantized8BitType(inputType))
Teresa Charlind8df0262019-11-11 12:28:15 +00001247 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001248 std::array<DataType, 3> supportedWeightTypes =
Sadik Armagan303980c2020-04-17 12:45:14 +01001249 {
1250 DataType::QAsymmS8,
1251 DataType::QAsymmU8,
1252 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001253 };
Teresa Charlind8df0262019-11-11 12:28:15 +00001254
1255 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Sadik Armagan303980c2020-04-17 12:45:14 +01001256 "Reference DepthwiseConvolution2d: weights type not supported for "
1257 "quantized input.");
Teresa Charlind8df0262019-11-11 12:28:15 +00001258 }
1259 else
1260 {
1261 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1262 "Reference DepthwiseConvolution2d: weights is not a supported type.");
1263
1264 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1265 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
1266 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001267
1268 if (biases.has_value())
1269 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001270 std::array<DataType,4> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001271 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001272 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001273 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001274 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001275 DataType::Signed32
1276 };
1277 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1278 "Reference DepthwiseConvolution2d: biases is not a supported type.");
1279 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001280
1281 return supported;
1282
arovir011c7c81b2018-10-08 11:34:28 +01001283}
1284
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001285bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
1286 const TensorInfo& output,
1287 Optional<std::string&> reasonIfUnsupported) const
1288{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001289 bool supported = true;
1290
Ryan OShea9add1202020-02-07 10:06:33 +00001291 std::array<DataType,4> supportedInputTypes = {
1292 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001293 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00001294 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001295 DataType::QSymmS16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001296 };
1297
1298 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001299 "Reference for Dequantize layer: input type not supported.");
1300
Derek Lambertid466a542020-01-22 15:37:29 +00001301 supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
Teresa Charlin1b1950d2021-06-02 20:23:21 +01001302 "Reference for Dequantize layer: per-axis quantized input not supported.");
Derek Lambertid466a542020-01-22 15:37:29 +00001303
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001304 std::array<DataType,3> supportedOutputTypes = {
1305 DataType::BFloat16,
Jan Eilersf7107932019-11-01 11:09:36 +00001306 DataType::Float32,
1307 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001308 };
1309
1310 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001311 "Reference for Dequantize layer: output type not supported.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001312
1313 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001314 "Reference for Dequantize layer: input/output shapes have different num total "
1315 "elements.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001316
1317 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001318}
1319
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001320bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
1321 const TensorInfo& scores,
1322 const TensorInfo& anchors,
1323 const TensorInfo& detectionBoxes,
1324 const TensorInfo& detectionClasses,
1325 const TensorInfo& detectionScores,
1326 const TensorInfo& numDetections,
1327 const DetectionPostProcessDescriptor& descriptor,
1328 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001329{
Jan Eilers8eb25602020-03-09 12:13:48 +00001330 IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
Derek Lamberti901ea112019-12-10 22:07:09 +00001331
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001332 bool supported = true;
1333
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001334 std::array<DataType,6> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001335 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001336 DataType::BFloat16,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001337 DataType::Float32,
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001338 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001339 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001340 DataType::QAsymmU8,
1341 DataType::QSymmS16
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001342 };
1343
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001344 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001345 "Reference DetectionPostProcess: input 0 is not a supported type.");
1346
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001347 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001348 "Reference DetectionPostProcess: input 1 is not a supported type.");
1349
1350 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001351}
1352
Pablo Tellof0bd6832019-04-26 17:58:13 +01001353bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
1354 const TensorInfo& output,
1355 const DepthwiseConvolution2dDescriptor& descriptor,
1356 const TensorInfo& weights,
1357 const Optional<TensorInfo>& biases,
1358 Optional<std::string&> reasonIfUnsupported) const
1359{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001360 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +01001361}
1362
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001363bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +01001364 const TensorInfo& input1,
1365 const TensorInfo& output,
1366 Optional<std::string&> reasonIfUnsupported) const
1367{
Sadik Armagan2999a022019-04-09 14:20:12 +01001368 bool supported = true;
1369
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001370 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001371 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001372 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001373 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001374 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001375 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001376 DataType::QSymmS16,
1377 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001378 };
1379
1380 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1381 "Reference division: input 0 is not a supported type.");
1382
1383 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1384 "Reference division: input 1 is not a supported type.");
1385
1386 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1387 "Reference division: output is not a supported type.");
1388
1389 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1390 "Reference division: input 0 and Input 1 types are mismatched");
1391
1392 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1393 "Reference division: input and output types are mismatched");
1394
1395 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1396 "Reference division: shapes are not suitable for implicit broadcast.");
1397
1398 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001399}
1400
josh minor4a3c6102020-01-06 16:40:46 -06001401bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
1402 const TensorInfo& output,
1403 const ElementwiseUnaryDescriptor& descriptor,
1404 Optional<std::string&> reasonIfUnsupported) const
1405{
Jan Eilers8eb25602020-03-09 12:13:48 +00001406 IgnoreUnused(descriptor);
josh minor4a3c6102020-01-06 16:40:46 -06001407
Sadik Armagan303980c2020-04-17 12:45:14 +01001408 std::array<DataType, 7> supportedTypes =
josh minor4a3c6102020-01-06 16:40:46 -06001409 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001410 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -06001411 DataType::Float32,
1412 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001413 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06001414 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00001415 DataType::QSymmS16,
1416 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06001417 };
1418
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001419 std::array<DataType, 1> logicalSupportedTypes =
1420 {
1421 DataType::Boolean
1422 };
1423
josh minor4a3c6102020-01-06 16:40:46 -06001424 bool supported = true;
1425
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001426 if (descriptor.m_Operation == UnaryOperation::LogicalNot)
1427 {
1428 supported &= CheckSupportRule(TypeAnyOf(input, logicalSupportedTypes), reasonIfUnsupported,
1429 "Reference elementwise unary: input type not supported");
josh minor4a3c6102020-01-06 16:40:46 -06001430
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001431 supported &= CheckSupportRule(TypeAnyOf(output, logicalSupportedTypes), reasonIfUnsupported,
1432 "Reference elementwise unary: output type not supported");
1433 }
1434 else
1435 {
1436 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1437 "Reference elementwise unary: input type not supported");
1438
1439 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1440 "Reference elementwise unary: output type not supported");
1441 }
josh minor4a3c6102020-01-06 16:40:46 -06001442
1443 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1444 "Reference elementwise unary: input and output types not matching");
1445
1446 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1447 "Reference elementwise unary: input and output shapes"
1448 "have different number of total elements");
1449
1450 return supported;
1451}
1452
arovir011c7c81b2018-10-08 11:34:28 +01001453bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
1454 const FakeQuantizationDescriptor& descriptor,
1455 Optional<std::string&> reasonIfUnsupported) const
1456{
Jan Eilers8eb25602020-03-09 12:13:48 +00001457 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001458 bool supported = true;
1459
1460 std::array<DataType,1> supportedTypes =
1461 {
1462 DataType::Float32
1463 };
1464
1465 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1466 "Reference fake quantization: input type not supported.");
1467
1468 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001469}
1470
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001471bool RefLayerSupport::IsFillSupported(const TensorInfo& input,
1472 const TensorInfo& output,
1473 const FillDescriptor& descriptor,
1474 Optional<std::string&> reasonIfUnsupported) const
1475{
1476 IgnoreUnused(descriptor);
1477 IgnoreUnused(output);
1478
1479 bool supported = true;
1480
Sadik Armagana792a052020-06-23 16:22:23 +01001481 std::array<DataType,3> supportedTypes =
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001482 {
1483 DataType::Float32,
Sadik Armagana792a052020-06-23 16:22:23 +01001484 DataType::Float16,
1485 DataType::Signed32
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001486 };
1487
Teresa Charlin4b10fef2020-07-29 09:36:41 +01001488 supported &= CheckSupportRule(TypeIs(input, DataType::Signed32), reasonIfUnsupported,
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001489 "Reference Fill: input type not supported.");
1490
Teresa Charlin44088502020-07-27 11:27:19 +01001491 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1492 "Reference Fill: output type not supported.");
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001493 return supported;
1494}
1495
arovir011c7c81b2018-10-08 11:34:28 +01001496bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
1497 const TensorInfo& output,
1498 Optional<std::string&> reasonIfUnsupported) const
1499{
Jan Eilers8eb25602020-03-09 12:13:48 +00001500 IgnoreUnused(output);
James Conroy83735b12019-05-30 16:36:59 +01001501 bool supported = true;
1502
Francis Murtaghe8ac1332020-07-30 18:03:40 +01001503 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +01001504 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001505 DataType::BFloat16,
James Conroyb40d7102019-06-04 12:32:09 +01001506 DataType::Float32,
Francis Murtaghe8ac1332020-07-30 18:03:40 +01001507 DataType::Float16
James Conroy83735b12019-05-30 16:36:59 +01001508 };
1509
1510 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1511 "Reference Floor: input type not supported.");
1512
1513 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1514 "Reference Floor: output type not supported.");
1515
1516 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001517}
1518
1519bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
1520 const TensorInfo& output,
1521 const TensorInfo& weights,
1522 const TensorInfo& biases,
1523 const FullyConnectedDescriptor& descriptor,
1524 Optional<std::string&> reasonIfUnsupported) const
1525{
Francis Murtagh46c09d02019-05-28 08:15:28 +01001526 bool supported = true;
1527
1528 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001529 std::array<DataType,6> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +01001530 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001531 DataType::BFloat16,
1532 DataType::Float32,
1533 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001534 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001535 DataType::QAsymmU8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001536 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001537 };
1538
1539 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1540 "Reference Fully Connected: input type not supported.");
1541
1542 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1543 "Reference Fully Connected: output type not supported.");
1544
Francis Murtagh46c09d02019-05-28 08:15:28 +01001545 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1546 "Reference Fully Connected: weights type not supported.");
1547
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001548 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1549 if (input.GetDataType() == DataType::BFloat16)
1550 {
1551 if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
1552 {
1553 reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
1554 supported = false;
1555 }
1556 }
1557 else
1558 {
1559 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1560 "Reference Fully Connected: input and output types mismatched.");
1561 }
1562
Jan Eilers1f45dc32020-06-15 11:43:03 +01001563 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1564 "Reference Fully Connected: weights is not a supported type.");
Francis Murtaghddb1d062020-03-10 13:51:45 +00001565
Jan Eilers1f45dc32020-06-15 11:43:03 +01001566 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1567 "Reference Fully Connected: input and weights types mismatched.");
Francis Murtagh46c09d02019-05-28 08:15:28 +01001568
1569 if (descriptor.m_BiasEnabled)
1570 {
1571 // Defined supported types for bias
Sadik Armagandb73c982020-04-01 17:35:30 +01001572 std::array<DataType, 5>
Francis Murtagh46c09d02019-05-28 08:15:28 +01001573 supportedBiasTypes =
1574 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001575 DataType::BFloat16,
Francis Murtagh46c09d02019-05-28 08:15:28 +01001576 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001577 DataType::Float16,
Sadik Armagandb73c982020-04-01 17:35:30 +01001578 DataType::Signed32,
1579 DataType::QAsymmS8
Francis Murtagh46c09d02019-05-28 08:15:28 +01001580 };
1581
1582 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
1583 "Reference Fully Connected: bias type not supported.");
1584
1585 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
1586 "Reference Fully Connected: bias and weight types mismatch.");
1587
1588 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
1589 "Reference Fully Connected: bias type inferred from weights is incompatible.");
1590
Narumol Prangnawarat366d7232020-04-29 12:58:17 +01001591 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(biases, 1U), reasonIfUnsupported,
1592 "Reference Fully Connected: bias must have 1 dimension.");
1593
Francis Murtagh46c09d02019-05-28 08:15:28 +01001594 }
1595
1596 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001597}
1598
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01001599bool RefLayerSupport::IsGatherNdSupported(const armnn::TensorInfo& input0,
1600 const armnn::TensorInfo& input1,
1601 const armnn::TensorInfo& output,
1602 armnn::Optional<std::string&> reasonIfUnsupported) const
1603{
1604 bool supported = true;
1605 std::array<DataType,7> supportedTypes =
1606 {
1607 DataType::BFloat16,
1608 DataType::Float32,
1609 DataType::Float16,
1610 DataType::QAsymmS8,
1611 DataType::QAsymmU8,
1612 DataType::QSymmS16,
1613 DataType::Signed32
1614 };
1615
1616 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1617 "Reference GatherNd: input type not supported");
1618
1619 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1620 "Reference GatherNd: output type not supported");
1621
1622 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1623 "Reference GatherNd: indices (input1) type not supported");
1624
1625 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1626 "Reference GatherNd: input and output types not matching");
1627
1628 return supported;
1629}
1630
narpra014951d842019-01-18 16:53:53 +00001631bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
1632 const armnn::TensorInfo& input1,
1633 const armnn::TensorInfo& output,
Teresa Charlin52664732020-06-29 16:27:03 +01001634 const GatherDescriptor& descriptor,
narpra014951d842019-01-18 16:53:53 +00001635 armnn::Optional<std::string&> reasonIfUnsupported) const
1636{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001637 bool supported = true;
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001638 std::array<DataType,7> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001639 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001640 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001641 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001642 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001643 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001644 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001645 DataType::QSymmS16,
1646 DataType::Signed32
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001647 };
1648
Teresa Charlin52664732020-06-29 16:27:03 +01001649 if (descriptor.m_Axis != 0)
1650 {
1651 reasonIfUnsupported.value() += std::string("Reference Gather: axis not supported\n");
1652 supported &= false;
1653 }
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001654 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1655 "Reference Gather: input type not supported");
1656
1657 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1658 "Reference Gather: output type not supported");
1659
1660 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1661 "Reference Gather: indices (input1) type not supported");
1662
1663 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1664 "Reference Gather: input and output types not matching");
1665
1666 return supported;
narpra014951d842019-01-18 16:53:53 +00001667}
1668
Derek Lamberti901ea112019-12-10 22:07:09 +00001669bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
1670 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +01001671{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001672 return true;
arovir011c7c81b2018-10-08 11:34:28 +01001673}
1674
Kevin May09ca49c2019-10-09 12:37:34 +01001675bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
1676 const TensorInfo& output,
1677 const InstanceNormalizationDescriptor& descriptor,
1678 Optional<std::string&> reasonIfUnsupported) const
1679{
Jan Eilers8eb25602020-03-09 12:13:48 +00001680 IgnoreUnused(descriptor);
Kevin May09ca49c2019-10-09 12:37:34 +01001681 // Define supported types
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001682 std::array<DataType, 3> supportedTypes =
Kevin May09ca49c2019-10-09 12:37:34 +01001683 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001684 DataType::BFloat16,
Kevin May09ca49c2019-10-09 12:37:34 +01001685 DataType::Float32,
1686 DataType::Float16
1687 };
1688
1689 bool supported = true;
1690
1691 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1692 "Reference Instance Normalization: input type not supported.");
1693
1694 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1695 "Reference Instance Normalization: output type not supported.");
1696
1697 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1698 "Reference Instance Normalization: input and output types mismatched.");
1699
1700 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1701 "Reference Instance Normalization: input and output shapes have different "
1702 "num total elements.");
1703
1704 return supported;
1705}
1706
arovir011c7c81b2018-10-08 11:34:28 +01001707bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1708 const TensorInfo& output,
1709 const L2NormalizationDescriptor& descriptor,
1710 Optional<std::string&> reasonIfUnsupported) const
1711{
Jan Eilers8eb25602020-03-09 12:13:48 +00001712 IgnoreUnused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001713 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01001714 std::array<DataType, 6> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001715 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001716 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001717 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001718 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001719 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001720 DataType::QAsymmU8,
1721 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001722 };
1723
1724 bool supported = true;
1725
1726 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1727 "Reference L2normalization: input type not supported.");
1728
1729 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1730 "Reference L2normalization: output type not supported.");
1731
1732 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1733 "Reference L2normalization: input and output types mismatched.");
1734
1735 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1736 "Reference L2normalization: input and output shapes have different "
1737 "num total elements.");
1738
1739 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001740}
1741
James Conroyaba90cd2020-11-06 16:28:18 +00001742bool RefLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0,
1743 const TensorInfo& input1,
1744 const TensorInfo& output,
1745 const LogicalBinaryDescriptor& descriptor,
1746 Optional<std::string&> reasonIfUnsupported) const
1747{
1748 IgnoreUnused(descriptor);
1749
1750 std::array<DataType, 1> supportedTypes =
1751 {
1752 DataType::Boolean
1753 };
1754
1755 bool supported = true;
1756 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1757 "Reference LogicalBinary: input 0 type not supported");
1758 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1759 "Reference LogicalBinary: input 1 type not supported");
1760
1761 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1762 "Reference LogicalBinary: input and output types do not match");
1763
1764 return supported;
1765}
1766
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001767bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1768 const TensorInfo& output,
1769 const LogSoftmaxDescriptor& descriptor,
1770 Optional<std::string&> reasonIfUnsupported) const
1771{
Jan Eilers8eb25602020-03-09 12:13:48 +00001772 IgnoreUnused(descriptor);
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001773
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001774 std::array<DataType, 3> supportedTypes =
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001775 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001776 DataType::BFloat16,
1777 DataType::Float32,
1778 DataType::Float16
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001779 };
1780
1781 bool supported = true;
1782 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1783 "Reference LogSoftmax: input type not supported");
1784
1785 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1786 "Reference LogSoftmax: output type not supported");
1787
1788 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1789 "Reference LogSoftmax: input and output types do not match");
1790
1791 return supported;
1792}
1793
arovir011c7c81b2018-10-08 11:34:28 +01001794bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1795 const TensorInfo& outputStateIn,
1796 const TensorInfo& cellStateIn,
1797 const TensorInfo& scratchBuffer,
1798 const TensorInfo& outputStateOut,
1799 const TensorInfo& cellStateOut,
1800 const TensorInfo& output,
1801 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001802 const LstmInputParamsInfo& paramsInfo,
1803 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +01001804{
Jan Eilers8eb25602020-03-09 12:13:48 +00001805 IgnoreUnused(descriptor);
1806 IgnoreUnused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001807
1808 bool supported = true;
1809
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001810 std::array<DataType,3> supportedTypes = {
1811 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001812 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001813 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001814 };
1815
Jan Eilersd01a83c2019-07-03 18:20:40 +01001816 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001817 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1818 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001819 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1820 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001821 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1822 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001823 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1824 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001825 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1826 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001827 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1828 "Reference Lstm: input and cellStateOut types are mismatched");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01001829
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001830 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1831 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001832 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001833 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001834 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001835 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001836 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001837 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001838 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001839 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001840 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001841 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001842 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001843 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001844 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001845 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001846 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001847 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001848 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001849 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001850 "Reference Lstm: input and OutputGateBias types are mismatched");
1851 if (!descriptor.m_CifgEnabled)
1852 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001853 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001854 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001855 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001856 reasonIfUnsupported,
1857 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001858 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001859 "Reference Lstm: input and InputGateBias types are mismatched");
1860 if (descriptor.m_PeepholeEnabled)
1861 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001862 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001863 reasonIfUnsupported,
1864 "Reference Lstm: input and CellToInputWeights types are mismatched");
1865 }
1866 }
1867 if (descriptor.m_PeepholeEnabled)
1868 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001869 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001870 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001871 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001872 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1873 }
1874 if (descriptor.m_ProjectionEnabled)
1875 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001876 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001877 "Reference Lstm: input and mProjectionWeights types are mismatched");
1878 if (paramsInfo.m_ProjectionBias != nullptr)
1879 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001880 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001881 "Reference Lstm: input and ProjectionBias types are mismatched");
1882 }
1883 }
1884 if (descriptor.m_LayerNormEnabled)
1885 {
1886 if (!descriptor.m_CifgEnabled)
1887 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001888 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001889 reasonIfUnsupported,
1890 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1891 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001892 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001893 reasonIfUnsupported,
1894 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001895 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001896 reasonIfUnsupported,
1897 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001898 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001899 reasonIfUnsupported,
1900 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1901 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001902
1903 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001904}
1905
saoste012df12b32018-11-28 16:57:20 +00001906bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1907 const TensorInfo& input1,
1908 const TensorInfo& output,
1909 Optional<std::string&> reasonIfUnsupported) const
1910{
Sadik Armagan2999a022019-04-09 14:20:12 +01001911 bool supported = true;
1912
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001913 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001914 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001915 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001916 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001917 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001918 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001919 DataType::QSymmS16,
1920 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001921 };
1922
1923 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1924 "Reference maximum: input 0 is not a supported type.");
1925
1926 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1927 "Reference maximum: input 1 is not a supported type.");
1928
1929 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1930 "Reference maximum: output is not a supported type.");
1931
1932 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1933 "Reference maximum: input 0 and Input 1 types are mismatched");
1934
1935 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1936 "Reference maximum: input and output types are mismatched");
1937
1938 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1939 "Reference maximum: shapes are not suitable for implicit broadcast.");
1940
1941 return supported;
saoste012df12b32018-11-28 16:57:20 +00001942}
1943
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001944bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1945 const TensorInfo& output,
1946 const MeanDescriptor& descriptor,
1947 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001948{
James Conroy4d1ff582019-06-10 17:06:39 +01001949 bool supported = true;
1950 std::string meanLayerStr = "Mean";
1951 std::string outputTensorStr = "output";
1952
Sadik Armagan303980c2020-04-17 12:45:14 +01001953 std::array<DataType,6> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001954 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001955 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01001956 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001957 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001958 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001959 DataType::QAsymmU8,
1960 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01001961 };
1962
1963 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1964 "Reference Mean: input type not supported.");
1965
1966 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1967 "Reference Mean: input and output types are mismatched");
1968
1969 if (descriptor.m_KeepDims)
1970 {
1971 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1972 reasonIfUnsupported,
1973 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1974 output.GetNumDimensions(),
1975 meanLayerStr, outputTensorStr).data());
1976 }
1977 else if (descriptor.m_Axis.empty())
1978 {
1979 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1980 reasonIfUnsupported,
1981 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1982 meanLayerStr, outputTensorStr).data());
1983 }
1984 else
1985 {
Matthew Sloyan171214c2020-09-09 09:07:37 +01001986 auto outputDim = input.GetNumDimensions() - armnn::numeric_cast<unsigned int>(descriptor.m_Axis.size());
James Conroy4d1ff582019-06-10 17:06:39 +01001987
1988 if (outputDim > 0)
1989 {
1990 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1991 reasonIfUnsupported,
1992 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1993 meanLayerStr, outputTensorStr).data());
1994 }
1995 else
1996 {
1997 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1998 reasonIfUnsupported,
1999 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
2000 meanLayerStr, outputTensorStr).data());
2001 }
2002 }
2003
2004 return supported;
narpra0132b90462018-09-13 11:07:48 +01002005}
2006
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00002007bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
2008 const TensorInfo &output,
2009 Optional<std::string &> reasonIfUnsupported) const
2010{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002011 bool supported = true;
2012
Sadik Armagan303980c2020-04-17 12:45:14 +01002013 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002014 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002015 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002016 DataType::Float32,
2017 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002018 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002019 DataType::QAsymmU8,
2020 DataType::QSymmS16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002021 DataType::Boolean
2022 };
2023
2024 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2025 "Reference MemCopy: input type not supported");
2026
2027 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2028 "Reference MemCopy: output type not supported");
2029
2030 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2031 "Reference MemCopy: input and output types are mismatched");
2032
2033 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00002034}
2035
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00002036bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
2037 const TensorInfo& input1,
2038 const TensorInfo& output,
2039 Optional<std::string&> reasonIfUnsupported) const
2040{
Sadik Armagan2999a022019-04-09 14:20:12 +01002041 bool supported = true;
2042
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002043 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002044 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01002045 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002046 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002047 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002048 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002049 DataType::QSymmS16,
2050 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002051 };
2052
2053 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2054 "Reference minimum: input 0 is not a supported type.");
2055
2056 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2057 "Reference minimum: input 1 is not a supported type.");
2058
2059 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2060 "Reference minimum: output is not a supported type.");
2061
2062 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2063 "Reference minimum: input 0 and Input 1 types are mismatched");
2064
2065 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2066 "Reference minimum: input and output types are mismatched");
2067
2068 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2069 "Reference minimum: shapes are not suitable for implicit broadcast.");
2070
2071 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00002072}
2073
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002074bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
2075 const TensorInfo& input1,
2076 const TensorInfo& output,
2077 Optional<std::string&> reasonIfUnsupported) const
2078{
Sadik Armagan2999a022019-04-09 14:20:12 +01002079 bool supported = true;
2080
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002081 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002082 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01002083 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002084 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002085 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002086 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002087 DataType::QSymmS16,
2088 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002089 };
2090
2091 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2092 "Reference multiplication: input 0 is not a supported type.");
2093
2094 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2095 "Reference multiplication: input 1 is not a supported type.");
2096
2097 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2098 "Reference multiplication: output is not a supported type.");
2099
2100 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2101 "Reference multiplication: input 0 and Input 1 types are mismatched");
2102
2103 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2104 "Reference multiplication: input and output types are mismatched");
2105
2106 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2107 "Reference multiplication: shapes are not suitable for implicit broadcast.");
2108
2109 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002110}
2111
2112bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
2113 const TensorInfo& output,
2114 const NormalizationDescriptor& descriptor,
2115 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01002116{
Jan Eilers8eb25602020-03-09 12:13:48 +00002117 IgnoreUnused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002118
2119 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01002120 std::array<DataType, 6> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002121 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002122 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002123 DataType::Float16,
2124 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002125 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002126 DataType::QAsymmU8,
2127 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002128 };
2129
2130 bool supported = true;
2131
2132 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2133 "Reference normalization: input type not supported.");
2134
2135 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2136 "Reference normalization: output type not supported.");
2137
2138 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2139 "Reference normalization: input and output shapes have different "
2140 "num total elements.");
2141
2142 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002143}
2144
Derek Lamberti901ea112019-12-10 22:07:09 +00002145bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
2146 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002147{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01002148 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002149}
2150
2151bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
2152 const TensorInfo& output,
2153 const PadDescriptor& descriptor,
2154 Optional<std::string&> reasonIfUnsupported) const
2155{
Jan Eilers8eb25602020-03-09 12:13:48 +00002156 IgnoreUnused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002157 bool supported = true;
2158
2159 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002160 std::array<DataType,6> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002161 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002162 DataType::BFloat16,
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002163 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002164 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002165 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002166 DataType::QAsymmU8,
2167 DataType::QSymmS16
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002168 };
2169
2170 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2171 "Reference pad: input is not a supported type.");
2172
2173 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2174 "Reference pad: output is not a supported type.");
2175
2176 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2177 "Reference pad: input and output types are mismatched.");
2178
2179 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01002180}
2181
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002182bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
2183 const TensorInfo& output,
2184 const PermuteDescriptor& descriptor,
2185 Optional<std::string&> reasonIfUnsupported) const
2186{
Jan Eilers8eb25602020-03-09 12:13:48 +00002187 IgnoreUnused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002188 bool supported = true;
2189
2190 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002191 std::array<DataType, 6> supportedTypes =
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002192 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002193 DataType::BFloat16,
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002194 DataType::Float32,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002195 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002196 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002197 DataType::QAsymmU8,
2198 DataType::QSymmS16
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002199 };
2200
2201 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2202 "Reference permute: input is not a supported type.");
2203
2204 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2205 "Reference permute: output is not a supported type.");
2206
2207 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2208 "Reference permute: input and output types are mismatched.");
2209
2210 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002211}
2212
2213bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
2214 const TensorInfo& output,
2215 const Pooling2dDescriptor& descriptor,
2216 Optional<std::string&> reasonIfUnsupported) const
2217{
Jan Eilers8eb25602020-03-09 12:13:48 +00002218 IgnoreUnused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01002219 bool supported = true;
2220
2221 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002222 std::array<DataType,6> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01002223 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002224 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01002225 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002226 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002227 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002228 DataType::QAsymmU8,
2229 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01002230 };
2231
2232 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2233 "Reference poolind2d: input is not a supported type.");
2234
2235 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2236 "Reference poolind2d: output is not a supported type.");
2237
2238 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2239 "Reference poolind2d: input and output types are mismatched.");
2240
2241 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002242}
2243
Tamás Nyíri7b885b32021-10-26 14:47:57 +01002244bool RefLayerSupport::IsPooling3dSupported(const TensorInfo& input,
2245 const TensorInfo& output,
2246 const Pooling3dDescriptor& descriptor,
2247 Optional<std::string&> reasonIfUnsupported) const
2248{
2249 IgnoreUnused(descriptor);
2250 bool supported = true;
2251
2252 // Define supported output and inputs types.
2253 std::array<DataType,6> supportedTypes =
2254 {
2255 DataType::BFloat16,
2256 DataType::Float32,
2257 DataType::Float16,
2258 DataType::QAsymmS8,
2259 DataType::QAsymmU8,
2260 DataType::QSymmS16
2261 };
2262
2263 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2264 "Reference poolind3d: input is not a supported type.");
2265
2266 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2267 "Reference poolind3d: output is not a supported type.");
2268
2269 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2270 "Reference poolind3d: input and output types are mismatched.");
2271
2272 return supported;
2273}
2274
2275
James Conroy4f1f8992020-04-29 20:01:10 +01002276bool RefLayerSupport::IsQLstmSupported(const TensorInfo& input,
2277 const TensorInfo& previousOutputIn,
2278 const TensorInfo& previousCellStateIn,
2279 const TensorInfo& outputStateOut,
2280 const TensorInfo& cellStateOut,
2281 const TensorInfo& output,
2282 const QLstmDescriptor& descriptor,
2283 const LstmInputParamsInfo& paramsInfo,
2284 Optional<std::string&> reasonIfUnsupported) const
2285{
2286 IgnoreUnused(input);
2287 IgnoreUnused(previousOutputIn);
2288 IgnoreUnused(previousCellStateIn);
2289 IgnoreUnused(outputStateOut);
2290 IgnoreUnused(cellStateOut);
2291 IgnoreUnused(output);
2292 IgnoreUnused(descriptor);
2293 IgnoreUnused(paramsInfo);
2294
2295 IgnoreUnused(reasonIfUnsupported);
2296
2297 return true;
2298}
2299
Derek Lamberti5f400d62019-03-25 15:41:58 +00002300bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
2301 const TensorInfo& output,
2302 Optional<std::string&> reasonIfUnsupported) const
2303{
2304 bool supported = true;
2305
Finn Williamsfd271062019-12-04 14:27:27 +00002306 // Define supported input types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002307 std::array<DataType,7> supportedInputTypes = {
2308 DataType::BFloat16,
Keith Davis5e51cd82020-01-29 16:52:59 +00002309 DataType::Float32,
Keith Davis3d8bc972020-02-04 09:31:47 +00002310 DataType::Float16,
Ryan OShea9add1202020-02-07 10:06:33 +00002311 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002312 DataType::QAsymmU8,
2313 DataType::QSymmS8,
2314 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002315 };
2316
2317 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
2318 "Reference quantize: input type not supported.");
2319
2320 // Define supported output types.
Ryan OShea9add1202020-02-07 10:06:33 +00002321 std::array<DataType,4> supportedOutputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00002322 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002323 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00002324 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002325 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002326 };
2327 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2328 "Reference quantize: output type not supported.");
2329
2330 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2331 "Reference quantize: input and output shapes have different num total elements.");
2332
2333 return supported;
2334}
2335
Finn Williams2605b232020-06-10 15:53:46 +01002336bool RefLayerSupport::IsRankSupported(const TensorInfo& input,
2337 const TensorInfo& output,
2338 Optional<std::string&> reasonIfUnsupported) const
2339{
2340 IgnoreUnused(input);
2341 // Define supported output types.
2342 std::array<DataType,1> supportedOutputTypes =
2343 {
2344 DataType::Signed32,
2345 };
2346
2347 return CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2348 "Reference rank: input type not supported.");
2349}
2350
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00002351bool RefLayerSupport::IsReduceSupported(const TensorInfo& input,
2352 const TensorInfo& output,
2353 const ReduceDescriptor& descriptor,
2354 Optional<std::string&> reasonIfUnsupported) const
2355{
2356 IgnoreUnused(descriptor);
2357 bool supported = true;
2358 std::array<DataType,7> supportedTypes =
2359 {
2360 DataType::BFloat16,
2361 DataType::Float32,
2362 DataType::Float16,
2363 DataType::QAsymmS8,
2364 DataType::QAsymmU8,
2365 DataType::QSymmS16,
2366 DataType::Signed32
2367 };
2368
2369 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2370 "Reference Reduce: input type not supported");
2371
2372 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2373 "Reference Reduce: output type not supported");
2374
2375 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2376 "Reference Reduce: input and output types not matching");
2377
2378 return supported;
2379}
2380
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002381bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Kevin Maya023c402019-12-12 17:28:05 +00002382 const TensorInfo& output,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00002383 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002384 Optional<std::string&> reasonIfUnsupported) const
2385{
Jan Eilers8eb25602020-03-09 12:13:48 +00002386 IgnoreUnused(output);
2387 IgnoreUnused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01002388 // Define supported output types.
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002389 std::array<DataType,8> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01002390 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002391 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01002392 DataType::Float32,
2393 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01002394 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002395 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002396 DataType::QAsymmU8,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002397 DataType::QSymmS16,
2398 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01002399 };
Keith Davis0c2eeac2020-02-11 16:51:50 +00002400
Nina Drozd2f2778f2019-05-27 10:37:05 +01002401 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
2402 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002403}
2404
Teresa Charlin970f43b2019-07-01 13:51:07 +01002405bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
2406 const TensorInfo& output,
2407 const ResizeDescriptor& descriptor,
2408 Optional<std::string&> reasonIfUnsupported) const
2409{
Jan Eilers8eb25602020-03-09 12:13:48 +00002410 IgnoreUnused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01002411 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002412 std::array<DataType,6> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01002413 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002414 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01002415 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002416 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002417 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002418 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002419 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01002420 };
2421
2422 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2423 "Reference Resize: input type not supported");
2424
2425 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2426 "Reference Resize: output type not supported");
2427
2428 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2429 "Reference Resize: input and output types not matching");
2430
2431 return supported;
2432}
2433
Keith Davis3ae3f972021-05-21 16:33:48 +01002434bool RefLayerSupport::IsShapeSupported(const TensorInfo& input,
2435 const TensorInfo& output,
2436 Optional<std::string&> reasonIfUnsupported) const
2437{
2438 IgnoreUnused(input);
2439 bool supported = true;
2440
2441 std::array<DataType, 1> supportedTypes =
2442 {
2443 DataType::Signed32
2444 };
2445
2446 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2447 "Reference Shape: output type not supported");
2448
2449 return supported;
2450}
2451
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002452bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
2453 const TensorInfo& output,
2454 const SliceDescriptor& descriptor,
2455 Optional<std::string&> reasonIfUnsupported) const
2456{
Jan Eilers8eb25602020-03-09 12:13:48 +00002457 IgnoreUnused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002458 bool supported = true;
2459
Sadik Armagan303980c2020-04-17 12:45:14 +01002460 std::array<DataType, 5> supportedTypes =
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002461 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002462 DataType::BFloat16,
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002463 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002464 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002465 DataType::QAsymmU8,
2466 DataType::QSymmS16
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002467 };
2468
2469 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2470 "Reference Slice: input type not supported");
2471
2472 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2473 "Reference Slice: output type not supported");
2474
2475 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2476 "Reference Slice: input and output types are mismatched");
2477
2478 return supported;
2479}
2480
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002481bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
2482 const TensorInfo& output,
2483 const SoftmaxDescriptor& descriptor,
2484 Optional<std::string&> reasonIfUnsupported) const
2485{
Jan Eilers8eb25602020-03-09 12:13:48 +00002486 IgnoreUnused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01002487 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002488 std::array<DataType,7> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01002489 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002490 DataType::BFloat16,
2491 DataType::Float32,
2492 DataType::Float16,
2493 DataType::QSymmS8,
2494 DataType::QAsymmS8,
2495 DataType::QAsymmU8,
2496 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +01002497 };
2498
2499 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002500 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002501
2502 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002503 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002504
2505 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002506 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002507
2508 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002509}
2510
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002511bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
2512 const TensorInfo& output,
2513 const SpaceToBatchNdDescriptor& descriptor,
2514 Optional<std::string&> reasonIfUnsupported) const
2515{
Jan Eilers8eb25602020-03-09 12:13:48 +00002516 IgnoreUnused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01002517 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002518 std::array<DataType,6> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01002519 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002520 DataType::BFloat16,
2521 DataType::Float32,
2522 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002523 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002524 DataType::QAsymmU8,
2525 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01002526 };
2527
2528 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2529 "Reference SpaceToBatchNd: input type not supported");
2530
2531 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2532 "Reference SpaceToBatchNd: output type not supported");
2533
2534 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2535 "Reference SpaceToBatchNd: input and output types are mismatched");
2536
2537 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002538}
2539
Keith Davisa57eccb2019-06-14 17:33:22 +01002540bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01002541 const TensorInfo& output,
2542 const SpaceToDepthDescriptor& descriptor,
2543 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01002544{
2545
Jan Eilers8eb25602020-03-09 12:13:48 +00002546 IgnoreUnused(descriptor);
Keith Davisa57eccb2019-06-14 17:33:22 +01002547 bool supported = true;
2548
Sadik Armagan303980c2020-04-17 12:45:14 +01002549 std::array<DataType,6> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01002550 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002551 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01002552 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002553 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002554 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002555 DataType::QAsymmU8,
2556 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01002557 };
2558
2559 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2560 "Reference SpaceToDepth: input type not supported");
2561
2562 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2563 "Reference SpaceToDepth: output type not supported");
2564
2565 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2566 "Reference SpaceToDepth: input and output types are mismatched");
2567
2568 return supported;
2569}
2570
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002571bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002572 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
2573 const ViewsDescriptor& descriptor,
2574 Optional<std::string&> reasonIfUnsupported) const
2575{
Jan Eilers8eb25602020-03-09 12:13:48 +00002576 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002577 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002578 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002579 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002580 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002581 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002582 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002583 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002584 DataType::QAsymmU8,
2585 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002586 };
2587
2588 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2589 "Reference splitter: output type not supported");
Derek Lambertieac4adb2020-08-25 13:05:59 +01002590 for (const TensorInfo& output : outputs)
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002591 {
2592 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2593 "Reference splitter: input type not supported");
2594
2595 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2596 "Reference splitter: input and output types mismatched.");
2597 }
2598
2599 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002600}
2601
Matthew Jackson81e601c2019-07-11 12:07:09 +01002602bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
2603 const TensorInfo& output,
2604 const StackDescriptor& descriptor,
2605 Optional<std::string&> reasonIfUnsupported) const
2606{
Jan Eilers8eb25602020-03-09 12:13:48 +00002607 IgnoreUnused(descriptor);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002608
2609 bool supported = true;
Sadik Armagan529195f2022-01-14 12:56:35 +00002610 std::array<DataType,7> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01002611 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002612 DataType::BFloat16,
Matthew Jackson81e601c2019-07-11 12:07:09 +01002613 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01002614 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002615 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002616 DataType::QAsymmU8,
Sadik Armagan529195f2022-01-14 12:56:35 +00002617 DataType::QSymmS16,
2618 DataType::Signed32
Matthew Jackson81e601c2019-07-11 12:07:09 +01002619 };
2620
2621 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2622 "Reference stack: output type not supported");
2623 for (const TensorInfo* input : inputs)
2624 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01002625 ARMNN_ASSERT(input != nullptr);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002626 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
2627 "Reference stack: input type not supported");
2628
2629 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
2630 "Reference stack: input and output types mismatched.");
2631 }
2632
2633 return supported;
2634}
2635
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002636bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
2637 const TensorInfo& output,
2638 const StridedSliceDescriptor& descriptor,
2639 Optional<std::string&> reasonIfUnsupported) const
2640{
Jan Eilers8eb25602020-03-09 12:13:48 +00002641 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002642 bool supported = true;
2643
Sadik Armagan303980c2020-04-17 12:45:14 +01002644 std::array<DataType,5> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002645 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002646 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002647 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002648 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002649 DataType::QAsymmU8,
2650 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002651 };
2652
2653 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2654 "Reference StridedSlice: input type not supported");
2655
2656 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2657 "Reference StridedSlice: output type not supported");
2658
2659 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2660 "Reference StridedSlice: input and output types are mismatched");
2661
2662 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002663}
2664
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002665bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
2666 const TensorInfo& input1,
2667 const TensorInfo& output,
2668 Optional<std::string&> reasonIfUnsupported) const
2669{
Sadik Armagan2999a022019-04-09 14:20:12 +01002670 bool supported = true;
2671
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002672 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002673 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01002674 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002675 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002676 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002677 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002678 DataType::QSymmS16,
2679 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002680 };
2681
2682 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2683 "Reference subtraction: input 0 is not a supported type.");
2684
2685 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2686 "Reference subtraction: input 1 is not a supported type.");
2687
2688 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2689 "Reference subtraction: output is not a supported type.");
2690
2691 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2692 "Reference subtraction: input 0 and Input 1 types are mismatched");
2693
2694 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2695 "Reference subtraction: input and output types are mismatched");
2696
2697 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2698 "Reference subtraction: shapes are not suitable for implicit broadcast.");
2699
2700 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002701}
2702
Matteo Martincighab9e5252019-06-13 17:27:46 +01002703bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
2704 const TensorInfo& alpha,
2705 const TensorInfo& output,
2706 Optional<std::string&> reasonIfUnsupported) const
2707{
2708 bool supported = true;
2709
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002710 std::array<DataType, 6> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01002711 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002712 DataType::BFloat16,
Matteo Martincighab9e5252019-06-13 17:27:46 +01002713 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002714 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002715 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002716 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002717 DataType::QSymmS16
Matteo Martincighab9e5252019-06-13 17:27:46 +01002718 };
2719
2720 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2721 "PReLU: input is not a supported type.");
2722
2723 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
2724 "PReLU: alpha is not a supported type.");
2725
2726 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2727 "PReLU: output is not a supported type.");
2728
2729 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
2730 "PReLU: input, alpha and output types are mismatched");
2731
2732 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
2733 "PReLU: shapes are not suitable for implicit broadcast");
2734
2735 return supported;
2736}
2737
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002738bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
2739 const TensorInfo& output,
2740 const TransposeConvolution2dDescriptor& descriptor,
2741 const TensorInfo& weights,
2742 const Optional<TensorInfo>& biases,
2743 Optional<std::string&> reasonIfUnsupported) const
2744{
Jan Eilers8eb25602020-03-09 12:13:48 +00002745 IgnoreUnused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002746 bool supported = true;
2747
Sadik Armagan303980c2020-04-17 12:45:14 +01002748 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002749 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002750 DataType::BFloat16,
2751 DataType::Float32,
2752 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002753 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002754 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002755 DataType::QSymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002756 DataType::QSymmS16
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002757 };
2758
2759 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2760 "Reference TransposeConvolution2d: input is not a supported type.");
2761
2762 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2763 "Reference TransposeConvolution2d: output is not a supported type.");
2764
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002765 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2766 "Reference TransposeConvolution2d: input and output types mismatched.");
2767
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002768
2769 const DataType inputType = input.GetDataType();
Sadik Armagan303980c2020-04-17 12:45:14 +01002770 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002771 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01002772 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002773 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002774 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002775 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01002776 DataType::QSymmS8
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002777 };
2778
2779 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
2780 "Reference TransposeConvolution2d: weights type not supported for "
2781 "quantized input.");
2782 }
2783 else
2784 {
2785 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
2786 "Reference TransposeConvolution2d: weights is not a supported type.");
2787
2788 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
2789 "Reference TransposeConvolution2d: input and weights types mismatched.");
2790 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002791
2792 if (biases.has_value())
2793 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002794 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01002795 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002796 DataType::BFloat16,
2797 DataType::Float32,
2798 DataType::Float16,
2799 DataType::Signed32
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002800 };
2801 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
2802 "Reference TransposeConvolution2d: biases is not a supported type.");
2803 }
2804
2805 return supported;
2806}
2807
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002808bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
2809 const TensorInfo& output,
2810 const TransposeDescriptor& descriptor,
2811 Optional<std::string&> reasonIfUnsupported) const
2812{
Jan Eilers8eb25602020-03-09 12:13:48 +00002813 IgnoreUnused(descriptor);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002814 bool supported = true;
2815
2816 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002817 std::array<DataType, 6> supportedTypes =
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002818 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002819 DataType::BFloat16,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002820 DataType::Float32,
2821 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002822 DataType::QAsymmS8,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002823 DataType::QAsymmU8,
2824 DataType::QSymmS16
2825 };
2826
2827 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2828 "Reference transpose: input is not a supported type.");
2829
2830 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2831 "Reference transpose: output is not a supported type.");
2832
2833 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2834 "Reference transpose: input and output types are mismatched.");
2835
2836 return supported;
2837}
2838
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002839bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
2840 const TensorInfo& input,
2841 const TensorInfo& outputStateIn,
2842 const TensorInfo& cellStateIn,
2843 const TensorInfo& output,
2844 const Optional<TensorInfo>& hiddenStateOutput,
2845 const Optional<TensorInfo>& cellStateOutput,
2846 const UnidirectionalSequenceLstmDescriptor& descriptor,
2847 const LstmInputParamsInfo& paramsInfo,
2848 Optional<std::string&> reasonIfUnsupported) const
2849{
2850 IgnoreUnused(descriptor);
2851 IgnoreUnused(paramsInfo);
2852 IgnoreUnused(outputStateIn);
2853 IgnoreUnused(cellStateIn);
2854 bool supported = true;
2855
2856 if (hiddenStateOutput.has_value() || cellStateOutput.has_value())
2857 {
2858 reasonIfUnsupported.value() += "Reference UnidirectionalSequenceLstm: hidden state output "
2859 "and cell state output are not supported at the moment.";
2860 }
2861
2862 std::array<DataType, 1> supportedTypes =
2863 {
2864 DataType::Float32
2865 };
2866
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002867 std::array<DataType, 2> supportedWeightTypes =
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002868 {
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002869 DataType::Float32,
2870 DataType::QAsymmS8
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002871 };
2872
2873 // check inputs and outputs
2874 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2875 "Reference UnidirectionalSequenceLstm: input is not a supported type.");
2876 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
2877 "Reference UnidirectionalSequenceLstm: input and outputStateIn types are mismatched");
2878 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
2879 "Reference UnidirectionalSequenceLstm: input and cellStateIn types are mismatched");
2880
2881 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2882 "Reference UnidirectionalSequenceLstm: input and output types are mismatched");
2883 // check layer parameters
2884 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes),
2885 reasonIfUnsupported,
2886 "Reference UnidirectionalSequenceLstm: InputToForgetWeights "
2887 "is not a supported type.");
2888 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToCellWeights(), supportedWeightTypes),
2889 reasonIfUnsupported,
2890 "Reference UnidirectionalSequenceLstm: InputToCellWeights is not a supported type.");
2891 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToOutputWeights(), supportedWeightTypes),
2892 reasonIfUnsupported,
2893 "Reference UnidirectionalSequenceLstm: InputToOutputWeights "
2894 "is not a supported type.");
2895 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToForgetWeights(), supportedWeightTypes),
2896 reasonIfUnsupported,
2897 "Reference UnidirectionalSequenceLstm: RecurrentToForgetWeights "
2898 "is not a supported type.");
2899 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToCellWeights(), supportedWeightTypes),
2900 reasonIfUnsupported,
2901 "Reference UnidirectionalSequenceLstm: RecurrentToCellWeights "
2902 "is not a supported type.");
2903 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToOutputWeights(), supportedWeightTypes),
2904 reasonIfUnsupported,
2905 "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights "
2906 "is not a supported type.");
2907 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
2908 "Reference UnidirectionalSequenceLstm: input and ForgetGateBias types "
2909 "are mismatched");
2910 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
2911 "Reference UnidirectionalSequenceLstm: input and CellBias types are mismatched");
2912 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
2913 "Reference UnidirectionalSequenceLstm: input and OutputGateBias types "
2914 "are mismatched");
2915 if (!descriptor.m_CifgEnabled)
2916 {
2917 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes),
2918 reasonIfUnsupported,
2919 "Reference UnidirectionalSequenceLstm: InputToInputWeights "
2920 "is not a supported type.");
2921 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToInputWeights(), supportedWeightTypes),
2922 reasonIfUnsupported,
2923 "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights "
2924 "is not a supported type.");
2925 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
2926 "Reference UnidirectionalSequenceLstm: input and InputGateBias types "
2927 "are mismatched");
2928 if (descriptor.m_PeepholeEnabled)
2929 {
2930 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes),
2931 reasonIfUnsupported,
2932 "Reference UnidirectionalSequenceLstm: CellToInputWeights "
2933 "is not a supported type.");
2934 }
2935 }
2936 if (descriptor.m_PeepholeEnabled)
2937 {
2938 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToForgetWeights(), supportedWeightTypes),
2939 reasonIfUnsupported,
2940 "Reference UnidirectionalSequenceLstm: CellToForgetWeights "
2941 "is not a supported type.");
2942 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToOutputWeights(), supportedWeightTypes),
2943 reasonIfUnsupported,
2944 "Reference UnidirectionalSequenceLstm: CellToOutputWeights "
2945 "is not a supported type.");
2946 }
2947 if (descriptor.m_ProjectionEnabled)
2948 {
2949 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetProjectionWeights(), supportedWeightTypes),
2950 reasonIfUnsupported,
2951 "Reference UnidirectionalSequenceLstm: ProjectionWeights "
2952 "is not a supported type.");
2953 if (paramsInfo.m_ProjectionBias != nullptr)
2954 {
2955 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
2956 "Reference UnidirectionalSequenceLstm: input and ProjectionBias types "
2957 "are mismatched");
2958 }
2959 }
2960 if (descriptor.m_LayerNormEnabled)
2961 {
2962 if (!descriptor.m_CifgEnabled)
2963 {
2964 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputLayerNormWeights(), supportedWeightTypes),
2965 reasonIfUnsupported,
2966 "Reference UnidirectionalSequenceLstm: InputLayerNormWeights "
2967 "is not a supported type.");
2968 }
2969 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetLayerNormWeights(), supportedWeightTypes),
2970 reasonIfUnsupported,
2971 "Reference UnidirectionalSequenceLstm: ForgetLayerNormWeights "
2972 "is not a supported type.");
2973 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellLayerNormWeights(), supportedWeightTypes),
2974 reasonIfUnsupported,
2975 "Reference UnidirectionalSequenceLstm: CellLayerNormWeights "
2976 "is not a supported type.");
2977 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputLayerNormWeights(), supportedWeightTypes),
2978 reasonIfUnsupported,
2979 "Reference UnidirectionalSequenceLstm: OutputLayerNormWeights "
2980 "is not a supported type.");
2981 }
2982
2983 return supported;
2984}
2985
arovir011c7c81b2018-10-08 11:34:28 +01002986} // namespace armnn