blob: b55adfa958998a99e885a77762e30eec26b704a0 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
Teresa Charlin52664732020-06-29 16:27:03 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01007
Keith Davis0c2eeac2020-02-11 16:51:50 +00008#include <armnn/TypesUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000010#include <armnn/utility/IgnoreUnused.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010011#include <armnn/utility/NumericCast.hpp>
Cathal Corbett34b429c2021-12-24 12:24:40 +000012#include <armnn/utility/PolymorphicDowncast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
Matteo Martincighe011d202019-11-28 11:35:47 +000014#include <LayerSupportCommon.hpp>
Derek Lambertif674aa02019-08-01 15:56:25 +010015#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000016
Derek Lamberti50db4e82019-03-13 14:16:15 +000017#include <vector>
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <array>
19
telsoa014fcda012018-03-09 14:13:49 +000020namespace armnn
21{
22
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010023namespace
24{
25
26template<typename Float32Func, typename Uint8Func, typename ... Params>
27bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
28 DataType dataType,
29 Float32Func floatFuncPtr,
30 Uint8Func uint8FuncPtr,
31 Params&&... params)
32{
33 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
34 dataType,
35 &FalseFunc<Params...>,
36 floatFuncPtr,
37 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000038 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000039 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010040 std::forward<Params>(params)...);
41}
42
43} // anonymous namespace
44
James Conroy4d1ff582019-06-10 17:06:39 +010045namespace
46{
47
48std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
49 unsigned int actual,
50 std::string& layerStr,
51 std::string& tensorName)
52{
53 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
54 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
55
56 return errorMsg;
57}
58
59} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000060
Cathal Corbett34b429c2021-12-24 12:24:40 +000061bool RefLayerSupport::IsLayerSupported(const LayerType& type,
62 const std::vector<TensorInfo>& infos,
63 const BaseDescriptor& descriptor,
64 const Optional<LstmInputParamsInfo>& lstmParamsInfo,
65 const Optional<QuantizedLstmInputParamsInfo>& quantizedLstmInputParamsInfo,
66 Optional<std::string&> reasonIfUnsupported) const
67{
68 switch (type)
69 {
70 case LayerType::Activation:
71 return IsActivationSupported(infos[0],
72 infos[1],
73 *(PolymorphicDowncast<const ActivationDescriptor*>(&descriptor)),
74 reasonIfUnsupported);
75 case LayerType::Addition:
76 return IsAdditionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
77 case LayerType::ArgMinMax:
78 return IsArgMinMaxSupported(infos[0],
79 infos[1],
80 *(PolymorphicDowncast<const ArgMinMaxDescriptor*>(&descriptor)),
81 reasonIfUnsupported);
82 case LayerType::BatchNormalization:
83 return IsBatchNormalizationSupported(infos[0],
84 infos[1],
85 infos[2],
86 infos[3],
87 infos[4],
88 infos[5],
89 *(PolymorphicDowncast<const BatchNormalizationDescriptor*>
90 (&descriptor)),
91 reasonIfUnsupported);
92 case LayerType::BatchToSpaceNd:
93 return IsBatchToSpaceNdSupported(infos[0],
94 infos[1],
95 *(PolymorphicDowncast<const BatchToSpaceNdDescriptor*>(&descriptor)),
96 reasonIfUnsupported);
97 case LayerType::Comparison:
98 return IsComparisonSupported(infos[0],
99 infos[1],
100 infos[2],
101 *(PolymorphicDowncast<const ComparisonDescriptor*>(&descriptor)),
102 reasonIfUnsupported);
103 case LayerType::Concat:
104 {
105 std::vector<const TensorInfo*> inputInfos;
106 for (uint32_t i = 0; i < (infos.size() - 1); i++)
107 {
108 inputInfos.push_back(&infos[i]);
109 }
110 return IsConcatSupported(inputInfos,
111 infos[infos.size() - 1],
112 *(PolymorphicDowncast<const OriginsDescriptor*>(&descriptor)),
113 reasonIfUnsupported);
114 }
115 case LayerType::Constant:
116 return IsConstantSupported(infos[0], reasonIfUnsupported);
117 case LayerType::ConvertBf16ToFp32:
118 return IsConvertBf16ToFp32Supported(infos[0], infos[1], reasonIfUnsupported);
119 case LayerType::ConvertFp16ToFp32:
120 return IsConvertFp16ToFp32Supported(infos[0], infos[1], reasonIfUnsupported);
121 case LayerType::ConvertFp32ToBf16:
122 return IsConvertFp32ToBf16Supported(infos[0], infos[1], reasonIfUnsupported);
123 case LayerType::ConvertFp32ToFp16:
124 return IsConvertFp32ToFp16Supported(infos[0], infos[1], reasonIfUnsupported);
125 case LayerType::Convolution2d:
126 {
127 if (infos.size() != 4)
128 {
129 throw InvalidArgumentException("Invalid number of Convolution2d TensorInfos. "
130 "TensorInfos should be of format: {input, output, weights, biases}.");
131 }
132
133 auto desc = *(PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor));
134 if (infos[3] == TensorInfo())
135 {
136 return IsConvolution2dSupported(infos[0],
137 infos[1],
138 desc,
139 infos[2],
140 EmptyOptional(),
141 reasonIfUnsupported);
142 }
143 else
144 {
145 return IsConvolution2dSupported(infos[0],
146 infos[1],
147 desc,
148 infos[2],
149 infos[3],
150 reasonIfUnsupported);
151 }
152 }
153 case LayerType::DepthToSpace:
154 return IsDepthToSpaceSupported(infos[0],
155 infos[1],
156 *(PolymorphicDowncast<const DepthToSpaceDescriptor*>(&descriptor)),
157 reasonIfUnsupported);
158 case LayerType::DepthwiseConvolution2d:
159 {
160 if (infos.size() != 4)
161 {
162 throw InvalidArgumentException("Invalid number of DepthwiseConvolution2d TensorInfos. "
163 "TensorInfos should be of format: {input, output, weights, biases}.");
164 }
165
166 auto desc = *(PolymorphicDowncast<const DepthwiseConvolution2dDescriptor*>(&descriptor));
167 if (infos[3] == TensorInfo())
168 {
169 return IsDepthwiseConvolutionSupported(infos[0],
170 infos[1],
171 desc,
172 infos[2],
173 EmptyOptional(),
174 reasonIfUnsupported);
175 }
176 else
177 {
178 return IsDepthwiseConvolutionSupported(infos[0],
179 infos[1],
180 desc,
181 infos[2],
182 infos[3],
183 reasonIfUnsupported);
184 }
185 }
186 case LayerType::Dequantize:
187 return IsDequantizeSupported(infos[0], infos[1], reasonIfUnsupported);
188 case LayerType::Division:
189 return IsDivisionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
190 case LayerType::ElementwiseUnary:
191 return IsElementwiseUnarySupported(infos[0],
192 infos[1],
193 *(PolymorphicDowncast<const ElementwiseUnaryDescriptor*>(&descriptor)),
194 reasonIfUnsupported);
195 case LayerType::Fill:
196 return IsFillSupported(infos[0],
197 infos[1],
198 *(PolymorphicDowncast<const FillDescriptor*>(&descriptor)),
199 reasonIfUnsupported);
200 case LayerType::Floor:
201 return IsFloorSupported(infos[0], infos[1], reasonIfUnsupported);
202 case LayerType::FullyConnected:
203 return IsFullyConnectedSupported(infos[0],
204 infos[1],
205 infos[2],
206 infos[3],
207 *(PolymorphicDowncast<const FullyConnectedDescriptor*>(&descriptor)),
208 reasonIfUnsupported);
209 case LayerType::Gather:
210 return IsGatherSupported(infos[0],
211 infos[1],
212 infos[2],
213 *(PolymorphicDowncast<const GatherDescriptor*>(&descriptor)),
214 reasonIfUnsupported);
215 case LayerType::Input:
216 return IsInputSupported(infos[0], reasonIfUnsupported);
217 case LayerType::InstanceNormalization:
218 return IsInstanceNormalizationSupported(infos[0],
219 infos[1],
220 *(PolymorphicDowncast<const InstanceNormalizationDescriptor*>
221 (&descriptor)),
222 reasonIfUnsupported);
223 case LayerType::L2Normalization:
224 return IsL2NormalizationSupported(infos[0],
225 infos[1],
226 *(PolymorphicDowncast<const L2NormalizationDescriptor*>(&descriptor)),
227 reasonIfUnsupported);
228 case LayerType::LogicalBinary:
229 return IsLogicalBinarySupported(infos[0],
230 infos[1],
231 infos[2],
232 *(PolymorphicDowncast<const LogicalBinaryDescriptor*>(&descriptor)),
233 reasonIfUnsupported);
234 case LayerType::LogSoftmax:
235 return IsLogSoftmaxSupported(infos[0],
236 infos[1],
237 *(PolymorphicDowncast<const LogSoftmaxDescriptor*>(&descriptor)),
238 reasonIfUnsupported);
239 case LayerType::Lstm:
240 return IsLstmSupported(infos[0],
241 infos[1],
242 infos[2],
243 infos[3],
244 infos[4],
245 infos[5],
246 infos[6],
247 *(PolymorphicDowncast<const LstmDescriptor*>(&descriptor)),
248 lstmParamsInfo.value(),
249 reasonIfUnsupported);
250 case LayerType::QLstm:
251 return IsQLstmSupported(infos[0],
252 infos[1],
253 infos[2],
254 infos[3],
255 infos[4],
256 infos[5],
257 *(PolymorphicDowncast<const QLstmDescriptor*>(&descriptor)),
258 lstmParamsInfo.value(),
259 reasonIfUnsupported);
260 case LayerType::Maximum:
261 return IsMaximumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
262 case LayerType::Mean:
263 return IsMeanSupported(infos[0],
264 infos[1],
265 *(PolymorphicDowncast<const MeanDescriptor*>(&descriptor)),
266 reasonIfUnsupported);
267 case LayerType::Minimum:
268 return IsMinimumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
269 case LayerType::Multiplication:
270 return IsMultiplicationSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
271 case LayerType::Normalization:
272 return IsNormalizationSupported(infos[0],
273 infos[1],
274 *(PolymorphicDowncast<const NormalizationDescriptor*>(&descriptor)),
275 reasonIfUnsupported);
276 case LayerType::Output:
277 return IsOutputSupported(infos[0], reasonIfUnsupported);
278 case LayerType::Pad:
279 return IsPadSupported(infos[0],
280 infos[1],
281 *(PolymorphicDowncast<const PadDescriptor*>(&descriptor)),
282 reasonIfUnsupported);
283 case LayerType::Permute:
284 return IsPermuteSupported(infos[0],
285 infos[1],
286 *(PolymorphicDowncast<const PermuteDescriptor*>(&descriptor)),
287 reasonIfUnsupported);
288 case LayerType::Pooling2d:
289 return IsPooling2dSupported(infos[0],
290 infos[1],
291 *(PolymorphicDowncast<const Pooling2dDescriptor*>(&descriptor)),
292 reasonIfUnsupported);
293 case LayerType::Prelu:
294 return IsPreluSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
295 case LayerType::Quantize:
296 return IsQuantizeSupported(infos[0], infos[1], reasonIfUnsupported);
297 case LayerType::Reshape:
298 return IsReshapeSupported(infos[0],
299 infos[1],
300 *(PolymorphicDowncast<const ReshapeDescriptor*>(&descriptor)),
301 reasonIfUnsupported);
302 case LayerType::Resize:
303 return IsResizeSupported(infos[0],
304 infos[1],
305 *(PolymorphicDowncast<const ResizeDescriptor*>(&descriptor)),
306 reasonIfUnsupported);
307 case LayerType::Reduce:
308 return IsReduceSupported(infos[0],
309 infos[1],
310 *(PolymorphicDowncast<const ReduceDescriptor*>(&descriptor)),
311 reasonIfUnsupported);
312 case LayerType::Slice:
313 return IsSliceSupported(infos[0],
314 infos[1],
315 *(PolymorphicDowncast<const SliceDescriptor*>(&descriptor)),
316 reasonIfUnsupported);
317 case LayerType::Softmax:
318 return IsSoftmaxSupported(infos[0],
319 infos[1],
320 *(PolymorphicDowncast<const SoftmaxDescriptor*>(&descriptor)),
321 reasonIfUnsupported);
322 case LayerType::SpaceToBatchNd:
323 return IsSpaceToBatchNdSupported(infos[0],
324 infos[1],
325 *(PolymorphicDowncast<const SpaceToBatchNdDescriptor*>(&descriptor)),
326 reasonIfUnsupported);
327 case LayerType::SpaceToDepth:
328 return IsSpaceToDepthSupported(infos[0],
329 infos[1],
330 *(PolymorphicDowncast<const SpaceToDepthDescriptor*>(&descriptor)),
331 reasonIfUnsupported);
332 case LayerType::Splitter:
333 {
334 std::vector<TensorInfo> outputInfos;
335 for (uint32_t i = 1; i < infos.size(); i++)
336 {
337 outputInfos.push_back(infos[i]);
338 }
339 return IsSplitterSupported(infos[0],
340 {outputInfos.begin(), outputInfos.end()},
341 *(PolymorphicDowncast<const ViewsDescriptor*>(&descriptor)),
342 reasonIfUnsupported);
343 }
344 case LayerType::Stack:
345 {
346 std::vector<const TensorInfo*> inputInfos;
347 for (uint32_t i = 0; i < infos.size() - 1; i++)
348 {
349 inputInfos.push_back(&infos[i]);
350 }
351 return IsStackSupported(inputInfos,
352 infos[infos.size() - 1],
353 *(PolymorphicDowncast<const StackDescriptor*>(&descriptor)),
354 reasonIfUnsupported);
355 }
356 case LayerType::StridedSlice:
357 return IsStridedSliceSupported(infos[0],
358 infos[1],
359 *(PolymorphicDowncast<const StridedSliceDescriptor*>(&descriptor)),
360 reasonIfUnsupported);
361 case LayerType::Subtraction:
362 return IsSubtractionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
363 case LayerType::Transpose:
364 return IsTransposeSupported(infos[0],
365 infos[1],
366 *(PolymorphicDowncast<const TransposeDescriptor*>(&descriptor)),
367 reasonIfUnsupported);
368 case LayerType::TransposeConvolution2d:
369 {
370 if (infos.size() != 4)
371 {
372 throw InvalidArgumentException("Invalid number of TransposeConvolution2d TensorInfos. "
373 "TensorInfos should be of format: {input, output, weights, biases}.");
374 }
375
376 auto desc = *(PolymorphicDowncast<const TransposeConvolution2dDescriptor*>(&descriptor));
377 if (infos[3] == TensorInfo())
378 {
379 return IsTransposeConvolution2dSupported(infos[0],
380 infos[1],
381 desc,
382 infos[2],
383 EmptyOptional(),
384 reasonIfUnsupported);
385 }
386 else
387 {
388 return IsTransposeConvolution2dSupported(infos[0],
389 infos[1],
390 desc,
391 infos[2],
392 infos[3],
393 reasonIfUnsupported);
394 }
395 }
396 case LayerType::Cast:
397 return IsCastSupported(infos[0], infos[1], reasonIfUnsupported);
398 case LayerType::ChannelShuffle:
399 return IsChannelShuffleSupported(infos[0],
400 infos[1],
401 *(PolymorphicDowncast<const ChannelShuffleDescriptor*>(&descriptor)),
402 reasonIfUnsupported);
403 case LayerType::Convolution3d:
404 {
405 if (infos.size() != 4)
406 {
407 throw InvalidArgumentException("Invalid number of Convolution3d TensorInfos. "
408 "TensorInfos should be of format: {input, output, weights, biases}.");
409 }
410
411 auto desc = *(PolymorphicDowncast<const Convolution3dDescriptor*>(&descriptor));
412 if (infos[3] == TensorInfo())
413 {
414 return IsConvolution3dSupported(infos[0],
415 infos[1],
416 desc,
417 infos[2],
418 EmptyOptional(),
419 reasonIfUnsupported);
420 }
421 else
422 {
423 return IsConvolution3dSupported(infos[0],
424 infos[1],
425 desc,
426 infos[2],
427 infos[3],
428 reasonIfUnsupported);
429 }
430 }
431 case LayerType::Debug:
432 return IsDebugSupported(infos[0], infos[1], reasonIfUnsupported);
433 case LayerType::DetectionPostProcess:
434 return IsDetectionPostProcessSupported(infos[0],
435 infos[1],
436 infos[2],
437 infos[3],
438 infos[4],
439 infos[5],
440 infos[6],
441 *(PolymorphicDowncast<const DetectionPostProcessDescriptor*>
442 (&descriptor)),
443 reasonIfUnsupported);
444 case LayerType::FakeQuantization:
445 return IsFakeQuantizationSupported(infos[0],
446 *(PolymorphicDowncast<const FakeQuantizationDescriptor*>(&descriptor)),
447 reasonIfUnsupported);
448 case LayerType::MemCopy:
449 return IsMemCopySupported(infos[0], infos[1], reasonIfUnsupported);
450 case LayerType::Rank:
451 return IsRankSupported(infos[0], infos[1], reasonIfUnsupported);
452 case LayerType::Shape:
453 return IsShapeSupported(infos[0], infos[1], reasonIfUnsupported);
454 case LayerType::UnidirectionalSequenceLstm:
455 {
456 if (infos.size() != 6)
457 {
458 throw InvalidArgumentException("Invalid number of UnidirectionalSequenceLstm TensorInfos. TensorInfos "
459 "should be of format: {input, outputStateIn, cellStateIn, "
460 "hiddenStateOutputVal, cellStateOutputVal, output}");
461 }
462 auto desc = *(PolymorphicDowncast<const UnidirectionalSequenceLstmDescriptor*>(&descriptor));
463
464 bool isHiddenStateOutputOptional = (infos[4] == TensorInfo());
465 bool isCellStateOutput = (infos[5] == TensorInfo());
466 if (isHiddenStateOutputOptional && isCellStateOutput)
467 {
468 return IsUnidirectionalSequenceLstmSupported(infos[0],
469 infos[1],
470 infos[2],
471 infos[3],
472 EmptyOptional(),
473 EmptyOptional(),
474 desc,
475 lstmParamsInfo.value(),
476 reasonIfUnsupported);
477 }
478 else if (isHiddenStateOutputOptional)
479 {
480 return IsUnidirectionalSequenceLstmSupported(infos[0],
481 infos[1],
482 infos[2],
483 infos[3],
484 EmptyOptional(),
485 infos[5],
486 desc,
487 lstmParamsInfo.value(),
488 reasonIfUnsupported);
489 }
490 else if (isCellStateOutput)
491 {
492 return IsUnidirectionalSequenceLstmSupported(infos[0],
493 infos[1],
494 infos[2],
495 infos[3],
496 infos[4],
497 EmptyOptional(),
498 desc,
499 lstmParamsInfo.value(),
500 reasonIfUnsupported);
501 }
502 else
503 {
504 return IsUnidirectionalSequenceLstmSupported(infos[0],
505 infos[1],
506 infos[2],
507 infos[3],
508 infos[4],
509 infos[5],
510 desc,
511 lstmParamsInfo.value(),
512 reasonIfUnsupported);
513 }
514 }
515 case LayerType::Pooling3d:
516 return IsPooling3dSupported(infos[0],
517 infos[1],
518 *(PolymorphicDowncast<const Pooling3dDescriptor*>(&descriptor)),
519 reasonIfUnsupported);
520 case LayerType::Map:
521 return true;
522 case LayerType::Unmap:
523 return true;
524 case LayerType::MemImport:
525 return LayerSupportBase::IsMemImportSupported(infos[0], infos[1], reasonIfUnsupported);
526 case LayerType::Merge:
527 return LayerSupportBase::IsMergeSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
528 case LayerType::QuantizedLstm:
529 return LayerSupportBase::IsQuantizedLstmSupported(infos[0],
530 infos[1],
531 infos[2],
532 infos[3],
533 infos[4],
534 quantizedLstmInputParamsInfo.value(),
535 reasonIfUnsupported);
536 default:
537 // layers not supported in neon by default:
538 // precompiled, standin, switch
539 return false;
540 }
541}
542
arovir011c7c81b2018-10-08 11:34:28 +0100543bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
544 const TensorInfo& output,
545 const ActivationDescriptor& descriptor,
546 Optional<std::string&> reasonIfUnsupported) const
547{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000548 bool supported = true;
549
550 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +0000551 std::array<DataType,6> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000552 DataType::BFloat16,
Derek Lamberti50db4e82019-03-13 14:16:15 +0000553 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100554 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000555 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000556 DataType::QAsymmU8,
557 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000558 };
559
560 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
561 "Reference activation: input type not supported.");
562
563 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
564 "Reference activation: output type not supported.");
565
566 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
567 "Reference activation: input and output types mismatched.");
568
569 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
570 "Reference activation: input and output shapes are of different rank.");
571
572
573 struct ActivationFunctionSupported : public Rule
574 {
575 ActivationFunctionSupported(const ActivationDescriptor& desc)
576 {
577 switch(desc.m_Function)
578 {
579 case ActivationFunction::Abs:
580 case ActivationFunction::BoundedReLu:
David Monahan3b3c3812020-02-25 09:03:29 +0000581 case ActivationFunction::Elu:
Colm Donelan03fbeaf2020-02-26 15:39:23 +0000582 case ActivationFunction::HardSwish:
Derek Lamberti50db4e82019-03-13 14:16:15 +0000583 case ActivationFunction::LeakyReLu:
584 case ActivationFunction::Linear:
585 case ActivationFunction::ReLu:
586 case ActivationFunction::Sigmoid:
587 case ActivationFunction::SoftReLu:
588 case ActivationFunction::Sqrt:
589 case ActivationFunction::Square:
590 case ActivationFunction::TanH:
591 {
592 m_Res = true;
593 break;
594 }
595 default:
596 {
597 m_Res = false;
598 break;
599 }
600 }
601 }
602 };
603
604 // Function is supported
605 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
606 "Reference activation: function not supported.");
607
608 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100609}
610
611bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
612 const TensorInfo& input1,
613 const TensorInfo& output,
614 Optional<std::string&> reasonIfUnsupported) const
615{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000616 bool supported = true;
617
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100618 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000619 DataType::BFloat16,
Derek Lamberti50db4e82019-03-13 14:16:15 +0000620 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100621 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000622 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000623 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100624 DataType::QSymmS16,
625 DataType::Signed32
Derek Lamberti50db4e82019-03-13 14:16:15 +0000626 };
627
628 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
629 "Reference addition: input 0 is not a supported type.");
630
631 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
632 "Reference addition: input 1 is not a supported type.");
633
634 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
635 "Reference addition: output is not a supported type.");
636
637 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
638 "Reference addition: input 0 and Input 1 types are mismatched");
639
640 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
641 "Reference addition: input and output types are mismatched");
642
643 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
644 "Reference addition: shapes are not suitable for implicit broadcast.");
645
646 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100647}
648
Nikhil Raj68c2c902019-09-19 11:21:11 +0100649bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
650 const armnn::ArgMinMaxDescriptor &descriptor,
651 armnn::Optional<std::string &> reasonIfUnsupported) const
652{
Jan Eilers8eb25602020-03-09 12:13:48 +0000653 IgnoreUnused(descriptor);
Nikhil Raj68c2c902019-09-19 11:21:11 +0100654
Mike Kelly1f140f72021-04-06 12:25:55 +0100655 std::array<DataType, 8> supportedInputTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100656 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000657 DataType::BFloat16,
Teresa Charline300b362020-05-25 10:01:03 +0100658 DataType::Float16,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100659 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100660 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000661 DataType::QAsymmU8,
662 DataType::QSymmS16,
Mike Kelly1f140f72021-04-06 12:25:55 +0100663 DataType::Signed32,
664 DataType::Signed64
665 };
666
667 std::array<DataType,2> supportedOutputTypes = {
668 DataType::Signed32,
669 DataType::Signed64
Nikhil Raj68c2c902019-09-19 11:21:11 +0100670 };
671
672 bool supported = true;
673
Mike Kelly1f140f72021-04-06 12:25:55 +0100674 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100675 "Reference ArgMinMax: input is not a supported type.");
Mike Kelly1f140f72021-04-06 12:25:55 +0100676 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100677 "Reference ArgMinMax: output type not supported");
678
679 return supported;
680}
681
arovir011c7c81b2018-10-08 11:34:28 +0100682bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
683 const TensorInfo& output,
684 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100685 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100686 const TensorInfo& beta,
687 const TensorInfo& gamma,
688 const BatchNormalizationDescriptor& descriptor,
689 Optional<std::string&> reasonIfUnsupported) const
690{
Jan Eilers8eb25602020-03-09 12:13:48 +0000691 IgnoreUnused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100692
Sadik Armagan303980c2020-04-17 12:45:14 +0100693 std::array<DataType, 6> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100694 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000695 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100696 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100697 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100698 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000699 DataType::QAsymmU8,
700 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100701 };
702
703 bool supported = true;
704
705 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
706 "Reference batch normalization: input is not a supported type.");
707
708 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
709 "Reference batch normalization: output is not a supported type.");
710
711 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
712 "Reference batch normalization: input and output types are mismatched");
713
714 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
715 "Reference batch normalization: mean is not a supported type.");
716
717 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
718 "Reference batch normalization: variance is not a supported type.");
719
720 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
721 "Reference batch normalization: beta is not a supported type.");
722
723 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
724 "Reference batch normalization: gamma is not a supported type.");
725
726 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100727}
728
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000729bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
730 const TensorInfo& output,
731 const BatchToSpaceNdDescriptor& descriptor,
732 Optional<std::string&> reasonIfUnsupported) const
733{
Jan Eilers8eb25602020-03-09 12:13:48 +0000734 IgnoreUnused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100735
736 bool supported = true;
737
738 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
739 std::string inputTensorStr = "input";
740 std::string outputTensorStr = "output";
741
742 // Define supported types.
Sadik Armagan303980c2020-04-17 12:45:14 +0100743 std::array<DataType,6> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100744 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000745 DataType::BFloat16,
746 DataType::Float32,
747 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100748 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000749 DataType::QAsymmU8,
750 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100751 };
752
753 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
754 "Reference BatchToSpaceNd: input type not supported.");
755
756 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
757 "Reference BatchToSpaceNd: output type not supported.");
758
759 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
760 "Reference BatchToSpaceNd: input and output types mismatched.");
761
762 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
763 reasonIfUnsupported,
764 CreateIncorrectDimensionsErrorMsg(4,
765 output.GetNumDimensions(),
766 batchToSpaceNdLayerStr,
767 outputTensorStr).data());
768
769 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
770 reasonIfUnsupported,
771 CreateIncorrectDimensionsErrorMsg(4,
772 input.GetNumDimensions(),
773 batchToSpaceNdLayerStr,
774 inputTensorStr).data());
775
776 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000777}
778
mathad01b392e982021-04-07 12:07:30 +0100779bool RefLayerSupport::IsCastSupported(const TensorInfo& input,
780 const TensorInfo& output,
781 Optional<std::string&> reasonIfUnsupported) const
782{
783 std::array<DataType, 9> supportedInputTypes =
784 {
785 DataType::BFloat16,
786 DataType::Float32,
787 DataType::Float16,
788 DataType::QSymmS8,
789 DataType::QAsymmS8,
790 DataType::QAsymmU8,
791 DataType::QSymmS16,
792 DataType::Signed32
793 };
794
795 bool supported = true;
796 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
797 "Reference cast: input is not a supported type");
798
799
800 supported &= CheckSupportRule(TypeAnyOf(output, supportedInputTypes), reasonIfUnsupported,
801 "Reference cast: output is not a supported type");
802
803 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
804 "Reference cast: input and output shapes have different number of total elements");
805
806 return supported;
807}
808
Simon Obute51f67772021-09-03 15:50:13 +0100809bool RefLayerSupport::IsChannelShuffleSupported(const TensorInfo& input,
810 const TensorInfo& output,
811 const ChannelShuffleDescriptor& descriptor,
812 Optional<std::string&> reasonIfUnsupported) const
813{
814 IgnoreUnused(descriptor);
815 bool supported = true;
816
817 // Define supported output and inputs types.
818 std::array<DataType, 7> supportedTypes =
819 {
820 DataType::BFloat16,
821 DataType::Float32,
822 DataType::Float16,
823 DataType::QAsymmS8,
824 DataType::QAsymmU8,
825 DataType::QSymmS8,
826 DataType::QSymmS16
827 };
828
829 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
830 "Reference ChannelShuffle: input is not a supported type.");
831
832 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
833 "Reference ChannelShuffle: output is not a supported type.");
834
835 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
836 "Reference ChannelShuffle: input and output types are mismatched.");
837
838 return supported;
839}
840
841
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100842bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
843 const TensorInfo& input1,
844 const TensorInfo& output,
845 const ComparisonDescriptor& descriptor,
846 Optional<std::string&> reasonIfUnsupported) const
847{
Jan Eilers8eb25602020-03-09 12:13:48 +0000848 IgnoreUnused(descriptor);
Sadik Armagan303980c2020-04-17 12:45:14 +0100849 std::array<DataType, 8> supportedInputTypes =
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100850 {
Sadik Armaganb60dd242020-03-19 13:53:16 +0000851 DataType::Boolean,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000852 DataType::BFloat16,
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100853 DataType::Float32,
854 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100855 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000856 DataType::QAsymmU8,
Sadik Armaganb60dd242020-03-19 13:53:16 +0000857 DataType::QSymmS16,
858 DataType::Signed32
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100859 };
860
861 bool supported = true;
862 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
863 "Reference comparison: input 0 is not a supported type");
864
865 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
866 "Reference comparison: input 0 and Input 1 types are mismatched");
867
868 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
869 "Reference comparison: output is not of type Boolean");
870
871 return supported;
872}
873
Jim Flynn906f9462019-05-10 13:55:21 +0100874bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
875 const TensorInfo& output,
Cathal Corbett34b429c2021-12-24 12:24:40 +0000876 const OriginsDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100877 Optional<std::string&> reasonIfUnsupported) const
878{
Jan Eilers8eb25602020-03-09 12:13:48 +0000879 IgnoreUnused(descriptor);
Jim Flynne242f2d2019-05-22 14:24:13 +0100880
881 bool supported = true;
Teresa Charlin6abc7ee2022-02-22 17:32:27 +0000882 std::array<DataType,7> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100883 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000884 DataType::BFloat16,
885 DataType::Float32,
886 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000887 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100888 DataType::QAsymmU8,
Teresa Charlin6abc7ee2022-02-22 17:32:27 +0000889 DataType::QSymmS16,
890 DataType::Signed32
Jim Flynne242f2d2019-05-22 14:24:13 +0100891 };
892
893 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
894 "Reference concatenation: output type not supported");
895 for (const TensorInfo* input : inputs)
896 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100897 ARMNN_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100898 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
899 "Reference concatenation: input type not supported");
900
901 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
902 "Reference concatenation: input and output types mismatched.");
903 }
904
905 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100906}
907
arovir011c7c81b2018-10-08 11:34:28 +0100908bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
909 Optional<std::string&> reasonIfUnsupported) const
910{
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100911 std::array<DataType,8> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100912 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000913 DataType::BFloat16,
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100914 DataType::Float16,
Nina Drozd58ef2c62019-05-16 12:09:18 +0100915 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +0000916 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100917 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000918 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100919 DataType::QSymmS16,
920 DataType::Signed32
Nina Drozd58ef2c62019-05-16 12:09:18 +0100921 };
922
923 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
924 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100925}
926
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000927bool RefLayerSupport::IsConvertBf16ToFp32Supported(const TensorInfo& input,
928 const TensorInfo& output,
929 Optional<std::string&> reasonIfUnsupported) const
930{
931 bool supported = true;
932
933 supported &= CheckSupportRule(TypeIs(input, DataType::BFloat16), reasonIfUnsupported,
934 "Reference for ConvertBf16ToFp32 layer: input type not supported");
935
936 supported &= CheckSupportRule(TypeIs(output, DataType::Float32), reasonIfUnsupported,
937 "Reference for ConvertBf16ToFp32 layer: output type not supported");
938
939 return supported;
940}
941
arovir011c7c81b2018-10-08 11:34:28 +0100942bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
943 const TensorInfo& output,
944 Optional<std::string&> reasonIfUnsupported) const
945{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100946 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
947 input.GetDataType(),
948 &TrueFunc<>,
949 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000950 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000951 &FalseFuncI32<>,
952 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100953 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
954 output.GetDataType(),
955 &FalseOutputFuncF16<>,
956 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000957 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000958 &FalseFuncI32<>,
959 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100960}
961
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000962bool RefLayerSupport::IsConvertFp32ToBf16Supported(const TensorInfo& input,
963 const TensorInfo& output,
964 Optional<std::string&> reasonIfUnsupported) const
965{
966 bool supported = true;
967
968 supported &= CheckSupportRule(TypeIs(input, DataType::Float32), reasonIfUnsupported,
969 "Reference for ConvertFp32ToBf16 layer: input type not supported");
970
971 supported &= CheckSupportRule(TypeIs(output, DataType::BFloat16), reasonIfUnsupported,
972 "Reference for ConvertFp32ToBf16 layer: output type not supported");
973
974 return supported;
975}
976
arovir011c7c81b2018-10-08 11:34:28 +0100977bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
978 const TensorInfo& output,
979 Optional<std::string&> reasonIfUnsupported) const
980{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100981 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
982 input.GetDataType(),
983 &FalseInputFuncF16<>,
984 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000985 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000986 &FalseFuncI32<>,
987 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100988 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
989 output.GetDataType(),
990 &TrueFunc<>,
991 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000992 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000993 &FalseFuncI32<>,
994 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100995}
996
997bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
998 const TensorInfo& output,
999 const Convolution2dDescriptor& descriptor,
1000 const TensorInfo& weights,
1001 const Optional<TensorInfo>& biases,
1002 Optional<std::string&> reasonIfUnsupported) const
1003{
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001004 bool supported = true;
1005
1006 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001007 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001008 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001009 DataType::BFloat16,
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001010 DataType::Float32,
1011 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001012 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001013 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001014 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001015 DataType::QSymmS16
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001016 };
1017
1018 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001019 "Reference Convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001020
1021 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001022 "Reference Convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001023
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001024 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
1025 if (input.GetDataType() == DataType::BFloat16)
1026 {
1027 if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
1028 {
1029 reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
1030 supported = false;
1031 }
1032 }
1033 else
1034 {
1035 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001036 "Reference Convolution2d: input and output types mismatched.");
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001037 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001038
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001039 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +00001040 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001041 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001042 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001043 {
Sadik Armagan303980c2020-04-17 12:45:14 +01001044 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001045 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01001046 DataType::QSymmS8
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001047 };
1048
1049 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001050 "Reference Convolution2d: weights type not supported for quantized input.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001051 }
1052 else
1053 {
1054 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001055 "Reference Convolution2d: weights is not a supported type.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001056
1057 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001058 "Reference Convolution2d: input and weights types mismatched.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001059 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001060
1061 if (biases.has_value())
1062 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001063 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001064 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001065 DataType::BFloat16,
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001066 DataType::Float32,
1067 DataType::Float16,
1068 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001069 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001070
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001071 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001072 "Reference Convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001073 }
Jan Eilers8eb25602020-03-09 12:13:48 +00001074 IgnoreUnused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001075
1076 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001077}
1078
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001079bool RefLayerSupport::IsConvolution3dSupported(const TensorInfo& input,
1080 const TensorInfo& output,
1081 const Convolution3dDescriptor& descriptor,
1082 const TensorInfo& weights,
1083 const Optional<TensorInfo>& biases,
1084 Optional<std::string&> reasonIfUnsupported) const
1085{
1086 bool supported = true;
1087
1088 // Define supported types.
1089 std::array<DataType,7> supportedTypes =
1090 {
1091 DataType::BFloat16,
1092 DataType::Float32,
1093 DataType::Float16,
1094 DataType::QAsymmS8,
1095 DataType::QAsymmU8,
1096 DataType::QSymmS8,
1097 DataType::QSymmS16
1098 };
1099
1100 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1101 "Reference Convolution3d: input is not a supported type.");
1102
1103 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1104 "Reference Convolution3d: output is not a supported type.");
1105
1106 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1107 "Reference Convolution3d: input and output types mismatched.");
1108
1109 const DataType inputType = input.GetDataType();
1110 if (IsQuantized8BitType(inputType))
1111 {
1112 std::array<DataType, 3> supportedWeightTypes =
1113 {
1114 DataType::QAsymmS8,
1115 DataType::QAsymmU8,
1116 DataType::QSymmS8
1117 };
1118
1119 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1120 "Reference Convolution3d: weights type not supported for quantized input.");
1121 }
1122 else
1123 {
1124 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1125 "Reference Convolution3d: weights is not a supported type.");
1126
1127 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1128 "Reference Convolution3d: input and weights types mismatched.");
1129 }
1130
1131 if (biases.has_value())
1132 {
1133 std::array<DataType,4> biasesSupportedTypes =
1134 {
1135 DataType::BFloat16,
1136 DataType::Float32,
1137 DataType::Float16,
1138 DataType::Signed32
1139 };
1140
1141 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1142 "Reference Convolution3d: biases is not a supported type.");
1143 }
1144 IgnoreUnused(descriptor);
1145
1146 return supported;
1147}
1148
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001149bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
1150 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001151 Optional<std::string&> reasonIfUnsupported) const
1152{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001153 bool supported = true;
1154
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001155 std::array<DataType, 8> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001156 {
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001157 DataType::BFloat16,
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +00001158 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001159 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001160 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001161 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001162 DataType::QSymmS8,
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +00001163 DataType::QSymmS16,
1164 DataType::Signed32
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001165 };
1166
1167 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001168 "Reference for Debug layer: input type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001169
1170 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001171 "Reference for Debug layer: output type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001172
1173 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001174 "Reference for Debug layer: input and output types are mismatched");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001175
1176 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001177}
1178
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001179bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
1180 const TensorInfo& output,
1181 const DepthToSpaceDescriptor& descriptor,
1182 Optional<std::string&> reasonIfUnsupported) const
1183{
Jan Eilers8eb25602020-03-09 12:13:48 +00001184 IgnoreUnused(descriptor);
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001185 bool supported = true;
1186
Sadik Armagan303980c2020-04-17 12:45:14 +01001187 std::array<DataType,6> supportedTypes =
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001188 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001189 DataType::BFloat16,
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001190 DataType::Float32,
1191 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001192 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001193 DataType::QAsymmU8,
1194 DataType::QSymmS16
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001195 };
1196
1197 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1198 "Reference DepthToSpace: input type not supported");
1199
1200 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1201 "Reference DepthToSpace: output type not supported");
1202
1203 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1204 "Reference DepthToSpace: input and output types are mismatched");
1205
1206 return supported;
1207}
1208
arovir011c7c81b2018-10-08 11:34:28 +01001209bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
1210 const TensorInfo& output,
1211 const DepthwiseConvolution2dDescriptor& descriptor,
1212 const TensorInfo& weights,
1213 const Optional<TensorInfo>& biases,
1214 Optional<std::string&> reasonIfUnsupported) const
1215{
Sadik Armagan303980c2020-04-17 12:45:14 +01001216 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001217 bool supported = true;
1218
1219 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001220 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001221 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001222 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001223 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001224 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001225 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001226 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001227 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001228 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001229 };
1230
1231 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1232 "Reference DepthwiseConvolution2d: input is not a supported type.");
1233
1234 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1235 "Reference DepthwiseConvolution2d: output is not a supported type.");
1236
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001237 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1238 "Reference DepthwiseConvolution2d: input and output types mismatched.");
1239
Teresa Charlind8df0262019-11-11 12:28:15 +00001240 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +00001241 if (IsQuantized8BitType(inputType))
Teresa Charlind8df0262019-11-11 12:28:15 +00001242 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001243 std::array<DataType, 3> supportedWeightTypes =
Sadik Armagan303980c2020-04-17 12:45:14 +01001244 {
1245 DataType::QAsymmS8,
1246 DataType::QAsymmU8,
1247 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001248 };
Teresa Charlind8df0262019-11-11 12:28:15 +00001249
1250 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Sadik Armagan303980c2020-04-17 12:45:14 +01001251 "Reference DepthwiseConvolution2d: weights type not supported for "
1252 "quantized input.");
Teresa Charlind8df0262019-11-11 12:28:15 +00001253 }
1254 else
1255 {
1256 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1257 "Reference DepthwiseConvolution2d: weights is not a supported type.");
1258
1259 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1260 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
1261 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001262
1263 if (biases.has_value())
1264 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001265 std::array<DataType,4> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001266 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001267 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001268 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001269 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001270 DataType::Signed32
1271 };
1272 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1273 "Reference DepthwiseConvolution2d: biases is not a supported type.");
1274 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001275
1276 return supported;
1277
arovir011c7c81b2018-10-08 11:34:28 +01001278}
1279
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001280bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
1281 const TensorInfo& output,
1282 Optional<std::string&> reasonIfUnsupported) const
1283{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001284 bool supported = true;
1285
Ryan OShea9add1202020-02-07 10:06:33 +00001286 std::array<DataType,4> supportedInputTypes = {
1287 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001288 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00001289 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001290 DataType::QSymmS16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001291 };
1292
1293 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001294 "Reference for Dequantize layer: input type not supported.");
1295
Derek Lambertid466a542020-01-22 15:37:29 +00001296 supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
Teresa Charlin1b1950d2021-06-02 20:23:21 +01001297 "Reference for Dequantize layer: per-axis quantized input not supported.");
Derek Lambertid466a542020-01-22 15:37:29 +00001298
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001299 std::array<DataType,3> supportedOutputTypes = {
1300 DataType::BFloat16,
Jan Eilersf7107932019-11-01 11:09:36 +00001301 DataType::Float32,
1302 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001303 };
1304
1305 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001306 "Reference for Dequantize layer: output type not supported.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001307
1308 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001309 "Reference for Dequantize layer: input/output shapes have different num total "
1310 "elements.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001311
1312 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001313}
1314
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001315bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
1316 const TensorInfo& scores,
1317 const TensorInfo& anchors,
1318 const TensorInfo& detectionBoxes,
1319 const TensorInfo& detectionClasses,
1320 const TensorInfo& detectionScores,
1321 const TensorInfo& numDetections,
1322 const DetectionPostProcessDescriptor& descriptor,
1323 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001324{
Jan Eilers8eb25602020-03-09 12:13:48 +00001325 IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
Derek Lamberti901ea112019-12-10 22:07:09 +00001326
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001327 bool supported = true;
1328
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001329 std::array<DataType,6> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001330 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001331 DataType::BFloat16,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001332 DataType::Float32,
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001333 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001334 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001335 DataType::QAsymmU8,
1336 DataType::QSymmS16
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001337 };
1338
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001339 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001340 "Reference DetectionPostProcess: input 0 is not a supported type.");
1341
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001342 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001343 "Reference DetectionPostProcess: input 1 is not a supported type.");
1344
1345 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001346}
1347
Pablo Tellof0bd6832019-04-26 17:58:13 +01001348bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
1349 const TensorInfo& output,
1350 const DepthwiseConvolution2dDescriptor& descriptor,
1351 const TensorInfo& weights,
1352 const Optional<TensorInfo>& biases,
1353 Optional<std::string&> reasonIfUnsupported) const
1354{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001355 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +01001356}
1357
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001358bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +01001359 const TensorInfo& input1,
1360 const TensorInfo& output,
1361 Optional<std::string&> reasonIfUnsupported) const
1362{
Sadik Armagan2999a022019-04-09 14:20:12 +01001363 bool supported = true;
1364
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001365 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001366 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001367 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001368 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001369 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001370 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001371 DataType::QSymmS16,
1372 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001373 };
1374
1375 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1376 "Reference division: input 0 is not a supported type.");
1377
1378 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1379 "Reference division: input 1 is not a supported type.");
1380
1381 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1382 "Reference division: output is not a supported type.");
1383
1384 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1385 "Reference division: input 0 and Input 1 types are mismatched");
1386
1387 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1388 "Reference division: input and output types are mismatched");
1389
1390 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1391 "Reference division: shapes are not suitable for implicit broadcast.");
1392
1393 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001394}
1395
josh minor4a3c6102020-01-06 16:40:46 -06001396bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
1397 const TensorInfo& output,
1398 const ElementwiseUnaryDescriptor& descriptor,
1399 Optional<std::string&> reasonIfUnsupported) const
1400{
Jan Eilers8eb25602020-03-09 12:13:48 +00001401 IgnoreUnused(descriptor);
josh minor4a3c6102020-01-06 16:40:46 -06001402
Sadik Armagan303980c2020-04-17 12:45:14 +01001403 std::array<DataType, 7> supportedTypes =
josh minor4a3c6102020-01-06 16:40:46 -06001404 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001405 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -06001406 DataType::Float32,
1407 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001408 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06001409 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00001410 DataType::QSymmS16,
1411 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06001412 };
1413
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001414 std::array<DataType, 1> logicalSupportedTypes =
1415 {
1416 DataType::Boolean
1417 };
1418
josh minor4a3c6102020-01-06 16:40:46 -06001419 bool supported = true;
1420
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001421 if (descriptor.m_Operation == UnaryOperation::LogicalNot)
1422 {
1423 supported &= CheckSupportRule(TypeAnyOf(input, logicalSupportedTypes), reasonIfUnsupported,
1424 "Reference elementwise unary: input type not supported");
josh minor4a3c6102020-01-06 16:40:46 -06001425
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001426 supported &= CheckSupportRule(TypeAnyOf(output, logicalSupportedTypes), reasonIfUnsupported,
1427 "Reference elementwise unary: output type not supported");
1428 }
1429 else
1430 {
1431 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1432 "Reference elementwise unary: input type not supported");
1433
1434 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1435 "Reference elementwise unary: output type not supported");
1436 }
josh minor4a3c6102020-01-06 16:40:46 -06001437
1438 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1439 "Reference elementwise unary: input and output types not matching");
1440
1441 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1442 "Reference elementwise unary: input and output shapes"
1443 "have different number of total elements");
1444
1445 return supported;
1446}
1447
arovir011c7c81b2018-10-08 11:34:28 +01001448bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
1449 const FakeQuantizationDescriptor& descriptor,
1450 Optional<std::string&> reasonIfUnsupported) const
1451{
Jan Eilers8eb25602020-03-09 12:13:48 +00001452 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001453 bool supported = true;
1454
1455 std::array<DataType,1> supportedTypes =
1456 {
1457 DataType::Float32
1458 };
1459
1460 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1461 "Reference fake quantization: input type not supported.");
1462
1463 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001464}
1465
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001466bool RefLayerSupport::IsFillSupported(const TensorInfo& input,
1467 const TensorInfo& output,
1468 const FillDescriptor& descriptor,
1469 Optional<std::string&> reasonIfUnsupported) const
1470{
1471 IgnoreUnused(descriptor);
1472 IgnoreUnused(output);
1473
1474 bool supported = true;
1475
Sadik Armagana792a052020-06-23 16:22:23 +01001476 std::array<DataType,3> supportedTypes =
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001477 {
1478 DataType::Float32,
Sadik Armagana792a052020-06-23 16:22:23 +01001479 DataType::Float16,
1480 DataType::Signed32
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001481 };
1482
Teresa Charlin4b10fef2020-07-29 09:36:41 +01001483 supported &= CheckSupportRule(TypeIs(input, DataType::Signed32), reasonIfUnsupported,
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001484 "Reference Fill: input type not supported.");
1485
Teresa Charlin44088502020-07-27 11:27:19 +01001486 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1487 "Reference Fill: output type not supported.");
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001488 return supported;
1489}
1490
arovir011c7c81b2018-10-08 11:34:28 +01001491bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
1492 const TensorInfo& output,
1493 Optional<std::string&> reasonIfUnsupported) const
1494{
Jan Eilers8eb25602020-03-09 12:13:48 +00001495 IgnoreUnused(output);
James Conroy83735b12019-05-30 16:36:59 +01001496 bool supported = true;
1497
Francis Murtaghe8ac1332020-07-30 18:03:40 +01001498 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +01001499 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001500 DataType::BFloat16,
James Conroyb40d7102019-06-04 12:32:09 +01001501 DataType::Float32,
Francis Murtaghe8ac1332020-07-30 18:03:40 +01001502 DataType::Float16
James Conroy83735b12019-05-30 16:36:59 +01001503 };
1504
1505 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1506 "Reference Floor: input type not supported.");
1507
1508 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1509 "Reference Floor: output type not supported.");
1510
1511 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001512}
1513
1514bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
1515 const TensorInfo& output,
1516 const TensorInfo& weights,
1517 const TensorInfo& biases,
1518 const FullyConnectedDescriptor& descriptor,
1519 Optional<std::string&> reasonIfUnsupported) const
1520{
Francis Murtagh46c09d02019-05-28 08:15:28 +01001521 bool supported = true;
1522
1523 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001524 std::array<DataType,6> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +01001525 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001526 DataType::BFloat16,
1527 DataType::Float32,
1528 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001529 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001530 DataType::QAsymmU8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001531 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001532 };
1533
1534 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1535 "Reference Fully Connected: input type not supported.");
1536
1537 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1538 "Reference Fully Connected: output type not supported.");
1539
Francis Murtagh46c09d02019-05-28 08:15:28 +01001540 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1541 "Reference Fully Connected: weights type not supported.");
1542
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001543 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1544 if (input.GetDataType() == DataType::BFloat16)
1545 {
1546 if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
1547 {
1548 reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
1549 supported = false;
1550 }
1551 }
1552 else
1553 {
1554 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1555 "Reference Fully Connected: input and output types mismatched.");
1556 }
1557
Jan Eilers1f45dc32020-06-15 11:43:03 +01001558 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1559 "Reference Fully Connected: weights is not a supported type.");
Francis Murtaghddb1d062020-03-10 13:51:45 +00001560
Jan Eilers1f45dc32020-06-15 11:43:03 +01001561 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1562 "Reference Fully Connected: input and weights types mismatched.");
Francis Murtagh46c09d02019-05-28 08:15:28 +01001563
1564 if (descriptor.m_BiasEnabled)
1565 {
1566 // Defined supported types for bias
Sadik Armagandb73c982020-04-01 17:35:30 +01001567 std::array<DataType, 5>
Francis Murtagh46c09d02019-05-28 08:15:28 +01001568 supportedBiasTypes =
1569 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001570 DataType::BFloat16,
Francis Murtagh46c09d02019-05-28 08:15:28 +01001571 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001572 DataType::Float16,
Sadik Armagandb73c982020-04-01 17:35:30 +01001573 DataType::Signed32,
1574 DataType::QAsymmS8
Francis Murtagh46c09d02019-05-28 08:15:28 +01001575 };
1576
1577 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
1578 "Reference Fully Connected: bias type not supported.");
1579
1580 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
1581 "Reference Fully Connected: bias and weight types mismatch.");
1582
1583 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
1584 "Reference Fully Connected: bias type inferred from weights is incompatible.");
1585
Narumol Prangnawarat366d7232020-04-29 12:58:17 +01001586 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(biases, 1U), reasonIfUnsupported,
1587 "Reference Fully Connected: bias must have 1 dimension.");
1588
Francis Murtagh46c09d02019-05-28 08:15:28 +01001589 }
1590
1591 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001592}
1593
narpra014951d842019-01-18 16:53:53 +00001594bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
1595 const armnn::TensorInfo& input1,
1596 const armnn::TensorInfo& output,
Teresa Charlin52664732020-06-29 16:27:03 +01001597 const GatherDescriptor& descriptor,
narpra014951d842019-01-18 16:53:53 +00001598 armnn::Optional<std::string&> reasonIfUnsupported) const
1599{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001600 bool supported = true;
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001601 std::array<DataType,7> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001602 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001603 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001604 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001605 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001606 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001607 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001608 DataType::QSymmS16,
1609 DataType::Signed32
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001610 };
1611
Teresa Charlin52664732020-06-29 16:27:03 +01001612 if (descriptor.m_Axis != 0)
1613 {
1614 reasonIfUnsupported.value() += std::string("Reference Gather: axis not supported\n");
1615 supported &= false;
1616 }
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001617 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1618 "Reference Gather: input type not supported");
1619
1620 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1621 "Reference Gather: output type not supported");
1622
1623 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1624 "Reference Gather: indices (input1) type not supported");
1625
1626 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1627 "Reference Gather: input and output types not matching");
1628
1629 return supported;
narpra014951d842019-01-18 16:53:53 +00001630}
1631
Derek Lamberti901ea112019-12-10 22:07:09 +00001632bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
1633 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +01001634{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001635 return true;
arovir011c7c81b2018-10-08 11:34:28 +01001636}
1637
Kevin May09ca49c2019-10-09 12:37:34 +01001638bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
1639 const TensorInfo& output,
1640 const InstanceNormalizationDescriptor& descriptor,
1641 Optional<std::string&> reasonIfUnsupported) const
1642{
Jan Eilers8eb25602020-03-09 12:13:48 +00001643 IgnoreUnused(descriptor);
Kevin May09ca49c2019-10-09 12:37:34 +01001644 // Define supported types
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001645 std::array<DataType, 3> supportedTypes =
Kevin May09ca49c2019-10-09 12:37:34 +01001646 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001647 DataType::BFloat16,
Kevin May09ca49c2019-10-09 12:37:34 +01001648 DataType::Float32,
1649 DataType::Float16
1650 };
1651
1652 bool supported = true;
1653
1654 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1655 "Reference Instance Normalization: input type not supported.");
1656
1657 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1658 "Reference Instance Normalization: output type not supported.");
1659
1660 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1661 "Reference Instance Normalization: input and output types mismatched.");
1662
1663 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1664 "Reference Instance Normalization: input and output shapes have different "
1665 "num total elements.");
1666
1667 return supported;
1668}
1669
arovir011c7c81b2018-10-08 11:34:28 +01001670bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1671 const TensorInfo& output,
1672 const L2NormalizationDescriptor& descriptor,
1673 Optional<std::string&> reasonIfUnsupported) const
1674{
Jan Eilers8eb25602020-03-09 12:13:48 +00001675 IgnoreUnused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001676 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01001677 std::array<DataType, 6> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001678 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001679 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001680 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001681 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001682 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001683 DataType::QAsymmU8,
1684 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001685 };
1686
1687 bool supported = true;
1688
1689 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1690 "Reference L2normalization: input type not supported.");
1691
1692 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1693 "Reference L2normalization: output type not supported.");
1694
1695 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1696 "Reference L2normalization: input and output types mismatched.");
1697
1698 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1699 "Reference L2normalization: input and output shapes have different "
1700 "num total elements.");
1701
1702 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001703}
1704
James Conroyaba90cd2020-11-06 16:28:18 +00001705bool RefLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0,
1706 const TensorInfo& input1,
1707 const TensorInfo& output,
1708 const LogicalBinaryDescriptor& descriptor,
1709 Optional<std::string&> reasonIfUnsupported) const
1710{
1711 IgnoreUnused(descriptor);
1712
1713 std::array<DataType, 1> supportedTypes =
1714 {
1715 DataType::Boolean
1716 };
1717
1718 bool supported = true;
1719 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1720 "Reference LogicalBinary: input 0 type not supported");
1721 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1722 "Reference LogicalBinary: input 1 type not supported");
1723
1724 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1725 "Reference LogicalBinary: input and output types do not match");
1726
1727 return supported;
1728}
1729
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001730bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1731 const TensorInfo& output,
1732 const LogSoftmaxDescriptor& descriptor,
1733 Optional<std::string&> reasonIfUnsupported) const
1734{
Jan Eilers8eb25602020-03-09 12:13:48 +00001735 IgnoreUnused(descriptor);
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001736
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001737 std::array<DataType, 3> supportedTypes =
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001738 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001739 DataType::BFloat16,
1740 DataType::Float32,
1741 DataType::Float16
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001742 };
1743
1744 bool supported = true;
1745 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1746 "Reference LogSoftmax: input type not supported");
1747
1748 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1749 "Reference LogSoftmax: output type not supported");
1750
1751 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1752 "Reference LogSoftmax: input and output types do not match");
1753
1754 return supported;
1755}
1756
arovir011c7c81b2018-10-08 11:34:28 +01001757bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1758 const TensorInfo& outputStateIn,
1759 const TensorInfo& cellStateIn,
1760 const TensorInfo& scratchBuffer,
1761 const TensorInfo& outputStateOut,
1762 const TensorInfo& cellStateOut,
1763 const TensorInfo& output,
1764 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001765 const LstmInputParamsInfo& paramsInfo,
1766 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +01001767{
Jan Eilers8eb25602020-03-09 12:13:48 +00001768 IgnoreUnused(descriptor);
1769 IgnoreUnused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001770
1771 bool supported = true;
1772
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001773 std::array<DataType,3> supportedTypes = {
1774 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001775 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001776 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001777 };
1778
Jan Eilersd01a83c2019-07-03 18:20:40 +01001779 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001780 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1781 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001782 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1783 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001784 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1785 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001786 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1787 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001788 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1789 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001790 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1791 "Reference Lstm: input and cellStateOut types are mismatched");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01001792
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001793 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1794 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001795 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001796 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001797 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001798 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001799 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001800 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001801 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001802 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001803 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001804 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001805 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001806 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001807 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001808 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001809 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001810 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001811 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001812 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001813 "Reference Lstm: input and OutputGateBias types are mismatched");
1814 if (!descriptor.m_CifgEnabled)
1815 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001816 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001817 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001818 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001819 reasonIfUnsupported,
1820 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001821 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001822 "Reference Lstm: input and InputGateBias types are mismatched");
1823 if (descriptor.m_PeepholeEnabled)
1824 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001825 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001826 reasonIfUnsupported,
1827 "Reference Lstm: input and CellToInputWeights types are mismatched");
1828 }
1829 }
1830 if (descriptor.m_PeepholeEnabled)
1831 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001832 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001833 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001834 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001835 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1836 }
1837 if (descriptor.m_ProjectionEnabled)
1838 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001839 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001840 "Reference Lstm: input and mProjectionWeights types are mismatched");
1841 if (paramsInfo.m_ProjectionBias != nullptr)
1842 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001843 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001844 "Reference Lstm: input and ProjectionBias types are mismatched");
1845 }
1846 }
1847 if (descriptor.m_LayerNormEnabled)
1848 {
1849 if (!descriptor.m_CifgEnabled)
1850 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001851 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001852 reasonIfUnsupported,
1853 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1854 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001855 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001856 reasonIfUnsupported,
1857 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001858 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001859 reasonIfUnsupported,
1860 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001861 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001862 reasonIfUnsupported,
1863 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1864 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001865
1866 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001867}
1868
saoste012df12b32018-11-28 16:57:20 +00001869bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1870 const TensorInfo& input1,
1871 const TensorInfo& output,
1872 Optional<std::string&> reasonIfUnsupported) const
1873{
Sadik Armagan2999a022019-04-09 14:20:12 +01001874 bool supported = true;
1875
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001876 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001877 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001878 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001879 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001880 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001881 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001882 DataType::QSymmS16,
1883 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001884 };
1885
1886 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1887 "Reference maximum: input 0 is not a supported type.");
1888
1889 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1890 "Reference maximum: input 1 is not a supported type.");
1891
1892 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1893 "Reference maximum: output is not a supported type.");
1894
1895 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1896 "Reference maximum: input 0 and Input 1 types are mismatched");
1897
1898 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1899 "Reference maximum: input and output types are mismatched");
1900
1901 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1902 "Reference maximum: shapes are not suitable for implicit broadcast.");
1903
1904 return supported;
saoste012df12b32018-11-28 16:57:20 +00001905}
1906
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001907bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1908 const TensorInfo& output,
1909 const MeanDescriptor& descriptor,
1910 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001911{
James Conroy4d1ff582019-06-10 17:06:39 +01001912 bool supported = true;
1913 std::string meanLayerStr = "Mean";
1914 std::string outputTensorStr = "output";
1915
Sadik Armagan303980c2020-04-17 12:45:14 +01001916 std::array<DataType,6> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001917 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001918 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01001919 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001920 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001921 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001922 DataType::QAsymmU8,
1923 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01001924 };
1925
1926 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1927 "Reference Mean: input type not supported.");
1928
1929 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1930 "Reference Mean: input and output types are mismatched");
1931
1932 if (descriptor.m_KeepDims)
1933 {
1934 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1935 reasonIfUnsupported,
1936 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1937 output.GetNumDimensions(),
1938 meanLayerStr, outputTensorStr).data());
1939 }
1940 else if (descriptor.m_Axis.empty())
1941 {
1942 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1943 reasonIfUnsupported,
1944 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1945 meanLayerStr, outputTensorStr).data());
1946 }
1947 else
1948 {
Matthew Sloyan171214c2020-09-09 09:07:37 +01001949 auto outputDim = input.GetNumDimensions() - armnn::numeric_cast<unsigned int>(descriptor.m_Axis.size());
James Conroy4d1ff582019-06-10 17:06:39 +01001950
1951 if (outputDim > 0)
1952 {
1953 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1954 reasonIfUnsupported,
1955 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1956 meanLayerStr, outputTensorStr).data());
1957 }
1958 else
1959 {
1960 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1961 reasonIfUnsupported,
1962 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1963 meanLayerStr, outputTensorStr).data());
1964 }
1965 }
1966
1967 return supported;
narpra0132b90462018-09-13 11:07:48 +01001968}
1969
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001970bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1971 const TensorInfo &output,
1972 Optional<std::string &> reasonIfUnsupported) const
1973{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001974 bool supported = true;
1975
Sadik Armagan303980c2020-04-17 12:45:14 +01001976 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001977 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001978 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001979 DataType::Float32,
1980 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001981 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001982 DataType::QAsymmU8,
1983 DataType::QSymmS16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001984 DataType::Boolean
1985 };
1986
1987 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1988 "Reference MemCopy: input type not supported");
1989
1990 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1991 "Reference MemCopy: output type not supported");
1992
1993 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1994 "Reference MemCopy: input and output types are mismatched");
1995
1996 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001997}
1998
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001999bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
2000 const TensorInfo& input1,
2001 const TensorInfo& output,
2002 Optional<std::string&> reasonIfUnsupported) const
2003{
Sadik Armagan2999a022019-04-09 14:20:12 +01002004 bool supported = true;
2005
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002006 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002007 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01002008 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002009 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002010 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002011 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002012 DataType::QSymmS16,
2013 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002014 };
2015
2016 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2017 "Reference minimum: input 0 is not a supported type.");
2018
2019 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2020 "Reference minimum: input 1 is not a supported type.");
2021
2022 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2023 "Reference minimum: output is not a supported type.");
2024
2025 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2026 "Reference minimum: input 0 and Input 1 types are mismatched");
2027
2028 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2029 "Reference minimum: input and output types are mismatched");
2030
2031 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2032 "Reference minimum: shapes are not suitable for implicit broadcast.");
2033
2034 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00002035}
2036
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002037bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
2038 const TensorInfo& input1,
2039 const TensorInfo& output,
2040 Optional<std::string&> reasonIfUnsupported) const
2041{
Sadik Armagan2999a022019-04-09 14:20:12 +01002042 bool supported = true;
2043
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002044 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002045 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01002046 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002047 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002048 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002049 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002050 DataType::QSymmS16,
2051 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002052 };
2053
2054 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2055 "Reference multiplication: input 0 is not a supported type.");
2056
2057 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2058 "Reference multiplication: input 1 is not a supported type.");
2059
2060 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2061 "Reference multiplication: output is not a supported type.");
2062
2063 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2064 "Reference multiplication: input 0 and Input 1 types are mismatched");
2065
2066 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2067 "Reference multiplication: input and output types are mismatched");
2068
2069 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2070 "Reference multiplication: shapes are not suitable for implicit broadcast.");
2071
2072 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002073}
2074
2075bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
2076 const TensorInfo& output,
2077 const NormalizationDescriptor& descriptor,
2078 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01002079{
Jan Eilers8eb25602020-03-09 12:13:48 +00002080 IgnoreUnused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002081
2082 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01002083 std::array<DataType, 6> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002084 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002085 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002086 DataType::Float16,
2087 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002088 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002089 DataType::QAsymmU8,
2090 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002091 };
2092
2093 bool supported = true;
2094
2095 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2096 "Reference normalization: input type not supported.");
2097
2098 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2099 "Reference normalization: output type not supported.");
2100
2101 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2102 "Reference normalization: input and output shapes have different "
2103 "num total elements.");
2104
2105 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002106}
2107
Derek Lamberti901ea112019-12-10 22:07:09 +00002108bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
2109 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002110{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01002111 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002112}
2113
2114bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
2115 const TensorInfo& output,
2116 const PadDescriptor& descriptor,
2117 Optional<std::string&> reasonIfUnsupported) const
2118{
Jan Eilers8eb25602020-03-09 12:13:48 +00002119 IgnoreUnused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002120 bool supported = true;
2121
2122 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002123 std::array<DataType,6> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002124 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002125 DataType::BFloat16,
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002126 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002127 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002128 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002129 DataType::QAsymmU8,
2130 DataType::QSymmS16
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002131 };
2132
2133 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2134 "Reference pad: input is not a supported type.");
2135
2136 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2137 "Reference pad: output is not a supported type.");
2138
2139 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2140 "Reference pad: input and output types are mismatched.");
2141
2142 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01002143}
2144
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002145bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
2146 const TensorInfo& output,
2147 const PermuteDescriptor& descriptor,
2148 Optional<std::string&> reasonIfUnsupported) const
2149{
Jan Eilers8eb25602020-03-09 12:13:48 +00002150 IgnoreUnused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002151 bool supported = true;
2152
2153 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002154 std::array<DataType, 6> supportedTypes =
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002155 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002156 DataType::BFloat16,
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002157 DataType::Float32,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002158 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002159 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002160 DataType::QAsymmU8,
2161 DataType::QSymmS16
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002162 };
2163
2164 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2165 "Reference permute: input is not a supported type.");
2166
2167 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2168 "Reference permute: output is not a supported type.");
2169
2170 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2171 "Reference permute: input and output types are mismatched.");
2172
2173 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002174}
2175
2176bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
2177 const TensorInfo& output,
2178 const Pooling2dDescriptor& descriptor,
2179 Optional<std::string&> reasonIfUnsupported) const
2180{
Jan Eilers8eb25602020-03-09 12:13:48 +00002181 IgnoreUnused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01002182 bool supported = true;
2183
2184 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002185 std::array<DataType,6> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01002186 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002187 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01002188 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002189 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002190 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002191 DataType::QAsymmU8,
2192 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01002193 };
2194
2195 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2196 "Reference poolind2d: input is not a supported type.");
2197
2198 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2199 "Reference poolind2d: output is not a supported type.");
2200
2201 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2202 "Reference poolind2d: input and output types are mismatched.");
2203
2204 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002205}
2206
Tamás Nyíri7b885b32021-10-26 14:47:57 +01002207bool RefLayerSupport::IsPooling3dSupported(const TensorInfo& input,
2208 const TensorInfo& output,
2209 const Pooling3dDescriptor& descriptor,
2210 Optional<std::string&> reasonIfUnsupported) const
2211{
2212 IgnoreUnused(descriptor);
2213 bool supported = true;
2214
2215 // Define supported output and inputs types.
2216 std::array<DataType,6> supportedTypes =
2217 {
2218 DataType::BFloat16,
2219 DataType::Float32,
2220 DataType::Float16,
2221 DataType::QAsymmS8,
2222 DataType::QAsymmU8,
2223 DataType::QSymmS16
2224 };
2225
2226 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2227 "Reference poolind3d: input is not a supported type.");
2228
2229 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2230 "Reference poolind3d: output is not a supported type.");
2231
2232 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2233 "Reference poolind3d: input and output types are mismatched.");
2234
2235 return supported;
2236}
2237
2238
James Conroy4f1f8992020-04-29 20:01:10 +01002239bool RefLayerSupport::IsQLstmSupported(const TensorInfo& input,
2240 const TensorInfo& previousOutputIn,
2241 const TensorInfo& previousCellStateIn,
2242 const TensorInfo& outputStateOut,
2243 const TensorInfo& cellStateOut,
2244 const TensorInfo& output,
2245 const QLstmDescriptor& descriptor,
2246 const LstmInputParamsInfo& paramsInfo,
2247 Optional<std::string&> reasonIfUnsupported) const
2248{
2249 IgnoreUnused(input);
2250 IgnoreUnused(previousOutputIn);
2251 IgnoreUnused(previousCellStateIn);
2252 IgnoreUnused(outputStateOut);
2253 IgnoreUnused(cellStateOut);
2254 IgnoreUnused(output);
2255 IgnoreUnused(descriptor);
2256 IgnoreUnused(paramsInfo);
2257
2258 IgnoreUnused(reasonIfUnsupported);
2259
2260 return true;
2261}
2262
Derek Lamberti5f400d62019-03-25 15:41:58 +00002263bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
2264 const TensorInfo& output,
2265 Optional<std::string&> reasonIfUnsupported) const
2266{
2267 bool supported = true;
2268
Finn Williamsfd271062019-12-04 14:27:27 +00002269 // Define supported input types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002270 std::array<DataType,7> supportedInputTypes = {
2271 DataType::BFloat16,
Keith Davis5e51cd82020-01-29 16:52:59 +00002272 DataType::Float32,
Keith Davis3d8bc972020-02-04 09:31:47 +00002273 DataType::Float16,
Ryan OShea9add1202020-02-07 10:06:33 +00002274 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002275 DataType::QAsymmU8,
2276 DataType::QSymmS8,
2277 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002278 };
2279
2280 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
2281 "Reference quantize: input type not supported.");
2282
2283 // Define supported output types.
Ryan OShea9add1202020-02-07 10:06:33 +00002284 std::array<DataType,4> supportedOutputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00002285 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002286 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00002287 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002288 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002289 };
2290 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2291 "Reference quantize: output type not supported.");
2292
2293 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2294 "Reference quantize: input and output shapes have different num total elements.");
2295
2296 return supported;
2297}
2298
Finn Williams2605b232020-06-10 15:53:46 +01002299bool RefLayerSupport::IsRankSupported(const TensorInfo& input,
2300 const TensorInfo& output,
2301 Optional<std::string&> reasonIfUnsupported) const
2302{
2303 IgnoreUnused(input);
2304 // Define supported output types.
2305 std::array<DataType,1> supportedOutputTypes =
2306 {
2307 DataType::Signed32,
2308 };
2309
2310 return CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2311 "Reference rank: input type not supported.");
2312}
2313
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00002314bool RefLayerSupport::IsReduceSupported(const TensorInfo& input,
2315 const TensorInfo& output,
2316 const ReduceDescriptor& descriptor,
2317 Optional<std::string&> reasonIfUnsupported) const
2318{
2319 IgnoreUnused(descriptor);
2320 bool supported = true;
2321 std::array<DataType,7> supportedTypes =
2322 {
2323 DataType::BFloat16,
2324 DataType::Float32,
2325 DataType::Float16,
2326 DataType::QAsymmS8,
2327 DataType::QAsymmU8,
2328 DataType::QSymmS16,
2329 DataType::Signed32
2330 };
2331
2332 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2333 "Reference Reduce: input type not supported");
2334
2335 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2336 "Reference Reduce: output type not supported");
2337
2338 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2339 "Reference Reduce: input and output types not matching");
2340
2341 return supported;
2342}
2343
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002344bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Kevin Maya023c402019-12-12 17:28:05 +00002345 const TensorInfo& output,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00002346 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002347 Optional<std::string&> reasonIfUnsupported) const
2348{
Jan Eilers8eb25602020-03-09 12:13:48 +00002349 IgnoreUnused(output);
2350 IgnoreUnused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01002351 // Define supported output types.
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002352 std::array<DataType,8> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01002353 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002354 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01002355 DataType::Float32,
2356 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01002357 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002358 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002359 DataType::QAsymmU8,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002360 DataType::QSymmS16,
2361 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01002362 };
Keith Davis0c2eeac2020-02-11 16:51:50 +00002363
Nina Drozd2f2778f2019-05-27 10:37:05 +01002364 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
2365 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002366}
2367
Teresa Charlin970f43b2019-07-01 13:51:07 +01002368bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
2369 const TensorInfo& output,
2370 const ResizeDescriptor& descriptor,
2371 Optional<std::string&> reasonIfUnsupported) const
2372{
Jan Eilers8eb25602020-03-09 12:13:48 +00002373 IgnoreUnused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01002374 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002375 std::array<DataType,6> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01002376 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002377 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01002378 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002379 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002380 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002381 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002382 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01002383 };
2384
2385 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2386 "Reference Resize: input type not supported");
2387
2388 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2389 "Reference Resize: output type not supported");
2390
2391 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2392 "Reference Resize: input and output types not matching");
2393
2394 return supported;
2395}
2396
Keith Davis3ae3f972021-05-21 16:33:48 +01002397bool RefLayerSupport::IsShapeSupported(const TensorInfo& input,
2398 const TensorInfo& output,
2399 Optional<std::string&> reasonIfUnsupported) const
2400{
2401 IgnoreUnused(input);
2402 bool supported = true;
2403
2404 std::array<DataType, 1> supportedTypes =
2405 {
2406 DataType::Signed32
2407 };
2408
2409 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2410 "Reference Shape: output type not supported");
2411
2412 return supported;
2413}
2414
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002415bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
2416 const TensorInfo& output,
2417 const SliceDescriptor& descriptor,
2418 Optional<std::string&> reasonIfUnsupported) const
2419{
Jan Eilers8eb25602020-03-09 12:13:48 +00002420 IgnoreUnused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002421 bool supported = true;
2422
Sadik Armagan303980c2020-04-17 12:45:14 +01002423 std::array<DataType, 5> supportedTypes =
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002424 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002425 DataType::BFloat16,
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002426 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002427 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002428 DataType::QAsymmU8,
2429 DataType::QSymmS16
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002430 };
2431
2432 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2433 "Reference Slice: input type not supported");
2434
2435 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2436 "Reference Slice: output type not supported");
2437
2438 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2439 "Reference Slice: input and output types are mismatched");
2440
2441 return supported;
2442}
2443
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002444bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
2445 const TensorInfo& output,
2446 const SoftmaxDescriptor& descriptor,
2447 Optional<std::string&> reasonIfUnsupported) const
2448{
Jan Eilers8eb25602020-03-09 12:13:48 +00002449 IgnoreUnused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01002450 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002451 std::array<DataType,7> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01002452 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002453 DataType::BFloat16,
2454 DataType::Float32,
2455 DataType::Float16,
2456 DataType::QSymmS8,
2457 DataType::QAsymmS8,
2458 DataType::QAsymmU8,
2459 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +01002460 };
2461
2462 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002463 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002464
2465 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002466 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002467
2468 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002469 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002470
2471 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002472}
2473
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002474bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
2475 const TensorInfo& output,
2476 const SpaceToBatchNdDescriptor& descriptor,
2477 Optional<std::string&> reasonIfUnsupported) const
2478{
Jan Eilers8eb25602020-03-09 12:13:48 +00002479 IgnoreUnused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01002480 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002481 std::array<DataType,6> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01002482 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002483 DataType::BFloat16,
2484 DataType::Float32,
2485 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002486 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002487 DataType::QAsymmU8,
2488 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01002489 };
2490
2491 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2492 "Reference SpaceToBatchNd: input type not supported");
2493
2494 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2495 "Reference SpaceToBatchNd: output type not supported");
2496
2497 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2498 "Reference SpaceToBatchNd: input and output types are mismatched");
2499
2500 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002501}
2502
Keith Davisa57eccb2019-06-14 17:33:22 +01002503bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01002504 const TensorInfo& output,
2505 const SpaceToDepthDescriptor& descriptor,
2506 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01002507{
2508
Jan Eilers8eb25602020-03-09 12:13:48 +00002509 IgnoreUnused(descriptor);
Keith Davisa57eccb2019-06-14 17:33:22 +01002510 bool supported = true;
2511
Sadik Armagan303980c2020-04-17 12:45:14 +01002512 std::array<DataType,6> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01002513 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002514 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01002515 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002516 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002517 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002518 DataType::QAsymmU8,
2519 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01002520 };
2521
2522 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2523 "Reference SpaceToDepth: input type not supported");
2524
2525 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2526 "Reference SpaceToDepth: output type not supported");
2527
2528 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2529 "Reference SpaceToDepth: input and output types are mismatched");
2530
2531 return supported;
2532}
2533
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002534bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002535 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
2536 const ViewsDescriptor& descriptor,
2537 Optional<std::string&> reasonIfUnsupported) const
2538{
Jan Eilers8eb25602020-03-09 12:13:48 +00002539 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002540 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002541 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002542 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002543 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002544 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002545 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002546 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002547 DataType::QAsymmU8,
2548 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002549 };
2550
2551 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2552 "Reference splitter: output type not supported");
Derek Lambertieac4adb2020-08-25 13:05:59 +01002553 for (const TensorInfo& output : outputs)
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002554 {
2555 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2556 "Reference splitter: input type not supported");
2557
2558 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2559 "Reference splitter: input and output types mismatched.");
2560 }
2561
2562 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002563}
2564
Matthew Jackson81e601c2019-07-11 12:07:09 +01002565bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
2566 const TensorInfo& output,
2567 const StackDescriptor& descriptor,
2568 Optional<std::string&> reasonIfUnsupported) const
2569{
Jan Eilers8eb25602020-03-09 12:13:48 +00002570 IgnoreUnused(descriptor);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002571
2572 bool supported = true;
Sadik Armagan529195f2022-01-14 12:56:35 +00002573 std::array<DataType,7> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01002574 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002575 DataType::BFloat16,
Matthew Jackson81e601c2019-07-11 12:07:09 +01002576 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01002577 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002578 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002579 DataType::QAsymmU8,
Sadik Armagan529195f2022-01-14 12:56:35 +00002580 DataType::QSymmS16,
2581 DataType::Signed32
Matthew Jackson81e601c2019-07-11 12:07:09 +01002582 };
2583
2584 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2585 "Reference stack: output type not supported");
2586 for (const TensorInfo* input : inputs)
2587 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01002588 ARMNN_ASSERT(input != nullptr);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002589 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
2590 "Reference stack: input type not supported");
2591
2592 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
2593 "Reference stack: input and output types mismatched.");
2594 }
2595
2596 return supported;
2597}
2598
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002599bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
2600 const TensorInfo& output,
2601 const StridedSliceDescriptor& descriptor,
2602 Optional<std::string&> reasonIfUnsupported) const
2603{
Jan Eilers8eb25602020-03-09 12:13:48 +00002604 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002605 bool supported = true;
2606
Sadik Armagan303980c2020-04-17 12:45:14 +01002607 std::array<DataType,5> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002608 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002609 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002610 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002611 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002612 DataType::QAsymmU8,
2613 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002614 };
2615
2616 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2617 "Reference StridedSlice: input type not supported");
2618
2619 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2620 "Reference StridedSlice: output type not supported");
2621
2622 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2623 "Reference StridedSlice: input and output types are mismatched");
2624
2625 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002626}
2627
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002628bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
2629 const TensorInfo& input1,
2630 const TensorInfo& output,
2631 Optional<std::string&> reasonIfUnsupported) const
2632{
Sadik Armagan2999a022019-04-09 14:20:12 +01002633 bool supported = true;
2634
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002635 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002636 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01002637 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002638 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002639 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002640 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002641 DataType::QSymmS16,
2642 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002643 };
2644
2645 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2646 "Reference subtraction: input 0 is not a supported type.");
2647
2648 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2649 "Reference subtraction: input 1 is not a supported type.");
2650
2651 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2652 "Reference subtraction: output is not a supported type.");
2653
2654 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2655 "Reference subtraction: input 0 and Input 1 types are mismatched");
2656
2657 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2658 "Reference subtraction: input and output types are mismatched");
2659
2660 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2661 "Reference subtraction: shapes are not suitable for implicit broadcast.");
2662
2663 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002664}
2665
Matteo Martincighab9e5252019-06-13 17:27:46 +01002666bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
2667 const TensorInfo& alpha,
2668 const TensorInfo& output,
2669 Optional<std::string&> reasonIfUnsupported) const
2670{
2671 bool supported = true;
2672
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002673 std::array<DataType, 6> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01002674 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002675 DataType::BFloat16,
Matteo Martincighab9e5252019-06-13 17:27:46 +01002676 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002677 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002678 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002679 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002680 DataType::QSymmS16
Matteo Martincighab9e5252019-06-13 17:27:46 +01002681 };
2682
2683 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2684 "PReLU: input is not a supported type.");
2685
2686 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
2687 "PReLU: alpha is not a supported type.");
2688
2689 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2690 "PReLU: output is not a supported type.");
2691
2692 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
2693 "PReLU: input, alpha and output types are mismatched");
2694
2695 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
2696 "PReLU: shapes are not suitable for implicit broadcast");
2697
2698 return supported;
2699}
2700
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002701bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
2702 const TensorInfo& output,
2703 const TransposeConvolution2dDescriptor& descriptor,
2704 const TensorInfo& weights,
2705 const Optional<TensorInfo>& biases,
2706 Optional<std::string&> reasonIfUnsupported) const
2707{
Jan Eilers8eb25602020-03-09 12:13:48 +00002708 IgnoreUnused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002709 bool supported = true;
2710
Sadik Armagan303980c2020-04-17 12:45:14 +01002711 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002712 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002713 DataType::BFloat16,
2714 DataType::Float32,
2715 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002716 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002717 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002718 DataType::QSymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002719 DataType::QSymmS16
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002720 };
2721
2722 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2723 "Reference TransposeConvolution2d: input is not a supported type.");
2724
2725 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2726 "Reference TransposeConvolution2d: output is not a supported type.");
2727
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002728 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2729 "Reference TransposeConvolution2d: input and output types mismatched.");
2730
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002731
2732 const DataType inputType = input.GetDataType();
Sadik Armagan303980c2020-04-17 12:45:14 +01002733 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002734 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01002735 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002736 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002737 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002738 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01002739 DataType::QSymmS8
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002740 };
2741
2742 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
2743 "Reference TransposeConvolution2d: weights type not supported for "
2744 "quantized input.");
2745 }
2746 else
2747 {
2748 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
2749 "Reference TransposeConvolution2d: weights is not a supported type.");
2750
2751 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
2752 "Reference TransposeConvolution2d: input and weights types mismatched.");
2753 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002754
2755 if (biases.has_value())
2756 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002757 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01002758 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002759 DataType::BFloat16,
2760 DataType::Float32,
2761 DataType::Float16,
2762 DataType::Signed32
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002763 };
2764 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
2765 "Reference TransposeConvolution2d: biases is not a supported type.");
2766 }
2767
2768 return supported;
2769}
2770
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002771bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
2772 const TensorInfo& output,
2773 const TransposeDescriptor& descriptor,
2774 Optional<std::string&> reasonIfUnsupported) const
2775{
Jan Eilers8eb25602020-03-09 12:13:48 +00002776 IgnoreUnused(descriptor);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002777 bool supported = true;
2778
2779 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002780 std::array<DataType, 6> supportedTypes =
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002781 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002782 DataType::BFloat16,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002783 DataType::Float32,
2784 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002785 DataType::QAsymmS8,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002786 DataType::QAsymmU8,
2787 DataType::QSymmS16
2788 };
2789
2790 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2791 "Reference transpose: input is not a supported type.");
2792
2793 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2794 "Reference transpose: output is not a supported type.");
2795
2796 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2797 "Reference transpose: input and output types are mismatched.");
2798
2799 return supported;
2800}
2801
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002802bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
2803 const TensorInfo& input,
2804 const TensorInfo& outputStateIn,
2805 const TensorInfo& cellStateIn,
2806 const TensorInfo& output,
2807 const Optional<TensorInfo>& hiddenStateOutput,
2808 const Optional<TensorInfo>& cellStateOutput,
2809 const UnidirectionalSequenceLstmDescriptor& descriptor,
2810 const LstmInputParamsInfo& paramsInfo,
2811 Optional<std::string&> reasonIfUnsupported) const
2812{
2813 IgnoreUnused(descriptor);
2814 IgnoreUnused(paramsInfo);
2815 IgnoreUnused(outputStateIn);
2816 IgnoreUnused(cellStateIn);
2817 bool supported = true;
2818
2819 if (hiddenStateOutput.has_value() || cellStateOutput.has_value())
2820 {
2821 reasonIfUnsupported.value() += "Reference UnidirectionalSequenceLstm: hidden state output "
2822 "and cell state output are not supported at the moment.";
2823 }
2824
2825 std::array<DataType, 1> supportedTypes =
2826 {
2827 DataType::Float32
2828 };
2829
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002830 std::array<DataType, 2> supportedWeightTypes =
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002831 {
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002832 DataType::Float32,
2833 DataType::QAsymmS8
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002834 };
2835
2836 // check inputs and outputs
2837 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2838 "Reference UnidirectionalSequenceLstm: input is not a supported type.");
2839 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
2840 "Reference UnidirectionalSequenceLstm: input and outputStateIn types are mismatched");
2841 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
2842 "Reference UnidirectionalSequenceLstm: input and cellStateIn types are mismatched");
2843
2844 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2845 "Reference UnidirectionalSequenceLstm: input and output types are mismatched");
2846 // check layer parameters
2847 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes),
2848 reasonIfUnsupported,
2849 "Reference UnidirectionalSequenceLstm: InputToForgetWeights "
2850 "is not a supported type.");
2851 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToCellWeights(), supportedWeightTypes),
2852 reasonIfUnsupported,
2853 "Reference UnidirectionalSequenceLstm: InputToCellWeights is not a supported type.");
2854 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToOutputWeights(), supportedWeightTypes),
2855 reasonIfUnsupported,
2856 "Reference UnidirectionalSequenceLstm: InputToOutputWeights "
2857 "is not a supported type.");
2858 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToForgetWeights(), supportedWeightTypes),
2859 reasonIfUnsupported,
2860 "Reference UnidirectionalSequenceLstm: RecurrentToForgetWeights "
2861 "is not a supported type.");
2862 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToCellWeights(), supportedWeightTypes),
2863 reasonIfUnsupported,
2864 "Reference UnidirectionalSequenceLstm: RecurrentToCellWeights "
2865 "is not a supported type.");
2866 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToOutputWeights(), supportedWeightTypes),
2867 reasonIfUnsupported,
2868 "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights "
2869 "is not a supported type.");
2870 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
2871 "Reference UnidirectionalSequenceLstm: input and ForgetGateBias types "
2872 "are mismatched");
2873 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
2874 "Reference UnidirectionalSequenceLstm: input and CellBias types are mismatched");
2875 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
2876 "Reference UnidirectionalSequenceLstm: input and OutputGateBias types "
2877 "are mismatched");
2878 if (!descriptor.m_CifgEnabled)
2879 {
2880 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes),
2881 reasonIfUnsupported,
2882 "Reference UnidirectionalSequenceLstm: InputToInputWeights "
2883 "is not a supported type.");
2884 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToInputWeights(), supportedWeightTypes),
2885 reasonIfUnsupported,
2886 "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights "
2887 "is not a supported type.");
2888 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
2889 "Reference UnidirectionalSequenceLstm: input and InputGateBias types "
2890 "are mismatched");
2891 if (descriptor.m_PeepholeEnabled)
2892 {
2893 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes),
2894 reasonIfUnsupported,
2895 "Reference UnidirectionalSequenceLstm: CellToInputWeights "
2896 "is not a supported type.");
2897 }
2898 }
2899 if (descriptor.m_PeepholeEnabled)
2900 {
2901 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToForgetWeights(), supportedWeightTypes),
2902 reasonIfUnsupported,
2903 "Reference UnidirectionalSequenceLstm: CellToForgetWeights "
2904 "is not a supported type.");
2905 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToOutputWeights(), supportedWeightTypes),
2906 reasonIfUnsupported,
2907 "Reference UnidirectionalSequenceLstm: CellToOutputWeights "
2908 "is not a supported type.");
2909 }
2910 if (descriptor.m_ProjectionEnabled)
2911 {
2912 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetProjectionWeights(), supportedWeightTypes),
2913 reasonIfUnsupported,
2914 "Reference UnidirectionalSequenceLstm: ProjectionWeights "
2915 "is not a supported type.");
2916 if (paramsInfo.m_ProjectionBias != nullptr)
2917 {
2918 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
2919 "Reference UnidirectionalSequenceLstm: input and ProjectionBias types "
2920 "are mismatched");
2921 }
2922 }
2923 if (descriptor.m_LayerNormEnabled)
2924 {
2925 if (!descriptor.m_CifgEnabled)
2926 {
2927 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputLayerNormWeights(), supportedWeightTypes),
2928 reasonIfUnsupported,
2929 "Reference UnidirectionalSequenceLstm: InputLayerNormWeights "
2930 "is not a supported type.");
2931 }
2932 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetLayerNormWeights(), supportedWeightTypes),
2933 reasonIfUnsupported,
2934 "Reference UnidirectionalSequenceLstm: ForgetLayerNormWeights "
2935 "is not a supported type.");
2936 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellLayerNormWeights(), supportedWeightTypes),
2937 reasonIfUnsupported,
2938 "Reference UnidirectionalSequenceLstm: CellLayerNormWeights "
2939 "is not a supported type.");
2940 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputLayerNormWeights(), supportedWeightTypes),
2941 reasonIfUnsupported,
2942 "Reference UnidirectionalSequenceLstm: OutputLayerNormWeights "
2943 "is not a supported type.");
2944 }
2945
2946 return supported;
2947}
2948
arovir011c7c81b2018-10-08 11:34:28 +01002949} // namespace armnn