blob: f5798c886f427a4cdf070ae17e83a3ba71950345 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
Teresa Charlin52664732020-06-29 16:27:03 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01007
Keith Davis0c2eeac2020-02-11 16:51:50 +00008#include <armnn/TypesUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000010#include <armnn/utility/IgnoreUnused.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010011#include <armnn/utility/NumericCast.hpp>
Cathal Corbett34b429c2021-12-24 12:24:40 +000012#include <armnn/utility/PolymorphicDowncast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
Matteo Martincighe011d202019-11-28 11:35:47 +000014#include <LayerSupportCommon.hpp>
Derek Lambertif674aa02019-08-01 15:56:25 +010015#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000016
Derek Lamberti50db4e82019-03-13 14:16:15 +000017#include <vector>
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <array>
19
telsoa014fcda012018-03-09 14:13:49 +000020namespace armnn
21{
22
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010023namespace
24{
25
26template<typename Float32Func, typename Uint8Func, typename ... Params>
27bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
28 DataType dataType,
29 Float32Func floatFuncPtr,
30 Uint8Func uint8FuncPtr,
31 Params&&... params)
32{
33 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
34 dataType,
35 &FalseFunc<Params...>,
36 floatFuncPtr,
37 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000038 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000039 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010040 std::forward<Params>(params)...);
41}
42
43} // anonymous namespace
44
James Conroy4d1ff582019-06-10 17:06:39 +010045namespace
46{
47
48std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
49 unsigned int actual,
50 std::string& layerStr,
51 std::string& tensorName)
52{
53 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
54 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
55
56 return errorMsg;
57}
58
59} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000060
Cathal Corbett34b429c2021-12-24 12:24:40 +000061bool RefLayerSupport::IsLayerSupported(const LayerType& type,
62 const std::vector<TensorInfo>& infos,
63 const BaseDescriptor& descriptor,
64 const Optional<LstmInputParamsInfo>& lstmParamsInfo,
65 const Optional<QuantizedLstmInputParamsInfo>& quantizedLstmInputParamsInfo,
66 Optional<std::string&> reasonIfUnsupported) const
67{
68 switch (type)
69 {
70 case LayerType::Activation:
71 return IsActivationSupported(infos[0],
72 infos[1],
73 *(PolymorphicDowncast<const ActivationDescriptor*>(&descriptor)),
74 reasonIfUnsupported);
75 case LayerType::Addition:
76 return IsAdditionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
77 case LayerType::ArgMinMax:
78 return IsArgMinMaxSupported(infos[0],
79 infos[1],
80 *(PolymorphicDowncast<const ArgMinMaxDescriptor*>(&descriptor)),
81 reasonIfUnsupported);
82 case LayerType::BatchNormalization:
83 return IsBatchNormalizationSupported(infos[0],
84 infos[1],
85 infos[2],
86 infos[3],
87 infos[4],
88 infos[5],
89 *(PolymorphicDowncast<const BatchNormalizationDescriptor*>
90 (&descriptor)),
91 reasonIfUnsupported);
92 case LayerType::BatchToSpaceNd:
93 return IsBatchToSpaceNdSupported(infos[0],
94 infos[1],
95 *(PolymorphicDowncast<const BatchToSpaceNdDescriptor*>(&descriptor)),
96 reasonIfUnsupported);
97 case LayerType::Comparison:
98 return IsComparisonSupported(infos[0],
99 infos[1],
100 infos[2],
101 *(PolymorphicDowncast<const ComparisonDescriptor*>(&descriptor)),
102 reasonIfUnsupported);
103 case LayerType::Concat:
104 {
105 std::vector<const TensorInfo*> inputInfos;
106 for (uint32_t i = 0; i < (infos.size() - 1); i++)
107 {
108 inputInfos.push_back(&infos[i]);
109 }
110 return IsConcatSupported(inputInfos,
111 infos[infos.size() - 1],
112 *(PolymorphicDowncast<const OriginsDescriptor*>(&descriptor)),
113 reasonIfUnsupported);
114 }
115 case LayerType::Constant:
116 return IsConstantSupported(infos[0], reasonIfUnsupported);
117 case LayerType::ConvertBf16ToFp32:
118 return IsConvertBf16ToFp32Supported(infos[0], infos[1], reasonIfUnsupported);
119 case LayerType::ConvertFp16ToFp32:
120 return IsConvertFp16ToFp32Supported(infos[0], infos[1], reasonIfUnsupported);
121 case LayerType::ConvertFp32ToBf16:
122 return IsConvertFp32ToBf16Supported(infos[0], infos[1], reasonIfUnsupported);
123 case LayerType::ConvertFp32ToFp16:
124 return IsConvertFp32ToFp16Supported(infos[0], infos[1], reasonIfUnsupported);
125 case LayerType::Convolution2d:
126 {
127 if (infos.size() != 4)
128 {
129 throw InvalidArgumentException("Invalid number of Convolution2d TensorInfos. "
130 "TensorInfos should be of format: {input, output, weights, biases}.");
131 }
132
133 auto desc = *(PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor));
134 if (infos[3] == TensorInfo())
135 {
136 return IsConvolution2dSupported(infos[0],
137 infos[1],
138 desc,
139 infos[2],
140 EmptyOptional(),
141 reasonIfUnsupported);
142 }
143 else
144 {
145 return IsConvolution2dSupported(infos[0],
146 infos[1],
147 desc,
148 infos[2],
149 infos[3],
150 reasonIfUnsupported);
151 }
152 }
153 case LayerType::DepthToSpace:
154 return IsDepthToSpaceSupported(infos[0],
155 infos[1],
156 *(PolymorphicDowncast<const DepthToSpaceDescriptor*>(&descriptor)),
157 reasonIfUnsupported);
158 case LayerType::DepthwiseConvolution2d:
159 {
160 if (infos.size() != 4)
161 {
162 throw InvalidArgumentException("Invalid number of DepthwiseConvolution2d TensorInfos. "
163 "TensorInfos should be of format: {input, output, weights, biases}.");
164 }
165
166 auto desc = *(PolymorphicDowncast<const DepthwiseConvolution2dDescriptor*>(&descriptor));
167 if (infos[3] == TensorInfo())
168 {
169 return IsDepthwiseConvolutionSupported(infos[0],
170 infos[1],
171 desc,
172 infos[2],
173 EmptyOptional(),
174 reasonIfUnsupported);
175 }
176 else
177 {
178 return IsDepthwiseConvolutionSupported(infos[0],
179 infos[1],
180 desc,
181 infos[2],
182 infos[3],
183 reasonIfUnsupported);
184 }
185 }
186 case LayerType::Dequantize:
187 return IsDequantizeSupported(infos[0], infos[1], reasonIfUnsupported);
188 case LayerType::Division:
189 return IsDivisionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
190 case LayerType::ElementwiseUnary:
191 return IsElementwiseUnarySupported(infos[0],
192 infos[1],
193 *(PolymorphicDowncast<const ElementwiseUnaryDescriptor*>(&descriptor)),
194 reasonIfUnsupported);
195 case LayerType::Fill:
196 return IsFillSupported(infos[0],
197 infos[1],
198 *(PolymorphicDowncast<const FillDescriptor*>(&descriptor)),
199 reasonIfUnsupported);
200 case LayerType::Floor:
201 return IsFloorSupported(infos[0], infos[1], reasonIfUnsupported);
202 case LayerType::FullyConnected:
203 return IsFullyConnectedSupported(infos[0],
204 infos[1],
205 infos[2],
206 infos[3],
207 *(PolymorphicDowncast<const FullyConnectedDescriptor*>(&descriptor)),
208 reasonIfUnsupported);
209 case LayerType::Gather:
210 return IsGatherSupported(infos[0],
211 infos[1],
212 infos[2],
213 *(PolymorphicDowncast<const GatherDescriptor*>(&descriptor)),
214 reasonIfUnsupported);
215 case LayerType::Input:
216 return IsInputSupported(infos[0], reasonIfUnsupported);
217 case LayerType::InstanceNormalization:
218 return IsInstanceNormalizationSupported(infos[0],
219 infos[1],
220 *(PolymorphicDowncast<const InstanceNormalizationDescriptor*>
221 (&descriptor)),
222 reasonIfUnsupported);
223 case LayerType::L2Normalization:
224 return IsL2NormalizationSupported(infos[0],
225 infos[1],
226 *(PolymorphicDowncast<const L2NormalizationDescriptor*>(&descriptor)),
227 reasonIfUnsupported);
228 case LayerType::LogicalBinary:
229 return IsLogicalBinarySupported(infos[0],
230 infos[1],
231 infos[2],
232 *(PolymorphicDowncast<const LogicalBinaryDescriptor*>(&descriptor)),
233 reasonIfUnsupported);
234 case LayerType::LogSoftmax:
235 return IsLogSoftmaxSupported(infos[0],
236 infos[1],
237 *(PolymorphicDowncast<const LogSoftmaxDescriptor*>(&descriptor)),
238 reasonIfUnsupported);
239 case LayerType::Lstm:
240 return IsLstmSupported(infos[0],
241 infos[1],
242 infos[2],
243 infos[3],
244 infos[4],
245 infos[5],
246 infos[6],
247 *(PolymorphicDowncast<const LstmDescriptor*>(&descriptor)),
248 lstmParamsInfo.value(),
249 reasonIfUnsupported);
250 case LayerType::QLstm:
251 return IsQLstmSupported(infos[0],
252 infos[1],
253 infos[2],
254 infos[3],
255 infos[4],
256 infos[5],
257 *(PolymorphicDowncast<const QLstmDescriptor*>(&descriptor)),
258 lstmParamsInfo.value(),
259 reasonIfUnsupported);
260 case LayerType::Maximum:
261 return IsMaximumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
262 case LayerType::Mean:
263 return IsMeanSupported(infos[0],
264 infos[1],
265 *(PolymorphicDowncast<const MeanDescriptor*>(&descriptor)),
266 reasonIfUnsupported);
267 case LayerType::Minimum:
268 return IsMinimumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
269 case LayerType::Multiplication:
270 return IsMultiplicationSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
271 case LayerType::Normalization:
272 return IsNormalizationSupported(infos[0],
273 infos[1],
274 *(PolymorphicDowncast<const NormalizationDescriptor*>(&descriptor)),
275 reasonIfUnsupported);
276 case LayerType::Output:
277 return IsOutputSupported(infos[0], reasonIfUnsupported);
278 case LayerType::Pad:
279 return IsPadSupported(infos[0],
280 infos[1],
281 *(PolymorphicDowncast<const PadDescriptor*>(&descriptor)),
282 reasonIfUnsupported);
283 case LayerType::Permute:
284 return IsPermuteSupported(infos[0],
285 infos[1],
286 *(PolymorphicDowncast<const PermuteDescriptor*>(&descriptor)),
287 reasonIfUnsupported);
288 case LayerType::Pooling2d:
289 return IsPooling2dSupported(infos[0],
290 infos[1],
291 *(PolymorphicDowncast<const Pooling2dDescriptor*>(&descriptor)),
292 reasonIfUnsupported);
293 case LayerType::Prelu:
294 return IsPreluSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
295 case LayerType::Quantize:
296 return IsQuantizeSupported(infos[0], infos[1], reasonIfUnsupported);
297 case LayerType::Reshape:
298 return IsReshapeSupported(infos[0],
299 infos[1],
300 *(PolymorphicDowncast<const ReshapeDescriptor*>(&descriptor)),
301 reasonIfUnsupported);
302 case LayerType::Resize:
303 return IsResizeSupported(infos[0],
304 infos[1],
305 *(PolymorphicDowncast<const ResizeDescriptor*>(&descriptor)),
306 reasonIfUnsupported);
307 case LayerType::Reduce:
308 return IsReduceSupported(infos[0],
309 infos[1],
310 *(PolymorphicDowncast<const ReduceDescriptor*>(&descriptor)),
311 reasonIfUnsupported);
312 case LayerType::Slice:
313 return IsSliceSupported(infos[0],
314 infos[1],
315 *(PolymorphicDowncast<const SliceDescriptor*>(&descriptor)),
316 reasonIfUnsupported);
317 case LayerType::Softmax:
318 return IsSoftmaxSupported(infos[0],
319 infos[1],
320 *(PolymorphicDowncast<const SoftmaxDescriptor*>(&descriptor)),
321 reasonIfUnsupported);
322 case LayerType::SpaceToBatchNd:
323 return IsSpaceToBatchNdSupported(infos[0],
324 infos[1],
325 *(PolymorphicDowncast<const SpaceToBatchNdDescriptor*>(&descriptor)),
326 reasonIfUnsupported);
327 case LayerType::SpaceToDepth:
328 return IsSpaceToDepthSupported(infos[0],
329 infos[1],
330 *(PolymorphicDowncast<const SpaceToDepthDescriptor*>(&descriptor)),
331 reasonIfUnsupported);
332 case LayerType::Splitter:
333 {
334 std::vector<TensorInfo> outputInfos;
335 for (uint32_t i = 1; i < infos.size(); i++)
336 {
337 outputInfos.push_back(infos[i]);
338 }
339 return IsSplitterSupported(infos[0],
340 {outputInfos.begin(), outputInfos.end()},
341 *(PolymorphicDowncast<const ViewsDescriptor*>(&descriptor)),
342 reasonIfUnsupported);
343 }
344 case LayerType::Stack:
345 {
346 std::vector<const TensorInfo*> inputInfos;
347 for (uint32_t i = 0; i < infos.size() - 1; i++)
348 {
349 inputInfos.push_back(&infos[i]);
350 }
351 return IsStackSupported(inputInfos,
352 infos[infos.size() - 1],
353 *(PolymorphicDowncast<const StackDescriptor*>(&descriptor)),
354 reasonIfUnsupported);
355 }
356 case LayerType::StridedSlice:
357 return IsStridedSliceSupported(infos[0],
358 infos[1],
359 *(PolymorphicDowncast<const StridedSliceDescriptor*>(&descriptor)),
360 reasonIfUnsupported);
361 case LayerType::Subtraction:
362 return IsSubtractionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
363 case LayerType::Transpose:
364 return IsTransposeSupported(infos[0],
365 infos[1],
366 *(PolymorphicDowncast<const TransposeDescriptor*>(&descriptor)),
367 reasonIfUnsupported);
368 case LayerType::TransposeConvolution2d:
369 {
370 if (infos.size() != 4)
371 {
372 throw InvalidArgumentException("Invalid number of TransposeConvolution2d TensorInfos. "
373 "TensorInfos should be of format: {input, output, weights, biases}.");
374 }
375
376 auto desc = *(PolymorphicDowncast<const TransposeConvolution2dDescriptor*>(&descriptor));
377 if (infos[3] == TensorInfo())
378 {
379 return IsTransposeConvolution2dSupported(infos[0],
380 infos[1],
381 desc,
382 infos[2],
383 EmptyOptional(),
384 reasonIfUnsupported);
385 }
386 else
387 {
388 return IsTransposeConvolution2dSupported(infos[0],
389 infos[1],
390 desc,
391 infos[2],
392 infos[3],
393 reasonIfUnsupported);
394 }
395 }
396 case LayerType::Cast:
397 return IsCastSupported(infos[0], infos[1], reasonIfUnsupported);
398 case LayerType::ChannelShuffle:
399 return IsChannelShuffleSupported(infos[0],
400 infos[1],
401 *(PolymorphicDowncast<const ChannelShuffleDescriptor*>(&descriptor)),
402 reasonIfUnsupported);
403 case LayerType::Convolution3d:
404 {
405 if (infos.size() != 4)
406 {
407 throw InvalidArgumentException("Invalid number of Convolution3d TensorInfos. "
408 "TensorInfos should be of format: {input, output, weights, biases}.");
409 }
410
411 auto desc = *(PolymorphicDowncast<const Convolution3dDescriptor*>(&descriptor));
412 if (infos[3] == TensorInfo())
413 {
414 return IsConvolution3dSupported(infos[0],
415 infos[1],
416 desc,
417 infos[2],
418 EmptyOptional(),
419 reasonIfUnsupported);
420 }
421 else
422 {
423 return IsConvolution3dSupported(infos[0],
424 infos[1],
425 desc,
426 infos[2],
427 infos[3],
428 reasonIfUnsupported);
429 }
430 }
431 case LayerType::Debug:
432 return IsDebugSupported(infos[0], infos[1], reasonIfUnsupported);
433 case LayerType::DetectionPostProcess:
434 return IsDetectionPostProcessSupported(infos[0],
435 infos[1],
436 infos[2],
437 infos[3],
438 infos[4],
439 infos[5],
440 infos[6],
441 *(PolymorphicDowncast<const DetectionPostProcessDescriptor*>
442 (&descriptor)),
443 reasonIfUnsupported);
444 case LayerType::FakeQuantization:
445 return IsFakeQuantizationSupported(infos[0],
446 *(PolymorphicDowncast<const FakeQuantizationDescriptor*>(&descriptor)),
447 reasonIfUnsupported);
448 case LayerType::MemCopy:
449 return IsMemCopySupported(infos[0], infos[1], reasonIfUnsupported);
450 case LayerType::Rank:
451 return IsRankSupported(infos[0], infos[1], reasonIfUnsupported);
452 case LayerType::Shape:
453 return IsShapeSupported(infos[0], infos[1], reasonIfUnsupported);
454 case LayerType::UnidirectionalSequenceLstm:
455 {
456 if (infos.size() != 6)
457 {
458 throw InvalidArgumentException("Invalid number of UnidirectionalSequenceLstm TensorInfos. TensorInfos "
459 "should be of format: {input, outputStateIn, cellStateIn, "
460 "hiddenStateOutputVal, cellStateOutputVal, output}");
461 }
462 auto desc = *(PolymorphicDowncast<const UnidirectionalSequenceLstmDescriptor*>(&descriptor));
463
464 bool isHiddenStateOutputOptional = (infos[4] == TensorInfo());
465 bool isCellStateOutput = (infos[5] == TensorInfo());
466 if (isHiddenStateOutputOptional && isCellStateOutput)
467 {
468 return IsUnidirectionalSequenceLstmSupported(infos[0],
469 infos[1],
470 infos[2],
471 infos[3],
472 EmptyOptional(),
473 EmptyOptional(),
474 desc,
475 lstmParamsInfo.value(),
476 reasonIfUnsupported);
477 }
478 else if (isHiddenStateOutputOptional)
479 {
480 return IsUnidirectionalSequenceLstmSupported(infos[0],
481 infos[1],
482 infos[2],
483 infos[3],
484 EmptyOptional(),
485 infos[5],
486 desc,
487 lstmParamsInfo.value(),
488 reasonIfUnsupported);
489 }
490 else if (isCellStateOutput)
491 {
492 return IsUnidirectionalSequenceLstmSupported(infos[0],
493 infos[1],
494 infos[2],
495 infos[3],
496 infos[4],
497 EmptyOptional(),
498 desc,
499 lstmParamsInfo.value(),
500 reasonIfUnsupported);
501 }
502 else
503 {
504 return IsUnidirectionalSequenceLstmSupported(infos[0],
505 infos[1],
506 infos[2],
507 infos[3],
508 infos[4],
509 infos[5],
510 desc,
511 lstmParamsInfo.value(),
512 reasonIfUnsupported);
513 }
514 }
515 case LayerType::Pooling3d:
516 return IsPooling3dSupported(infos[0],
517 infos[1],
518 *(PolymorphicDowncast<const Pooling3dDescriptor*>(&descriptor)),
519 reasonIfUnsupported);
520 case LayerType::Map:
521 return true;
522 case LayerType::Unmap:
523 return true;
524 case LayerType::MemImport:
525 return LayerSupportBase::IsMemImportSupported(infos[0], infos[1], reasonIfUnsupported);
526 case LayerType::Merge:
527 return LayerSupportBase::IsMergeSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
528 case LayerType::QuantizedLstm:
529 return LayerSupportBase::IsQuantizedLstmSupported(infos[0],
530 infos[1],
531 infos[2],
532 infos[3],
533 infos[4],
534 quantizedLstmInputParamsInfo.value(),
535 reasonIfUnsupported);
536 default:
537 // layers not supported in neon by default:
538 // precompiled, standin, switch
539 return false;
540 }
541}
542
arovir011c7c81b2018-10-08 11:34:28 +0100543bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
544 const TensorInfo& output,
545 const ActivationDescriptor& descriptor,
546 Optional<std::string&> reasonIfUnsupported) const
547{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000548 bool supported = true;
549
550 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +0000551 std::array<DataType,6> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000552 DataType::BFloat16,
Derek Lamberti50db4e82019-03-13 14:16:15 +0000553 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100554 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000555 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000556 DataType::QAsymmU8,
557 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000558 };
559
560 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
561 "Reference activation: input type not supported.");
562
563 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
564 "Reference activation: output type not supported.");
565
566 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
567 "Reference activation: input and output types mismatched.");
568
569 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
570 "Reference activation: input and output shapes are of different rank.");
571
572
573 struct ActivationFunctionSupported : public Rule
574 {
575 ActivationFunctionSupported(const ActivationDescriptor& desc)
576 {
577 switch(desc.m_Function)
578 {
579 case ActivationFunction::Abs:
580 case ActivationFunction::BoundedReLu:
David Monahan3b3c3812020-02-25 09:03:29 +0000581 case ActivationFunction::Elu:
Colm Donelan03fbeaf2020-02-26 15:39:23 +0000582 case ActivationFunction::HardSwish:
Derek Lamberti50db4e82019-03-13 14:16:15 +0000583 case ActivationFunction::LeakyReLu:
584 case ActivationFunction::Linear:
585 case ActivationFunction::ReLu:
586 case ActivationFunction::Sigmoid:
587 case ActivationFunction::SoftReLu:
588 case ActivationFunction::Sqrt:
589 case ActivationFunction::Square:
590 case ActivationFunction::TanH:
591 {
592 m_Res = true;
593 break;
594 }
595 default:
596 {
597 m_Res = false;
598 break;
599 }
600 }
601 }
602 };
603
604 // Function is supported
605 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
606 "Reference activation: function not supported.");
607
608 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100609}
610
611bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
612 const TensorInfo& input1,
613 const TensorInfo& output,
614 Optional<std::string&> reasonIfUnsupported) const
615{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000616 bool supported = true;
617
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100618 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000619 DataType::BFloat16,
Derek Lamberti50db4e82019-03-13 14:16:15 +0000620 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100621 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000622 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000623 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100624 DataType::QSymmS16,
625 DataType::Signed32
Derek Lamberti50db4e82019-03-13 14:16:15 +0000626 };
627
628 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
629 "Reference addition: input 0 is not a supported type.");
630
631 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
632 "Reference addition: input 1 is not a supported type.");
633
634 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
635 "Reference addition: output is not a supported type.");
636
637 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
638 "Reference addition: input 0 and Input 1 types are mismatched");
639
640 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
641 "Reference addition: input and output types are mismatched");
642
643 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
644 "Reference addition: shapes are not suitable for implicit broadcast.");
645
646 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100647}
648
Nikhil Raj68c2c902019-09-19 11:21:11 +0100649bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
650 const armnn::ArgMinMaxDescriptor &descriptor,
651 armnn::Optional<std::string &> reasonIfUnsupported) const
652{
Jan Eilers8eb25602020-03-09 12:13:48 +0000653 IgnoreUnused(descriptor);
Nikhil Raj68c2c902019-09-19 11:21:11 +0100654
Mike Kelly1f140f72021-04-06 12:25:55 +0100655 std::array<DataType, 8> supportedInputTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100656 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000657 DataType::BFloat16,
Teresa Charline300b362020-05-25 10:01:03 +0100658 DataType::Float16,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100659 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100660 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000661 DataType::QAsymmU8,
662 DataType::QSymmS16,
Mike Kelly1f140f72021-04-06 12:25:55 +0100663 DataType::Signed32,
664 DataType::Signed64
665 };
666
667 std::array<DataType,2> supportedOutputTypes = {
668 DataType::Signed32,
669 DataType::Signed64
Nikhil Raj68c2c902019-09-19 11:21:11 +0100670 };
671
672 bool supported = true;
673
Mike Kelly1f140f72021-04-06 12:25:55 +0100674 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100675 "Reference ArgMinMax: input is not a supported type.");
Mike Kelly1f140f72021-04-06 12:25:55 +0100676 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100677 "Reference ArgMinMax: output type not supported");
678
679 return supported;
680}
681
arovir011c7c81b2018-10-08 11:34:28 +0100682bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
683 const TensorInfo& output,
684 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100685 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100686 const TensorInfo& beta,
687 const TensorInfo& gamma,
688 const BatchNormalizationDescriptor& descriptor,
689 Optional<std::string&> reasonIfUnsupported) const
690{
Jan Eilers8eb25602020-03-09 12:13:48 +0000691 IgnoreUnused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100692
Sadik Armagan303980c2020-04-17 12:45:14 +0100693 std::array<DataType, 6> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100694 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000695 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100696 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100697 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100698 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000699 DataType::QAsymmU8,
700 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100701 };
702
703 bool supported = true;
704
705 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
706 "Reference batch normalization: input is not a supported type.");
707
708 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
709 "Reference batch normalization: output is not a supported type.");
710
711 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
712 "Reference batch normalization: input and output types are mismatched");
713
714 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
715 "Reference batch normalization: mean is not a supported type.");
716
717 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
718 "Reference batch normalization: variance is not a supported type.");
719
720 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
721 "Reference batch normalization: beta is not a supported type.");
722
723 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
724 "Reference batch normalization: gamma is not a supported type.");
725
726 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100727}
728
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000729bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
730 const TensorInfo& output,
731 const BatchToSpaceNdDescriptor& descriptor,
732 Optional<std::string&> reasonIfUnsupported) const
733{
Jan Eilers8eb25602020-03-09 12:13:48 +0000734 IgnoreUnused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100735
736 bool supported = true;
737
738 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
739 std::string inputTensorStr = "input";
740 std::string outputTensorStr = "output";
741
742 // Define supported types.
Sadik Armagan303980c2020-04-17 12:45:14 +0100743 std::array<DataType,6> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100744 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000745 DataType::BFloat16,
746 DataType::Float32,
747 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100748 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000749 DataType::QAsymmU8,
750 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100751 };
752
753 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
754 "Reference BatchToSpaceNd: input type not supported.");
755
756 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
757 "Reference BatchToSpaceNd: output type not supported.");
758
759 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
760 "Reference BatchToSpaceNd: input and output types mismatched.");
761
762 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
763 reasonIfUnsupported,
764 CreateIncorrectDimensionsErrorMsg(4,
765 output.GetNumDimensions(),
766 batchToSpaceNdLayerStr,
767 outputTensorStr).data());
768
769 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
770 reasonIfUnsupported,
771 CreateIncorrectDimensionsErrorMsg(4,
772 input.GetNumDimensions(),
773 batchToSpaceNdLayerStr,
774 inputTensorStr).data());
775
776 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000777}
778
mathad01b392e982021-04-07 12:07:30 +0100779bool RefLayerSupport::IsCastSupported(const TensorInfo& input,
780 const TensorInfo& output,
781 Optional<std::string&> reasonIfUnsupported) const
782{
783 std::array<DataType, 9> supportedInputTypes =
784 {
785 DataType::BFloat16,
786 DataType::Float32,
787 DataType::Float16,
788 DataType::QSymmS8,
789 DataType::QAsymmS8,
790 DataType::QAsymmU8,
791 DataType::QSymmS16,
792 DataType::Signed32
793 };
794
795 bool supported = true;
796 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
797 "Reference cast: input is not a supported type");
798
799
800 supported &= CheckSupportRule(TypeAnyOf(output, supportedInputTypes), reasonIfUnsupported,
801 "Reference cast: output is not a supported type");
802
803 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
804 "Reference cast: input and output shapes have different number of total elements");
805
806 return supported;
807}
808
Simon Obute51f67772021-09-03 15:50:13 +0100809bool RefLayerSupport::IsChannelShuffleSupported(const TensorInfo& input,
810 const TensorInfo& output,
811 const ChannelShuffleDescriptor& descriptor,
812 Optional<std::string&> reasonIfUnsupported) const
813{
814 IgnoreUnused(descriptor);
815 bool supported = true;
816
817 // Define supported output and inputs types.
818 std::array<DataType, 7> supportedTypes =
819 {
820 DataType::BFloat16,
821 DataType::Float32,
822 DataType::Float16,
823 DataType::QAsymmS8,
824 DataType::QAsymmU8,
825 DataType::QSymmS8,
826 DataType::QSymmS16
827 };
828
829 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
830 "Reference ChannelShuffle: input is not a supported type.");
831
832 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
833 "Reference ChannelShuffle: output is not a supported type.");
834
835 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
836 "Reference ChannelShuffle: input and output types are mismatched.");
837
838 return supported;
839}
840
841
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100842bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
843 const TensorInfo& input1,
844 const TensorInfo& output,
845 const ComparisonDescriptor& descriptor,
846 Optional<std::string&> reasonIfUnsupported) const
847{
Jan Eilers8eb25602020-03-09 12:13:48 +0000848 IgnoreUnused(descriptor);
Sadik Armagan303980c2020-04-17 12:45:14 +0100849 std::array<DataType, 8> supportedInputTypes =
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100850 {
Sadik Armaganb60dd242020-03-19 13:53:16 +0000851 DataType::Boolean,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000852 DataType::BFloat16,
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100853 DataType::Float32,
854 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100855 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000856 DataType::QAsymmU8,
Sadik Armaganb60dd242020-03-19 13:53:16 +0000857 DataType::QSymmS16,
858 DataType::Signed32
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100859 };
860
861 bool supported = true;
862 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
863 "Reference comparison: input 0 is not a supported type");
864
865 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
866 "Reference comparison: input 0 and Input 1 types are mismatched");
867
868 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
869 "Reference comparison: output is not of type Boolean");
870
871 return supported;
872}
873
Jim Flynn906f9462019-05-10 13:55:21 +0100874bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
875 const TensorInfo& output,
Cathal Corbett34b429c2021-12-24 12:24:40 +0000876 const OriginsDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100877 Optional<std::string&> reasonIfUnsupported) const
878{
Jan Eilers8eb25602020-03-09 12:13:48 +0000879 IgnoreUnused(descriptor);
Jim Flynne242f2d2019-05-22 14:24:13 +0100880
881 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000882 std::array<DataType,6> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100883 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000884 DataType::BFloat16,
885 DataType::Float32,
886 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000887 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100888 DataType::QAsymmU8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000889 DataType::QSymmS16
Jim Flynne242f2d2019-05-22 14:24:13 +0100890 };
891
892 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
893 "Reference concatenation: output type not supported");
894 for (const TensorInfo* input : inputs)
895 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100896 ARMNN_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100897 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
898 "Reference concatenation: input type not supported");
899
900 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
901 "Reference concatenation: input and output types mismatched.");
902 }
903
904 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100905}
906
arovir011c7c81b2018-10-08 11:34:28 +0100907bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
908 Optional<std::string&> reasonIfUnsupported) const
909{
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100910 std::array<DataType,8> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100911 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000912 DataType::BFloat16,
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100913 DataType::Float16,
Nina Drozd58ef2c62019-05-16 12:09:18 +0100914 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +0000915 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100916 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000917 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100918 DataType::QSymmS16,
919 DataType::Signed32
Nina Drozd58ef2c62019-05-16 12:09:18 +0100920 };
921
922 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
923 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100924}
925
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000926bool RefLayerSupport::IsConvertBf16ToFp32Supported(const TensorInfo& input,
927 const TensorInfo& output,
928 Optional<std::string&> reasonIfUnsupported) const
929{
930 bool supported = true;
931
932 supported &= CheckSupportRule(TypeIs(input, DataType::BFloat16), reasonIfUnsupported,
933 "Reference for ConvertBf16ToFp32 layer: input type not supported");
934
935 supported &= CheckSupportRule(TypeIs(output, DataType::Float32), reasonIfUnsupported,
936 "Reference for ConvertBf16ToFp32 layer: output type not supported");
937
938 return supported;
939}
940
arovir011c7c81b2018-10-08 11:34:28 +0100941bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
942 const TensorInfo& output,
943 Optional<std::string&> reasonIfUnsupported) const
944{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100945 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
946 input.GetDataType(),
947 &TrueFunc<>,
948 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000949 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000950 &FalseFuncI32<>,
951 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100952 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
953 output.GetDataType(),
954 &FalseOutputFuncF16<>,
955 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000956 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000957 &FalseFuncI32<>,
958 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100959}
960
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000961bool RefLayerSupport::IsConvertFp32ToBf16Supported(const TensorInfo& input,
962 const TensorInfo& output,
963 Optional<std::string&> reasonIfUnsupported) const
964{
965 bool supported = true;
966
967 supported &= CheckSupportRule(TypeIs(input, DataType::Float32), reasonIfUnsupported,
968 "Reference for ConvertFp32ToBf16 layer: input type not supported");
969
970 supported &= CheckSupportRule(TypeIs(output, DataType::BFloat16), reasonIfUnsupported,
971 "Reference for ConvertFp32ToBf16 layer: output type not supported");
972
973 return supported;
974}
975
arovir011c7c81b2018-10-08 11:34:28 +0100976bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
977 const TensorInfo& output,
978 Optional<std::string&> reasonIfUnsupported) const
979{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100980 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
981 input.GetDataType(),
982 &FalseInputFuncF16<>,
983 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000984 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000985 &FalseFuncI32<>,
986 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100987 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
988 output.GetDataType(),
989 &TrueFunc<>,
990 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000991 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000992 &FalseFuncI32<>,
993 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100994}
995
996bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
997 const TensorInfo& output,
998 const Convolution2dDescriptor& descriptor,
999 const TensorInfo& weights,
1000 const Optional<TensorInfo>& biases,
1001 Optional<std::string&> reasonIfUnsupported) const
1002{
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001003 bool supported = true;
1004
1005 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001006 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001007 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001008 DataType::BFloat16,
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001009 DataType::Float32,
1010 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001011 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001012 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001013 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001014 DataType::QSymmS16
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001015 };
1016
1017 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001018 "Reference Convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001019
1020 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001021 "Reference Convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001022
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001023 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
1024 if (input.GetDataType() == DataType::BFloat16)
1025 {
1026 if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
1027 {
1028 reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
1029 supported = false;
1030 }
1031 }
1032 else
1033 {
1034 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001035 "Reference Convolution2d: input and output types mismatched.");
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001036 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001037
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001038 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +00001039 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001040 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001041 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001042 {
Sadik Armagan303980c2020-04-17 12:45:14 +01001043 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001044 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01001045 DataType::QSymmS8
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001046 };
1047
1048 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001049 "Reference Convolution2d: weights type not supported for quantized input.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001050 }
1051 else
1052 {
1053 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001054 "Reference Convolution2d: weights is not a supported type.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001055
1056 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001057 "Reference Convolution2d: input and weights types mismatched.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001058 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001059
1060 if (biases.has_value())
1061 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001062 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001063 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001064 DataType::BFloat16,
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001065 DataType::Float32,
1066 DataType::Float16,
1067 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001068 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001069
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001070 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001071 "Reference Convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001072 }
Jan Eilers8eb25602020-03-09 12:13:48 +00001073 IgnoreUnused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001074
1075 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001076}
1077
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001078bool RefLayerSupport::IsConvolution3dSupported(const TensorInfo& input,
1079 const TensorInfo& output,
1080 const Convolution3dDescriptor& descriptor,
1081 const TensorInfo& weights,
1082 const Optional<TensorInfo>& biases,
1083 Optional<std::string&> reasonIfUnsupported) const
1084{
1085 bool supported = true;
1086
1087 // Define supported types.
1088 std::array<DataType,7> supportedTypes =
1089 {
1090 DataType::BFloat16,
1091 DataType::Float32,
1092 DataType::Float16,
1093 DataType::QAsymmS8,
1094 DataType::QAsymmU8,
1095 DataType::QSymmS8,
1096 DataType::QSymmS16
1097 };
1098
1099 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1100 "Reference Convolution3d: input is not a supported type.");
1101
1102 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1103 "Reference Convolution3d: output is not a supported type.");
1104
1105 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1106 "Reference Convolution3d: input and output types mismatched.");
1107
1108 const DataType inputType = input.GetDataType();
1109 if (IsQuantized8BitType(inputType))
1110 {
1111 std::array<DataType, 3> supportedWeightTypes =
1112 {
1113 DataType::QAsymmS8,
1114 DataType::QAsymmU8,
1115 DataType::QSymmS8
1116 };
1117
1118 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1119 "Reference Convolution3d: weights type not supported for quantized input.");
1120 }
1121 else
1122 {
1123 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1124 "Reference Convolution3d: weights is not a supported type.");
1125
1126 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1127 "Reference Convolution3d: input and weights types mismatched.");
1128 }
1129
1130 if (biases.has_value())
1131 {
1132 std::array<DataType,4> biasesSupportedTypes =
1133 {
1134 DataType::BFloat16,
1135 DataType::Float32,
1136 DataType::Float16,
1137 DataType::Signed32
1138 };
1139
1140 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1141 "Reference Convolution3d: biases is not a supported type.");
1142 }
1143 IgnoreUnused(descriptor);
1144
1145 return supported;
1146}
1147
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001148bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
1149 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001150 Optional<std::string&> reasonIfUnsupported) const
1151{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001152 bool supported = true;
1153
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001154 std::array<DataType, 8> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001155 {
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001156 DataType::BFloat16,
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +00001157 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001158 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001159 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001160 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001161 DataType::QSymmS8,
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +00001162 DataType::QSymmS16,
1163 DataType::Signed32
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001164 };
1165
1166 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001167 "Reference for Debug layer: input type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001168
1169 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001170 "Reference for Debug layer: output type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001171
1172 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001173 "Reference for Debug layer: input and output types are mismatched");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001174
1175 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001176}
1177
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001178bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
1179 const TensorInfo& output,
1180 const DepthToSpaceDescriptor& descriptor,
1181 Optional<std::string&> reasonIfUnsupported) const
1182{
Jan Eilers8eb25602020-03-09 12:13:48 +00001183 IgnoreUnused(descriptor);
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001184 bool supported = true;
1185
Sadik Armagan303980c2020-04-17 12:45:14 +01001186 std::array<DataType,6> supportedTypes =
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001187 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001188 DataType::BFloat16,
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001189 DataType::Float32,
1190 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001191 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001192 DataType::QAsymmU8,
1193 DataType::QSymmS16
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001194 };
1195
1196 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1197 "Reference DepthToSpace: input type not supported");
1198
1199 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1200 "Reference DepthToSpace: output type not supported");
1201
1202 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1203 "Reference DepthToSpace: input and output types are mismatched");
1204
1205 return supported;
1206}
1207
arovir011c7c81b2018-10-08 11:34:28 +01001208bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
1209 const TensorInfo& output,
1210 const DepthwiseConvolution2dDescriptor& descriptor,
1211 const TensorInfo& weights,
1212 const Optional<TensorInfo>& biases,
1213 Optional<std::string&> reasonIfUnsupported) const
1214{
Sadik Armagan303980c2020-04-17 12:45:14 +01001215 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001216 bool supported = true;
1217
1218 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001219 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001220 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001221 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001222 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001223 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001224 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001225 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001226 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001227 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001228 };
1229
1230 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1231 "Reference DepthwiseConvolution2d: input is not a supported type.");
1232
1233 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1234 "Reference DepthwiseConvolution2d: output is not a supported type.");
1235
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001236 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1237 "Reference DepthwiseConvolution2d: input and output types mismatched.");
1238
Teresa Charlind8df0262019-11-11 12:28:15 +00001239 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +00001240 if (IsQuantized8BitType(inputType))
Teresa Charlind8df0262019-11-11 12:28:15 +00001241 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001242 std::array<DataType, 3> supportedWeightTypes =
Sadik Armagan303980c2020-04-17 12:45:14 +01001243 {
1244 DataType::QAsymmS8,
1245 DataType::QAsymmU8,
1246 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001247 };
Teresa Charlind8df0262019-11-11 12:28:15 +00001248
1249 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Sadik Armagan303980c2020-04-17 12:45:14 +01001250 "Reference DepthwiseConvolution2d: weights type not supported for "
1251 "quantized input.");
Teresa Charlind8df0262019-11-11 12:28:15 +00001252 }
1253 else
1254 {
1255 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1256 "Reference DepthwiseConvolution2d: weights is not a supported type.");
1257
1258 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1259 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
1260 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001261
1262 if (biases.has_value())
1263 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001264 std::array<DataType,4> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001265 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001266 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001267 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001268 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001269 DataType::Signed32
1270 };
1271 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1272 "Reference DepthwiseConvolution2d: biases is not a supported type.");
1273 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001274
1275 return supported;
1276
arovir011c7c81b2018-10-08 11:34:28 +01001277}
1278
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001279bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
1280 const TensorInfo& output,
1281 Optional<std::string&> reasonIfUnsupported) const
1282{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001283 bool supported = true;
1284
Ryan OShea9add1202020-02-07 10:06:33 +00001285 std::array<DataType,4> supportedInputTypes = {
1286 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001287 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00001288 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001289 DataType::QSymmS16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001290 };
1291
1292 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001293 "Reference for Dequantize layer: input type not supported.");
1294
Derek Lambertid466a542020-01-22 15:37:29 +00001295 supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
Teresa Charlin1b1950d2021-06-02 20:23:21 +01001296 "Reference for Dequantize layer: per-axis quantized input not supported.");
Derek Lambertid466a542020-01-22 15:37:29 +00001297
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001298 std::array<DataType,3> supportedOutputTypes = {
1299 DataType::BFloat16,
Jan Eilersf7107932019-11-01 11:09:36 +00001300 DataType::Float32,
1301 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001302 };
1303
1304 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001305 "Reference for Dequantize layer: output type not supported.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001306
1307 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001308 "Reference for Dequantize layer: input/output shapes have different num total "
1309 "elements.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001310
1311 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001312}
1313
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001314bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
1315 const TensorInfo& scores,
1316 const TensorInfo& anchors,
1317 const TensorInfo& detectionBoxes,
1318 const TensorInfo& detectionClasses,
1319 const TensorInfo& detectionScores,
1320 const TensorInfo& numDetections,
1321 const DetectionPostProcessDescriptor& descriptor,
1322 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001323{
Jan Eilers8eb25602020-03-09 12:13:48 +00001324 IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
Derek Lamberti901ea112019-12-10 22:07:09 +00001325
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001326 bool supported = true;
1327
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001328 std::array<DataType,6> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001329 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001330 DataType::BFloat16,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001331 DataType::Float32,
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001332 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001333 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001334 DataType::QAsymmU8,
1335 DataType::QSymmS16
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001336 };
1337
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001338 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001339 "Reference DetectionPostProcess: input 0 is not a supported type.");
1340
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001341 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001342 "Reference DetectionPostProcess: input 1 is not a supported type.");
1343
1344 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001345}
1346
Pablo Tellof0bd6832019-04-26 17:58:13 +01001347bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
1348 const TensorInfo& output,
1349 const DepthwiseConvolution2dDescriptor& descriptor,
1350 const TensorInfo& weights,
1351 const Optional<TensorInfo>& biases,
1352 Optional<std::string&> reasonIfUnsupported) const
1353{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001354 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +01001355}
1356
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001357bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +01001358 const TensorInfo& input1,
1359 const TensorInfo& output,
1360 Optional<std::string&> reasonIfUnsupported) const
1361{
Sadik Armagan2999a022019-04-09 14:20:12 +01001362 bool supported = true;
1363
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001364 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001365 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001366 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001367 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001368 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001369 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001370 DataType::QSymmS16,
1371 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001372 };
1373
1374 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1375 "Reference division: input 0 is not a supported type.");
1376
1377 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1378 "Reference division: input 1 is not a supported type.");
1379
1380 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1381 "Reference division: output is not a supported type.");
1382
1383 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1384 "Reference division: input 0 and Input 1 types are mismatched");
1385
1386 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1387 "Reference division: input and output types are mismatched");
1388
1389 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1390 "Reference division: shapes are not suitable for implicit broadcast.");
1391
1392 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001393}
1394
josh minor4a3c6102020-01-06 16:40:46 -06001395bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
1396 const TensorInfo& output,
1397 const ElementwiseUnaryDescriptor& descriptor,
1398 Optional<std::string&> reasonIfUnsupported) const
1399{
Jan Eilers8eb25602020-03-09 12:13:48 +00001400 IgnoreUnused(descriptor);
josh minor4a3c6102020-01-06 16:40:46 -06001401
Sadik Armagan303980c2020-04-17 12:45:14 +01001402 std::array<DataType, 7> supportedTypes =
josh minor4a3c6102020-01-06 16:40:46 -06001403 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001404 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -06001405 DataType::Float32,
1406 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001407 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06001408 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00001409 DataType::QSymmS16,
1410 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06001411 };
1412
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001413 std::array<DataType, 1> logicalSupportedTypes =
1414 {
1415 DataType::Boolean
1416 };
1417
josh minor4a3c6102020-01-06 16:40:46 -06001418 bool supported = true;
1419
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001420 if (descriptor.m_Operation == UnaryOperation::LogicalNot)
1421 {
1422 supported &= CheckSupportRule(TypeAnyOf(input, logicalSupportedTypes), reasonIfUnsupported,
1423 "Reference elementwise unary: input type not supported");
josh minor4a3c6102020-01-06 16:40:46 -06001424
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001425 supported &= CheckSupportRule(TypeAnyOf(output, logicalSupportedTypes), reasonIfUnsupported,
1426 "Reference elementwise unary: output type not supported");
1427 }
1428 else
1429 {
1430 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1431 "Reference elementwise unary: input type not supported");
1432
1433 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1434 "Reference elementwise unary: output type not supported");
1435 }
josh minor4a3c6102020-01-06 16:40:46 -06001436
1437 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1438 "Reference elementwise unary: input and output types not matching");
1439
1440 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1441 "Reference elementwise unary: input and output shapes"
1442 "have different number of total elements");
1443
1444 return supported;
1445}
1446
arovir011c7c81b2018-10-08 11:34:28 +01001447bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
1448 const FakeQuantizationDescriptor& descriptor,
1449 Optional<std::string&> reasonIfUnsupported) const
1450{
Jan Eilers8eb25602020-03-09 12:13:48 +00001451 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001452 bool supported = true;
1453
1454 std::array<DataType,1> supportedTypes =
1455 {
1456 DataType::Float32
1457 };
1458
1459 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1460 "Reference fake quantization: input type not supported.");
1461
1462 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001463}
1464
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001465bool RefLayerSupport::IsFillSupported(const TensorInfo& input,
1466 const TensorInfo& output,
1467 const FillDescriptor& descriptor,
1468 Optional<std::string&> reasonIfUnsupported) const
1469{
1470 IgnoreUnused(descriptor);
1471 IgnoreUnused(output);
1472
1473 bool supported = true;
1474
Sadik Armagana792a052020-06-23 16:22:23 +01001475 std::array<DataType,3> supportedTypes =
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001476 {
1477 DataType::Float32,
Sadik Armagana792a052020-06-23 16:22:23 +01001478 DataType::Float16,
1479 DataType::Signed32
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001480 };
1481
Teresa Charlin4b10fef2020-07-29 09:36:41 +01001482 supported &= CheckSupportRule(TypeIs(input, DataType::Signed32), reasonIfUnsupported,
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001483 "Reference Fill: input type not supported.");
1484
Teresa Charlin44088502020-07-27 11:27:19 +01001485 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1486 "Reference Fill: output type not supported.");
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001487 return supported;
1488}
1489
arovir011c7c81b2018-10-08 11:34:28 +01001490bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
1491 const TensorInfo& output,
1492 Optional<std::string&> reasonIfUnsupported) const
1493{
Jan Eilers8eb25602020-03-09 12:13:48 +00001494 IgnoreUnused(output);
James Conroy83735b12019-05-30 16:36:59 +01001495 bool supported = true;
1496
Francis Murtaghe8ac1332020-07-30 18:03:40 +01001497 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +01001498 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001499 DataType::BFloat16,
James Conroyb40d7102019-06-04 12:32:09 +01001500 DataType::Float32,
Francis Murtaghe8ac1332020-07-30 18:03:40 +01001501 DataType::Float16
James Conroy83735b12019-05-30 16:36:59 +01001502 };
1503
1504 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1505 "Reference Floor: input type not supported.");
1506
1507 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1508 "Reference Floor: output type not supported.");
1509
1510 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001511}
1512
1513bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
1514 const TensorInfo& output,
1515 const TensorInfo& weights,
1516 const TensorInfo& biases,
1517 const FullyConnectedDescriptor& descriptor,
1518 Optional<std::string&> reasonIfUnsupported) const
1519{
Francis Murtagh46c09d02019-05-28 08:15:28 +01001520 bool supported = true;
1521
1522 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001523 std::array<DataType,6> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +01001524 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001525 DataType::BFloat16,
1526 DataType::Float32,
1527 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001528 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001529 DataType::QAsymmU8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001530 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001531 };
1532
1533 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1534 "Reference Fully Connected: input type not supported.");
1535
1536 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1537 "Reference Fully Connected: output type not supported.");
1538
Francis Murtagh46c09d02019-05-28 08:15:28 +01001539 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1540 "Reference Fully Connected: weights type not supported.");
1541
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001542 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1543 if (input.GetDataType() == DataType::BFloat16)
1544 {
1545 if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
1546 {
1547 reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
1548 supported = false;
1549 }
1550 }
1551 else
1552 {
1553 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1554 "Reference Fully Connected: input and output types mismatched.");
1555 }
1556
Jan Eilers1f45dc32020-06-15 11:43:03 +01001557 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1558 "Reference Fully Connected: weights is not a supported type.");
Francis Murtaghddb1d062020-03-10 13:51:45 +00001559
Jan Eilers1f45dc32020-06-15 11:43:03 +01001560 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1561 "Reference Fully Connected: input and weights types mismatched.");
Francis Murtagh46c09d02019-05-28 08:15:28 +01001562
1563 if (descriptor.m_BiasEnabled)
1564 {
1565 // Defined supported types for bias
Sadik Armagandb73c982020-04-01 17:35:30 +01001566 std::array<DataType, 5>
Francis Murtagh46c09d02019-05-28 08:15:28 +01001567 supportedBiasTypes =
1568 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001569 DataType::BFloat16,
Francis Murtagh46c09d02019-05-28 08:15:28 +01001570 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001571 DataType::Float16,
Sadik Armagandb73c982020-04-01 17:35:30 +01001572 DataType::Signed32,
1573 DataType::QAsymmS8
Francis Murtagh46c09d02019-05-28 08:15:28 +01001574 };
1575
1576 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
1577 "Reference Fully Connected: bias type not supported.");
1578
1579 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
1580 "Reference Fully Connected: bias and weight types mismatch.");
1581
1582 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
1583 "Reference Fully Connected: bias type inferred from weights is incompatible.");
1584
Narumol Prangnawarat366d7232020-04-29 12:58:17 +01001585 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(biases, 1U), reasonIfUnsupported,
1586 "Reference Fully Connected: bias must have 1 dimension.");
1587
Francis Murtagh46c09d02019-05-28 08:15:28 +01001588 }
1589
1590 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001591}
1592
narpra014951d842019-01-18 16:53:53 +00001593bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
1594 const armnn::TensorInfo& input1,
1595 const armnn::TensorInfo& output,
Teresa Charlin52664732020-06-29 16:27:03 +01001596 const GatherDescriptor& descriptor,
narpra014951d842019-01-18 16:53:53 +00001597 armnn::Optional<std::string&> reasonIfUnsupported) const
1598{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001599 bool supported = true;
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001600 std::array<DataType,7> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001601 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001602 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001603 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001604 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001605 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001606 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001607 DataType::QSymmS16,
1608 DataType::Signed32
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001609 };
1610
Teresa Charlin52664732020-06-29 16:27:03 +01001611 if (descriptor.m_Axis != 0)
1612 {
1613 reasonIfUnsupported.value() += std::string("Reference Gather: axis not supported\n");
1614 supported &= false;
1615 }
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001616 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1617 "Reference Gather: input type not supported");
1618
1619 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1620 "Reference Gather: output type not supported");
1621
1622 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1623 "Reference Gather: indices (input1) type not supported");
1624
1625 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1626 "Reference Gather: input and output types not matching");
1627
1628 return supported;
narpra014951d842019-01-18 16:53:53 +00001629}
1630
Derek Lamberti901ea112019-12-10 22:07:09 +00001631bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
1632 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +01001633{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001634 return true;
arovir011c7c81b2018-10-08 11:34:28 +01001635}
1636
Kevin May09ca49c2019-10-09 12:37:34 +01001637bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
1638 const TensorInfo& output,
1639 const InstanceNormalizationDescriptor& descriptor,
1640 Optional<std::string&> reasonIfUnsupported) const
1641{
Jan Eilers8eb25602020-03-09 12:13:48 +00001642 IgnoreUnused(descriptor);
Kevin May09ca49c2019-10-09 12:37:34 +01001643 // Define supported types
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001644 std::array<DataType, 3> supportedTypes =
Kevin May09ca49c2019-10-09 12:37:34 +01001645 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001646 DataType::BFloat16,
Kevin May09ca49c2019-10-09 12:37:34 +01001647 DataType::Float32,
1648 DataType::Float16
1649 };
1650
1651 bool supported = true;
1652
1653 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1654 "Reference Instance Normalization: input type not supported.");
1655
1656 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1657 "Reference Instance Normalization: output type not supported.");
1658
1659 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1660 "Reference Instance Normalization: input and output types mismatched.");
1661
1662 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1663 "Reference Instance Normalization: input and output shapes have different "
1664 "num total elements.");
1665
1666 return supported;
1667}
1668
arovir011c7c81b2018-10-08 11:34:28 +01001669bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1670 const TensorInfo& output,
1671 const L2NormalizationDescriptor& descriptor,
1672 Optional<std::string&> reasonIfUnsupported) const
1673{
Jan Eilers8eb25602020-03-09 12:13:48 +00001674 IgnoreUnused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001675 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01001676 std::array<DataType, 6> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001677 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001678 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001679 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001680 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001681 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001682 DataType::QAsymmU8,
1683 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001684 };
1685
1686 bool supported = true;
1687
1688 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1689 "Reference L2normalization: input type not supported.");
1690
1691 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1692 "Reference L2normalization: output type not supported.");
1693
1694 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1695 "Reference L2normalization: input and output types mismatched.");
1696
1697 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1698 "Reference L2normalization: input and output shapes have different "
1699 "num total elements.");
1700
1701 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001702}
1703
James Conroyaba90cd2020-11-06 16:28:18 +00001704bool RefLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0,
1705 const TensorInfo& input1,
1706 const TensorInfo& output,
1707 const LogicalBinaryDescriptor& descriptor,
1708 Optional<std::string&> reasonIfUnsupported) const
1709{
1710 IgnoreUnused(descriptor);
1711
1712 std::array<DataType, 1> supportedTypes =
1713 {
1714 DataType::Boolean
1715 };
1716
1717 bool supported = true;
1718 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1719 "Reference LogicalBinary: input 0 type not supported");
1720 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1721 "Reference LogicalBinary: input 1 type not supported");
1722
1723 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1724 "Reference LogicalBinary: input and output types do not match");
1725
1726 return supported;
1727}
1728
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001729bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1730 const TensorInfo& output,
1731 const LogSoftmaxDescriptor& descriptor,
1732 Optional<std::string&> reasonIfUnsupported) const
1733{
Jan Eilers8eb25602020-03-09 12:13:48 +00001734 IgnoreUnused(descriptor);
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001735
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001736 std::array<DataType, 3> supportedTypes =
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001737 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001738 DataType::BFloat16,
1739 DataType::Float32,
1740 DataType::Float16
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001741 };
1742
1743 bool supported = true;
1744 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1745 "Reference LogSoftmax: input type not supported");
1746
1747 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1748 "Reference LogSoftmax: output type not supported");
1749
1750 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1751 "Reference LogSoftmax: input and output types do not match");
1752
1753 return supported;
1754}
1755
arovir011c7c81b2018-10-08 11:34:28 +01001756bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1757 const TensorInfo& outputStateIn,
1758 const TensorInfo& cellStateIn,
1759 const TensorInfo& scratchBuffer,
1760 const TensorInfo& outputStateOut,
1761 const TensorInfo& cellStateOut,
1762 const TensorInfo& output,
1763 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001764 const LstmInputParamsInfo& paramsInfo,
1765 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +01001766{
Jan Eilers8eb25602020-03-09 12:13:48 +00001767 IgnoreUnused(descriptor);
1768 IgnoreUnused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001769
1770 bool supported = true;
1771
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001772 std::array<DataType,3> supportedTypes = {
1773 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001774 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001775 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001776 };
1777
Jan Eilersd01a83c2019-07-03 18:20:40 +01001778 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001779 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1780 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001781 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1782 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001783 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1784 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001785 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1786 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001787 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1788 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001789 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1790 "Reference Lstm: input and cellStateOut types are mismatched");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01001791
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001792 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1793 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001794 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001795 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001796 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001797 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001798 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001799 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001800 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001801 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001802 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001803 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001804 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001805 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001806 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001807 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001808 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001809 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001810 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001811 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001812 "Reference Lstm: input and OutputGateBias types are mismatched");
1813 if (!descriptor.m_CifgEnabled)
1814 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001815 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001816 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001817 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001818 reasonIfUnsupported,
1819 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001820 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001821 "Reference Lstm: input and InputGateBias types are mismatched");
1822 if (descriptor.m_PeepholeEnabled)
1823 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001824 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001825 reasonIfUnsupported,
1826 "Reference Lstm: input and CellToInputWeights types are mismatched");
1827 }
1828 }
1829 if (descriptor.m_PeepholeEnabled)
1830 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001831 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001832 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001833 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001834 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1835 }
1836 if (descriptor.m_ProjectionEnabled)
1837 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001838 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001839 "Reference Lstm: input and mProjectionWeights types are mismatched");
1840 if (paramsInfo.m_ProjectionBias != nullptr)
1841 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001842 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001843 "Reference Lstm: input and ProjectionBias types are mismatched");
1844 }
1845 }
1846 if (descriptor.m_LayerNormEnabled)
1847 {
1848 if (!descriptor.m_CifgEnabled)
1849 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001850 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001851 reasonIfUnsupported,
1852 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1853 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001854 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001855 reasonIfUnsupported,
1856 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001857 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001858 reasonIfUnsupported,
1859 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001860 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001861 reasonIfUnsupported,
1862 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1863 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001864
1865 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001866}
1867
saoste012df12b32018-11-28 16:57:20 +00001868bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1869 const TensorInfo& input1,
1870 const TensorInfo& output,
1871 Optional<std::string&> reasonIfUnsupported) const
1872{
Sadik Armagan2999a022019-04-09 14:20:12 +01001873 bool supported = true;
1874
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001875 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001876 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001877 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001878 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001879 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001880 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001881 DataType::QSymmS16,
1882 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001883 };
1884
1885 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1886 "Reference maximum: input 0 is not a supported type.");
1887
1888 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1889 "Reference maximum: input 1 is not a supported type.");
1890
1891 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1892 "Reference maximum: output is not a supported type.");
1893
1894 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1895 "Reference maximum: input 0 and Input 1 types are mismatched");
1896
1897 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1898 "Reference maximum: input and output types are mismatched");
1899
1900 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1901 "Reference maximum: shapes are not suitable for implicit broadcast.");
1902
1903 return supported;
saoste012df12b32018-11-28 16:57:20 +00001904}
1905
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001906bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1907 const TensorInfo& output,
1908 const MeanDescriptor& descriptor,
1909 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001910{
James Conroy4d1ff582019-06-10 17:06:39 +01001911 bool supported = true;
1912 std::string meanLayerStr = "Mean";
1913 std::string outputTensorStr = "output";
1914
Sadik Armagan303980c2020-04-17 12:45:14 +01001915 std::array<DataType,6> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001916 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001917 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01001918 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001919 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001920 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001921 DataType::QAsymmU8,
1922 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01001923 };
1924
1925 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1926 "Reference Mean: input type not supported.");
1927
1928 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1929 "Reference Mean: input and output types are mismatched");
1930
1931 if (descriptor.m_KeepDims)
1932 {
1933 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1934 reasonIfUnsupported,
1935 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1936 output.GetNumDimensions(),
1937 meanLayerStr, outputTensorStr).data());
1938 }
1939 else if (descriptor.m_Axis.empty())
1940 {
1941 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1942 reasonIfUnsupported,
1943 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1944 meanLayerStr, outputTensorStr).data());
1945 }
1946 else
1947 {
Matthew Sloyan171214c2020-09-09 09:07:37 +01001948 auto outputDim = input.GetNumDimensions() - armnn::numeric_cast<unsigned int>(descriptor.m_Axis.size());
James Conroy4d1ff582019-06-10 17:06:39 +01001949
1950 if (outputDim > 0)
1951 {
1952 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1953 reasonIfUnsupported,
1954 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1955 meanLayerStr, outputTensorStr).data());
1956 }
1957 else
1958 {
1959 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1960 reasonIfUnsupported,
1961 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1962 meanLayerStr, outputTensorStr).data());
1963 }
1964 }
1965
1966 return supported;
narpra0132b90462018-09-13 11:07:48 +01001967}
1968
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001969bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1970 const TensorInfo &output,
1971 Optional<std::string &> reasonIfUnsupported) const
1972{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001973 bool supported = true;
1974
Sadik Armagan303980c2020-04-17 12:45:14 +01001975 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001976 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001977 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001978 DataType::Float32,
1979 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001980 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001981 DataType::QAsymmU8,
1982 DataType::QSymmS16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001983 DataType::Boolean
1984 };
1985
1986 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1987 "Reference MemCopy: input type not supported");
1988
1989 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1990 "Reference MemCopy: output type not supported");
1991
1992 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1993 "Reference MemCopy: input and output types are mismatched");
1994
1995 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001996}
1997
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001998bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1999 const TensorInfo& input1,
2000 const TensorInfo& output,
2001 Optional<std::string&> reasonIfUnsupported) const
2002{
Sadik Armagan2999a022019-04-09 14:20:12 +01002003 bool supported = true;
2004
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002005 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002006 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01002007 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002008 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002009 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002010 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002011 DataType::QSymmS16,
2012 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002013 };
2014
2015 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2016 "Reference minimum: input 0 is not a supported type.");
2017
2018 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2019 "Reference minimum: input 1 is not a supported type.");
2020
2021 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2022 "Reference minimum: output is not a supported type.");
2023
2024 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2025 "Reference minimum: input 0 and Input 1 types are mismatched");
2026
2027 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2028 "Reference minimum: input and output types are mismatched");
2029
2030 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2031 "Reference minimum: shapes are not suitable for implicit broadcast.");
2032
2033 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00002034}
2035
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002036bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
2037 const TensorInfo& input1,
2038 const TensorInfo& output,
2039 Optional<std::string&> reasonIfUnsupported) const
2040{
Sadik Armagan2999a022019-04-09 14:20:12 +01002041 bool supported = true;
2042
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002043 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002044 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01002045 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002046 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002047 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002048 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002049 DataType::QSymmS16,
2050 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002051 };
2052
2053 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2054 "Reference multiplication: input 0 is not a supported type.");
2055
2056 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2057 "Reference multiplication: input 1 is not a supported type.");
2058
2059 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2060 "Reference multiplication: output is not a supported type.");
2061
2062 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2063 "Reference multiplication: input 0 and Input 1 types are mismatched");
2064
2065 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2066 "Reference multiplication: input and output types are mismatched");
2067
2068 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2069 "Reference multiplication: shapes are not suitable for implicit broadcast.");
2070
2071 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002072}
2073
2074bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
2075 const TensorInfo& output,
2076 const NormalizationDescriptor& descriptor,
2077 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01002078{
Jan Eilers8eb25602020-03-09 12:13:48 +00002079 IgnoreUnused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002080
2081 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01002082 std::array<DataType, 6> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002083 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002084 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002085 DataType::Float16,
2086 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002087 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002088 DataType::QAsymmU8,
2089 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002090 };
2091
2092 bool supported = true;
2093
2094 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2095 "Reference normalization: input type not supported.");
2096
2097 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2098 "Reference normalization: output type not supported.");
2099
2100 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2101 "Reference normalization: input and output shapes have different "
2102 "num total elements.");
2103
2104 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002105}
2106
Derek Lamberti901ea112019-12-10 22:07:09 +00002107bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
2108 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002109{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01002110 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002111}
2112
2113bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
2114 const TensorInfo& output,
2115 const PadDescriptor& descriptor,
2116 Optional<std::string&> reasonIfUnsupported) const
2117{
Jan Eilers8eb25602020-03-09 12:13:48 +00002118 IgnoreUnused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002119 bool supported = true;
2120
2121 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002122 std::array<DataType,6> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002123 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002124 DataType::BFloat16,
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002125 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002126 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002127 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002128 DataType::QAsymmU8,
2129 DataType::QSymmS16
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002130 };
2131
2132 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2133 "Reference pad: input is not a supported type.");
2134
2135 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2136 "Reference pad: output is not a supported type.");
2137
2138 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2139 "Reference pad: input and output types are mismatched.");
2140
2141 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01002142}
2143
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002144bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
2145 const TensorInfo& output,
2146 const PermuteDescriptor& descriptor,
2147 Optional<std::string&> reasonIfUnsupported) const
2148{
Jan Eilers8eb25602020-03-09 12:13:48 +00002149 IgnoreUnused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002150 bool supported = true;
2151
2152 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002153 std::array<DataType, 6> supportedTypes =
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002154 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002155 DataType::BFloat16,
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002156 DataType::Float32,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002157 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002158 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002159 DataType::QAsymmU8,
2160 DataType::QSymmS16
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002161 };
2162
2163 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2164 "Reference permute: input is not a supported type.");
2165
2166 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2167 "Reference permute: output is not a supported type.");
2168
2169 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2170 "Reference permute: input and output types are mismatched.");
2171
2172 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002173}
2174
2175bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
2176 const TensorInfo& output,
2177 const Pooling2dDescriptor& descriptor,
2178 Optional<std::string&> reasonIfUnsupported) const
2179{
Jan Eilers8eb25602020-03-09 12:13:48 +00002180 IgnoreUnused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01002181 bool supported = true;
2182
2183 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002184 std::array<DataType,6> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01002185 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002186 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01002187 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002188 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002189 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002190 DataType::QAsymmU8,
2191 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01002192 };
2193
2194 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2195 "Reference poolind2d: input is not a supported type.");
2196
2197 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2198 "Reference poolind2d: output is not a supported type.");
2199
2200 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2201 "Reference poolind2d: input and output types are mismatched.");
2202
2203 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002204}
2205
Tamás Nyíri7b885b32021-10-26 14:47:57 +01002206bool RefLayerSupport::IsPooling3dSupported(const TensorInfo& input,
2207 const TensorInfo& output,
2208 const Pooling3dDescriptor& descriptor,
2209 Optional<std::string&> reasonIfUnsupported) const
2210{
2211 IgnoreUnused(descriptor);
2212 bool supported = true;
2213
2214 // Define supported output and inputs types.
2215 std::array<DataType,6> supportedTypes =
2216 {
2217 DataType::BFloat16,
2218 DataType::Float32,
2219 DataType::Float16,
2220 DataType::QAsymmS8,
2221 DataType::QAsymmU8,
2222 DataType::QSymmS16
2223 };
2224
2225 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2226 "Reference poolind3d: input is not a supported type.");
2227
2228 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2229 "Reference poolind3d: output is not a supported type.");
2230
2231 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2232 "Reference poolind3d: input and output types are mismatched.");
2233
2234 return supported;
2235}
2236
2237
James Conroy4f1f8992020-04-29 20:01:10 +01002238bool RefLayerSupport::IsQLstmSupported(const TensorInfo& input,
2239 const TensorInfo& previousOutputIn,
2240 const TensorInfo& previousCellStateIn,
2241 const TensorInfo& outputStateOut,
2242 const TensorInfo& cellStateOut,
2243 const TensorInfo& output,
2244 const QLstmDescriptor& descriptor,
2245 const LstmInputParamsInfo& paramsInfo,
2246 Optional<std::string&> reasonIfUnsupported) const
2247{
2248 IgnoreUnused(input);
2249 IgnoreUnused(previousOutputIn);
2250 IgnoreUnused(previousCellStateIn);
2251 IgnoreUnused(outputStateOut);
2252 IgnoreUnused(cellStateOut);
2253 IgnoreUnused(output);
2254 IgnoreUnused(descriptor);
2255 IgnoreUnused(paramsInfo);
2256
2257 IgnoreUnused(reasonIfUnsupported);
2258
2259 return true;
2260}
2261
Derek Lamberti5f400d62019-03-25 15:41:58 +00002262bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
2263 const TensorInfo& output,
2264 Optional<std::string&> reasonIfUnsupported) const
2265{
2266 bool supported = true;
2267
Finn Williamsfd271062019-12-04 14:27:27 +00002268 // Define supported input types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002269 std::array<DataType,7> supportedInputTypes = {
2270 DataType::BFloat16,
Keith Davis5e51cd82020-01-29 16:52:59 +00002271 DataType::Float32,
Keith Davis3d8bc972020-02-04 09:31:47 +00002272 DataType::Float16,
Ryan OShea9add1202020-02-07 10:06:33 +00002273 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002274 DataType::QAsymmU8,
2275 DataType::QSymmS8,
2276 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002277 };
2278
2279 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
2280 "Reference quantize: input type not supported.");
2281
2282 // Define supported output types.
Ryan OShea9add1202020-02-07 10:06:33 +00002283 std::array<DataType,4> supportedOutputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00002284 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002285 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00002286 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002287 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002288 };
2289 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2290 "Reference quantize: output type not supported.");
2291
2292 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2293 "Reference quantize: input and output shapes have different num total elements.");
2294
2295 return supported;
2296}
2297
Finn Williams2605b232020-06-10 15:53:46 +01002298bool RefLayerSupport::IsRankSupported(const TensorInfo& input,
2299 const TensorInfo& output,
2300 Optional<std::string&> reasonIfUnsupported) const
2301{
2302 IgnoreUnused(input);
2303 // Define supported output types.
2304 std::array<DataType,1> supportedOutputTypes =
2305 {
2306 DataType::Signed32,
2307 };
2308
2309 return CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2310 "Reference rank: input type not supported.");
2311}
2312
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00002313bool RefLayerSupport::IsReduceSupported(const TensorInfo& input,
2314 const TensorInfo& output,
2315 const ReduceDescriptor& descriptor,
2316 Optional<std::string&> reasonIfUnsupported) const
2317{
2318 IgnoreUnused(descriptor);
2319 bool supported = true;
2320 std::array<DataType,7> supportedTypes =
2321 {
2322 DataType::BFloat16,
2323 DataType::Float32,
2324 DataType::Float16,
2325 DataType::QAsymmS8,
2326 DataType::QAsymmU8,
2327 DataType::QSymmS16,
2328 DataType::Signed32
2329 };
2330
2331 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2332 "Reference Reduce: input type not supported");
2333
2334 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2335 "Reference Reduce: output type not supported");
2336
2337 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2338 "Reference Reduce: input and output types not matching");
2339
2340 return supported;
2341}
2342
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002343bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Kevin Maya023c402019-12-12 17:28:05 +00002344 const TensorInfo& output,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00002345 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002346 Optional<std::string&> reasonIfUnsupported) const
2347{
Jan Eilers8eb25602020-03-09 12:13:48 +00002348 IgnoreUnused(output);
2349 IgnoreUnused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01002350 // Define supported output types.
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002351 std::array<DataType,8> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01002352 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002353 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01002354 DataType::Float32,
2355 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01002356 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002357 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002358 DataType::QAsymmU8,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002359 DataType::QSymmS16,
2360 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01002361 };
Keith Davis0c2eeac2020-02-11 16:51:50 +00002362
Nina Drozd2f2778f2019-05-27 10:37:05 +01002363 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
2364 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002365}
2366
Teresa Charlin970f43b2019-07-01 13:51:07 +01002367bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
2368 const TensorInfo& output,
2369 const ResizeDescriptor& descriptor,
2370 Optional<std::string&> reasonIfUnsupported) const
2371{
Jan Eilers8eb25602020-03-09 12:13:48 +00002372 IgnoreUnused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01002373 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002374 std::array<DataType,6> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01002375 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002376 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01002377 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002378 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002379 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002380 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002381 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01002382 };
2383
2384 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2385 "Reference Resize: input type not supported");
2386
2387 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2388 "Reference Resize: output type not supported");
2389
2390 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2391 "Reference Resize: input and output types not matching");
2392
2393 return supported;
2394}
2395
Keith Davis3ae3f972021-05-21 16:33:48 +01002396bool RefLayerSupport::IsShapeSupported(const TensorInfo& input,
2397 const TensorInfo& output,
2398 Optional<std::string&> reasonIfUnsupported) const
2399{
2400 IgnoreUnused(input);
2401 bool supported = true;
2402
2403 std::array<DataType, 1> supportedTypes =
2404 {
2405 DataType::Signed32
2406 };
2407
2408 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2409 "Reference Shape: output type not supported");
2410
2411 return supported;
2412}
2413
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002414bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
2415 const TensorInfo& output,
2416 const SliceDescriptor& descriptor,
2417 Optional<std::string&> reasonIfUnsupported) const
2418{
Jan Eilers8eb25602020-03-09 12:13:48 +00002419 IgnoreUnused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002420 bool supported = true;
2421
Sadik Armagan303980c2020-04-17 12:45:14 +01002422 std::array<DataType, 5> supportedTypes =
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002423 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002424 DataType::BFloat16,
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002425 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002426 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002427 DataType::QAsymmU8,
2428 DataType::QSymmS16
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002429 };
2430
2431 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2432 "Reference Slice: input type not supported");
2433
2434 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2435 "Reference Slice: output type not supported");
2436
2437 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2438 "Reference Slice: input and output types are mismatched");
2439
2440 return supported;
2441}
2442
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002443bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
2444 const TensorInfo& output,
2445 const SoftmaxDescriptor& descriptor,
2446 Optional<std::string&> reasonIfUnsupported) const
2447{
Jan Eilers8eb25602020-03-09 12:13:48 +00002448 IgnoreUnused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01002449 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002450 std::array<DataType,7> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01002451 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002452 DataType::BFloat16,
2453 DataType::Float32,
2454 DataType::Float16,
2455 DataType::QSymmS8,
2456 DataType::QAsymmS8,
2457 DataType::QAsymmU8,
2458 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +01002459 };
2460
2461 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002462 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002463
2464 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002465 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002466
2467 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002468 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002469
2470 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002471}
2472
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002473bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
2474 const TensorInfo& output,
2475 const SpaceToBatchNdDescriptor& descriptor,
2476 Optional<std::string&> reasonIfUnsupported) const
2477{
Jan Eilers8eb25602020-03-09 12:13:48 +00002478 IgnoreUnused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01002479 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002480 std::array<DataType,6> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01002481 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002482 DataType::BFloat16,
2483 DataType::Float32,
2484 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002485 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002486 DataType::QAsymmU8,
2487 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01002488 };
2489
2490 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2491 "Reference SpaceToBatchNd: input type not supported");
2492
2493 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2494 "Reference SpaceToBatchNd: output type not supported");
2495
2496 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2497 "Reference SpaceToBatchNd: input and output types are mismatched");
2498
2499 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002500}
2501
Keith Davisa57eccb2019-06-14 17:33:22 +01002502bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01002503 const TensorInfo& output,
2504 const SpaceToDepthDescriptor& descriptor,
2505 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01002506{
2507
Jan Eilers8eb25602020-03-09 12:13:48 +00002508 IgnoreUnused(descriptor);
Keith Davisa57eccb2019-06-14 17:33:22 +01002509 bool supported = true;
2510
Sadik Armagan303980c2020-04-17 12:45:14 +01002511 std::array<DataType,6> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01002512 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002513 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01002514 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002515 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002516 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002517 DataType::QAsymmU8,
2518 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01002519 };
2520
2521 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2522 "Reference SpaceToDepth: input type not supported");
2523
2524 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2525 "Reference SpaceToDepth: output type not supported");
2526
2527 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2528 "Reference SpaceToDepth: input and output types are mismatched");
2529
2530 return supported;
2531}
2532
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002533bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002534 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
2535 const ViewsDescriptor& descriptor,
2536 Optional<std::string&> reasonIfUnsupported) const
2537{
Jan Eilers8eb25602020-03-09 12:13:48 +00002538 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002539 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002540 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002541 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002542 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002543 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002544 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002545 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002546 DataType::QAsymmU8,
2547 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002548 };
2549
2550 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2551 "Reference splitter: output type not supported");
Derek Lambertieac4adb2020-08-25 13:05:59 +01002552 for (const TensorInfo& output : outputs)
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002553 {
2554 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2555 "Reference splitter: input type not supported");
2556
2557 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2558 "Reference splitter: input and output types mismatched.");
2559 }
2560
2561 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002562}
2563
Matthew Jackson81e601c2019-07-11 12:07:09 +01002564bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
2565 const TensorInfo& output,
2566 const StackDescriptor& descriptor,
2567 Optional<std::string&> reasonIfUnsupported) const
2568{
Jan Eilers8eb25602020-03-09 12:13:48 +00002569 IgnoreUnused(descriptor);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002570
2571 bool supported = true;
Sadik Armagan529195f2022-01-14 12:56:35 +00002572 std::array<DataType,7> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01002573 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002574 DataType::BFloat16,
Matthew Jackson81e601c2019-07-11 12:07:09 +01002575 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01002576 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002577 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002578 DataType::QAsymmU8,
Sadik Armagan529195f2022-01-14 12:56:35 +00002579 DataType::QSymmS16,
2580 DataType::Signed32
Matthew Jackson81e601c2019-07-11 12:07:09 +01002581 };
2582
2583 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2584 "Reference stack: output type not supported");
2585 for (const TensorInfo* input : inputs)
2586 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01002587 ARMNN_ASSERT(input != nullptr);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002588 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
2589 "Reference stack: input type not supported");
2590
2591 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
2592 "Reference stack: input and output types mismatched.");
2593 }
2594
2595 return supported;
2596}
2597
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002598bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
2599 const TensorInfo& output,
2600 const StridedSliceDescriptor& descriptor,
2601 Optional<std::string&> reasonIfUnsupported) const
2602{
Jan Eilers8eb25602020-03-09 12:13:48 +00002603 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002604 bool supported = true;
2605
Sadik Armagan303980c2020-04-17 12:45:14 +01002606 std::array<DataType,5> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002607 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002608 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002609 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002610 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002611 DataType::QAsymmU8,
2612 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002613 };
2614
2615 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2616 "Reference StridedSlice: input type not supported");
2617
2618 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2619 "Reference StridedSlice: output type not supported");
2620
2621 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2622 "Reference StridedSlice: input and output types are mismatched");
2623
2624 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002625}
2626
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002627bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
2628 const TensorInfo& input1,
2629 const TensorInfo& output,
2630 Optional<std::string&> reasonIfUnsupported) const
2631{
Sadik Armagan2999a022019-04-09 14:20:12 +01002632 bool supported = true;
2633
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002634 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002635 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01002636 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002637 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002638 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002639 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002640 DataType::QSymmS16,
2641 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002642 };
2643
2644 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2645 "Reference subtraction: input 0 is not a supported type.");
2646
2647 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2648 "Reference subtraction: input 1 is not a supported type.");
2649
2650 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2651 "Reference subtraction: output is not a supported type.");
2652
2653 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2654 "Reference subtraction: input 0 and Input 1 types are mismatched");
2655
2656 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2657 "Reference subtraction: input and output types are mismatched");
2658
2659 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2660 "Reference subtraction: shapes are not suitable for implicit broadcast.");
2661
2662 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002663}
2664
Matteo Martincighab9e5252019-06-13 17:27:46 +01002665bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
2666 const TensorInfo& alpha,
2667 const TensorInfo& output,
2668 Optional<std::string&> reasonIfUnsupported) const
2669{
2670 bool supported = true;
2671
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002672 std::array<DataType, 6> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01002673 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002674 DataType::BFloat16,
Matteo Martincighab9e5252019-06-13 17:27:46 +01002675 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002676 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002677 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002678 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002679 DataType::QSymmS16
Matteo Martincighab9e5252019-06-13 17:27:46 +01002680 };
2681
2682 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2683 "PReLU: input is not a supported type.");
2684
2685 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
2686 "PReLU: alpha is not a supported type.");
2687
2688 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2689 "PReLU: output is not a supported type.");
2690
2691 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
2692 "PReLU: input, alpha and output types are mismatched");
2693
2694 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
2695 "PReLU: shapes are not suitable for implicit broadcast");
2696
2697 return supported;
2698}
2699
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002700bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
2701 const TensorInfo& output,
2702 const TransposeConvolution2dDescriptor& descriptor,
2703 const TensorInfo& weights,
2704 const Optional<TensorInfo>& biases,
2705 Optional<std::string&> reasonIfUnsupported) const
2706{
Jan Eilers8eb25602020-03-09 12:13:48 +00002707 IgnoreUnused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002708 bool supported = true;
2709
Sadik Armagan303980c2020-04-17 12:45:14 +01002710 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002711 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002712 DataType::BFloat16,
2713 DataType::Float32,
2714 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002715 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002716 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002717 DataType::QSymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002718 DataType::QSymmS16
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002719 };
2720
2721 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2722 "Reference TransposeConvolution2d: input is not a supported type.");
2723
2724 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2725 "Reference TransposeConvolution2d: output is not a supported type.");
2726
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002727 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2728 "Reference TransposeConvolution2d: input and output types mismatched.");
2729
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002730
2731 const DataType inputType = input.GetDataType();
Sadik Armagan303980c2020-04-17 12:45:14 +01002732 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002733 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01002734 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002735 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002736 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002737 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01002738 DataType::QSymmS8
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002739 };
2740
2741 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
2742 "Reference TransposeConvolution2d: weights type not supported for "
2743 "quantized input.");
2744 }
2745 else
2746 {
2747 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
2748 "Reference TransposeConvolution2d: weights is not a supported type.");
2749
2750 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
2751 "Reference TransposeConvolution2d: input and weights types mismatched.");
2752 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002753
2754 if (biases.has_value())
2755 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002756 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01002757 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002758 DataType::BFloat16,
2759 DataType::Float32,
2760 DataType::Float16,
2761 DataType::Signed32
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002762 };
2763 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
2764 "Reference TransposeConvolution2d: biases is not a supported type.");
2765 }
2766
2767 return supported;
2768}
2769
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002770bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
2771 const TensorInfo& output,
2772 const TransposeDescriptor& descriptor,
2773 Optional<std::string&> reasonIfUnsupported) const
2774{
Jan Eilers8eb25602020-03-09 12:13:48 +00002775 IgnoreUnused(descriptor);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002776 bool supported = true;
2777
2778 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002779 std::array<DataType, 6> supportedTypes =
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002780 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002781 DataType::BFloat16,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002782 DataType::Float32,
2783 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002784 DataType::QAsymmS8,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002785 DataType::QAsymmU8,
2786 DataType::QSymmS16
2787 };
2788
2789 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2790 "Reference transpose: input is not a supported type.");
2791
2792 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2793 "Reference transpose: output is not a supported type.");
2794
2795 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2796 "Reference transpose: input and output types are mismatched.");
2797
2798 return supported;
2799}
2800
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002801bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
2802 const TensorInfo& input,
2803 const TensorInfo& outputStateIn,
2804 const TensorInfo& cellStateIn,
2805 const TensorInfo& output,
2806 const Optional<TensorInfo>& hiddenStateOutput,
2807 const Optional<TensorInfo>& cellStateOutput,
2808 const UnidirectionalSequenceLstmDescriptor& descriptor,
2809 const LstmInputParamsInfo& paramsInfo,
2810 Optional<std::string&> reasonIfUnsupported) const
2811{
2812 IgnoreUnused(descriptor);
2813 IgnoreUnused(paramsInfo);
2814 IgnoreUnused(outputStateIn);
2815 IgnoreUnused(cellStateIn);
2816 bool supported = true;
2817
2818 if (hiddenStateOutput.has_value() || cellStateOutput.has_value())
2819 {
2820 reasonIfUnsupported.value() += "Reference UnidirectionalSequenceLstm: hidden state output "
2821 "and cell state output are not supported at the moment.";
2822 }
2823
2824 std::array<DataType, 1> supportedTypes =
2825 {
2826 DataType::Float32
2827 };
2828
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002829 std::array<DataType, 2> supportedWeightTypes =
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002830 {
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002831 DataType::Float32,
2832 DataType::QAsymmS8
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002833 };
2834
2835 // check inputs and outputs
2836 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2837 "Reference UnidirectionalSequenceLstm: input is not a supported type.");
2838 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
2839 "Reference UnidirectionalSequenceLstm: input and outputStateIn types are mismatched");
2840 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
2841 "Reference UnidirectionalSequenceLstm: input and cellStateIn types are mismatched");
2842
2843 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2844 "Reference UnidirectionalSequenceLstm: input and output types are mismatched");
2845 // check layer parameters
2846 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes),
2847 reasonIfUnsupported,
2848 "Reference UnidirectionalSequenceLstm: InputToForgetWeights "
2849 "is not a supported type.");
2850 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToCellWeights(), supportedWeightTypes),
2851 reasonIfUnsupported,
2852 "Reference UnidirectionalSequenceLstm: InputToCellWeights is not a supported type.");
2853 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToOutputWeights(), supportedWeightTypes),
2854 reasonIfUnsupported,
2855 "Reference UnidirectionalSequenceLstm: InputToOutputWeights "
2856 "is not a supported type.");
2857 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToForgetWeights(), supportedWeightTypes),
2858 reasonIfUnsupported,
2859 "Reference UnidirectionalSequenceLstm: RecurrentToForgetWeights "
2860 "is not a supported type.");
2861 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToCellWeights(), supportedWeightTypes),
2862 reasonIfUnsupported,
2863 "Reference UnidirectionalSequenceLstm: RecurrentToCellWeights "
2864 "is not a supported type.");
2865 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToOutputWeights(), supportedWeightTypes),
2866 reasonIfUnsupported,
2867 "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights "
2868 "is not a supported type.");
2869 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
2870 "Reference UnidirectionalSequenceLstm: input and ForgetGateBias types "
2871 "are mismatched");
2872 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
2873 "Reference UnidirectionalSequenceLstm: input and CellBias types are mismatched");
2874 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
2875 "Reference UnidirectionalSequenceLstm: input and OutputGateBias types "
2876 "are mismatched");
2877 if (!descriptor.m_CifgEnabled)
2878 {
2879 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes),
2880 reasonIfUnsupported,
2881 "Reference UnidirectionalSequenceLstm: InputToInputWeights "
2882 "is not a supported type.");
2883 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToInputWeights(), supportedWeightTypes),
2884 reasonIfUnsupported,
2885 "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights "
2886 "is not a supported type.");
2887 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
2888 "Reference UnidirectionalSequenceLstm: input and InputGateBias types "
2889 "are mismatched");
2890 if (descriptor.m_PeepholeEnabled)
2891 {
2892 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes),
2893 reasonIfUnsupported,
2894 "Reference UnidirectionalSequenceLstm: CellToInputWeights "
2895 "is not a supported type.");
2896 }
2897 }
2898 if (descriptor.m_PeepholeEnabled)
2899 {
2900 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToForgetWeights(), supportedWeightTypes),
2901 reasonIfUnsupported,
2902 "Reference UnidirectionalSequenceLstm: CellToForgetWeights "
2903 "is not a supported type.");
2904 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToOutputWeights(), supportedWeightTypes),
2905 reasonIfUnsupported,
2906 "Reference UnidirectionalSequenceLstm: CellToOutputWeights "
2907 "is not a supported type.");
2908 }
2909 if (descriptor.m_ProjectionEnabled)
2910 {
2911 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetProjectionWeights(), supportedWeightTypes),
2912 reasonIfUnsupported,
2913 "Reference UnidirectionalSequenceLstm: ProjectionWeights "
2914 "is not a supported type.");
2915 if (paramsInfo.m_ProjectionBias != nullptr)
2916 {
2917 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
2918 "Reference UnidirectionalSequenceLstm: input and ProjectionBias types "
2919 "are mismatched");
2920 }
2921 }
2922 if (descriptor.m_LayerNormEnabled)
2923 {
2924 if (!descriptor.m_CifgEnabled)
2925 {
2926 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputLayerNormWeights(), supportedWeightTypes),
2927 reasonIfUnsupported,
2928 "Reference UnidirectionalSequenceLstm: InputLayerNormWeights "
2929 "is not a supported type.");
2930 }
2931 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetLayerNormWeights(), supportedWeightTypes),
2932 reasonIfUnsupported,
2933 "Reference UnidirectionalSequenceLstm: ForgetLayerNormWeights "
2934 "is not a supported type.");
2935 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellLayerNormWeights(), supportedWeightTypes),
2936 reasonIfUnsupported,
2937 "Reference UnidirectionalSequenceLstm: CellLayerNormWeights "
2938 "is not a supported type.");
2939 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputLayerNormWeights(), supportedWeightTypes),
2940 reasonIfUnsupported,
2941 "Reference UnidirectionalSequenceLstm: OutputLayerNormWeights "
2942 "is not a supported type.");
2943 }
2944
2945 return supported;
2946}
2947
arovir011c7c81b2018-10-08 11:34:28 +01002948} // namespace armnn