blob: 669c91d628451a164649eb562ddeca4f26d6c9e0 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
Teresa Charlin52664732020-06-29 16:27:03 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01007
Keith Davis0c2eeac2020-02-11 16:51:50 +00008#include <armnn/TypesUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000010#include <armnn/utility/IgnoreUnused.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010011#include <armnn/utility/NumericCast.hpp>
Cathal Corbett34b429c2021-12-24 12:24:40 +000012#include <armnn/utility/PolymorphicDowncast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
Matteo Martincighe011d202019-11-28 11:35:47 +000014#include <LayerSupportCommon.hpp>
Derek Lambertif674aa02019-08-01 15:56:25 +010015#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000016
Derek Lamberti50db4e82019-03-13 14:16:15 +000017#include <vector>
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <array>
19
telsoa014fcda012018-03-09 14:13:49 +000020namespace armnn
21{
22
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010023namespace
24{
25
26template<typename Float32Func, typename Uint8Func, typename ... Params>
27bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
28 DataType dataType,
29 Float32Func floatFuncPtr,
30 Uint8Func uint8FuncPtr,
31 Params&&... params)
32{
33 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
34 dataType,
35 &FalseFunc<Params...>,
36 floatFuncPtr,
37 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000038 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000039 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010040 std::forward<Params>(params)...);
41}
42
43} // anonymous namespace
44
James Conroy4d1ff582019-06-10 17:06:39 +010045namespace
46{
47
48std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
49 unsigned int actual,
50 std::string& layerStr,
51 std::string& tensorName)
52{
53 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
54 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
55
56 return errorMsg;
57}
58
59} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000060
Cathal Corbett34b429c2021-12-24 12:24:40 +000061bool RefLayerSupport::IsLayerSupported(const LayerType& type,
62 const std::vector<TensorInfo>& infos,
63 const BaseDescriptor& descriptor,
64 const Optional<LstmInputParamsInfo>& lstmParamsInfo,
65 const Optional<QuantizedLstmInputParamsInfo>& quantizedLstmInputParamsInfo,
66 Optional<std::string&> reasonIfUnsupported) const
67{
68 switch (type)
69 {
70 case LayerType::Activation:
71 return IsActivationSupported(infos[0],
72 infos[1],
73 *(PolymorphicDowncast<const ActivationDescriptor*>(&descriptor)),
74 reasonIfUnsupported);
75 case LayerType::Addition:
76 return IsAdditionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
77 case LayerType::ArgMinMax:
78 return IsArgMinMaxSupported(infos[0],
79 infos[1],
80 *(PolymorphicDowncast<const ArgMinMaxDescriptor*>(&descriptor)),
81 reasonIfUnsupported);
Samuel Yap6b478092022-07-06 15:36:03 +010082 case LayerType::BatchMatMul:
83 return IsBatchMatMulSupported(infos[0],
84 infos[1],
85 infos[2],
86 *(PolymorphicDowncast<const BatchMatMulDescriptor*>(&descriptor)),
87 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +000088 case LayerType::BatchNormalization:
89 return IsBatchNormalizationSupported(infos[0],
90 infos[1],
91 infos[2],
92 infos[3],
93 infos[4],
94 infos[5],
95 *(PolymorphicDowncast<const BatchNormalizationDescriptor*>
96 (&descriptor)),
97 reasonIfUnsupported);
98 case LayerType::BatchToSpaceNd:
99 return IsBatchToSpaceNdSupported(infos[0],
100 infos[1],
101 *(PolymorphicDowncast<const BatchToSpaceNdDescriptor*>(&descriptor)),
102 reasonIfUnsupported);
103 case LayerType::Comparison:
104 return IsComparisonSupported(infos[0],
105 infos[1],
106 infos[2],
107 *(PolymorphicDowncast<const ComparisonDescriptor*>(&descriptor)),
108 reasonIfUnsupported);
109 case LayerType::Concat:
110 {
111 std::vector<const TensorInfo*> inputInfos;
112 for (uint32_t i = 0; i < (infos.size() - 1); i++)
113 {
114 inputInfos.push_back(&infos[i]);
115 }
116 return IsConcatSupported(inputInfos,
117 infos[infos.size() - 1],
118 *(PolymorphicDowncast<const OriginsDescriptor*>(&descriptor)),
119 reasonIfUnsupported);
120 }
121 case LayerType::Constant:
122 return IsConstantSupported(infos[0], reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000123 case LayerType::ConvertFp16ToFp32:
124 return IsConvertFp16ToFp32Supported(infos[0], infos[1], reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000125 case LayerType::ConvertFp32ToFp16:
126 return IsConvertFp32ToFp16Supported(infos[0], infos[1], reasonIfUnsupported);
127 case LayerType::Convolution2d:
128 {
129 if (infos.size() != 4)
130 {
131 throw InvalidArgumentException("Invalid number of Convolution2d TensorInfos. "
132 "TensorInfos should be of format: {input, output, weights, biases}.");
133 }
134
135 auto desc = *(PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor));
136 if (infos[3] == TensorInfo())
137 {
138 return IsConvolution2dSupported(infos[0],
139 infos[1],
140 desc,
141 infos[2],
142 EmptyOptional(),
143 reasonIfUnsupported);
144 }
145 else
146 {
147 return IsConvolution2dSupported(infos[0],
148 infos[1],
149 desc,
150 infos[2],
151 infos[3],
152 reasonIfUnsupported);
153 }
154 }
155 case LayerType::DepthToSpace:
156 return IsDepthToSpaceSupported(infos[0],
157 infos[1],
158 *(PolymorphicDowncast<const DepthToSpaceDescriptor*>(&descriptor)),
159 reasonIfUnsupported);
160 case LayerType::DepthwiseConvolution2d:
161 {
162 if (infos.size() != 4)
163 {
164 throw InvalidArgumentException("Invalid number of DepthwiseConvolution2d TensorInfos. "
165 "TensorInfos should be of format: {input, output, weights, biases}.");
166 }
167
168 auto desc = *(PolymorphicDowncast<const DepthwiseConvolution2dDescriptor*>(&descriptor));
169 if (infos[3] == TensorInfo())
170 {
171 return IsDepthwiseConvolutionSupported(infos[0],
172 infos[1],
173 desc,
174 infos[2],
175 EmptyOptional(),
176 reasonIfUnsupported);
177 }
178 else
179 {
180 return IsDepthwiseConvolutionSupported(infos[0],
181 infos[1],
182 desc,
183 infos[2],
184 infos[3],
185 reasonIfUnsupported);
186 }
187 }
188 case LayerType::Dequantize:
189 return IsDequantizeSupported(infos[0], infos[1], reasonIfUnsupported);
190 case LayerType::Division:
191 return IsDivisionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
192 case LayerType::ElementwiseUnary:
193 return IsElementwiseUnarySupported(infos[0],
194 infos[1],
195 *(PolymorphicDowncast<const ElementwiseUnaryDescriptor*>(&descriptor)),
196 reasonIfUnsupported);
197 case LayerType::Fill:
198 return IsFillSupported(infos[0],
199 infos[1],
200 *(PolymorphicDowncast<const FillDescriptor*>(&descriptor)),
201 reasonIfUnsupported);
202 case LayerType::Floor:
203 return IsFloorSupported(infos[0], infos[1], reasonIfUnsupported);
204 case LayerType::FullyConnected:
205 return IsFullyConnectedSupported(infos[0],
206 infos[1],
207 infos[2],
208 infos[3],
209 *(PolymorphicDowncast<const FullyConnectedDescriptor*>(&descriptor)),
210 reasonIfUnsupported);
211 case LayerType::Gather:
212 return IsGatherSupported(infos[0],
213 infos[1],
214 infos[2],
215 *(PolymorphicDowncast<const GatherDescriptor*>(&descriptor)),
216 reasonIfUnsupported);
Teresa Charlinb2d3ec52022-04-12 22:07:09 +0100217 case LayerType::GatherNd:
218 return IsGatherNdSupported(infos[0],
219 infos[1],
220 infos[2],
221 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000222 case LayerType::Input:
223 return IsInputSupported(infos[0], reasonIfUnsupported);
224 case LayerType::InstanceNormalization:
225 return IsInstanceNormalizationSupported(infos[0],
226 infos[1],
227 *(PolymorphicDowncast<const InstanceNormalizationDescriptor*>
228 (&descriptor)),
229 reasonIfUnsupported);
230 case LayerType::L2Normalization:
231 return IsL2NormalizationSupported(infos[0],
232 infos[1],
233 *(PolymorphicDowncast<const L2NormalizationDescriptor*>(&descriptor)),
234 reasonIfUnsupported);
235 case LayerType::LogicalBinary:
236 return IsLogicalBinarySupported(infos[0],
237 infos[1],
238 infos[2],
239 *(PolymorphicDowncast<const LogicalBinaryDescriptor*>(&descriptor)),
240 reasonIfUnsupported);
241 case LayerType::LogSoftmax:
242 return IsLogSoftmaxSupported(infos[0],
243 infos[1],
244 *(PolymorphicDowncast<const LogSoftmaxDescriptor*>(&descriptor)),
245 reasonIfUnsupported);
246 case LayerType::Lstm:
247 return IsLstmSupported(infos[0],
248 infos[1],
249 infos[2],
250 infos[3],
251 infos[4],
252 infos[5],
253 infos[6],
254 *(PolymorphicDowncast<const LstmDescriptor*>(&descriptor)),
255 lstmParamsInfo.value(),
256 reasonIfUnsupported);
257 case LayerType::QLstm:
258 return IsQLstmSupported(infos[0],
259 infos[1],
260 infos[2],
261 infos[3],
262 infos[4],
263 infos[5],
264 *(PolymorphicDowncast<const QLstmDescriptor*>(&descriptor)),
265 lstmParamsInfo.value(),
266 reasonIfUnsupported);
267 case LayerType::Maximum:
268 return IsMaximumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
269 case LayerType::Mean:
270 return IsMeanSupported(infos[0],
271 infos[1],
272 *(PolymorphicDowncast<const MeanDescriptor*>(&descriptor)),
273 reasonIfUnsupported);
274 case LayerType::Minimum:
275 return IsMinimumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
276 case LayerType::Multiplication:
277 return IsMultiplicationSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
278 case LayerType::Normalization:
279 return IsNormalizationSupported(infos[0],
280 infos[1],
281 *(PolymorphicDowncast<const NormalizationDescriptor*>(&descriptor)),
282 reasonIfUnsupported);
283 case LayerType::Output:
284 return IsOutputSupported(infos[0], reasonIfUnsupported);
285 case LayerType::Pad:
286 return IsPadSupported(infos[0],
287 infos[1],
288 *(PolymorphicDowncast<const PadDescriptor*>(&descriptor)),
289 reasonIfUnsupported);
290 case LayerType::Permute:
291 return IsPermuteSupported(infos[0],
292 infos[1],
293 *(PolymorphicDowncast<const PermuteDescriptor*>(&descriptor)),
294 reasonIfUnsupported);
295 case LayerType::Pooling2d:
296 return IsPooling2dSupported(infos[0],
297 infos[1],
298 *(PolymorphicDowncast<const Pooling2dDescriptor*>(&descriptor)),
299 reasonIfUnsupported);
300 case LayerType::Prelu:
301 return IsPreluSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
302 case LayerType::Quantize:
303 return IsQuantizeSupported(infos[0], infos[1], reasonIfUnsupported);
304 case LayerType::Reshape:
305 return IsReshapeSupported(infos[0],
306 infos[1],
307 *(PolymorphicDowncast<const ReshapeDescriptor*>(&descriptor)),
308 reasonIfUnsupported);
309 case LayerType::Resize:
310 return IsResizeSupported(infos[0],
311 infos[1],
312 *(PolymorphicDowncast<const ResizeDescriptor*>(&descriptor)),
313 reasonIfUnsupported);
314 case LayerType::Reduce:
315 return IsReduceSupported(infos[0],
316 infos[1],
317 *(PolymorphicDowncast<const ReduceDescriptor*>(&descriptor)),
318 reasonIfUnsupported);
319 case LayerType::Slice:
320 return IsSliceSupported(infos[0],
321 infos[1],
322 *(PolymorphicDowncast<const SliceDescriptor*>(&descriptor)),
323 reasonIfUnsupported);
324 case LayerType::Softmax:
325 return IsSoftmaxSupported(infos[0],
326 infos[1],
327 *(PolymorphicDowncast<const SoftmaxDescriptor*>(&descriptor)),
328 reasonIfUnsupported);
329 case LayerType::SpaceToBatchNd:
330 return IsSpaceToBatchNdSupported(infos[0],
331 infos[1],
332 *(PolymorphicDowncast<const SpaceToBatchNdDescriptor*>(&descriptor)),
333 reasonIfUnsupported);
334 case LayerType::SpaceToDepth:
335 return IsSpaceToDepthSupported(infos[0],
336 infos[1],
337 *(PolymorphicDowncast<const SpaceToDepthDescriptor*>(&descriptor)),
338 reasonIfUnsupported);
339 case LayerType::Splitter:
340 {
341 std::vector<TensorInfo> outputInfos;
342 for (uint32_t i = 1; i < infos.size(); i++)
343 {
344 outputInfos.push_back(infos[i]);
345 }
346 return IsSplitterSupported(infos[0],
347 {outputInfos.begin(), outputInfos.end()},
348 *(PolymorphicDowncast<const ViewsDescriptor*>(&descriptor)),
349 reasonIfUnsupported);
350 }
351 case LayerType::Stack:
352 {
353 std::vector<const TensorInfo*> inputInfos;
354 for (uint32_t i = 0; i < infos.size() - 1; i++)
355 {
356 inputInfos.push_back(&infos[i]);
357 }
358 return IsStackSupported(inputInfos,
359 infos[infos.size() - 1],
360 *(PolymorphicDowncast<const StackDescriptor*>(&descriptor)),
361 reasonIfUnsupported);
362 }
363 case LayerType::StridedSlice:
364 return IsStridedSliceSupported(infos[0],
365 infos[1],
366 *(PolymorphicDowncast<const StridedSliceDescriptor*>(&descriptor)),
367 reasonIfUnsupported);
368 case LayerType::Subtraction:
369 return IsSubtractionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
370 case LayerType::Transpose:
371 return IsTransposeSupported(infos[0],
372 infos[1],
373 *(PolymorphicDowncast<const TransposeDescriptor*>(&descriptor)),
374 reasonIfUnsupported);
375 case LayerType::TransposeConvolution2d:
376 {
377 if (infos.size() != 4)
378 {
379 throw InvalidArgumentException("Invalid number of TransposeConvolution2d TensorInfos. "
380 "TensorInfos should be of format: {input, output, weights, biases}.");
381 }
382
383 auto desc = *(PolymorphicDowncast<const TransposeConvolution2dDescriptor*>(&descriptor));
384 if (infos[3] == TensorInfo())
385 {
386 return IsTransposeConvolution2dSupported(infos[0],
387 infos[1],
388 desc,
389 infos[2],
390 EmptyOptional(),
391 reasonIfUnsupported);
392 }
393 else
394 {
395 return IsTransposeConvolution2dSupported(infos[0],
396 infos[1],
397 desc,
398 infos[2],
399 infos[3],
400 reasonIfUnsupported);
401 }
402 }
403 case LayerType::Cast:
404 return IsCastSupported(infos[0], infos[1], reasonIfUnsupported);
405 case LayerType::ChannelShuffle:
406 return IsChannelShuffleSupported(infos[0],
407 infos[1],
408 *(PolymorphicDowncast<const ChannelShuffleDescriptor*>(&descriptor)),
409 reasonIfUnsupported);
410 case LayerType::Convolution3d:
411 {
412 if (infos.size() != 4)
413 {
414 throw InvalidArgumentException("Invalid number of Convolution3d TensorInfos. "
415 "TensorInfos should be of format: {input, output, weights, biases}.");
416 }
417
418 auto desc = *(PolymorphicDowncast<const Convolution3dDescriptor*>(&descriptor));
419 if (infos[3] == TensorInfo())
420 {
421 return IsConvolution3dSupported(infos[0],
422 infos[1],
423 desc,
424 infos[2],
425 EmptyOptional(),
426 reasonIfUnsupported);
427 }
428 else
429 {
430 return IsConvolution3dSupported(infos[0],
431 infos[1],
432 desc,
433 infos[2],
434 infos[3],
435 reasonIfUnsupported);
436 }
437 }
438 case LayerType::Debug:
439 return IsDebugSupported(infos[0], infos[1], reasonIfUnsupported);
440 case LayerType::DetectionPostProcess:
441 return IsDetectionPostProcessSupported(infos[0],
442 infos[1],
443 infos[2],
444 infos[3],
445 infos[4],
446 infos[5],
447 infos[6],
448 *(PolymorphicDowncast<const DetectionPostProcessDescriptor*>
449 (&descriptor)),
450 reasonIfUnsupported);
451 case LayerType::FakeQuantization:
452 return IsFakeQuantizationSupported(infos[0],
453 *(PolymorphicDowncast<const FakeQuantizationDescriptor*>(&descriptor)),
454 reasonIfUnsupported);
455 case LayerType::MemCopy:
456 return IsMemCopySupported(infos[0], infos[1], reasonIfUnsupported);
457 case LayerType::Rank:
458 return IsRankSupported(infos[0], infos[1], reasonIfUnsupported);
459 case LayerType::Shape:
460 return IsShapeSupported(infos[0], infos[1], reasonIfUnsupported);
461 case LayerType::UnidirectionalSequenceLstm:
462 {
463 if (infos.size() != 6)
464 {
465 throw InvalidArgumentException("Invalid number of UnidirectionalSequenceLstm TensorInfos. TensorInfos "
466 "should be of format: {input, outputStateIn, cellStateIn, "
467 "hiddenStateOutputVal, cellStateOutputVal, output}");
468 }
469 auto desc = *(PolymorphicDowncast<const UnidirectionalSequenceLstmDescriptor*>(&descriptor));
Mike Kelly12994962022-04-21 11:57:09 +0100470 return IsUnidirectionalSequenceLstmSupported(infos[0],
471 infos[1],
472 infos[2],
473 infos[3],
474 infos[4],
475 infos[5],
476 desc,
477 lstmParamsInfo.value(),
478 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000479 }
480 case LayerType::Pooling3d:
481 return IsPooling3dSupported(infos[0],
482 infos[1],
483 *(PolymorphicDowncast<const Pooling3dDescriptor*>(&descriptor)),
484 reasonIfUnsupported);
485 case LayerType::Map:
486 return true;
487 case LayerType::Unmap:
488 return true;
489 case LayerType::MemImport:
490 return LayerSupportBase::IsMemImportSupported(infos[0], infos[1], reasonIfUnsupported);
491 case LayerType::Merge:
492 return LayerSupportBase::IsMergeSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
493 case LayerType::QuantizedLstm:
494 return LayerSupportBase::IsQuantizedLstmSupported(infos[0],
495 infos[1],
496 infos[2],
497 infos[3],
498 infos[4],
499 quantizedLstmInputParamsInfo.value(),
500 reasonIfUnsupported);
501 default:
502 // layers not supported in neon by default:
503 // precompiled, standin, switch
504 return false;
505 }
506}
507
arovir011c7c81b2018-10-08 11:34:28 +0100508bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
509 const TensorInfo& output,
510 const ActivationDescriptor& descriptor,
511 Optional<std::string&> reasonIfUnsupported) const
512{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000513 bool supported = true;
514
515 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +0000516 std::array<DataType,6> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000517 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100518 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000519 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000520 DataType::QAsymmU8,
521 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000522 };
523
524 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
525 "Reference activation: input type not supported.");
526
527 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
528 "Reference activation: output type not supported.");
529
530 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
531 "Reference activation: input and output types mismatched.");
532
533 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
534 "Reference activation: input and output shapes are of different rank.");
535
536
537 struct ActivationFunctionSupported : public Rule
538 {
539 ActivationFunctionSupported(const ActivationDescriptor& desc)
540 {
541 switch(desc.m_Function)
542 {
543 case ActivationFunction::Abs:
544 case ActivationFunction::BoundedReLu:
David Monahan3b3c3812020-02-25 09:03:29 +0000545 case ActivationFunction::Elu:
Colm Donelan03fbeaf2020-02-26 15:39:23 +0000546 case ActivationFunction::HardSwish:
Derek Lamberti50db4e82019-03-13 14:16:15 +0000547 case ActivationFunction::LeakyReLu:
548 case ActivationFunction::Linear:
549 case ActivationFunction::ReLu:
550 case ActivationFunction::Sigmoid:
551 case ActivationFunction::SoftReLu:
552 case ActivationFunction::Sqrt:
553 case ActivationFunction::Square:
554 case ActivationFunction::TanH:
555 {
556 m_Res = true;
557 break;
558 }
559 default:
560 {
561 m_Res = false;
562 break;
563 }
564 }
565 }
566 };
567
568 // Function is supported
569 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
570 "Reference activation: function not supported.");
571
572 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100573}
574
575bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
576 const TensorInfo& input1,
577 const TensorInfo& output,
578 Optional<std::string&> reasonIfUnsupported) const
579{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000580 bool supported = true;
581
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100582 std::array<DataType,7> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000583 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100584 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000585 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000586 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100587 DataType::QSymmS16,
588 DataType::Signed32
Derek Lamberti50db4e82019-03-13 14:16:15 +0000589 };
590
591 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
592 "Reference addition: input 0 is not a supported type.");
593
594 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
595 "Reference addition: input 1 is not a supported type.");
596
597 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
598 "Reference addition: output is not a supported type.");
599
600 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
601 "Reference addition: input 0 and Input 1 types are mismatched");
602
603 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
604 "Reference addition: input and output types are mismatched");
605
606 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
607 "Reference addition: shapes are not suitable for implicit broadcast.");
608
609 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100610}
611
Nikhil Raj68c2c902019-09-19 11:21:11 +0100612bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
613 const armnn::ArgMinMaxDescriptor &descriptor,
614 armnn::Optional<std::string &> reasonIfUnsupported) const
615{
Jan Eilers8eb25602020-03-09 12:13:48 +0000616 IgnoreUnused(descriptor);
Nikhil Raj68c2c902019-09-19 11:21:11 +0100617
Mike Kelly1f140f72021-04-06 12:25:55 +0100618 std::array<DataType, 8> supportedInputTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100619 {
Teresa Charline300b362020-05-25 10:01:03 +0100620 DataType::Float16,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100621 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100622 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000623 DataType::QAsymmU8,
624 DataType::QSymmS16,
Mike Kelly1f140f72021-04-06 12:25:55 +0100625 DataType::Signed32,
626 DataType::Signed64
627 };
628
629 std::array<DataType,2> supportedOutputTypes = {
630 DataType::Signed32,
631 DataType::Signed64
Nikhil Raj68c2c902019-09-19 11:21:11 +0100632 };
633
634 bool supported = true;
635
Mike Kelly1f140f72021-04-06 12:25:55 +0100636 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100637 "Reference ArgMinMax: input is not a supported type.");
Mike Kelly1f140f72021-04-06 12:25:55 +0100638 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100639 "Reference ArgMinMax: output type not supported");
640
641 return supported;
642}
643
Samuel Yap6b478092022-07-06 15:36:03 +0100644bool RefLayerSupport::IsBatchMatMulSupported(const TensorInfo& inputX,
645 const TensorInfo& inputY,
646 const TensorInfo& output,
647 const BatchMatMulDescriptor& descriptor,
648 Optional<std::string &> reasonIfUnsupported) const
649{
650 IgnoreUnused(descriptor);
651
652 std::array<DataType, 6> supportedTypes =
653 {
Samuel Yap6b478092022-07-06 15:36:03 +0100654 DataType::Float16,
655 DataType::Float32,
656 DataType::QAsymmS8,
657 DataType::QAsymmU8,
658 DataType::QSymmS16
659 };
660
661 bool supported = true;
662
663 supported &= CheckSupportRule(TypeAnyOf(inputX, supportedTypes), reasonIfUnsupported,
664 "Reference batch matrix multiplication: input X is not a supported type");
665
666 supported &= CheckSupportRule(TypeAnyOf(inputY, supportedTypes), reasonIfUnsupported,
667 "Reference batch matrix multiplication: input Y is not a supported type");
668
669 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
670 "Reference batch matrix multiplication: output is not a supported type");
671
672 supported &= CheckSupportRule(TypesAreEqual(inputX, inputY), reasonIfUnsupported,
673 "Reference batch matrix multiplication: input X and input Y types are mismatched");
674
675 supported &= CheckSupportRule(TypesAreEqual(inputX, output), reasonIfUnsupported,
676 "Reference batch matrix multiplication: inputs and output types are mismatched");
677
678 supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputX, 2),
679 reasonIfUnsupported,
680 "Reference batch matrix multiplication: input X is not of rank 2 or greater");
681
682 supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputY, 2),
683 reasonIfUnsupported,
684 "Reference batch matrix multiplication: input Y is not of rank 2 or greater");
685
686 return supported;
687}
688
arovir011c7c81b2018-10-08 11:34:28 +0100689bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
690 const TensorInfo& output,
691 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100692 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100693 const TensorInfo& beta,
694 const TensorInfo& gamma,
695 const BatchNormalizationDescriptor& descriptor,
696 Optional<std::string&> reasonIfUnsupported) const
697{
Jan Eilers8eb25602020-03-09 12:13:48 +0000698 IgnoreUnused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100699
Sadik Armagan303980c2020-04-17 12:45:14 +0100700 std::array<DataType, 6> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100701 {
702 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100703 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100704 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000705 DataType::QAsymmU8,
706 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100707 };
708
709 bool supported = true;
710
711 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
712 "Reference batch normalization: input is not a supported type.");
713
714 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
715 "Reference batch normalization: output is not a supported type.");
716
717 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
718 "Reference batch normalization: input and output types are mismatched");
719
720 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
721 "Reference batch normalization: mean is not a supported type.");
722
723 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
724 "Reference batch normalization: variance is not a supported type.");
725
726 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
727 "Reference batch normalization: beta is not a supported type.");
728
729 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
730 "Reference batch normalization: gamma is not a supported type.");
731
732 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100733}
734
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000735bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
736 const TensorInfo& output,
737 const BatchToSpaceNdDescriptor& descriptor,
738 Optional<std::string&> reasonIfUnsupported) const
739{
Jan Eilers8eb25602020-03-09 12:13:48 +0000740 IgnoreUnused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100741
742 bool supported = true;
743
744 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
745 std::string inputTensorStr = "input";
746 std::string outputTensorStr = "output";
747
748 // Define supported types.
Sadik Armagan303980c2020-04-17 12:45:14 +0100749 std::array<DataType,6> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100750 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000751 DataType::Float32,
752 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100753 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000754 DataType::QAsymmU8,
755 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100756 };
757
758 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
759 "Reference BatchToSpaceNd: input type not supported.");
760
761 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
762 "Reference BatchToSpaceNd: output type not supported.");
763
764 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
765 "Reference BatchToSpaceNd: input and output types mismatched.");
766
767 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
768 reasonIfUnsupported,
769 CreateIncorrectDimensionsErrorMsg(4,
770 output.GetNumDimensions(),
771 batchToSpaceNdLayerStr,
772 outputTensorStr).data());
773
774 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
775 reasonIfUnsupported,
776 CreateIncorrectDimensionsErrorMsg(4,
777 input.GetNumDimensions(),
778 batchToSpaceNdLayerStr,
779 inputTensorStr).data());
780
781 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000782}
783
mathad01b392e982021-04-07 12:07:30 +0100784bool RefLayerSupport::IsCastSupported(const TensorInfo& input,
785 const TensorInfo& output,
786 Optional<std::string&> reasonIfUnsupported) const
787{
788 std::array<DataType, 9> supportedInputTypes =
789 {
mathad01b392e982021-04-07 12:07:30 +0100790 DataType::Float32,
791 DataType::Float16,
792 DataType::QSymmS8,
793 DataType::QAsymmS8,
794 DataType::QAsymmU8,
795 DataType::QSymmS16,
796 DataType::Signed32
797 };
798
799 bool supported = true;
800 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
801 "Reference cast: input is not a supported type");
802
803
804 supported &= CheckSupportRule(TypeAnyOf(output, supportedInputTypes), reasonIfUnsupported,
805 "Reference cast: output is not a supported type");
806
807 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
808 "Reference cast: input and output shapes have different number of total elements");
809
810 return supported;
811}
812
Simon Obute51f67772021-09-03 15:50:13 +0100813bool RefLayerSupport::IsChannelShuffleSupported(const TensorInfo& input,
814 const TensorInfo& output,
815 const ChannelShuffleDescriptor& descriptor,
816 Optional<std::string&> reasonIfUnsupported) const
817{
818 IgnoreUnused(descriptor);
819 bool supported = true;
820
821 // Define supported output and inputs types.
822 std::array<DataType, 7> supportedTypes =
823 {
Simon Obute51f67772021-09-03 15:50:13 +0100824 DataType::Float32,
825 DataType::Float16,
826 DataType::QAsymmS8,
827 DataType::QAsymmU8,
828 DataType::QSymmS8,
829 DataType::QSymmS16
830 };
831
832 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
833 "Reference ChannelShuffle: input is not a supported type.");
834
835 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
836 "Reference ChannelShuffle: output is not a supported type.");
837
838 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
839 "Reference ChannelShuffle: input and output types are mismatched.");
840
841 return supported;
842}
843
844
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100845bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
846 const TensorInfo& input1,
847 const TensorInfo& output,
848 const ComparisonDescriptor& descriptor,
849 Optional<std::string&> reasonIfUnsupported) const
850{
Jan Eilers8eb25602020-03-09 12:13:48 +0000851 IgnoreUnused(descriptor);
Sadik Armagan303980c2020-04-17 12:45:14 +0100852 std::array<DataType, 8> supportedInputTypes =
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100853 {
Sadik Armaganb60dd242020-03-19 13:53:16 +0000854 DataType::Boolean,
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100855 DataType::Float32,
856 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100857 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000858 DataType::QAsymmU8,
Sadik Armaganb60dd242020-03-19 13:53:16 +0000859 DataType::QSymmS16,
860 DataType::Signed32
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100861 };
862
863 bool supported = true;
864 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
865 "Reference comparison: input 0 is not a supported type");
866
867 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
868 "Reference comparison: input 0 and Input 1 types are mismatched");
869
870 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
871 "Reference comparison: output is not of type Boolean");
872
873 return supported;
874}
875
Jim Flynn906f9462019-05-10 13:55:21 +0100876bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
877 const TensorInfo& output,
Cathal Corbett34b429c2021-12-24 12:24:40 +0000878 const OriginsDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100879 Optional<std::string&> reasonIfUnsupported) const
880{
Jan Eilers8eb25602020-03-09 12:13:48 +0000881 IgnoreUnused(descriptor);
Jim Flynne242f2d2019-05-22 14:24:13 +0100882
883 bool supported = true;
Teresa Charlin6abc7ee2022-02-22 17:32:27 +0000884 std::array<DataType,7> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100885 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000886 DataType::Float32,
887 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000888 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100889 DataType::QAsymmU8,
Teresa Charlin6abc7ee2022-02-22 17:32:27 +0000890 DataType::QSymmS16,
891 DataType::Signed32
Jim Flynne242f2d2019-05-22 14:24:13 +0100892 };
893
894 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
895 "Reference concatenation: output type not supported");
896 for (const TensorInfo* input : inputs)
897 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100898 ARMNN_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100899 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
900 "Reference concatenation: input type not supported");
901
902 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
903 "Reference concatenation: input and output types mismatched.");
904 }
905
906 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100907}
908
arovir011c7c81b2018-10-08 11:34:28 +0100909bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
910 Optional<std::string&> reasonIfUnsupported) const
911{
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100912 std::array<DataType,8> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100913 {
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100914 DataType::Float16,
Nina Drozd58ef2c62019-05-16 12:09:18 +0100915 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +0000916 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100917 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000918 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100919 DataType::QSymmS16,
920 DataType::Signed32
Nina Drozd58ef2c62019-05-16 12:09:18 +0100921 };
922
923 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
924 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100925}
926
927bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
928 const TensorInfo& output,
929 Optional<std::string&> reasonIfUnsupported) const
930{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100931 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
932 input.GetDataType(),
933 &TrueFunc<>,
934 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000935 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000936 &FalseFuncI32<>,
937 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100938 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
939 output.GetDataType(),
940 &FalseOutputFuncF16<>,
941 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000942 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000943 &FalseFuncI32<>,
944 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100945}
946
947bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
948 const TensorInfo& output,
949 Optional<std::string&> reasonIfUnsupported) const
950{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100951 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
952 input.GetDataType(),
953 &FalseInputFuncF16<>,
954 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000955 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000956 &FalseFuncI32<>,
957 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100958 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
959 output.GetDataType(),
960 &TrueFunc<>,
961 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000962 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000963 &FalseFuncI32<>,
964 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100965}
966
967bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
968 const TensorInfo& output,
969 const Convolution2dDescriptor& descriptor,
970 const TensorInfo& weights,
971 const Optional<TensorInfo>& biases,
972 Optional<std::string&> reasonIfUnsupported) const
973{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100974 bool supported = true;
975
976 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000977 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000978 {
979 DataType::Float32,
980 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000981 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100982 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000983 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000984 DataType::QSymmS16
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100985 };
986
987 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000988 "Reference Convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100989
990 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000991 "Reference Convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100992
Ryan OShea31441592022-11-07 16:20:48 +0000993 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
994 "Reference Convolution2d: input and output types mismatched.");
995
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100996
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000997 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000998 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000999 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001000 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001001 {
Sadik Armagan303980c2020-04-17 12:45:14 +01001002 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001003 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01001004 DataType::QSymmS8
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001005 };
1006
1007 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001008 "Reference Convolution2d: weights type not supported for quantized input.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001009 }
1010 else
1011 {
1012 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001013 "Reference Convolution2d: weights is not a supported type.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001014
1015 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001016 "Reference Convolution2d: input and weights types mismatched.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001017 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001018
1019 if (biases.has_value())
1020 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001021 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001022 {
1023 DataType::Float32,
1024 DataType::Float16,
1025 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001026 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001027
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001028 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001029 "Reference Convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001030 }
Jan Eilers8eb25602020-03-09 12:13:48 +00001031 IgnoreUnused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001032
1033 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001034}
1035
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001036bool RefLayerSupport::IsConvolution3dSupported(const TensorInfo& input,
1037 const TensorInfo& output,
1038 const Convolution3dDescriptor& descriptor,
1039 const TensorInfo& weights,
1040 const Optional<TensorInfo>& biases,
1041 Optional<std::string&> reasonIfUnsupported) const
1042{
1043 bool supported = true;
1044
1045 // Define supported types.
1046 std::array<DataType,7> supportedTypes =
1047 {
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001048 DataType::Float32,
1049 DataType::Float16,
1050 DataType::QAsymmS8,
1051 DataType::QAsymmU8,
1052 DataType::QSymmS8,
1053 DataType::QSymmS16
1054 };
1055
1056 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1057 "Reference Convolution3d: input is not a supported type.");
1058
1059 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1060 "Reference Convolution3d: output is not a supported type.");
1061
1062 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1063 "Reference Convolution3d: input and output types mismatched.");
1064
1065 const DataType inputType = input.GetDataType();
1066 if (IsQuantized8BitType(inputType))
1067 {
1068 std::array<DataType, 3> supportedWeightTypes =
1069 {
1070 DataType::QAsymmS8,
1071 DataType::QAsymmU8,
1072 DataType::QSymmS8
1073 };
1074
1075 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1076 "Reference Convolution3d: weights type not supported for quantized input.");
1077 }
1078 else
1079 {
1080 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1081 "Reference Convolution3d: weights is not a supported type.");
1082
1083 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1084 "Reference Convolution3d: input and weights types mismatched.");
1085 }
1086
1087 if (biases.has_value())
1088 {
1089 std::array<DataType,4> biasesSupportedTypes =
1090 {
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001091 DataType::Float32,
1092 DataType::Float16,
1093 DataType::Signed32
1094 };
1095
1096 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1097 "Reference Convolution3d: biases is not a supported type.");
1098 }
1099 IgnoreUnused(descriptor);
1100
1101 return supported;
1102}
1103
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001104bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
1105 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001106 Optional<std::string&> reasonIfUnsupported) const
1107{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001108 bool supported = true;
1109
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001110 std::array<DataType, 8> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001111 {
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001112 DataType::BFloat16,
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +00001113 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001114 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001115 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001116 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001117 DataType::QSymmS8,
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +00001118 DataType::QSymmS16,
1119 DataType::Signed32
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001120 };
1121
1122 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001123 "Reference for Debug layer: input type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001124
1125 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001126 "Reference for Debug layer: output type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001127
1128 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001129 "Reference for Debug layer: input and output types are mismatched");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001130
1131 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001132}
1133
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001134bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
1135 const TensorInfo& output,
1136 const DepthToSpaceDescriptor& descriptor,
1137 Optional<std::string&> reasonIfUnsupported) const
1138{
Jan Eilers8eb25602020-03-09 12:13:48 +00001139 IgnoreUnused(descriptor);
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001140 bool supported = true;
1141
Sadik Armagan303980c2020-04-17 12:45:14 +01001142 std::array<DataType,6> supportedTypes =
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001143 {
1144 DataType::Float32,
1145 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001146 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001147 DataType::QAsymmU8,
1148 DataType::QSymmS16
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001149 };
1150
1151 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1152 "Reference DepthToSpace: input type not supported");
1153
1154 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1155 "Reference DepthToSpace: output type not supported");
1156
1157 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1158 "Reference DepthToSpace: input and output types are mismatched");
1159
1160 return supported;
1161}
1162
arovir011c7c81b2018-10-08 11:34:28 +01001163bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
1164 const TensorInfo& output,
1165 const DepthwiseConvolution2dDescriptor& descriptor,
1166 const TensorInfo& weights,
1167 const Optional<TensorInfo>& biases,
1168 Optional<std::string&> reasonIfUnsupported) const
1169{
Sadik Armagan303980c2020-04-17 12:45:14 +01001170 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001171 bool supported = true;
1172
1173 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001174 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001175 {
1176 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001177 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001178 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001179 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001180 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001181 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001182 };
1183
1184 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1185 "Reference DepthwiseConvolution2d: input is not a supported type.");
1186
1187 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1188 "Reference DepthwiseConvolution2d: output is not a supported type.");
1189
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001190 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1191 "Reference DepthwiseConvolution2d: input and output types mismatched.");
1192
Teresa Charlind8df0262019-11-11 12:28:15 +00001193 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +00001194 if (IsQuantized8BitType(inputType))
Teresa Charlind8df0262019-11-11 12:28:15 +00001195 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001196 std::array<DataType, 3> supportedWeightTypes =
Sadik Armagan303980c2020-04-17 12:45:14 +01001197 {
1198 DataType::QAsymmS8,
1199 DataType::QAsymmU8,
1200 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001201 };
Teresa Charlind8df0262019-11-11 12:28:15 +00001202
1203 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Sadik Armagan303980c2020-04-17 12:45:14 +01001204 "Reference DepthwiseConvolution2d: weights type not supported for "
1205 "quantized input.");
Teresa Charlind8df0262019-11-11 12:28:15 +00001206 }
1207 else
1208 {
1209 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1210 "Reference DepthwiseConvolution2d: weights is not a supported type.");
1211
1212 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1213 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
1214 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001215
1216 if (biases.has_value())
1217 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001218 std::array<DataType,4> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001219 {
1220 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001221 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001222 DataType::Signed32
1223 };
1224 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1225 "Reference DepthwiseConvolution2d: biases is not a supported type.");
1226 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001227
1228 return supported;
1229
arovir011c7c81b2018-10-08 11:34:28 +01001230}
1231
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001232bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
1233 const TensorInfo& output,
1234 Optional<std::string&> reasonIfUnsupported) const
1235{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001236 bool supported = true;
1237
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001238 std::array<DataType,5> supportedInputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00001239 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001240 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00001241 DataType::QSymmS8,
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001242 DataType::QSymmS16,
1243 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001244 };
1245
1246 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001247 "Reference for Dequantize layer: input type not supported.");
1248
Derek Lambertid466a542020-01-22 15:37:29 +00001249 supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
Teresa Charlin1b1950d2021-06-02 20:23:21 +01001250 "Reference for Dequantize layer: per-axis quantized input not supported.");
Derek Lambertid466a542020-01-22 15:37:29 +00001251
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001252 std::array<DataType,3> supportedOutputTypes = {
Jan Eilersf7107932019-11-01 11:09:36 +00001253 DataType::Float32,
1254 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001255 };
1256
1257 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001258 "Reference for Dequantize layer: output type not supported.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001259
1260 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001261 "Reference for Dequantize layer: input/output shapes have different num total "
1262 "elements.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001263
1264 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001265}
1266
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001267bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
1268 const TensorInfo& scores,
1269 const TensorInfo& anchors,
1270 const TensorInfo& detectionBoxes,
1271 const TensorInfo& detectionClasses,
1272 const TensorInfo& detectionScores,
1273 const TensorInfo& numDetections,
1274 const DetectionPostProcessDescriptor& descriptor,
1275 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001276{
Jan Eilers8eb25602020-03-09 12:13:48 +00001277 IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
Derek Lamberti901ea112019-12-10 22:07:09 +00001278
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001279 bool supported = true;
1280
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001281 std::array<DataType,6> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001282 {
1283 DataType::Float32,
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001284 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001285 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001286 DataType::QAsymmU8,
1287 DataType::QSymmS16
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001288 };
1289
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001290 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001291 "Reference DetectionPostProcess: input 0 is not a supported type.");
1292
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001293 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001294 "Reference DetectionPostProcess: input 1 is not a supported type.");
1295
1296 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001297}
1298
Pablo Tellof0bd6832019-04-26 17:58:13 +01001299bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
1300 const TensorInfo& output,
1301 const DepthwiseConvolution2dDescriptor& descriptor,
1302 const TensorInfo& weights,
1303 const Optional<TensorInfo>& biases,
1304 Optional<std::string&> reasonIfUnsupported) const
1305{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001306 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +01001307}
1308
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001309bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +01001310 const TensorInfo& input1,
1311 const TensorInfo& output,
1312 Optional<std::string&> reasonIfUnsupported) const
1313{
Sadik Armagan2999a022019-04-09 14:20:12 +01001314 bool supported = true;
1315
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001316 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001317 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001318 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001319 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001320 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001321 DataType::QSymmS16,
1322 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001323 };
1324
1325 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1326 "Reference division: input 0 is not a supported type.");
1327
1328 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1329 "Reference division: input 1 is not a supported type.");
1330
1331 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1332 "Reference division: output is not a supported type.");
1333
1334 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1335 "Reference division: input 0 and Input 1 types are mismatched");
1336
1337 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1338 "Reference division: input and output types are mismatched");
1339
1340 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1341 "Reference division: shapes are not suitable for implicit broadcast.");
1342
1343 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001344}
1345
josh minor4a3c6102020-01-06 16:40:46 -06001346bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
1347 const TensorInfo& output,
1348 const ElementwiseUnaryDescriptor& descriptor,
1349 Optional<std::string&> reasonIfUnsupported) const
1350{
Jan Eilers8eb25602020-03-09 12:13:48 +00001351 IgnoreUnused(descriptor);
josh minor4a3c6102020-01-06 16:40:46 -06001352
Sadik Armagan303980c2020-04-17 12:45:14 +01001353 std::array<DataType, 7> supportedTypes =
josh minor4a3c6102020-01-06 16:40:46 -06001354 {
1355 DataType::Float32,
1356 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001357 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06001358 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00001359 DataType::QSymmS16,
1360 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06001361 };
1362
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001363 std::array<DataType, 1> logicalSupportedTypes =
1364 {
1365 DataType::Boolean
1366 };
1367
josh minor4a3c6102020-01-06 16:40:46 -06001368 bool supported = true;
1369
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001370 if (descriptor.m_Operation == UnaryOperation::LogicalNot)
1371 {
1372 supported &= CheckSupportRule(TypeAnyOf(input, logicalSupportedTypes), reasonIfUnsupported,
1373 "Reference elementwise unary: input type not supported");
josh minor4a3c6102020-01-06 16:40:46 -06001374
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001375 supported &= CheckSupportRule(TypeAnyOf(output, logicalSupportedTypes), reasonIfUnsupported,
1376 "Reference elementwise unary: output type not supported");
1377 }
1378 else
1379 {
1380 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1381 "Reference elementwise unary: input type not supported");
1382
1383 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1384 "Reference elementwise unary: output type not supported");
1385 }
josh minor4a3c6102020-01-06 16:40:46 -06001386
1387 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1388 "Reference elementwise unary: input and output types not matching");
1389
1390 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1391 "Reference elementwise unary: input and output shapes"
1392 "have different number of total elements");
1393
1394 return supported;
1395}
1396
arovir011c7c81b2018-10-08 11:34:28 +01001397bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
1398 const FakeQuantizationDescriptor& descriptor,
1399 Optional<std::string&> reasonIfUnsupported) const
1400{
Jan Eilers8eb25602020-03-09 12:13:48 +00001401 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001402 bool supported = true;
1403
1404 std::array<DataType,1> supportedTypes =
1405 {
1406 DataType::Float32
1407 };
1408
1409 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1410 "Reference fake quantization: input type not supported.");
1411
1412 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001413}
1414
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001415bool RefLayerSupport::IsFillSupported(const TensorInfo& input,
1416 const TensorInfo& output,
1417 const FillDescriptor& descriptor,
1418 Optional<std::string&> reasonIfUnsupported) const
1419{
1420 IgnoreUnused(descriptor);
1421 IgnoreUnused(output);
1422
1423 bool supported = true;
1424
Sadik Armagana792a052020-06-23 16:22:23 +01001425 std::array<DataType,3> supportedTypes =
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001426 {
1427 DataType::Float32,
Sadik Armagana792a052020-06-23 16:22:23 +01001428 DataType::Float16,
1429 DataType::Signed32
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001430 };
1431
Teresa Charlin4b10fef2020-07-29 09:36:41 +01001432 supported &= CheckSupportRule(TypeIs(input, DataType::Signed32), reasonIfUnsupported,
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001433 "Reference Fill: input type not supported.");
1434
Teresa Charlin44088502020-07-27 11:27:19 +01001435 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1436 "Reference Fill: output type not supported.");
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001437 return supported;
1438}
1439
arovir011c7c81b2018-10-08 11:34:28 +01001440bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
1441 const TensorInfo& output,
1442 Optional<std::string&> reasonIfUnsupported) const
1443{
Jan Eilers8eb25602020-03-09 12:13:48 +00001444 IgnoreUnused(output);
James Conroy83735b12019-05-30 16:36:59 +01001445 bool supported = true;
1446
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01001447 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +01001448 {
James Conroyb40d7102019-06-04 12:32:09 +01001449 DataType::Float32,
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01001450 DataType::Float16
James Conroy83735b12019-05-30 16:36:59 +01001451 };
1452
1453 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1454 "Reference Floor: input type not supported.");
1455
1456 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1457 "Reference Floor: output type not supported.");
1458
1459 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001460}
1461
1462bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
1463 const TensorInfo& output,
1464 const TensorInfo& weights,
1465 const TensorInfo& biases,
1466 const FullyConnectedDescriptor& descriptor,
1467 Optional<std::string&> reasonIfUnsupported) const
1468{
Francis Murtagh46c09d02019-05-28 08:15:28 +01001469 bool supported = true;
1470
1471 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001472 std::array<DataType,6> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +01001473 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001474 DataType::Float32,
1475 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001476 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001477 DataType::QAsymmU8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001478 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001479 };
1480
1481 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1482 "Reference Fully Connected: input type not supported.");
1483
1484 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1485 "Reference Fully Connected: output type not supported.");
1486
Francis Murtagh46c09d02019-05-28 08:15:28 +01001487 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1488 "Reference Fully Connected: weights type not supported.");
1489
Ryan OShea31441592022-11-07 16:20:48 +00001490 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1491 "Reference Fully Connected: input and output types mismatched.");
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001492
Jan Eilers1f45dc32020-06-15 11:43:03 +01001493 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1494 "Reference Fully Connected: weights is not a supported type.");
Francis Murtaghddb1d062020-03-10 13:51:45 +00001495
Jan Eilers1f45dc32020-06-15 11:43:03 +01001496 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1497 "Reference Fully Connected: input and weights types mismatched.");
Francis Murtagh46c09d02019-05-28 08:15:28 +01001498
1499 if (descriptor.m_BiasEnabled)
1500 {
1501 // Defined supported types for bias
Sadik Armagandb73c982020-04-01 17:35:30 +01001502 std::array<DataType, 5>
Francis Murtagh46c09d02019-05-28 08:15:28 +01001503 supportedBiasTypes =
1504 {
1505 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001506 DataType::Float16,
Sadik Armagandb73c982020-04-01 17:35:30 +01001507 DataType::Signed32,
1508 DataType::QAsymmS8
Francis Murtagh46c09d02019-05-28 08:15:28 +01001509 };
1510
1511 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
1512 "Reference Fully Connected: bias type not supported.");
1513
1514 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
1515 "Reference Fully Connected: bias and weight types mismatch.");
1516
1517 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
1518 "Reference Fully Connected: bias type inferred from weights is incompatible.");
1519
Narumol Prangnawarat366d7232020-04-29 12:58:17 +01001520 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(biases, 1U), reasonIfUnsupported,
1521 "Reference Fully Connected: bias must have 1 dimension.");
1522
Francis Murtagh46c09d02019-05-28 08:15:28 +01001523 }
1524
1525 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001526}
1527
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01001528bool RefLayerSupport::IsGatherNdSupported(const armnn::TensorInfo& input0,
1529 const armnn::TensorInfo& input1,
1530 const armnn::TensorInfo& output,
1531 armnn::Optional<std::string&> reasonIfUnsupported) const
1532{
1533 bool supported = true;
1534 std::array<DataType,7> supportedTypes =
1535 {
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01001536 DataType::Float32,
1537 DataType::Float16,
1538 DataType::QAsymmS8,
1539 DataType::QAsymmU8,
1540 DataType::QSymmS16,
1541 DataType::Signed32
1542 };
1543
1544 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1545 "Reference GatherNd: input type not supported");
1546
1547 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1548 "Reference GatherNd: output type not supported");
1549
1550 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1551 "Reference GatherNd: indices (input1) type not supported");
1552
1553 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1554 "Reference GatherNd: input and output types not matching");
1555
1556 return supported;
1557}
1558
narpra014951d842019-01-18 16:53:53 +00001559bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
1560 const armnn::TensorInfo& input1,
1561 const armnn::TensorInfo& output,
Teresa Charlin52664732020-06-29 16:27:03 +01001562 const GatherDescriptor& descriptor,
narpra014951d842019-01-18 16:53:53 +00001563 armnn::Optional<std::string&> reasonIfUnsupported) const
1564{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001565 bool supported = true;
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001566 std::array<DataType,7> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001567 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001568 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001569 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001570 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001571 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001572 DataType::QSymmS16,
1573 DataType::Signed32
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001574 };
1575
Teresa Charlin52664732020-06-29 16:27:03 +01001576 if (descriptor.m_Axis != 0)
1577 {
1578 reasonIfUnsupported.value() += std::string("Reference Gather: axis not supported\n");
1579 supported &= false;
1580 }
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001581 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1582 "Reference Gather: input type not supported");
1583
1584 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1585 "Reference Gather: output type not supported");
1586
1587 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1588 "Reference Gather: indices (input1) type not supported");
1589
1590 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1591 "Reference Gather: input and output types not matching");
1592
1593 return supported;
narpra014951d842019-01-18 16:53:53 +00001594}
1595
Derek Lamberti901ea112019-12-10 22:07:09 +00001596bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
1597 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +01001598{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001599 return true;
arovir011c7c81b2018-10-08 11:34:28 +01001600}
1601
Kevin May09ca49c2019-10-09 12:37:34 +01001602bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
1603 const TensorInfo& output,
1604 const InstanceNormalizationDescriptor& descriptor,
1605 Optional<std::string&> reasonIfUnsupported) const
1606{
Jan Eilers8eb25602020-03-09 12:13:48 +00001607 IgnoreUnused(descriptor);
Kevin May09ca49c2019-10-09 12:37:34 +01001608 // Define supported types
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001609 std::array<DataType, 3> supportedTypes =
Kevin May09ca49c2019-10-09 12:37:34 +01001610 {
1611 DataType::Float32,
1612 DataType::Float16
1613 };
1614
1615 bool supported = true;
1616
1617 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1618 "Reference Instance Normalization: input type not supported.");
1619
1620 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1621 "Reference Instance Normalization: output type not supported.");
1622
1623 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1624 "Reference Instance Normalization: input and output types mismatched.");
1625
1626 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1627 "Reference Instance Normalization: input and output shapes have different "
1628 "num total elements.");
1629
1630 return supported;
1631}
1632
arovir011c7c81b2018-10-08 11:34:28 +01001633bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1634 const TensorInfo& output,
1635 const L2NormalizationDescriptor& descriptor,
1636 Optional<std::string&> reasonIfUnsupported) const
1637{
Jan Eilers8eb25602020-03-09 12:13:48 +00001638 IgnoreUnused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001639 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01001640 std::array<DataType, 6> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001641 {
1642 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001643 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001644 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001645 DataType::QAsymmU8,
1646 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001647 };
1648
1649 bool supported = true;
1650
1651 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1652 "Reference L2normalization: input type not supported.");
1653
1654 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1655 "Reference L2normalization: output type not supported.");
1656
1657 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1658 "Reference L2normalization: input and output types mismatched.");
1659
1660 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1661 "Reference L2normalization: input and output shapes have different "
1662 "num total elements.");
1663
1664 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001665}
1666
James Conroyaba90cd2020-11-06 16:28:18 +00001667bool RefLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0,
1668 const TensorInfo& input1,
1669 const TensorInfo& output,
1670 const LogicalBinaryDescriptor& descriptor,
1671 Optional<std::string&> reasonIfUnsupported) const
1672{
1673 IgnoreUnused(descriptor);
1674
1675 std::array<DataType, 1> supportedTypes =
1676 {
1677 DataType::Boolean
1678 };
1679
1680 bool supported = true;
1681 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1682 "Reference LogicalBinary: input 0 type not supported");
1683 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1684 "Reference LogicalBinary: input 1 type not supported");
1685
1686 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1687 "Reference LogicalBinary: input and output types do not match");
1688
1689 return supported;
1690}
1691
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001692bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1693 const TensorInfo& output,
1694 const LogSoftmaxDescriptor& descriptor,
1695 Optional<std::string&> reasonIfUnsupported) const
1696{
Jan Eilers8eb25602020-03-09 12:13:48 +00001697 IgnoreUnused(descriptor);
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001698
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001699 std::array<DataType, 3> supportedTypes =
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001700 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001701 DataType::Float32,
1702 DataType::Float16
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001703 };
1704
1705 bool supported = true;
1706 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1707 "Reference LogSoftmax: input type not supported");
1708
1709 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1710 "Reference LogSoftmax: output type not supported");
1711
1712 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1713 "Reference LogSoftmax: input and output types do not match");
1714
1715 return supported;
1716}
1717
arovir011c7c81b2018-10-08 11:34:28 +01001718bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1719 const TensorInfo& outputStateIn,
1720 const TensorInfo& cellStateIn,
1721 const TensorInfo& scratchBuffer,
1722 const TensorInfo& outputStateOut,
1723 const TensorInfo& cellStateOut,
1724 const TensorInfo& output,
1725 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001726 const LstmInputParamsInfo& paramsInfo,
1727 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +01001728{
Jan Eilers8eb25602020-03-09 12:13:48 +00001729 IgnoreUnused(descriptor);
1730 IgnoreUnused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001731
1732 bool supported = true;
1733
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001734 std::array<DataType,3> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +01001735 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001736 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001737 };
1738
Jan Eilersd01a83c2019-07-03 18:20:40 +01001739 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001740 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1741 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001742 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1743 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001744 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1745 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001746 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1747 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001748 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1749 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001750 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1751 "Reference Lstm: input and cellStateOut types are mismatched");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01001752
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001753 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1754 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001755 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001756 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001757 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001758 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001759 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001760 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001761 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001762 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001763 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001764 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001765 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001766 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001767 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001768 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001769 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001770 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001771 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001772 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001773 "Reference Lstm: input and OutputGateBias types are mismatched");
1774 if (!descriptor.m_CifgEnabled)
1775 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001776 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001777 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001778 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001779 reasonIfUnsupported,
1780 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001781 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001782 "Reference Lstm: input and InputGateBias types are mismatched");
1783 if (descriptor.m_PeepholeEnabled)
1784 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001785 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001786 reasonIfUnsupported,
1787 "Reference Lstm: input and CellToInputWeights types are mismatched");
1788 }
1789 }
1790 if (descriptor.m_PeepholeEnabled)
1791 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001792 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001793 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001794 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001795 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1796 }
1797 if (descriptor.m_ProjectionEnabled)
1798 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001799 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001800 "Reference Lstm: input and mProjectionWeights types are mismatched");
1801 if (paramsInfo.m_ProjectionBias != nullptr)
1802 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001803 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001804 "Reference Lstm: input and ProjectionBias types are mismatched");
1805 }
1806 }
1807 if (descriptor.m_LayerNormEnabled)
1808 {
1809 if (!descriptor.m_CifgEnabled)
1810 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001811 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001812 reasonIfUnsupported,
1813 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1814 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001815 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001816 reasonIfUnsupported,
1817 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001818 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001819 reasonIfUnsupported,
1820 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001821 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001822 reasonIfUnsupported,
1823 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1824 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001825
1826 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001827}
1828
saoste012df12b32018-11-28 16:57:20 +00001829bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1830 const TensorInfo& input1,
1831 const TensorInfo& output,
1832 Optional<std::string&> reasonIfUnsupported) const
1833{
Sadik Armagan2999a022019-04-09 14:20:12 +01001834 bool supported = true;
1835
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001836 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001837 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001838 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001839 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001840 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001841 DataType::QSymmS16,
1842 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001843 };
1844
1845 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1846 "Reference maximum: input 0 is not a supported type.");
1847
1848 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1849 "Reference maximum: input 1 is not a supported type.");
1850
1851 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1852 "Reference maximum: output is not a supported type.");
1853
1854 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1855 "Reference maximum: input 0 and Input 1 types are mismatched");
1856
1857 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1858 "Reference maximum: input and output types are mismatched");
1859
1860 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1861 "Reference maximum: shapes are not suitable for implicit broadcast.");
1862
1863 return supported;
saoste012df12b32018-11-28 16:57:20 +00001864}
1865
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001866bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1867 const TensorInfo& output,
1868 const MeanDescriptor& descriptor,
1869 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001870{
James Conroy4d1ff582019-06-10 17:06:39 +01001871 bool supported = true;
1872 std::string meanLayerStr = "Mean";
1873 std::string outputTensorStr = "output";
1874
Sadik Armagan303980c2020-04-17 12:45:14 +01001875 std::array<DataType,6> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001876 {
1877 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001878 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001879 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001880 DataType::QAsymmU8,
1881 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01001882 };
1883
1884 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1885 "Reference Mean: input type not supported.");
1886
1887 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1888 "Reference Mean: input and output types are mismatched");
1889
1890 if (descriptor.m_KeepDims)
1891 {
1892 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1893 reasonIfUnsupported,
1894 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1895 output.GetNumDimensions(),
1896 meanLayerStr, outputTensorStr).data());
1897 }
1898 else if (descriptor.m_Axis.empty())
1899 {
1900 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1901 reasonIfUnsupported,
1902 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1903 meanLayerStr, outputTensorStr).data());
1904 }
1905 else
1906 {
Matthew Sloyan171214c2020-09-09 09:07:37 +01001907 auto outputDim = input.GetNumDimensions() - armnn::numeric_cast<unsigned int>(descriptor.m_Axis.size());
James Conroy4d1ff582019-06-10 17:06:39 +01001908
1909 if (outputDim > 0)
1910 {
1911 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1912 reasonIfUnsupported,
1913 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1914 meanLayerStr, outputTensorStr).data());
1915 }
1916 else
1917 {
1918 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1919 reasonIfUnsupported,
1920 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1921 meanLayerStr, outputTensorStr).data());
1922 }
1923 }
1924
1925 return supported;
narpra0132b90462018-09-13 11:07:48 +01001926}
1927
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001928bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1929 const TensorInfo &output,
1930 Optional<std::string &> reasonIfUnsupported) const
1931{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001932 bool supported = true;
1933
Sadik Armagan303980c2020-04-17 12:45:14 +01001934 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001935 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001936 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001937 DataType::Float32,
1938 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001939 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001940 DataType::QAsymmU8,
1941 DataType::QSymmS16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001942 DataType::Boolean
1943 };
1944
1945 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1946 "Reference MemCopy: input type not supported");
1947
1948 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1949 "Reference MemCopy: output type not supported");
1950
1951 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1952 "Reference MemCopy: input and output types are mismatched");
1953
1954 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001955}
1956
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001957bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1958 const TensorInfo& input1,
1959 const TensorInfo& output,
1960 Optional<std::string&> reasonIfUnsupported) const
1961{
Sadik Armagan2999a022019-04-09 14:20:12 +01001962 bool supported = true;
1963
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001964 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001965 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001966 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001967 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001968 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001969 DataType::QSymmS16,
1970 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001971 };
1972
1973 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1974 "Reference minimum: input 0 is not a supported type.");
1975
1976 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1977 "Reference minimum: input 1 is not a supported type.");
1978
1979 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1980 "Reference minimum: output is not a supported type.");
1981
1982 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1983 "Reference minimum: input 0 and Input 1 types are mismatched");
1984
1985 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1986 "Reference minimum: input and output types are mismatched");
1987
1988 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1989 "Reference minimum: shapes are not suitable for implicit broadcast.");
1990
1991 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001992}
1993
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001994bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1995 const TensorInfo& input1,
1996 const TensorInfo& output,
1997 Optional<std::string&> reasonIfUnsupported) const
1998{
Sadik Armagan2999a022019-04-09 14:20:12 +01001999 bool supported = true;
2000
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002001 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01002002 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002003 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002004 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002005 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002006 DataType::QSymmS16,
2007 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002008 };
2009
2010 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2011 "Reference multiplication: input 0 is not a supported type.");
2012
2013 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2014 "Reference multiplication: input 1 is not a supported type.");
2015
2016 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2017 "Reference multiplication: output is not a supported type.");
2018
2019 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2020 "Reference multiplication: input 0 and Input 1 types are mismatched");
2021
2022 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2023 "Reference multiplication: input and output types are mismatched");
2024
2025 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2026 "Reference multiplication: shapes are not suitable for implicit broadcast.");
2027
2028 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002029}
2030
2031bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
2032 const TensorInfo& output,
2033 const NormalizationDescriptor& descriptor,
2034 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01002035{
Jan Eilers8eb25602020-03-09 12:13:48 +00002036 IgnoreUnused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002037
2038 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01002039 std::array<DataType, 6> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002040 {
2041 DataType::Float16,
2042 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002043 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002044 DataType::QAsymmU8,
2045 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002046 };
2047
2048 bool supported = true;
2049
2050 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2051 "Reference normalization: input type not supported.");
2052
2053 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2054 "Reference normalization: output type not supported.");
2055
2056 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2057 "Reference normalization: input and output shapes have different "
2058 "num total elements.");
2059
2060 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002061}
2062
Derek Lamberti901ea112019-12-10 22:07:09 +00002063bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
2064 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002065{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01002066 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002067}
2068
2069bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
2070 const TensorInfo& output,
2071 const PadDescriptor& descriptor,
2072 Optional<std::string&> reasonIfUnsupported) const
2073{
Jan Eilers8eb25602020-03-09 12:13:48 +00002074 IgnoreUnused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002075 bool supported = true;
2076
2077 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002078 std::array<DataType,6> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002079 {
2080 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002081 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002082 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002083 DataType::QAsymmU8,
2084 DataType::QSymmS16
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002085 };
2086
2087 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2088 "Reference pad: input is not a supported type.");
2089
2090 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2091 "Reference pad: output is not a supported type.");
2092
2093 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2094 "Reference pad: input and output types are mismatched.");
2095
2096 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01002097}
2098
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002099bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
2100 const TensorInfo& output,
2101 const PermuteDescriptor& descriptor,
2102 Optional<std::string&> reasonIfUnsupported) const
2103{
Jan Eilers8eb25602020-03-09 12:13:48 +00002104 IgnoreUnused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002105 bool supported = true;
2106
2107 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002108 std::array<DataType, 6> supportedTypes =
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002109 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002110 DataType::BFloat16,
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002111 DataType::Float32,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002112 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002113 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002114 DataType::QAsymmU8,
2115 DataType::QSymmS16
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002116 };
2117
2118 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2119 "Reference permute: input is not a supported type.");
2120
2121 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2122 "Reference permute: output is not a supported type.");
2123
2124 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2125 "Reference permute: input and output types are mismatched.");
2126
2127 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002128}
2129
2130bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
2131 const TensorInfo& output,
2132 const Pooling2dDescriptor& descriptor,
2133 Optional<std::string&> reasonIfUnsupported) const
2134{
Jan Eilers8eb25602020-03-09 12:13:48 +00002135 IgnoreUnused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01002136 bool supported = true;
2137
2138 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002139 std::array<DataType,6> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01002140 {
2141 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002142 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002143 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002144 DataType::QAsymmU8,
2145 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01002146 };
2147
2148 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2149 "Reference poolind2d: input is not a supported type.");
2150
2151 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2152 "Reference poolind2d: output is not a supported type.");
2153
2154 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2155 "Reference poolind2d: input and output types are mismatched.");
2156
2157 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002158}
2159
Tamás Nyíri7b885b32021-10-26 14:47:57 +01002160bool RefLayerSupport::IsPooling3dSupported(const TensorInfo& input,
2161 const TensorInfo& output,
2162 const Pooling3dDescriptor& descriptor,
2163 Optional<std::string&> reasonIfUnsupported) const
2164{
2165 IgnoreUnused(descriptor);
2166 bool supported = true;
2167
2168 // Define supported output and inputs types.
2169 std::array<DataType,6> supportedTypes =
2170 {
Tamás Nyíri7b885b32021-10-26 14:47:57 +01002171 DataType::Float32,
2172 DataType::Float16,
2173 DataType::QAsymmS8,
2174 DataType::QAsymmU8,
2175 DataType::QSymmS16
2176 };
2177
2178 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2179 "Reference poolind3d: input is not a supported type.");
2180
2181 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2182 "Reference poolind3d: output is not a supported type.");
2183
2184 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2185 "Reference poolind3d: input and output types are mismatched.");
2186
2187 return supported;
2188}
2189
2190
James Conroy4f1f8992020-04-29 20:01:10 +01002191bool RefLayerSupport::IsQLstmSupported(const TensorInfo& input,
2192 const TensorInfo& previousOutputIn,
2193 const TensorInfo& previousCellStateIn,
2194 const TensorInfo& outputStateOut,
2195 const TensorInfo& cellStateOut,
2196 const TensorInfo& output,
2197 const QLstmDescriptor& descriptor,
2198 const LstmInputParamsInfo& paramsInfo,
2199 Optional<std::string&> reasonIfUnsupported) const
2200{
2201 IgnoreUnused(input);
2202 IgnoreUnused(previousOutputIn);
2203 IgnoreUnused(previousCellStateIn);
2204 IgnoreUnused(outputStateOut);
2205 IgnoreUnused(cellStateOut);
2206 IgnoreUnused(output);
2207 IgnoreUnused(descriptor);
2208 IgnoreUnused(paramsInfo);
2209
2210 IgnoreUnused(reasonIfUnsupported);
2211
2212 return true;
2213}
2214
Derek Lamberti5f400d62019-03-25 15:41:58 +00002215bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
2216 const TensorInfo& output,
2217 Optional<std::string&> reasonIfUnsupported) const
2218{
2219 bool supported = true;
2220
Finn Williamsfd271062019-12-04 14:27:27 +00002221 // Define supported input types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002222 std::array<DataType,7> supportedInputTypes = {
Keith Davis5e51cd82020-01-29 16:52:59 +00002223 DataType::Float32,
Keith Davis3d8bc972020-02-04 09:31:47 +00002224 DataType::Float16,
Ryan OShea9add1202020-02-07 10:06:33 +00002225 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002226 DataType::QAsymmU8,
2227 DataType::QSymmS8,
2228 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002229 };
2230
2231 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
2232 "Reference quantize: input type not supported.");
2233
2234 // Define supported output types.
Ryan OShea9add1202020-02-07 10:06:33 +00002235 std::array<DataType,4> supportedOutputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00002236 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002237 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00002238 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002239 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002240 };
2241 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2242 "Reference quantize: output type not supported.");
2243
2244 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2245 "Reference quantize: input and output shapes have different num total elements.");
2246
2247 return supported;
2248}
2249
Finn Williams2605b232020-06-10 15:53:46 +01002250bool RefLayerSupport::IsRankSupported(const TensorInfo& input,
2251 const TensorInfo& output,
2252 Optional<std::string&> reasonIfUnsupported) const
2253{
2254 IgnoreUnused(input);
2255 // Define supported output types.
2256 std::array<DataType,1> supportedOutputTypes =
2257 {
2258 DataType::Signed32,
2259 };
2260
2261 return CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2262 "Reference rank: input type not supported.");
2263}
2264
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00002265bool RefLayerSupport::IsReduceSupported(const TensorInfo& input,
2266 const TensorInfo& output,
2267 const ReduceDescriptor& descriptor,
2268 Optional<std::string&> reasonIfUnsupported) const
2269{
2270 IgnoreUnused(descriptor);
2271 bool supported = true;
2272 std::array<DataType,7> supportedTypes =
2273 {
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00002274 DataType::Float32,
2275 DataType::Float16,
2276 DataType::QAsymmS8,
2277 DataType::QAsymmU8,
2278 DataType::QSymmS16,
2279 DataType::Signed32
2280 };
2281
2282 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2283 "Reference Reduce: input type not supported");
2284
2285 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2286 "Reference Reduce: output type not supported");
2287
2288 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2289 "Reference Reduce: input and output types not matching");
2290
2291 return supported;
2292}
2293
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002294bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Kevin Maya023c402019-12-12 17:28:05 +00002295 const TensorInfo& output,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00002296 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002297 Optional<std::string&> reasonIfUnsupported) const
2298{
Jan Eilers8eb25602020-03-09 12:13:48 +00002299 IgnoreUnused(output);
2300 IgnoreUnused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01002301 // Define supported output types.
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002302 std::array<DataType,8> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01002303 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002304 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01002305 DataType::Float32,
2306 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01002307 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002308 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002309 DataType::QAsymmU8,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002310 DataType::QSymmS16,
2311 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01002312 };
Keith Davis0c2eeac2020-02-11 16:51:50 +00002313
Nina Drozd2f2778f2019-05-27 10:37:05 +01002314 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
2315 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002316}
2317
Teresa Charlin970f43b2019-07-01 13:51:07 +01002318bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
2319 const TensorInfo& output,
2320 const ResizeDescriptor& descriptor,
2321 Optional<std::string&> reasonIfUnsupported) const
2322{
Jan Eilers8eb25602020-03-09 12:13:48 +00002323 IgnoreUnused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01002324 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002325 std::array<DataType,6> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01002326 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002327 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01002328 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002329 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002330 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002331 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002332 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01002333 };
2334
2335 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2336 "Reference Resize: input type not supported");
2337
2338 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2339 "Reference Resize: output type not supported");
2340
2341 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2342 "Reference Resize: input and output types not matching");
2343
2344 return supported;
2345}
2346
Keith Davis3ae3f972021-05-21 16:33:48 +01002347bool RefLayerSupport::IsShapeSupported(const TensorInfo& input,
2348 const TensorInfo& output,
2349 Optional<std::string&> reasonIfUnsupported) const
2350{
2351 IgnoreUnused(input);
2352 bool supported = true;
2353
2354 std::array<DataType, 1> supportedTypes =
2355 {
2356 DataType::Signed32
2357 };
2358
2359 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2360 "Reference Shape: output type not supported");
2361
2362 return supported;
2363}
2364
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002365bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
2366 const TensorInfo& output,
2367 const SliceDescriptor& descriptor,
2368 Optional<std::string&> reasonIfUnsupported) const
2369{
Jan Eilers8eb25602020-03-09 12:13:48 +00002370 IgnoreUnused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002371 bool supported = true;
2372
Sadik Armagan303980c2020-04-17 12:45:14 +01002373 std::array<DataType, 5> supportedTypes =
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002374 {
2375 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002376 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002377 DataType::QAsymmU8,
2378 DataType::QSymmS16
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002379 };
2380
2381 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2382 "Reference Slice: input type not supported");
2383
2384 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2385 "Reference Slice: output type not supported");
2386
2387 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2388 "Reference Slice: input and output types are mismatched");
2389
2390 return supported;
2391}
2392
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002393bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
2394 const TensorInfo& output,
2395 const SoftmaxDescriptor& descriptor,
2396 Optional<std::string&> reasonIfUnsupported) const
2397{
Jan Eilers8eb25602020-03-09 12:13:48 +00002398 IgnoreUnused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01002399 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002400 std::array<DataType,7> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01002401 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002402 DataType::Float32,
2403 DataType::Float16,
2404 DataType::QSymmS8,
2405 DataType::QAsymmS8,
2406 DataType::QAsymmU8,
2407 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +01002408 };
2409
2410 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002411 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002412
2413 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002414 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002415
2416 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002417 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002418
2419 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002420}
2421
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002422bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
2423 const TensorInfo& output,
2424 const SpaceToBatchNdDescriptor& descriptor,
2425 Optional<std::string&> reasonIfUnsupported) const
2426{
Jan Eilers8eb25602020-03-09 12:13:48 +00002427 IgnoreUnused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01002428 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002429 std::array<DataType,6> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01002430 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002431 DataType::Float32,
2432 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002433 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002434 DataType::QAsymmU8,
2435 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01002436 };
2437
2438 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2439 "Reference SpaceToBatchNd: input type not supported");
2440
2441 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2442 "Reference SpaceToBatchNd: output type not supported");
2443
2444 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2445 "Reference SpaceToBatchNd: input and output types are mismatched");
2446
2447 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002448}
2449
Keith Davisa57eccb2019-06-14 17:33:22 +01002450bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01002451 const TensorInfo& output,
2452 const SpaceToDepthDescriptor& descriptor,
2453 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01002454{
2455
Jan Eilers8eb25602020-03-09 12:13:48 +00002456 IgnoreUnused(descriptor);
Keith Davisa57eccb2019-06-14 17:33:22 +01002457 bool supported = true;
2458
Sadik Armagan303980c2020-04-17 12:45:14 +01002459 std::array<DataType,6> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01002460 {
2461 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002462 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002463 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002464 DataType::QAsymmU8,
2465 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01002466 };
2467
2468 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2469 "Reference SpaceToDepth: input type not supported");
2470
2471 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2472 "Reference SpaceToDepth: output type not supported");
2473
2474 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2475 "Reference SpaceToDepth: input and output types are mismatched");
2476
2477 return supported;
2478}
2479
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002480bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002481 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
2482 const ViewsDescriptor& descriptor,
2483 Optional<std::string&> reasonIfUnsupported) const
2484{
Jan Eilers8eb25602020-03-09 12:13:48 +00002485 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002486 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002487 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002488 {
2489 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002490 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002491 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002492 DataType::QAsymmU8,
2493 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002494 };
2495
2496 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2497 "Reference splitter: output type not supported");
Derek Lambertieac4adb2020-08-25 13:05:59 +01002498 for (const TensorInfo& output : outputs)
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002499 {
2500 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2501 "Reference splitter: input type not supported");
2502
2503 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2504 "Reference splitter: input and output types mismatched.");
2505 }
2506
2507 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002508}
2509
Matthew Jackson81e601c2019-07-11 12:07:09 +01002510bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
2511 const TensorInfo& output,
2512 const StackDescriptor& descriptor,
2513 Optional<std::string&> reasonIfUnsupported) const
2514{
Jan Eilers8eb25602020-03-09 12:13:48 +00002515 IgnoreUnused(descriptor);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002516
2517 bool supported = true;
Sadik Armagan529195f2022-01-14 12:56:35 +00002518 std::array<DataType,7> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01002519 {
2520 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01002521 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002522 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002523 DataType::QAsymmU8,
Sadik Armagan529195f2022-01-14 12:56:35 +00002524 DataType::QSymmS16,
2525 DataType::Signed32
Matthew Jackson81e601c2019-07-11 12:07:09 +01002526 };
2527
2528 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2529 "Reference stack: output type not supported");
2530 for (const TensorInfo* input : inputs)
2531 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01002532 ARMNN_ASSERT(input != nullptr);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002533 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
2534 "Reference stack: input type not supported");
2535
2536 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
2537 "Reference stack: input and output types mismatched.");
2538 }
2539
2540 return supported;
2541}
2542
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002543bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
2544 const TensorInfo& output,
2545 const StridedSliceDescriptor& descriptor,
2546 Optional<std::string&> reasonIfUnsupported) const
2547{
Jan Eilers8eb25602020-03-09 12:13:48 +00002548 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002549 bool supported = true;
2550
Sadik Armagan303980c2020-04-17 12:45:14 +01002551 std::array<DataType,5> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002552 {
2553 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002554 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002555 DataType::QAsymmU8,
2556 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002557 };
2558
2559 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2560 "Reference StridedSlice: input type not supported");
2561
2562 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2563 "Reference StridedSlice: output type not supported");
2564
2565 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2566 "Reference StridedSlice: input and output types are mismatched");
2567
2568 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002569}
2570
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002571bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
2572 const TensorInfo& input1,
2573 const TensorInfo& output,
2574 Optional<std::string&> reasonIfUnsupported) const
2575{
Sadik Armagan2999a022019-04-09 14:20:12 +01002576 bool supported = true;
2577
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002578 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01002579 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002580 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002581 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002582 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002583 DataType::QSymmS16,
2584 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002585 };
2586
2587 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2588 "Reference subtraction: input 0 is not a supported type.");
2589
2590 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2591 "Reference subtraction: input 1 is not a supported type.");
2592
2593 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2594 "Reference subtraction: output is not a supported type.");
2595
2596 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2597 "Reference subtraction: input 0 and Input 1 types are mismatched");
2598
2599 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2600 "Reference subtraction: input and output types are mismatched");
2601
2602 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2603 "Reference subtraction: shapes are not suitable for implicit broadcast.");
2604
2605 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002606}
2607
Matteo Martincighab9e5252019-06-13 17:27:46 +01002608bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
2609 const TensorInfo& alpha,
2610 const TensorInfo& output,
2611 Optional<std::string&> reasonIfUnsupported) const
2612{
2613 bool supported = true;
2614
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002615 std::array<DataType, 6> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01002616 {
2617 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002618 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002619 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002620 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002621 DataType::QSymmS16
Matteo Martincighab9e5252019-06-13 17:27:46 +01002622 };
2623
2624 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2625 "PReLU: input is not a supported type.");
2626
2627 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
2628 "PReLU: alpha is not a supported type.");
2629
2630 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2631 "PReLU: output is not a supported type.");
2632
2633 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
2634 "PReLU: input, alpha and output types are mismatched");
2635
2636 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
2637 "PReLU: shapes are not suitable for implicit broadcast");
2638
2639 return supported;
2640}
2641
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002642bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
2643 const TensorInfo& output,
2644 const TransposeConvolution2dDescriptor& descriptor,
2645 const TensorInfo& weights,
2646 const Optional<TensorInfo>& biases,
2647 Optional<std::string&> reasonIfUnsupported) const
2648{
Jan Eilers8eb25602020-03-09 12:13:48 +00002649 IgnoreUnused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002650 bool supported = true;
2651
Sadik Armagan303980c2020-04-17 12:45:14 +01002652 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002653 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002654 DataType::Float32,
2655 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002656 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002657 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002658 DataType::QSymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002659 DataType::QSymmS16
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002660 };
2661
2662 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2663 "Reference TransposeConvolution2d: input is not a supported type.");
2664
2665 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2666 "Reference TransposeConvolution2d: output is not a supported type.");
2667
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002668 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2669 "Reference TransposeConvolution2d: input and output types mismatched.");
2670
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002671
2672 const DataType inputType = input.GetDataType();
Sadik Armagan303980c2020-04-17 12:45:14 +01002673 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002674 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01002675 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002676 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002677 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002678 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01002679 DataType::QSymmS8
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002680 };
2681
2682 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
2683 "Reference TransposeConvolution2d: weights type not supported for "
2684 "quantized input.");
2685 }
2686 else
2687 {
2688 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
2689 "Reference TransposeConvolution2d: weights is not a supported type.");
2690
2691 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
2692 "Reference TransposeConvolution2d: input and weights types mismatched.");
2693 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002694
2695 if (biases.has_value())
2696 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002697 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01002698 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002699 DataType::Float32,
2700 DataType::Float16,
2701 DataType::Signed32
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002702 };
2703 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
2704 "Reference TransposeConvolution2d: biases is not a supported type.");
2705 }
2706
2707 return supported;
2708}
2709
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002710bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
2711 const TensorInfo& output,
2712 const TransposeDescriptor& descriptor,
2713 Optional<std::string&> reasonIfUnsupported) const
2714{
Jan Eilers8eb25602020-03-09 12:13:48 +00002715 IgnoreUnused(descriptor);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002716 bool supported = true;
2717
2718 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002719 std::array<DataType, 6> supportedTypes =
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002720 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002721 DataType::BFloat16,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002722 DataType::Float32,
2723 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002724 DataType::QAsymmS8,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002725 DataType::QAsymmU8,
2726 DataType::QSymmS16
2727 };
2728
2729 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2730 "Reference transpose: input is not a supported type.");
2731
2732 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2733 "Reference transpose: output is not a supported type.");
2734
2735 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2736 "Reference transpose: input and output types are mismatched.");
2737
2738 return supported;
2739}
2740
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002741bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
2742 const TensorInfo& input,
2743 const TensorInfo& outputStateIn,
2744 const TensorInfo& cellStateIn,
Mike Kelly12994962022-04-21 11:57:09 +01002745 const TensorInfo& outputStateOut,
2746 const TensorInfo& cellStateOut,
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002747 const TensorInfo& output,
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002748 const UnidirectionalSequenceLstmDescriptor& descriptor,
2749 const LstmInputParamsInfo& paramsInfo,
2750 Optional<std::string&> reasonIfUnsupported) const
2751{
2752 IgnoreUnused(descriptor);
2753 IgnoreUnused(paramsInfo);
2754 IgnoreUnused(outputStateIn);
2755 IgnoreUnused(cellStateIn);
Mike Kelly12994962022-04-21 11:57:09 +01002756 IgnoreUnused(outputStateOut);
2757 IgnoreUnused(cellStateOut);
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002758 bool supported = true;
2759
Mike Kelly12994962022-04-21 11:57:09 +01002760 std::array<DataType, 2> supportedTypes =
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002761 {
Mike Kelly12994962022-04-21 11:57:09 +01002762 DataType::Float32,
2763 DataType::QAsymmS8
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002764 };
2765
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002766 std::array<DataType, 2> supportedWeightTypes =
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002767 {
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002768 DataType::Float32,
2769 DataType::QAsymmS8
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002770 };
2771
Mike Kelly12994962022-04-21 11:57:09 +01002772 std::array<DataType, 3> supportedBiasTypes =
2773 {
2774 DataType::Float32,
2775 DataType::QAsymmS8,
2776 DataType::Signed32
2777 };
2778
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002779 // check inputs and outputs
2780 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2781 "Reference UnidirectionalSequenceLstm: input is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002782 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2783 "Reference UnidirectionalSequenceLstm: output is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002784
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002785 // check layer parameters
2786 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes),
2787 reasonIfUnsupported,
2788 "Reference UnidirectionalSequenceLstm: InputToForgetWeights "
2789 "is not a supported type.");
2790 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToCellWeights(), supportedWeightTypes),
2791 reasonIfUnsupported,
2792 "Reference UnidirectionalSequenceLstm: InputToCellWeights is not a supported type.");
2793 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToOutputWeights(), supportedWeightTypes),
2794 reasonIfUnsupported,
2795 "Reference UnidirectionalSequenceLstm: InputToOutputWeights "
2796 "is not a supported type.");
2797 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToForgetWeights(), supportedWeightTypes),
2798 reasonIfUnsupported,
2799 "Reference UnidirectionalSequenceLstm: RecurrentToForgetWeights "
2800 "is not a supported type.");
2801 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToCellWeights(), supportedWeightTypes),
2802 reasonIfUnsupported,
2803 "Reference UnidirectionalSequenceLstm: RecurrentToCellWeights "
2804 "is not a supported type.");
2805 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToOutputWeights(), supportedWeightTypes),
2806 reasonIfUnsupported,
2807 "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights "
2808 "is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002809
2810 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetGateBias(), supportedBiasTypes), reasonIfUnsupported,
2811 "Reference UnidirectionalSequenceLstm: ForgetGateBias is not a supported type.");
2812 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellBias(), supportedBiasTypes), reasonIfUnsupported,
2813 "Reference UnidirectionalSequenceLstm: CellBias is not a supported type.");
2814 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputGateBias(), supportedBiasTypes), reasonIfUnsupported,
2815 "Reference UnidirectionalSequenceLstm: OutputGateBias is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002816 if (!descriptor.m_CifgEnabled)
2817 {
2818 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes),
2819 reasonIfUnsupported,
2820 "Reference UnidirectionalSequenceLstm: InputToInputWeights "
2821 "is not a supported type.");
2822 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToInputWeights(), supportedWeightTypes),
2823 reasonIfUnsupported,
2824 "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights "
2825 "is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002826 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputGateBias(), supportedBiasTypes), reasonIfUnsupported,
2827 "Reference UnidirectionalSequenceLstm: InputGateBias is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002828 if (descriptor.m_PeepholeEnabled)
2829 {
2830 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes),
2831 reasonIfUnsupported,
2832 "Reference UnidirectionalSequenceLstm: CellToInputWeights "
2833 "is not a supported type.");
2834 }
2835 }
2836 if (descriptor.m_PeepholeEnabled)
2837 {
2838 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToForgetWeights(), supportedWeightTypes),
2839 reasonIfUnsupported,
2840 "Reference UnidirectionalSequenceLstm: CellToForgetWeights "
2841 "is not a supported type.");
2842 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToOutputWeights(), supportedWeightTypes),
2843 reasonIfUnsupported,
2844 "Reference UnidirectionalSequenceLstm: CellToOutputWeights "
2845 "is not a supported type.");
2846 }
2847 if (descriptor.m_ProjectionEnabled)
2848 {
2849 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetProjectionWeights(), supportedWeightTypes),
2850 reasonIfUnsupported,
2851 "Reference UnidirectionalSequenceLstm: ProjectionWeights "
2852 "is not a supported type.");
2853 if (paramsInfo.m_ProjectionBias != nullptr)
2854 {
2855 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
2856 "Reference UnidirectionalSequenceLstm: input and ProjectionBias types "
2857 "are mismatched");
2858 }
2859 }
2860 if (descriptor.m_LayerNormEnabled)
2861 {
2862 if (!descriptor.m_CifgEnabled)
2863 {
2864 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputLayerNormWeights(), supportedWeightTypes),
2865 reasonIfUnsupported,
2866 "Reference UnidirectionalSequenceLstm: InputLayerNormWeights "
2867 "is not a supported type.");
2868 }
2869 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetLayerNormWeights(), supportedWeightTypes),
2870 reasonIfUnsupported,
2871 "Reference UnidirectionalSequenceLstm: ForgetLayerNormWeights "
2872 "is not a supported type.");
2873 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellLayerNormWeights(), supportedWeightTypes),
2874 reasonIfUnsupported,
2875 "Reference UnidirectionalSequenceLstm: CellLayerNormWeights "
2876 "is not a supported type.");
2877 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputLayerNormWeights(), supportedWeightTypes),
2878 reasonIfUnsupported,
2879 "Reference UnidirectionalSequenceLstm: OutputLayerNormWeights "
2880 "is not a supported type.");
2881 }
2882
2883 return supported;
2884}
2885
arovir011c7c81b2018-10-08 11:34:28 +01002886} // namespace armnn