blob: cbc6723dbc99f568ccb2960774e4f9f336d38f6b [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
Mike Kelly3ec30772023-03-08 13:47:17 +00002// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01007
Keith Davis0c2eeac2020-02-11 16:51:50 +00008#include <armnn/TypesUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000010#include <armnn/utility/IgnoreUnused.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010011#include <armnn/utility/NumericCast.hpp>
Cathal Corbett34b429c2021-12-24 12:24:40 +000012#include <armnn/utility/PolymorphicDowncast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
Matteo Martincighe011d202019-11-28 11:35:47 +000014#include <LayerSupportCommon.hpp>
Derek Lambertif674aa02019-08-01 15:56:25 +010015#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000016
Derek Lamberti50db4e82019-03-13 14:16:15 +000017#include <vector>
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <array>
19
telsoa014fcda012018-03-09 14:13:49 +000020namespace armnn
21{
22
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010023namespace
24{
25
26template<typename Float32Func, typename Uint8Func, typename ... Params>
27bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
28 DataType dataType,
29 Float32Func floatFuncPtr,
30 Uint8Func uint8FuncPtr,
31 Params&&... params)
32{
33 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
34 dataType,
35 &FalseFunc<Params...>,
36 floatFuncPtr,
37 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000038 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000039 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010040 std::forward<Params>(params)...);
41}
42
43} // anonymous namespace
44
James Conroy4d1ff582019-06-10 17:06:39 +010045namespace
46{
47
48std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
49 unsigned int actual,
50 std::string& layerStr,
51 std::string& tensorName)
52{
53 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
54 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
55
56 return errorMsg;
57}
58
59} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000060
Cathal Corbett34b429c2021-12-24 12:24:40 +000061bool RefLayerSupport::IsLayerSupported(const LayerType& type,
62 const std::vector<TensorInfo>& infos,
63 const BaseDescriptor& descriptor,
64 const Optional<LstmInputParamsInfo>& lstmParamsInfo,
65 const Optional<QuantizedLstmInputParamsInfo>& quantizedLstmInputParamsInfo,
66 Optional<std::string&> reasonIfUnsupported) const
67{
68 switch (type)
69 {
70 case LayerType::Activation:
71 return IsActivationSupported(infos[0],
72 infos[1],
73 *(PolymorphicDowncast<const ActivationDescriptor*>(&descriptor)),
74 reasonIfUnsupported);
75 case LayerType::Addition:
76 return IsAdditionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
77 case LayerType::ArgMinMax:
78 return IsArgMinMaxSupported(infos[0],
79 infos[1],
80 *(PolymorphicDowncast<const ArgMinMaxDescriptor*>(&descriptor)),
81 reasonIfUnsupported);
Samuel Yap6b478092022-07-06 15:36:03 +010082 case LayerType::BatchMatMul:
83 return IsBatchMatMulSupported(infos[0],
84 infos[1],
85 infos[2],
86 *(PolymorphicDowncast<const BatchMatMulDescriptor*>(&descriptor)),
87 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +000088 case LayerType::BatchNormalization:
89 return IsBatchNormalizationSupported(infos[0],
90 infos[1],
91 infos[2],
92 infos[3],
93 infos[4],
94 infos[5],
95 *(PolymorphicDowncast<const BatchNormalizationDescriptor*>
96 (&descriptor)),
97 reasonIfUnsupported);
98 case LayerType::BatchToSpaceNd:
99 return IsBatchToSpaceNdSupported(infos[0],
100 infos[1],
101 *(PolymorphicDowncast<const BatchToSpaceNdDescriptor*>(&descriptor)),
102 reasonIfUnsupported);
103 case LayerType::Comparison:
104 return IsComparisonSupported(infos[0],
105 infos[1],
106 infos[2],
107 *(PolymorphicDowncast<const ComparisonDescriptor*>(&descriptor)),
108 reasonIfUnsupported);
109 case LayerType::Concat:
110 {
111 std::vector<const TensorInfo*> inputInfos;
112 for (uint32_t i = 0; i < (infos.size() - 1); i++)
113 {
114 inputInfos.push_back(&infos[i]);
115 }
116 return IsConcatSupported(inputInfos,
117 infos[infos.size() - 1],
118 *(PolymorphicDowncast<const OriginsDescriptor*>(&descriptor)),
119 reasonIfUnsupported);
120 }
121 case LayerType::Constant:
122 return IsConstantSupported(infos[0], reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000123 case LayerType::ConvertFp16ToFp32:
124 return IsConvertFp16ToFp32Supported(infos[0], infos[1], reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000125 case LayerType::ConvertFp32ToFp16:
126 return IsConvertFp32ToFp16Supported(infos[0], infos[1], reasonIfUnsupported);
127 case LayerType::Convolution2d:
128 {
129 if (infos.size() != 4)
130 {
131 throw InvalidArgumentException("Invalid number of Convolution2d TensorInfos. "
132 "TensorInfos should be of format: {input, output, weights, biases}.");
133 }
134
135 auto desc = *(PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor));
136 if (infos[3] == TensorInfo())
137 {
138 return IsConvolution2dSupported(infos[0],
139 infos[1],
140 desc,
141 infos[2],
142 EmptyOptional(),
143 reasonIfUnsupported);
144 }
145 else
146 {
147 return IsConvolution2dSupported(infos[0],
148 infos[1],
149 desc,
150 infos[2],
151 infos[3],
152 reasonIfUnsupported);
153 }
154 }
155 case LayerType::DepthToSpace:
156 return IsDepthToSpaceSupported(infos[0],
157 infos[1],
158 *(PolymorphicDowncast<const DepthToSpaceDescriptor*>(&descriptor)),
159 reasonIfUnsupported);
160 case LayerType::DepthwiseConvolution2d:
161 {
162 if (infos.size() != 4)
163 {
164 throw InvalidArgumentException("Invalid number of DepthwiseConvolution2d TensorInfos. "
165 "TensorInfos should be of format: {input, output, weights, biases}.");
166 }
167
168 auto desc = *(PolymorphicDowncast<const DepthwiseConvolution2dDescriptor*>(&descriptor));
169 if (infos[3] == TensorInfo())
170 {
171 return IsDepthwiseConvolutionSupported(infos[0],
172 infos[1],
173 desc,
174 infos[2],
175 EmptyOptional(),
176 reasonIfUnsupported);
177 }
178 else
179 {
180 return IsDepthwiseConvolutionSupported(infos[0],
181 infos[1],
182 desc,
183 infos[2],
184 infos[3],
185 reasonIfUnsupported);
186 }
187 }
188 case LayerType::Dequantize:
189 return IsDequantizeSupported(infos[0], infos[1], reasonIfUnsupported);
190 case LayerType::Division:
191 return IsDivisionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
Mike Kelly3ec30772023-03-08 13:47:17 +0000192 case LayerType::ElementwiseBinary:
193 {
194 std::array<DataType, 7> supportedTypes =
195 {
196 DataType::Float32,
197 DataType::Float16,
198 DataType::QAsymmS8,
199 DataType::QAsymmU8,
200 DataType::QSymmS16,
201 DataType::Signed32
202 };
203
204 bool supported = true;
205 supported &= CheckSupportRule(TypeAnyOf(infos[0], supportedTypes), reasonIfUnsupported,
206 "Reference elementwise unary: input type not supported");
207
208 supported &= CheckSupportRule(TypeAnyOf(infos[1], supportedTypes), reasonIfUnsupported,
209 "Reference elementwise unary: input type not supported");
210
211 supported &= CheckSupportRule(TypeAnyOf(infos[2], supportedTypes), reasonIfUnsupported,
212 "Reference elementwise unary: output type not supported");
213
214 supported &= CheckSupportRule(TypesAreEqual(infos[0], infos[1]), reasonIfUnsupported,
215 "Reference elementwise unary: input types not matching");
216
217 supported &= CheckSupportRule(TypesAreEqual(infos[0], infos[2]), reasonIfUnsupported,
218 "Reference elementwise unary: input and output types not matching");
219
220 return supported;
221 }
Cathal Corbett34b429c2021-12-24 12:24:40 +0000222 case LayerType::ElementwiseUnary:
223 return IsElementwiseUnarySupported(infos[0],
224 infos[1],
225 *(PolymorphicDowncast<const ElementwiseUnaryDescriptor*>(&descriptor)),
226 reasonIfUnsupported);
227 case LayerType::Fill:
228 return IsFillSupported(infos[0],
229 infos[1],
230 *(PolymorphicDowncast<const FillDescriptor*>(&descriptor)),
231 reasonIfUnsupported);
232 case LayerType::Floor:
233 return IsFloorSupported(infos[0], infos[1], reasonIfUnsupported);
234 case LayerType::FullyConnected:
235 return IsFullyConnectedSupported(infos[0],
236 infos[1],
237 infos[2],
238 infos[3],
239 *(PolymorphicDowncast<const FullyConnectedDescriptor*>(&descriptor)),
240 reasonIfUnsupported);
241 case LayerType::Gather:
242 return IsGatherSupported(infos[0],
243 infos[1],
244 infos[2],
245 *(PolymorphicDowncast<const GatherDescriptor*>(&descriptor)),
246 reasonIfUnsupported);
Teresa Charlinb2d3ec52022-04-12 22:07:09 +0100247 case LayerType::GatherNd:
248 return IsGatherNdSupported(infos[0],
249 infos[1],
250 infos[2],
251 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000252 case LayerType::Input:
253 return IsInputSupported(infos[0], reasonIfUnsupported);
254 case LayerType::InstanceNormalization:
255 return IsInstanceNormalizationSupported(infos[0],
256 infos[1],
257 *(PolymorphicDowncast<const InstanceNormalizationDescriptor*>
258 (&descriptor)),
259 reasonIfUnsupported);
260 case LayerType::L2Normalization:
261 return IsL2NormalizationSupported(infos[0],
262 infos[1],
263 *(PolymorphicDowncast<const L2NormalizationDescriptor*>(&descriptor)),
264 reasonIfUnsupported);
265 case LayerType::LogicalBinary:
266 return IsLogicalBinarySupported(infos[0],
267 infos[1],
268 infos[2],
269 *(PolymorphicDowncast<const LogicalBinaryDescriptor*>(&descriptor)),
270 reasonIfUnsupported);
271 case LayerType::LogSoftmax:
272 return IsLogSoftmaxSupported(infos[0],
273 infos[1],
274 *(PolymorphicDowncast<const LogSoftmaxDescriptor*>(&descriptor)),
275 reasonIfUnsupported);
276 case LayerType::Lstm:
277 return IsLstmSupported(infos[0],
278 infos[1],
279 infos[2],
280 infos[3],
281 infos[4],
282 infos[5],
283 infos[6],
284 *(PolymorphicDowncast<const LstmDescriptor*>(&descriptor)),
285 lstmParamsInfo.value(),
286 reasonIfUnsupported);
287 case LayerType::QLstm:
288 return IsQLstmSupported(infos[0],
289 infos[1],
290 infos[2],
291 infos[3],
292 infos[4],
293 infos[5],
294 *(PolymorphicDowncast<const QLstmDescriptor*>(&descriptor)),
295 lstmParamsInfo.value(),
296 reasonIfUnsupported);
297 case LayerType::Maximum:
298 return IsMaximumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
299 case LayerType::Mean:
300 return IsMeanSupported(infos[0],
301 infos[1],
302 *(PolymorphicDowncast<const MeanDescriptor*>(&descriptor)),
303 reasonIfUnsupported);
304 case LayerType::Minimum:
305 return IsMinimumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
306 case LayerType::Multiplication:
307 return IsMultiplicationSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
308 case LayerType::Normalization:
309 return IsNormalizationSupported(infos[0],
310 infos[1],
311 *(PolymorphicDowncast<const NormalizationDescriptor*>(&descriptor)),
312 reasonIfUnsupported);
313 case LayerType::Output:
314 return IsOutputSupported(infos[0], reasonIfUnsupported);
315 case LayerType::Pad:
316 return IsPadSupported(infos[0],
317 infos[1],
318 *(PolymorphicDowncast<const PadDescriptor*>(&descriptor)),
319 reasonIfUnsupported);
320 case LayerType::Permute:
321 return IsPermuteSupported(infos[0],
322 infos[1],
323 *(PolymorphicDowncast<const PermuteDescriptor*>(&descriptor)),
324 reasonIfUnsupported);
325 case LayerType::Pooling2d:
326 return IsPooling2dSupported(infos[0],
327 infos[1],
328 *(PolymorphicDowncast<const Pooling2dDescriptor*>(&descriptor)),
329 reasonIfUnsupported);
330 case LayerType::Prelu:
331 return IsPreluSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
332 case LayerType::Quantize:
333 return IsQuantizeSupported(infos[0], infos[1], reasonIfUnsupported);
334 case LayerType::Reshape:
335 return IsReshapeSupported(infos[0],
336 infos[1],
337 *(PolymorphicDowncast<const ReshapeDescriptor*>(&descriptor)),
338 reasonIfUnsupported);
339 case LayerType::Resize:
340 return IsResizeSupported(infos[0],
341 infos[1],
342 *(PolymorphicDowncast<const ResizeDescriptor*>(&descriptor)),
343 reasonIfUnsupported);
344 case LayerType::Reduce:
345 return IsReduceSupported(infos[0],
346 infos[1],
347 *(PolymorphicDowncast<const ReduceDescriptor*>(&descriptor)),
348 reasonIfUnsupported);
349 case LayerType::Slice:
350 return IsSliceSupported(infos[0],
351 infos[1],
352 *(PolymorphicDowncast<const SliceDescriptor*>(&descriptor)),
353 reasonIfUnsupported);
354 case LayerType::Softmax:
355 return IsSoftmaxSupported(infos[0],
356 infos[1],
357 *(PolymorphicDowncast<const SoftmaxDescriptor*>(&descriptor)),
358 reasonIfUnsupported);
359 case LayerType::SpaceToBatchNd:
360 return IsSpaceToBatchNdSupported(infos[0],
361 infos[1],
362 *(PolymorphicDowncast<const SpaceToBatchNdDescriptor*>(&descriptor)),
363 reasonIfUnsupported);
364 case LayerType::SpaceToDepth:
365 return IsSpaceToDepthSupported(infos[0],
366 infos[1],
367 *(PolymorphicDowncast<const SpaceToDepthDescriptor*>(&descriptor)),
368 reasonIfUnsupported);
369 case LayerType::Splitter:
370 {
371 std::vector<TensorInfo> outputInfos;
372 for (uint32_t i = 1; i < infos.size(); i++)
373 {
374 outputInfos.push_back(infos[i]);
375 }
376 return IsSplitterSupported(infos[0],
377 {outputInfos.begin(), outputInfos.end()},
378 *(PolymorphicDowncast<const ViewsDescriptor*>(&descriptor)),
379 reasonIfUnsupported);
380 }
381 case LayerType::Stack:
382 {
383 std::vector<const TensorInfo*> inputInfos;
384 for (uint32_t i = 0; i < infos.size() - 1; i++)
385 {
386 inputInfos.push_back(&infos[i]);
387 }
388 return IsStackSupported(inputInfos,
389 infos[infos.size() - 1],
390 *(PolymorphicDowncast<const StackDescriptor*>(&descriptor)),
391 reasonIfUnsupported);
392 }
393 case LayerType::StridedSlice:
394 return IsStridedSliceSupported(infos[0],
395 infos[1],
396 *(PolymorphicDowncast<const StridedSliceDescriptor*>(&descriptor)),
397 reasonIfUnsupported);
398 case LayerType::Subtraction:
399 return IsSubtractionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
400 case LayerType::Transpose:
401 return IsTransposeSupported(infos[0],
402 infos[1],
403 *(PolymorphicDowncast<const TransposeDescriptor*>(&descriptor)),
404 reasonIfUnsupported);
405 case LayerType::TransposeConvolution2d:
406 {
407 if (infos.size() != 4)
408 {
409 throw InvalidArgumentException("Invalid number of TransposeConvolution2d TensorInfos. "
410 "TensorInfos should be of format: {input, output, weights, biases}.");
411 }
412
413 auto desc = *(PolymorphicDowncast<const TransposeConvolution2dDescriptor*>(&descriptor));
414 if (infos[3] == TensorInfo())
415 {
416 return IsTransposeConvolution2dSupported(infos[0],
417 infos[1],
418 desc,
419 infos[2],
420 EmptyOptional(),
421 reasonIfUnsupported);
422 }
423 else
424 {
425 return IsTransposeConvolution2dSupported(infos[0],
426 infos[1],
427 desc,
428 infos[2],
429 infos[3],
430 reasonIfUnsupported);
431 }
432 }
433 case LayerType::Cast:
434 return IsCastSupported(infos[0], infos[1], reasonIfUnsupported);
435 case LayerType::ChannelShuffle:
436 return IsChannelShuffleSupported(infos[0],
437 infos[1],
438 *(PolymorphicDowncast<const ChannelShuffleDescriptor*>(&descriptor)),
439 reasonIfUnsupported);
440 case LayerType::Convolution3d:
441 {
442 if (infos.size() != 4)
443 {
444 throw InvalidArgumentException("Invalid number of Convolution3d TensorInfos. "
445 "TensorInfos should be of format: {input, output, weights, biases}.");
446 }
447
448 auto desc = *(PolymorphicDowncast<const Convolution3dDescriptor*>(&descriptor));
449 if (infos[3] == TensorInfo())
450 {
451 return IsConvolution3dSupported(infos[0],
452 infos[1],
453 desc,
454 infos[2],
455 EmptyOptional(),
456 reasonIfUnsupported);
457 }
458 else
459 {
460 return IsConvolution3dSupported(infos[0],
461 infos[1],
462 desc,
463 infos[2],
464 infos[3],
465 reasonIfUnsupported);
466 }
467 }
468 case LayerType::Debug:
469 return IsDebugSupported(infos[0], infos[1], reasonIfUnsupported);
470 case LayerType::DetectionPostProcess:
471 return IsDetectionPostProcessSupported(infos[0],
472 infos[1],
473 infos[2],
474 infos[3],
475 infos[4],
476 infos[5],
477 infos[6],
478 *(PolymorphicDowncast<const DetectionPostProcessDescriptor*>
479 (&descriptor)),
480 reasonIfUnsupported);
481 case LayerType::FakeQuantization:
482 return IsFakeQuantizationSupported(infos[0],
483 *(PolymorphicDowncast<const FakeQuantizationDescriptor*>(&descriptor)),
484 reasonIfUnsupported);
485 case LayerType::MemCopy:
486 return IsMemCopySupported(infos[0], infos[1], reasonIfUnsupported);
487 case LayerType::Rank:
488 return IsRankSupported(infos[0], infos[1], reasonIfUnsupported);
489 case LayerType::Shape:
490 return IsShapeSupported(infos[0], infos[1], reasonIfUnsupported);
491 case LayerType::UnidirectionalSequenceLstm:
492 {
493 if (infos.size() != 6)
494 {
495 throw InvalidArgumentException("Invalid number of UnidirectionalSequenceLstm TensorInfos. TensorInfos "
496 "should be of format: {input, outputStateIn, cellStateIn, "
497 "hiddenStateOutputVal, cellStateOutputVal, output}");
498 }
499 auto desc = *(PolymorphicDowncast<const UnidirectionalSequenceLstmDescriptor*>(&descriptor));
Mike Kelly12994962022-04-21 11:57:09 +0100500 return IsUnidirectionalSequenceLstmSupported(infos[0],
501 infos[1],
502 infos[2],
503 infos[3],
504 infos[4],
505 infos[5],
506 desc,
507 lstmParamsInfo.value(),
508 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000509 }
510 case LayerType::Pooling3d:
511 return IsPooling3dSupported(infos[0],
512 infos[1],
513 *(PolymorphicDowncast<const Pooling3dDescriptor*>(&descriptor)),
514 reasonIfUnsupported);
515 case LayerType::Map:
516 return true;
517 case LayerType::Unmap:
518 return true;
519 case LayerType::MemImport:
520 return LayerSupportBase::IsMemImportSupported(infos[0], infos[1], reasonIfUnsupported);
521 case LayerType::Merge:
522 return LayerSupportBase::IsMergeSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
523 case LayerType::QuantizedLstm:
524 return LayerSupportBase::IsQuantizedLstmSupported(infos[0],
525 infos[1],
526 infos[2],
527 infos[3],
528 infos[4],
529 quantizedLstmInputParamsInfo.value(),
530 reasonIfUnsupported);
531 default:
532 // layers not supported in neon by default:
533 // precompiled, standin, switch
534 return false;
535 }
536}
537
arovir011c7c81b2018-10-08 11:34:28 +0100538bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
539 const TensorInfo& output,
540 const ActivationDescriptor& descriptor,
541 Optional<std::string&> reasonIfUnsupported) const
542{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000543 bool supported = true;
544
545 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +0000546 std::array<DataType,6> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000547 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100548 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000549 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000550 DataType::QAsymmU8,
551 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000552 };
553
554 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
555 "Reference activation: input type not supported.");
556
557 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
558 "Reference activation: output type not supported.");
559
560 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
561 "Reference activation: input and output types mismatched.");
562
563 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
564 "Reference activation: input and output shapes are of different rank.");
565
566
567 struct ActivationFunctionSupported : public Rule
568 {
569 ActivationFunctionSupported(const ActivationDescriptor& desc)
570 {
571 switch(desc.m_Function)
572 {
573 case ActivationFunction::Abs:
574 case ActivationFunction::BoundedReLu:
David Monahan3b3c3812020-02-25 09:03:29 +0000575 case ActivationFunction::Elu:
Colm Donelan03fbeaf2020-02-26 15:39:23 +0000576 case ActivationFunction::HardSwish:
Derek Lamberti50db4e82019-03-13 14:16:15 +0000577 case ActivationFunction::LeakyReLu:
578 case ActivationFunction::Linear:
579 case ActivationFunction::ReLu:
580 case ActivationFunction::Sigmoid:
581 case ActivationFunction::SoftReLu:
582 case ActivationFunction::Sqrt:
583 case ActivationFunction::Square:
584 case ActivationFunction::TanH:
585 {
586 m_Res = true;
587 break;
588 }
589 default:
590 {
591 m_Res = false;
592 break;
593 }
594 }
595 }
596 };
597
598 // Function is supported
599 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
600 "Reference activation: function not supported.");
601
602 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100603}
604
605bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
606 const TensorInfo& input1,
607 const TensorInfo& output,
608 Optional<std::string&> reasonIfUnsupported) const
609{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000610 bool supported = true;
611
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100612 std::array<DataType,7> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000613 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100614 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000615 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000616 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100617 DataType::QSymmS16,
618 DataType::Signed32
Derek Lamberti50db4e82019-03-13 14:16:15 +0000619 };
620
621 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
622 "Reference addition: input 0 is not a supported type.");
623
624 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
625 "Reference addition: input 1 is not a supported type.");
626
627 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
628 "Reference addition: output is not a supported type.");
629
630 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
631 "Reference addition: input 0 and Input 1 types are mismatched");
632
633 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
634 "Reference addition: input and output types are mismatched");
635
636 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
637 "Reference addition: shapes are not suitable for implicit broadcast.");
638
639 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100640}
641
Nikhil Raj68c2c902019-09-19 11:21:11 +0100642bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
643 const armnn::ArgMinMaxDescriptor &descriptor,
644 armnn::Optional<std::string &> reasonIfUnsupported) const
645{
Jan Eilers8eb25602020-03-09 12:13:48 +0000646 IgnoreUnused(descriptor);
Nikhil Raj68c2c902019-09-19 11:21:11 +0100647
Mike Kelly1f140f72021-04-06 12:25:55 +0100648 std::array<DataType, 8> supportedInputTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100649 {
Teresa Charline300b362020-05-25 10:01:03 +0100650 DataType::Float16,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100651 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100652 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000653 DataType::QAsymmU8,
654 DataType::QSymmS16,
Mike Kelly1f140f72021-04-06 12:25:55 +0100655 DataType::Signed32,
656 DataType::Signed64
657 };
658
659 std::array<DataType,2> supportedOutputTypes = {
660 DataType::Signed32,
661 DataType::Signed64
Nikhil Raj68c2c902019-09-19 11:21:11 +0100662 };
663
664 bool supported = true;
665
Mike Kelly1f140f72021-04-06 12:25:55 +0100666 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100667 "Reference ArgMinMax: input is not a supported type.");
Mike Kelly1f140f72021-04-06 12:25:55 +0100668 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100669 "Reference ArgMinMax: output type not supported");
670
671 return supported;
672}
673
Samuel Yap6b478092022-07-06 15:36:03 +0100674bool RefLayerSupport::IsBatchMatMulSupported(const TensorInfo& inputX,
675 const TensorInfo& inputY,
676 const TensorInfo& output,
677 const BatchMatMulDescriptor& descriptor,
678 Optional<std::string &> reasonIfUnsupported) const
679{
680 IgnoreUnused(descriptor);
681
682 std::array<DataType, 6> supportedTypes =
683 {
Samuel Yap6b478092022-07-06 15:36:03 +0100684 DataType::Float16,
685 DataType::Float32,
686 DataType::QAsymmS8,
687 DataType::QAsymmU8,
688 DataType::QSymmS16
689 };
690
691 bool supported = true;
692
693 supported &= CheckSupportRule(TypeAnyOf(inputX, supportedTypes), reasonIfUnsupported,
694 "Reference batch matrix multiplication: input X is not a supported type");
695
696 supported &= CheckSupportRule(TypeAnyOf(inputY, supportedTypes), reasonIfUnsupported,
697 "Reference batch matrix multiplication: input Y is not a supported type");
698
699 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
700 "Reference batch matrix multiplication: output is not a supported type");
701
702 supported &= CheckSupportRule(TypesAreEqual(inputX, inputY), reasonIfUnsupported,
703 "Reference batch matrix multiplication: input X and input Y types are mismatched");
704
705 supported &= CheckSupportRule(TypesAreEqual(inputX, output), reasonIfUnsupported,
706 "Reference batch matrix multiplication: inputs and output types are mismatched");
707
708 supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputX, 2),
709 reasonIfUnsupported,
710 "Reference batch matrix multiplication: input X is not of rank 2 or greater");
711
712 supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputY, 2),
713 reasonIfUnsupported,
714 "Reference batch matrix multiplication: input Y is not of rank 2 or greater");
715
716 return supported;
717}
718
arovir011c7c81b2018-10-08 11:34:28 +0100719bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
720 const TensorInfo& output,
721 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100722 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100723 const TensorInfo& beta,
724 const TensorInfo& gamma,
725 const BatchNormalizationDescriptor& descriptor,
726 Optional<std::string&> reasonIfUnsupported) const
727{
Jan Eilers8eb25602020-03-09 12:13:48 +0000728 IgnoreUnused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100729
Sadik Armagan303980c2020-04-17 12:45:14 +0100730 std::array<DataType, 6> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100731 {
732 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100733 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100734 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000735 DataType::QAsymmU8,
736 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100737 };
738
739 bool supported = true;
740
741 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
742 "Reference batch normalization: input is not a supported type.");
743
744 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
745 "Reference batch normalization: output is not a supported type.");
746
747 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
748 "Reference batch normalization: input and output types are mismatched");
749
750 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
751 "Reference batch normalization: mean is not a supported type.");
752
753 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
754 "Reference batch normalization: variance is not a supported type.");
755
756 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
757 "Reference batch normalization: beta is not a supported type.");
758
759 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
760 "Reference batch normalization: gamma is not a supported type.");
761
762 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100763}
764
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000765bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
766 const TensorInfo& output,
767 const BatchToSpaceNdDescriptor& descriptor,
768 Optional<std::string&> reasonIfUnsupported) const
769{
Jan Eilers8eb25602020-03-09 12:13:48 +0000770 IgnoreUnused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100771
772 bool supported = true;
773
774 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
775 std::string inputTensorStr = "input";
776 std::string outputTensorStr = "output";
777
778 // Define supported types.
Sadik Armagan303980c2020-04-17 12:45:14 +0100779 std::array<DataType,6> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100780 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000781 DataType::Float32,
782 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100783 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000784 DataType::QAsymmU8,
785 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100786 };
787
788 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
789 "Reference BatchToSpaceNd: input type not supported.");
790
791 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
792 "Reference BatchToSpaceNd: output type not supported.");
793
794 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
795 "Reference BatchToSpaceNd: input and output types mismatched.");
796
797 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
798 reasonIfUnsupported,
799 CreateIncorrectDimensionsErrorMsg(4,
800 output.GetNumDimensions(),
801 batchToSpaceNdLayerStr,
802 outputTensorStr).data());
803
804 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
805 reasonIfUnsupported,
806 CreateIncorrectDimensionsErrorMsg(4,
807 input.GetNumDimensions(),
808 batchToSpaceNdLayerStr,
809 inputTensorStr).data());
810
811 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000812}
813
mathad01b392e982021-04-07 12:07:30 +0100814bool RefLayerSupport::IsCastSupported(const TensorInfo& input,
815 const TensorInfo& output,
816 Optional<std::string&> reasonIfUnsupported) const
817{
818 std::array<DataType, 9> supportedInputTypes =
819 {
mathad01b392e982021-04-07 12:07:30 +0100820 DataType::Float32,
821 DataType::Float16,
822 DataType::QSymmS8,
823 DataType::QAsymmS8,
824 DataType::QAsymmU8,
825 DataType::QSymmS16,
826 DataType::Signed32
827 };
828
829 bool supported = true;
830 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
831 "Reference cast: input is not a supported type");
832
833
834 supported &= CheckSupportRule(TypeAnyOf(output, supportedInputTypes), reasonIfUnsupported,
835 "Reference cast: output is not a supported type");
836
837 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
838 "Reference cast: input and output shapes have different number of total elements");
839
840 return supported;
841}
842
Simon Obute51f67772021-09-03 15:50:13 +0100843bool RefLayerSupport::IsChannelShuffleSupported(const TensorInfo& input,
844 const TensorInfo& output,
845 const ChannelShuffleDescriptor& descriptor,
846 Optional<std::string&> reasonIfUnsupported) const
847{
848 IgnoreUnused(descriptor);
849 bool supported = true;
850
851 // Define supported output and inputs types.
852 std::array<DataType, 7> supportedTypes =
853 {
Simon Obute51f67772021-09-03 15:50:13 +0100854 DataType::Float32,
855 DataType::Float16,
856 DataType::QAsymmS8,
857 DataType::QAsymmU8,
858 DataType::QSymmS8,
859 DataType::QSymmS16
860 };
861
862 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
863 "Reference ChannelShuffle: input is not a supported type.");
864
865 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
866 "Reference ChannelShuffle: output is not a supported type.");
867
868 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
869 "Reference ChannelShuffle: input and output types are mismatched.");
870
871 return supported;
872}
873
874
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100875bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
876 const TensorInfo& input1,
877 const TensorInfo& output,
878 const ComparisonDescriptor& descriptor,
879 Optional<std::string&> reasonIfUnsupported) const
880{
Jan Eilers8eb25602020-03-09 12:13:48 +0000881 IgnoreUnused(descriptor);
Sadik Armagan303980c2020-04-17 12:45:14 +0100882 std::array<DataType, 8> supportedInputTypes =
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100883 {
Sadik Armaganb60dd242020-03-19 13:53:16 +0000884 DataType::Boolean,
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100885 DataType::Float32,
886 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100887 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000888 DataType::QAsymmU8,
Sadik Armaganb60dd242020-03-19 13:53:16 +0000889 DataType::QSymmS16,
890 DataType::Signed32
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100891 };
892
893 bool supported = true;
894 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
895 "Reference comparison: input 0 is not a supported type");
896
897 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
898 "Reference comparison: input 0 and Input 1 types are mismatched");
899
900 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
901 "Reference comparison: output is not of type Boolean");
902
903 return supported;
904}
905
Jim Flynn906f9462019-05-10 13:55:21 +0100906bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
907 const TensorInfo& output,
Cathal Corbett34b429c2021-12-24 12:24:40 +0000908 const OriginsDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100909 Optional<std::string&> reasonIfUnsupported) const
910{
Jan Eilers8eb25602020-03-09 12:13:48 +0000911 IgnoreUnused(descriptor);
Jim Flynne242f2d2019-05-22 14:24:13 +0100912
913 bool supported = true;
Teresa Charlin6abc7ee2022-02-22 17:32:27 +0000914 std::array<DataType,7> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100915 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000916 DataType::Float32,
917 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000918 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100919 DataType::QAsymmU8,
Teresa Charlin6abc7ee2022-02-22 17:32:27 +0000920 DataType::QSymmS16,
921 DataType::Signed32
Jim Flynne242f2d2019-05-22 14:24:13 +0100922 };
923
924 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
925 "Reference concatenation: output type not supported");
926 for (const TensorInfo* input : inputs)
927 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100928 ARMNN_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100929 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
930 "Reference concatenation: input type not supported");
931
932 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
933 "Reference concatenation: input and output types mismatched.");
934 }
935
936 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100937}
938
arovir011c7c81b2018-10-08 11:34:28 +0100939bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
940 Optional<std::string&> reasonIfUnsupported) const
941{
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100942 std::array<DataType,8> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100943 {
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100944 DataType::Float16,
Nina Drozd58ef2c62019-05-16 12:09:18 +0100945 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +0000946 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100947 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000948 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100949 DataType::QSymmS16,
950 DataType::Signed32
Nina Drozd58ef2c62019-05-16 12:09:18 +0100951 };
952
953 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
954 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100955}
956
957bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
958 const TensorInfo& output,
959 Optional<std::string&> reasonIfUnsupported) const
960{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100961 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
962 input.GetDataType(),
963 &TrueFunc<>,
964 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000965 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000966 &FalseFuncI32<>,
967 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100968 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
969 output.GetDataType(),
970 &FalseOutputFuncF16<>,
971 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000972 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000973 &FalseFuncI32<>,
974 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100975}
976
977bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
978 const TensorInfo& output,
979 Optional<std::string&> reasonIfUnsupported) const
980{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100981 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
982 input.GetDataType(),
983 &FalseInputFuncF16<>,
984 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000985 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000986 &FalseFuncI32<>,
987 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100988 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
989 output.GetDataType(),
990 &TrueFunc<>,
991 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000992 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000993 &FalseFuncI32<>,
994 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100995}
996
997bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
998 const TensorInfo& output,
999 const Convolution2dDescriptor& descriptor,
1000 const TensorInfo& weights,
1001 const Optional<TensorInfo>& biases,
1002 Optional<std::string&> reasonIfUnsupported) const
1003{
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001004 bool supported = true;
1005
1006 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001007 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001008 {
1009 DataType::Float32,
1010 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001011 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001012 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001013 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001014 DataType::QSymmS16
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001015 };
1016
1017 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001018 "Reference Convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001019
1020 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001021 "Reference Convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001022
Ryan OShea31441592022-11-07 16:20:48 +00001023 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1024 "Reference Convolution2d: input and output types mismatched.");
1025
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001026
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001027 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +00001028 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001029 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001030 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001031 {
Sadik Armagan303980c2020-04-17 12:45:14 +01001032 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001033 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01001034 DataType::QSymmS8
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001035 };
1036
1037 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001038 "Reference Convolution2d: weights type not supported for quantized input.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001039 }
1040 else
1041 {
1042 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001043 "Reference Convolution2d: weights is not a supported type.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001044
1045 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001046 "Reference Convolution2d: input and weights types mismatched.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001047 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001048
1049 if (biases.has_value())
1050 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001051 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001052 {
1053 DataType::Float32,
1054 DataType::Float16,
1055 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001056 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001057
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001058 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001059 "Reference Convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001060 }
Jan Eilers8eb25602020-03-09 12:13:48 +00001061 IgnoreUnused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001062
1063 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001064}
1065
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001066bool RefLayerSupport::IsConvolution3dSupported(const TensorInfo& input,
1067 const TensorInfo& output,
1068 const Convolution3dDescriptor& descriptor,
1069 const TensorInfo& weights,
1070 const Optional<TensorInfo>& biases,
1071 Optional<std::string&> reasonIfUnsupported) const
1072{
1073 bool supported = true;
1074
1075 // Define supported types.
1076 std::array<DataType,7> supportedTypes =
1077 {
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001078 DataType::Float32,
1079 DataType::Float16,
1080 DataType::QAsymmS8,
1081 DataType::QAsymmU8,
1082 DataType::QSymmS8,
1083 DataType::QSymmS16
1084 };
1085
1086 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1087 "Reference Convolution3d: input is not a supported type.");
1088
1089 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1090 "Reference Convolution3d: output is not a supported type.");
1091
1092 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1093 "Reference Convolution3d: input and output types mismatched.");
1094
1095 const DataType inputType = input.GetDataType();
1096 if (IsQuantized8BitType(inputType))
1097 {
1098 std::array<DataType, 3> supportedWeightTypes =
1099 {
1100 DataType::QAsymmS8,
1101 DataType::QAsymmU8,
1102 DataType::QSymmS8
1103 };
1104
1105 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1106 "Reference Convolution3d: weights type not supported for quantized input.");
1107 }
1108 else
1109 {
1110 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1111 "Reference Convolution3d: weights is not a supported type.");
1112
1113 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1114 "Reference Convolution3d: input and weights types mismatched.");
1115 }
1116
1117 if (biases.has_value())
1118 {
1119 std::array<DataType,4> biasesSupportedTypes =
1120 {
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001121 DataType::Float32,
1122 DataType::Float16,
1123 DataType::Signed32
1124 };
1125
1126 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1127 "Reference Convolution3d: biases is not a supported type.");
1128 }
1129 IgnoreUnused(descriptor);
1130
1131 return supported;
1132}
1133
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001134bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
1135 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001136 Optional<std::string&> reasonIfUnsupported) const
1137{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001138 bool supported = true;
1139
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001140 std::array<DataType, 8> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001141 {
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001142 DataType::BFloat16,
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +00001143 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001144 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001145 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001146 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001147 DataType::QSymmS8,
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +00001148 DataType::QSymmS16,
1149 DataType::Signed32
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001150 };
1151
1152 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001153 "Reference for Debug layer: input type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001154
1155 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001156 "Reference for Debug layer: output type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001157
1158 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001159 "Reference for Debug layer: input and output types are mismatched");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001160
1161 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001162}
1163
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001164bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
1165 const TensorInfo& output,
1166 const DepthToSpaceDescriptor& descriptor,
1167 Optional<std::string&> reasonIfUnsupported) const
1168{
Jan Eilers8eb25602020-03-09 12:13:48 +00001169 IgnoreUnused(descriptor);
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001170 bool supported = true;
1171
Sadik Armagan303980c2020-04-17 12:45:14 +01001172 std::array<DataType,6> supportedTypes =
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001173 {
1174 DataType::Float32,
1175 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001176 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001177 DataType::QAsymmU8,
1178 DataType::QSymmS16
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001179 };
1180
1181 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1182 "Reference DepthToSpace: input type not supported");
1183
1184 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1185 "Reference DepthToSpace: output type not supported");
1186
1187 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1188 "Reference DepthToSpace: input and output types are mismatched");
1189
1190 return supported;
1191}
1192
arovir011c7c81b2018-10-08 11:34:28 +01001193bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
1194 const TensorInfo& output,
1195 const DepthwiseConvolution2dDescriptor& descriptor,
1196 const TensorInfo& weights,
1197 const Optional<TensorInfo>& biases,
1198 Optional<std::string&> reasonIfUnsupported) const
1199{
Sadik Armagan303980c2020-04-17 12:45:14 +01001200 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001201 bool supported = true;
1202
1203 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001204 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001205 {
1206 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001207 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001208 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001209 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001210 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001211 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001212 };
1213
1214 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1215 "Reference DepthwiseConvolution2d: input is not a supported type.");
1216
1217 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1218 "Reference DepthwiseConvolution2d: output is not a supported type.");
1219
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001220 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1221 "Reference DepthwiseConvolution2d: input and output types mismatched.");
1222
Teresa Charlind8df0262019-11-11 12:28:15 +00001223 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +00001224 if (IsQuantized8BitType(inputType))
Teresa Charlind8df0262019-11-11 12:28:15 +00001225 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001226 std::array<DataType, 3> supportedWeightTypes =
Sadik Armagan303980c2020-04-17 12:45:14 +01001227 {
1228 DataType::QAsymmS8,
1229 DataType::QAsymmU8,
1230 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001231 };
Teresa Charlind8df0262019-11-11 12:28:15 +00001232
1233 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Sadik Armagan303980c2020-04-17 12:45:14 +01001234 "Reference DepthwiseConvolution2d: weights type not supported for "
1235 "quantized input.");
Teresa Charlind8df0262019-11-11 12:28:15 +00001236 }
1237 else
1238 {
1239 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1240 "Reference DepthwiseConvolution2d: weights is not a supported type.");
1241
1242 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1243 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
1244 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001245
1246 if (biases.has_value())
1247 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001248 std::array<DataType,4> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001249 {
1250 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001251 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001252 DataType::Signed32
1253 };
1254 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1255 "Reference DepthwiseConvolution2d: biases is not a supported type.");
1256 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001257
1258 return supported;
1259
arovir011c7c81b2018-10-08 11:34:28 +01001260}
1261
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001262bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
1263 const TensorInfo& output,
1264 Optional<std::string&> reasonIfUnsupported) const
1265{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001266 bool supported = true;
1267
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001268 std::array<DataType,5> supportedInputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00001269 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001270 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00001271 DataType::QSymmS8,
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001272 DataType::QSymmS16,
1273 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001274 };
1275
1276 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001277 "Reference for Dequantize layer: input type not supported.");
1278
Derek Lambertid466a542020-01-22 15:37:29 +00001279 supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
Teresa Charlin1b1950d2021-06-02 20:23:21 +01001280 "Reference for Dequantize layer: per-axis quantized input not supported.");
Derek Lambertid466a542020-01-22 15:37:29 +00001281
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001282 std::array<DataType,3> supportedOutputTypes = {
Jan Eilersf7107932019-11-01 11:09:36 +00001283 DataType::Float32,
1284 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001285 };
1286
1287 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001288 "Reference for Dequantize layer: output type not supported.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001289
1290 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001291 "Reference for Dequantize layer: input/output shapes have different num total "
1292 "elements.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001293
1294 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001295}
1296
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001297bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
1298 const TensorInfo& scores,
1299 const TensorInfo& anchors,
1300 const TensorInfo& detectionBoxes,
1301 const TensorInfo& detectionClasses,
1302 const TensorInfo& detectionScores,
1303 const TensorInfo& numDetections,
1304 const DetectionPostProcessDescriptor& descriptor,
1305 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001306{
Jan Eilers8eb25602020-03-09 12:13:48 +00001307 IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
Derek Lamberti901ea112019-12-10 22:07:09 +00001308
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001309 bool supported = true;
1310
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001311 std::array<DataType,6> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001312 {
1313 DataType::Float32,
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001314 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001315 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001316 DataType::QAsymmU8,
1317 DataType::QSymmS16
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001318 };
1319
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001320 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001321 "Reference DetectionPostProcess: input 0 is not a supported type.");
1322
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001323 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001324 "Reference DetectionPostProcess: input 1 is not a supported type.");
1325
1326 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001327}
1328
Pablo Tellof0bd6832019-04-26 17:58:13 +01001329bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
1330 const TensorInfo& output,
1331 const DepthwiseConvolution2dDescriptor& descriptor,
1332 const TensorInfo& weights,
1333 const Optional<TensorInfo>& biases,
1334 Optional<std::string&> reasonIfUnsupported) const
1335{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001336 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +01001337}
1338
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001339bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +01001340 const TensorInfo& input1,
1341 const TensorInfo& output,
1342 Optional<std::string&> reasonIfUnsupported) const
1343{
Sadik Armagan2999a022019-04-09 14:20:12 +01001344 bool supported = true;
1345
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001346 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001347 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001348 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001349 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001350 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001351 DataType::QSymmS16,
1352 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001353 };
1354
1355 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1356 "Reference division: input 0 is not a supported type.");
1357
1358 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1359 "Reference division: input 1 is not a supported type.");
1360
1361 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1362 "Reference division: output is not a supported type.");
1363
1364 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1365 "Reference division: input 0 and Input 1 types are mismatched");
1366
1367 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1368 "Reference division: input and output types are mismatched");
1369
1370 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1371 "Reference division: shapes are not suitable for implicit broadcast.");
1372
1373 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001374}
1375
josh minor4a3c6102020-01-06 16:40:46 -06001376bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
1377 const TensorInfo& output,
1378 const ElementwiseUnaryDescriptor& descriptor,
1379 Optional<std::string&> reasonIfUnsupported) const
1380{
Jan Eilers8eb25602020-03-09 12:13:48 +00001381 IgnoreUnused(descriptor);
josh minor4a3c6102020-01-06 16:40:46 -06001382
Sadik Armagan303980c2020-04-17 12:45:14 +01001383 std::array<DataType, 7> supportedTypes =
josh minor4a3c6102020-01-06 16:40:46 -06001384 {
1385 DataType::Float32,
1386 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001387 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06001388 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00001389 DataType::QSymmS16,
1390 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06001391 };
1392
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001393 std::array<DataType, 1> logicalSupportedTypes =
1394 {
1395 DataType::Boolean
1396 };
1397
josh minor4a3c6102020-01-06 16:40:46 -06001398 bool supported = true;
1399
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001400 if (descriptor.m_Operation == UnaryOperation::LogicalNot)
1401 {
1402 supported &= CheckSupportRule(TypeAnyOf(input, logicalSupportedTypes), reasonIfUnsupported,
1403 "Reference elementwise unary: input type not supported");
josh minor4a3c6102020-01-06 16:40:46 -06001404
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001405 supported &= CheckSupportRule(TypeAnyOf(output, logicalSupportedTypes), reasonIfUnsupported,
1406 "Reference elementwise unary: output type not supported");
1407 }
1408 else
1409 {
1410 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1411 "Reference elementwise unary: input type not supported");
1412
1413 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1414 "Reference elementwise unary: output type not supported");
1415 }
josh minor4a3c6102020-01-06 16:40:46 -06001416
1417 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1418 "Reference elementwise unary: input and output types not matching");
1419
1420 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1421 "Reference elementwise unary: input and output shapes"
1422 "have different number of total elements");
1423
1424 return supported;
1425}
1426
arovir011c7c81b2018-10-08 11:34:28 +01001427bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
1428 const FakeQuantizationDescriptor& descriptor,
1429 Optional<std::string&> reasonIfUnsupported) const
1430{
Jan Eilers8eb25602020-03-09 12:13:48 +00001431 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001432 bool supported = true;
1433
1434 std::array<DataType,1> supportedTypes =
1435 {
1436 DataType::Float32
1437 };
1438
1439 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1440 "Reference fake quantization: input type not supported.");
1441
1442 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001443}
1444
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001445bool RefLayerSupport::IsFillSupported(const TensorInfo& input,
1446 const TensorInfo& output,
1447 const FillDescriptor& descriptor,
1448 Optional<std::string&> reasonIfUnsupported) const
1449{
1450 IgnoreUnused(descriptor);
1451 IgnoreUnused(output);
1452
1453 bool supported = true;
1454
Sadik Armagana792a052020-06-23 16:22:23 +01001455 std::array<DataType,3> supportedTypes =
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001456 {
1457 DataType::Float32,
Sadik Armagana792a052020-06-23 16:22:23 +01001458 DataType::Float16,
1459 DataType::Signed32
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001460 };
1461
Teresa Charlin4b10fef2020-07-29 09:36:41 +01001462 supported &= CheckSupportRule(TypeIs(input, DataType::Signed32), reasonIfUnsupported,
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001463 "Reference Fill: input type not supported.");
1464
Teresa Charlin44088502020-07-27 11:27:19 +01001465 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1466 "Reference Fill: output type not supported.");
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001467 return supported;
1468}
1469
arovir011c7c81b2018-10-08 11:34:28 +01001470bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
1471 const TensorInfo& output,
1472 Optional<std::string&> reasonIfUnsupported) const
1473{
Jan Eilers8eb25602020-03-09 12:13:48 +00001474 IgnoreUnused(output);
James Conroy83735b12019-05-30 16:36:59 +01001475 bool supported = true;
1476
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01001477 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +01001478 {
James Conroyb40d7102019-06-04 12:32:09 +01001479 DataType::Float32,
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01001480 DataType::Float16
James Conroy83735b12019-05-30 16:36:59 +01001481 };
1482
1483 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1484 "Reference Floor: input type not supported.");
1485
1486 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1487 "Reference Floor: output type not supported.");
1488
1489 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001490}
1491
1492bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
1493 const TensorInfo& output,
1494 const TensorInfo& weights,
1495 const TensorInfo& biases,
1496 const FullyConnectedDescriptor& descriptor,
1497 Optional<std::string&> reasonIfUnsupported) const
1498{
Francis Murtagh46c09d02019-05-28 08:15:28 +01001499 bool supported = true;
1500
1501 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001502 std::array<DataType,6> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +01001503 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001504 DataType::Float32,
1505 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001506 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001507 DataType::QAsymmU8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001508 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001509 };
1510
1511 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1512 "Reference Fully Connected: input type not supported.");
1513
1514 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1515 "Reference Fully Connected: output type not supported.");
1516
Francis Murtagh46c09d02019-05-28 08:15:28 +01001517 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1518 "Reference Fully Connected: weights type not supported.");
1519
Ryan OShea31441592022-11-07 16:20:48 +00001520 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1521 "Reference Fully Connected: input and output types mismatched.");
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001522
Jan Eilers1f45dc32020-06-15 11:43:03 +01001523 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1524 "Reference Fully Connected: weights is not a supported type.");
Francis Murtaghddb1d062020-03-10 13:51:45 +00001525
Jan Eilers1f45dc32020-06-15 11:43:03 +01001526 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1527 "Reference Fully Connected: input and weights types mismatched.");
Francis Murtagh46c09d02019-05-28 08:15:28 +01001528
1529 if (descriptor.m_BiasEnabled)
1530 {
1531 // Defined supported types for bias
Sadik Armagandb73c982020-04-01 17:35:30 +01001532 std::array<DataType, 5>
Francis Murtagh46c09d02019-05-28 08:15:28 +01001533 supportedBiasTypes =
1534 {
1535 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001536 DataType::Float16,
Sadik Armagandb73c982020-04-01 17:35:30 +01001537 DataType::Signed32,
1538 DataType::QAsymmS8
Francis Murtagh46c09d02019-05-28 08:15:28 +01001539 };
1540
1541 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
1542 "Reference Fully Connected: bias type not supported.");
1543
1544 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
1545 "Reference Fully Connected: bias and weight types mismatch.");
1546
1547 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
1548 "Reference Fully Connected: bias type inferred from weights is incompatible.");
1549
Narumol Prangnawarat366d7232020-04-29 12:58:17 +01001550 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(biases, 1U), reasonIfUnsupported,
1551 "Reference Fully Connected: bias must have 1 dimension.");
1552
Francis Murtagh46c09d02019-05-28 08:15:28 +01001553 }
1554
1555 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001556}
1557
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01001558bool RefLayerSupport::IsGatherNdSupported(const armnn::TensorInfo& input0,
1559 const armnn::TensorInfo& input1,
1560 const armnn::TensorInfo& output,
1561 armnn::Optional<std::string&> reasonIfUnsupported) const
1562{
1563 bool supported = true;
1564 std::array<DataType,7> supportedTypes =
1565 {
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01001566 DataType::Float32,
1567 DataType::Float16,
1568 DataType::QAsymmS8,
1569 DataType::QAsymmU8,
1570 DataType::QSymmS16,
1571 DataType::Signed32
1572 };
1573
1574 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1575 "Reference GatherNd: input type not supported");
1576
1577 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1578 "Reference GatherNd: output type not supported");
1579
1580 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1581 "Reference GatherNd: indices (input1) type not supported");
1582
1583 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1584 "Reference GatherNd: input and output types not matching");
1585
1586 return supported;
1587}
1588
narpra014951d842019-01-18 16:53:53 +00001589bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
1590 const armnn::TensorInfo& input1,
1591 const armnn::TensorInfo& output,
Teresa Charlin52664732020-06-29 16:27:03 +01001592 const GatherDescriptor& descriptor,
narpra014951d842019-01-18 16:53:53 +00001593 armnn::Optional<std::string&> reasonIfUnsupported) const
1594{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001595 bool supported = true;
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001596 std::array<DataType,7> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001597 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001598 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001599 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001600 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001601 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001602 DataType::QSymmS16,
1603 DataType::Signed32
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001604 };
1605
Nikhil Raj369d8fc2022-11-24 13:12:36 +00001606 IgnoreUnused(descriptor);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001607 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1608 "Reference Gather: input type not supported");
1609
1610 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1611 "Reference Gather: output type not supported");
1612
1613 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1614 "Reference Gather: indices (input1) type not supported");
1615
1616 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1617 "Reference Gather: input and output types not matching");
1618
1619 return supported;
narpra014951d842019-01-18 16:53:53 +00001620}
1621
Derek Lamberti901ea112019-12-10 22:07:09 +00001622bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
1623 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +01001624{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001625 return true;
arovir011c7c81b2018-10-08 11:34:28 +01001626}
1627
Kevin May09ca49c2019-10-09 12:37:34 +01001628bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
1629 const TensorInfo& output,
1630 const InstanceNormalizationDescriptor& descriptor,
1631 Optional<std::string&> reasonIfUnsupported) const
1632{
Jan Eilers8eb25602020-03-09 12:13:48 +00001633 IgnoreUnused(descriptor);
Kevin May09ca49c2019-10-09 12:37:34 +01001634 // Define supported types
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001635 std::array<DataType, 3> supportedTypes =
Kevin May09ca49c2019-10-09 12:37:34 +01001636 {
1637 DataType::Float32,
1638 DataType::Float16
1639 };
1640
1641 bool supported = true;
1642
1643 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1644 "Reference Instance Normalization: input type not supported.");
1645
1646 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1647 "Reference Instance Normalization: output type not supported.");
1648
1649 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1650 "Reference Instance Normalization: input and output types mismatched.");
1651
1652 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1653 "Reference Instance Normalization: input and output shapes have different "
1654 "num total elements.");
1655
1656 return supported;
1657}
1658
arovir011c7c81b2018-10-08 11:34:28 +01001659bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1660 const TensorInfo& output,
1661 const L2NormalizationDescriptor& descriptor,
1662 Optional<std::string&> reasonIfUnsupported) const
1663{
Jan Eilers8eb25602020-03-09 12:13:48 +00001664 IgnoreUnused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001665 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01001666 std::array<DataType, 6> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001667 {
1668 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001669 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001670 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001671 DataType::QAsymmU8,
1672 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001673 };
1674
1675 bool supported = true;
1676
1677 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1678 "Reference L2normalization: input type not supported.");
1679
1680 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1681 "Reference L2normalization: output type not supported.");
1682
1683 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1684 "Reference L2normalization: input and output types mismatched.");
1685
1686 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1687 "Reference L2normalization: input and output shapes have different "
1688 "num total elements.");
1689
1690 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001691}
1692
James Conroyaba90cd2020-11-06 16:28:18 +00001693bool RefLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0,
1694 const TensorInfo& input1,
1695 const TensorInfo& output,
1696 const LogicalBinaryDescriptor& descriptor,
1697 Optional<std::string&> reasonIfUnsupported) const
1698{
1699 IgnoreUnused(descriptor);
1700
1701 std::array<DataType, 1> supportedTypes =
1702 {
1703 DataType::Boolean
1704 };
1705
1706 bool supported = true;
1707 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1708 "Reference LogicalBinary: input 0 type not supported");
1709 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1710 "Reference LogicalBinary: input 1 type not supported");
1711
1712 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1713 "Reference LogicalBinary: input and output types do not match");
1714
1715 return supported;
1716}
1717
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001718bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1719 const TensorInfo& output,
1720 const LogSoftmaxDescriptor& descriptor,
1721 Optional<std::string&> reasonIfUnsupported) const
1722{
Jan Eilers8eb25602020-03-09 12:13:48 +00001723 IgnoreUnused(descriptor);
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001724
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001725 std::array<DataType, 3> supportedTypes =
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001726 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001727 DataType::Float32,
1728 DataType::Float16
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001729 };
1730
1731 bool supported = true;
1732 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1733 "Reference LogSoftmax: input type not supported");
1734
1735 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1736 "Reference LogSoftmax: output type not supported");
1737
1738 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1739 "Reference LogSoftmax: input and output types do not match");
1740
1741 return supported;
1742}
1743
arovir011c7c81b2018-10-08 11:34:28 +01001744bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1745 const TensorInfo& outputStateIn,
1746 const TensorInfo& cellStateIn,
1747 const TensorInfo& scratchBuffer,
1748 const TensorInfo& outputStateOut,
1749 const TensorInfo& cellStateOut,
1750 const TensorInfo& output,
1751 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001752 const LstmInputParamsInfo& paramsInfo,
1753 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +01001754{
Jan Eilers8eb25602020-03-09 12:13:48 +00001755 IgnoreUnused(descriptor);
1756 IgnoreUnused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001757
1758 bool supported = true;
1759
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001760 std::array<DataType,3> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +01001761 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001762 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001763 };
1764
Jan Eilersd01a83c2019-07-03 18:20:40 +01001765 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001766 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1767 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001768 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1769 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001770 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1771 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001772 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1773 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001774 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1775 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001776 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1777 "Reference Lstm: input and cellStateOut types are mismatched");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01001778
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001779 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1780 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001781 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001782 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001783 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001784 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001785 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001786 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001787 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001788 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001789 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001790 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001791 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001792 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001793 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001794 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001795 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001796 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001797 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001798 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001799 "Reference Lstm: input and OutputGateBias types are mismatched");
1800 if (!descriptor.m_CifgEnabled)
1801 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001802 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001803 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001804 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001805 reasonIfUnsupported,
1806 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001807 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001808 "Reference Lstm: input and InputGateBias types are mismatched");
1809 if (descriptor.m_PeepholeEnabled)
1810 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001811 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001812 reasonIfUnsupported,
1813 "Reference Lstm: input and CellToInputWeights types are mismatched");
1814 }
1815 }
1816 if (descriptor.m_PeepholeEnabled)
1817 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001818 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001819 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001820 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001821 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1822 }
1823 if (descriptor.m_ProjectionEnabled)
1824 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001825 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001826 "Reference Lstm: input and mProjectionWeights types are mismatched");
1827 if (paramsInfo.m_ProjectionBias != nullptr)
1828 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001829 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001830 "Reference Lstm: input and ProjectionBias types are mismatched");
1831 }
1832 }
1833 if (descriptor.m_LayerNormEnabled)
1834 {
1835 if (!descriptor.m_CifgEnabled)
1836 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001837 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001838 reasonIfUnsupported,
1839 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1840 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001841 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001842 reasonIfUnsupported,
1843 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001844 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001845 reasonIfUnsupported,
1846 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001847 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001848 reasonIfUnsupported,
1849 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1850 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001851
1852 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001853}
1854
saoste012df12b32018-11-28 16:57:20 +00001855bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1856 const TensorInfo& input1,
1857 const TensorInfo& output,
1858 Optional<std::string&> reasonIfUnsupported) const
1859{
Sadik Armagan2999a022019-04-09 14:20:12 +01001860 bool supported = true;
1861
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001862 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001863 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001864 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001865 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001866 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001867 DataType::QSymmS16,
1868 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001869 };
1870
1871 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1872 "Reference maximum: input 0 is not a supported type.");
1873
1874 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1875 "Reference maximum: input 1 is not a supported type.");
1876
1877 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1878 "Reference maximum: output is not a supported type.");
1879
1880 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1881 "Reference maximum: input 0 and Input 1 types are mismatched");
1882
1883 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1884 "Reference maximum: input and output types are mismatched");
1885
1886 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1887 "Reference maximum: shapes are not suitable for implicit broadcast.");
1888
1889 return supported;
saoste012df12b32018-11-28 16:57:20 +00001890}
1891
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001892bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1893 const TensorInfo& output,
1894 const MeanDescriptor& descriptor,
1895 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001896{
James Conroy4d1ff582019-06-10 17:06:39 +01001897 bool supported = true;
1898 std::string meanLayerStr = "Mean";
1899 std::string outputTensorStr = "output";
1900
Sadik Armagan303980c2020-04-17 12:45:14 +01001901 std::array<DataType,6> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001902 {
1903 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001904 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001905 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001906 DataType::QAsymmU8,
1907 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01001908 };
1909
1910 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1911 "Reference Mean: input type not supported.");
1912
1913 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1914 "Reference Mean: input and output types are mismatched");
1915
1916 if (descriptor.m_KeepDims)
1917 {
1918 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1919 reasonIfUnsupported,
1920 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1921 output.GetNumDimensions(),
1922 meanLayerStr, outputTensorStr).data());
1923 }
1924 else if (descriptor.m_Axis.empty())
1925 {
1926 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1927 reasonIfUnsupported,
1928 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1929 meanLayerStr, outputTensorStr).data());
1930 }
1931 else
1932 {
Matthew Sloyan171214c2020-09-09 09:07:37 +01001933 auto outputDim = input.GetNumDimensions() - armnn::numeric_cast<unsigned int>(descriptor.m_Axis.size());
James Conroy4d1ff582019-06-10 17:06:39 +01001934
1935 if (outputDim > 0)
1936 {
1937 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1938 reasonIfUnsupported,
1939 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1940 meanLayerStr, outputTensorStr).data());
1941 }
1942 else
1943 {
1944 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1945 reasonIfUnsupported,
1946 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1947 meanLayerStr, outputTensorStr).data());
1948 }
1949 }
1950
1951 return supported;
narpra0132b90462018-09-13 11:07:48 +01001952}
1953
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001954bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1955 const TensorInfo &output,
1956 Optional<std::string &> reasonIfUnsupported) const
1957{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001958 bool supported = true;
1959
Sadik Armagan303980c2020-04-17 12:45:14 +01001960 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001961 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001962 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001963 DataType::Float32,
1964 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001965 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001966 DataType::QAsymmU8,
1967 DataType::QSymmS16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001968 DataType::Boolean
1969 };
1970
1971 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1972 "Reference MemCopy: input type not supported");
1973
1974 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1975 "Reference MemCopy: output type not supported");
1976
1977 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1978 "Reference MemCopy: input and output types are mismatched");
1979
1980 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001981}
1982
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001983bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1984 const TensorInfo& input1,
1985 const TensorInfo& output,
1986 Optional<std::string&> reasonIfUnsupported) const
1987{
Sadik Armagan2999a022019-04-09 14:20:12 +01001988 bool supported = true;
1989
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001990 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001991 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001992 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001993 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001994 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001995 DataType::QSymmS16,
1996 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001997 };
1998
1999 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2000 "Reference minimum: input 0 is not a supported type.");
2001
2002 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2003 "Reference minimum: input 1 is not a supported type.");
2004
2005 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2006 "Reference minimum: output is not a supported type.");
2007
2008 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2009 "Reference minimum: input 0 and Input 1 types are mismatched");
2010
2011 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2012 "Reference minimum: input and output types are mismatched");
2013
2014 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2015 "Reference minimum: shapes are not suitable for implicit broadcast.");
2016
2017 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00002018}
2019
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002020bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
2021 const TensorInfo& input1,
2022 const TensorInfo& output,
2023 Optional<std::string&> reasonIfUnsupported) const
2024{
Sadik Armagan2999a022019-04-09 14:20:12 +01002025 bool supported = true;
2026
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002027 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01002028 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002029 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002030 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002031 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002032 DataType::QSymmS16,
2033 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002034 };
2035
2036 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2037 "Reference multiplication: input 0 is not a supported type.");
2038
2039 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2040 "Reference multiplication: input 1 is not a supported type.");
2041
2042 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2043 "Reference multiplication: output is not a supported type.");
2044
2045 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2046 "Reference multiplication: input 0 and Input 1 types are mismatched");
2047
2048 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2049 "Reference multiplication: input and output types are mismatched");
2050
2051 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2052 "Reference multiplication: shapes are not suitable for implicit broadcast.");
2053
2054 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002055}
2056
2057bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
2058 const TensorInfo& output,
2059 const NormalizationDescriptor& descriptor,
2060 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01002061{
Jan Eilers8eb25602020-03-09 12:13:48 +00002062 IgnoreUnused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002063
2064 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01002065 std::array<DataType, 6> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002066 {
2067 DataType::Float16,
2068 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002069 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002070 DataType::QAsymmU8,
2071 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002072 };
2073
2074 bool supported = true;
2075
2076 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2077 "Reference normalization: input type not supported.");
2078
2079 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2080 "Reference normalization: output type not supported.");
2081
2082 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2083 "Reference normalization: input and output shapes have different "
2084 "num total elements.");
2085
2086 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002087}
2088
Derek Lamberti901ea112019-12-10 22:07:09 +00002089bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
2090 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002091{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01002092 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002093}
2094
2095bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
2096 const TensorInfo& output,
2097 const PadDescriptor& descriptor,
2098 Optional<std::string&> reasonIfUnsupported) const
2099{
Jan Eilers8eb25602020-03-09 12:13:48 +00002100 IgnoreUnused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002101 bool supported = true;
2102
2103 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002104 std::array<DataType,6> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002105 {
2106 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002107 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002108 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002109 DataType::QAsymmU8,
2110 DataType::QSymmS16
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002111 };
2112
2113 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2114 "Reference pad: input is not a supported type.");
2115
2116 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2117 "Reference pad: output is not a supported type.");
2118
2119 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2120 "Reference pad: input and output types are mismatched.");
2121
2122 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01002123}
2124
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002125bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
2126 const TensorInfo& output,
2127 const PermuteDescriptor& descriptor,
2128 Optional<std::string&> reasonIfUnsupported) const
2129{
Jan Eilers8eb25602020-03-09 12:13:48 +00002130 IgnoreUnused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002131 bool supported = true;
2132
2133 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002134 std::array<DataType, 6> supportedTypes =
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002135 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002136 DataType::BFloat16,
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002137 DataType::Float32,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002138 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002139 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002140 DataType::QAsymmU8,
2141 DataType::QSymmS16
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002142 };
2143
2144 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2145 "Reference permute: input is not a supported type.");
2146
2147 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2148 "Reference permute: output is not a supported type.");
2149
2150 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2151 "Reference permute: input and output types are mismatched.");
2152
2153 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002154}
2155
2156bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
2157 const TensorInfo& output,
2158 const Pooling2dDescriptor& descriptor,
2159 Optional<std::string&> reasonIfUnsupported) const
2160{
Jan Eilers8eb25602020-03-09 12:13:48 +00002161 IgnoreUnused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01002162 bool supported = true;
2163
2164 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002165 std::array<DataType,6> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01002166 {
2167 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002168 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002169 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002170 DataType::QAsymmU8,
2171 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01002172 };
2173
2174 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2175 "Reference poolind2d: input is not a supported type.");
2176
2177 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2178 "Reference poolind2d: output is not a supported type.");
2179
2180 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2181 "Reference poolind2d: input and output types are mismatched.");
2182
2183 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002184}
2185
Tamás Nyíri7b885b32021-10-26 14:47:57 +01002186bool RefLayerSupport::IsPooling3dSupported(const TensorInfo& input,
2187 const TensorInfo& output,
2188 const Pooling3dDescriptor& descriptor,
2189 Optional<std::string&> reasonIfUnsupported) const
2190{
2191 IgnoreUnused(descriptor);
2192 bool supported = true;
2193
2194 // Define supported output and inputs types.
2195 std::array<DataType,6> supportedTypes =
2196 {
Tamás Nyíri7b885b32021-10-26 14:47:57 +01002197 DataType::Float32,
2198 DataType::Float16,
2199 DataType::QAsymmS8,
2200 DataType::QAsymmU8,
2201 DataType::QSymmS16
2202 };
2203
2204 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2205 "Reference poolind3d: input is not a supported type.");
2206
2207 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2208 "Reference poolind3d: output is not a supported type.");
2209
2210 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2211 "Reference poolind3d: input and output types are mismatched.");
2212
2213 return supported;
2214}
2215
2216
James Conroy4f1f8992020-04-29 20:01:10 +01002217bool RefLayerSupport::IsQLstmSupported(const TensorInfo& input,
2218 const TensorInfo& previousOutputIn,
2219 const TensorInfo& previousCellStateIn,
2220 const TensorInfo& outputStateOut,
2221 const TensorInfo& cellStateOut,
2222 const TensorInfo& output,
2223 const QLstmDescriptor& descriptor,
2224 const LstmInputParamsInfo& paramsInfo,
2225 Optional<std::string&> reasonIfUnsupported) const
2226{
2227 IgnoreUnused(input);
2228 IgnoreUnused(previousOutputIn);
2229 IgnoreUnused(previousCellStateIn);
2230 IgnoreUnused(outputStateOut);
2231 IgnoreUnused(cellStateOut);
2232 IgnoreUnused(output);
2233 IgnoreUnused(descriptor);
2234 IgnoreUnused(paramsInfo);
2235
2236 IgnoreUnused(reasonIfUnsupported);
2237
2238 return true;
2239}
2240
Derek Lamberti5f400d62019-03-25 15:41:58 +00002241bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
2242 const TensorInfo& output,
2243 Optional<std::string&> reasonIfUnsupported) const
2244{
2245 bool supported = true;
2246
Finn Williamsfd271062019-12-04 14:27:27 +00002247 // Define supported input types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002248 std::array<DataType,7> supportedInputTypes = {
Keith Davis5e51cd82020-01-29 16:52:59 +00002249 DataType::Float32,
Keith Davis3d8bc972020-02-04 09:31:47 +00002250 DataType::Float16,
Ryan OShea9add1202020-02-07 10:06:33 +00002251 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002252 DataType::QAsymmU8,
2253 DataType::QSymmS8,
2254 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002255 };
2256
2257 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
2258 "Reference quantize: input type not supported.");
2259
2260 // Define supported output types.
Ryan OShea9add1202020-02-07 10:06:33 +00002261 std::array<DataType,4> supportedOutputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00002262 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002263 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00002264 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002265 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002266 };
2267 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2268 "Reference quantize: output type not supported.");
2269
2270 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2271 "Reference quantize: input and output shapes have different num total elements.");
2272
2273 return supported;
2274}
2275
Finn Williams2605b232020-06-10 15:53:46 +01002276bool RefLayerSupport::IsRankSupported(const TensorInfo& input,
2277 const TensorInfo& output,
2278 Optional<std::string&> reasonIfUnsupported) const
2279{
2280 IgnoreUnused(input);
2281 // Define supported output types.
2282 std::array<DataType,1> supportedOutputTypes =
2283 {
2284 DataType::Signed32,
2285 };
2286
2287 return CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2288 "Reference rank: input type not supported.");
2289}
2290
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00002291bool RefLayerSupport::IsReduceSupported(const TensorInfo& input,
2292 const TensorInfo& output,
2293 const ReduceDescriptor& descriptor,
2294 Optional<std::string&> reasonIfUnsupported) const
2295{
2296 IgnoreUnused(descriptor);
2297 bool supported = true;
2298 std::array<DataType,7> supportedTypes =
2299 {
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00002300 DataType::Float32,
2301 DataType::Float16,
2302 DataType::QAsymmS8,
2303 DataType::QAsymmU8,
2304 DataType::QSymmS16,
2305 DataType::Signed32
2306 };
2307
2308 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2309 "Reference Reduce: input type not supported");
2310
2311 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2312 "Reference Reduce: output type not supported");
2313
2314 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2315 "Reference Reduce: input and output types not matching");
2316
2317 return supported;
2318}
2319
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002320bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Kevin Maya023c402019-12-12 17:28:05 +00002321 const TensorInfo& output,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00002322 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002323 Optional<std::string&> reasonIfUnsupported) const
2324{
Jan Eilers8eb25602020-03-09 12:13:48 +00002325 IgnoreUnused(output);
2326 IgnoreUnused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01002327 // Define supported output types.
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002328 std::array<DataType,8> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01002329 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002330 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01002331 DataType::Float32,
2332 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01002333 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002334 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002335 DataType::QAsymmU8,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002336 DataType::QSymmS16,
2337 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01002338 };
Keith Davis0c2eeac2020-02-11 16:51:50 +00002339
Nina Drozd2f2778f2019-05-27 10:37:05 +01002340 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
2341 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002342}
2343
Teresa Charlin970f43b2019-07-01 13:51:07 +01002344bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
2345 const TensorInfo& output,
2346 const ResizeDescriptor& descriptor,
2347 Optional<std::string&> reasonIfUnsupported) const
2348{
Jan Eilers8eb25602020-03-09 12:13:48 +00002349 IgnoreUnused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01002350 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002351 std::array<DataType,6> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01002352 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002353 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01002354 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002355 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002356 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002357 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002358 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01002359 };
2360
2361 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2362 "Reference Resize: input type not supported");
2363
2364 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2365 "Reference Resize: output type not supported");
2366
2367 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2368 "Reference Resize: input and output types not matching");
2369
2370 return supported;
2371}
2372
Keith Davis3ae3f972021-05-21 16:33:48 +01002373bool RefLayerSupport::IsShapeSupported(const TensorInfo& input,
2374 const TensorInfo& output,
2375 Optional<std::string&> reasonIfUnsupported) const
2376{
2377 IgnoreUnused(input);
2378 bool supported = true;
2379
2380 std::array<DataType, 1> supportedTypes =
2381 {
2382 DataType::Signed32
2383 };
2384
2385 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2386 "Reference Shape: output type not supported");
2387
2388 return supported;
2389}
2390
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002391bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
2392 const TensorInfo& output,
2393 const SliceDescriptor& descriptor,
2394 Optional<std::string&> reasonIfUnsupported) const
2395{
Jan Eilers8eb25602020-03-09 12:13:48 +00002396 IgnoreUnused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002397 bool supported = true;
2398
Sadik Armagan303980c2020-04-17 12:45:14 +01002399 std::array<DataType, 5> supportedTypes =
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002400 {
2401 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002402 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002403 DataType::QAsymmU8,
2404 DataType::QSymmS16
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002405 };
2406
2407 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2408 "Reference Slice: input type not supported");
2409
2410 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2411 "Reference Slice: output type not supported");
2412
2413 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2414 "Reference Slice: input and output types are mismatched");
2415
2416 return supported;
2417}
2418
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002419bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
2420 const TensorInfo& output,
2421 const SoftmaxDescriptor& descriptor,
2422 Optional<std::string&> reasonIfUnsupported) const
2423{
Jan Eilers8eb25602020-03-09 12:13:48 +00002424 IgnoreUnused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01002425 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002426 std::array<DataType,7> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01002427 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002428 DataType::Float32,
2429 DataType::Float16,
2430 DataType::QSymmS8,
2431 DataType::QAsymmS8,
2432 DataType::QAsymmU8,
2433 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +01002434 };
2435
2436 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002437 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002438
2439 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002440 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002441
2442 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002443 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002444
2445 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002446}
2447
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002448bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
2449 const TensorInfo& output,
2450 const SpaceToBatchNdDescriptor& descriptor,
2451 Optional<std::string&> reasonIfUnsupported) const
2452{
Jan Eilers8eb25602020-03-09 12:13:48 +00002453 IgnoreUnused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01002454 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002455 std::array<DataType,6> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01002456 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002457 DataType::Float32,
2458 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002459 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002460 DataType::QAsymmU8,
2461 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01002462 };
2463
2464 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2465 "Reference SpaceToBatchNd: input type not supported");
2466
2467 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2468 "Reference SpaceToBatchNd: output type not supported");
2469
2470 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2471 "Reference SpaceToBatchNd: input and output types are mismatched");
2472
2473 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002474}
2475
Keith Davisa57eccb2019-06-14 17:33:22 +01002476bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01002477 const TensorInfo& output,
2478 const SpaceToDepthDescriptor& descriptor,
2479 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01002480{
2481
Jan Eilers8eb25602020-03-09 12:13:48 +00002482 IgnoreUnused(descriptor);
Keith Davisa57eccb2019-06-14 17:33:22 +01002483 bool supported = true;
2484
Sadik Armagan303980c2020-04-17 12:45:14 +01002485 std::array<DataType,6> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01002486 {
2487 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002488 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002489 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002490 DataType::QAsymmU8,
2491 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01002492 };
2493
2494 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2495 "Reference SpaceToDepth: input type not supported");
2496
2497 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2498 "Reference SpaceToDepth: output type not supported");
2499
2500 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2501 "Reference SpaceToDepth: input and output types are mismatched");
2502
2503 return supported;
2504}
2505
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002506bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002507 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
2508 const ViewsDescriptor& descriptor,
2509 Optional<std::string&> reasonIfUnsupported) const
2510{
Jan Eilers8eb25602020-03-09 12:13:48 +00002511 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002512 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002513 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002514 {
2515 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002516 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002517 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002518 DataType::QAsymmU8,
2519 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002520 };
2521
2522 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2523 "Reference splitter: output type not supported");
Derek Lambertieac4adb2020-08-25 13:05:59 +01002524 for (const TensorInfo& output : outputs)
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002525 {
2526 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2527 "Reference splitter: input type not supported");
2528
2529 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2530 "Reference splitter: input and output types mismatched.");
2531 }
2532
2533 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002534}
2535
Matthew Jackson81e601c2019-07-11 12:07:09 +01002536bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
2537 const TensorInfo& output,
2538 const StackDescriptor& descriptor,
2539 Optional<std::string&> reasonIfUnsupported) const
2540{
Jan Eilers8eb25602020-03-09 12:13:48 +00002541 IgnoreUnused(descriptor);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002542
2543 bool supported = true;
Sadik Armagan529195f2022-01-14 12:56:35 +00002544 std::array<DataType,7> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01002545 {
2546 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01002547 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002548 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002549 DataType::QAsymmU8,
Sadik Armagan529195f2022-01-14 12:56:35 +00002550 DataType::QSymmS16,
2551 DataType::Signed32
Matthew Jackson81e601c2019-07-11 12:07:09 +01002552 };
2553
2554 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2555 "Reference stack: output type not supported");
2556 for (const TensorInfo* input : inputs)
2557 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01002558 ARMNN_ASSERT(input != nullptr);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002559 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
2560 "Reference stack: input type not supported");
2561
2562 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
2563 "Reference stack: input and output types mismatched.");
2564 }
2565
2566 return supported;
2567}
2568
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002569bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
2570 const TensorInfo& output,
2571 const StridedSliceDescriptor& descriptor,
2572 Optional<std::string&> reasonIfUnsupported) const
2573{
Jan Eilers8eb25602020-03-09 12:13:48 +00002574 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002575 bool supported = true;
2576
Sadik Armagan303980c2020-04-17 12:45:14 +01002577 std::array<DataType,5> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002578 {
2579 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002580 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002581 DataType::QAsymmU8,
2582 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002583 };
2584
2585 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2586 "Reference StridedSlice: input type not supported");
2587
2588 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2589 "Reference StridedSlice: output type not supported");
2590
2591 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2592 "Reference StridedSlice: input and output types are mismatched");
2593
2594 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002595}
2596
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002597bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
2598 const TensorInfo& input1,
2599 const TensorInfo& output,
2600 Optional<std::string&> reasonIfUnsupported) const
2601{
Sadik Armagan2999a022019-04-09 14:20:12 +01002602 bool supported = true;
2603
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002604 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01002605 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002606 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002607 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002608 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002609 DataType::QSymmS16,
2610 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002611 };
2612
2613 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2614 "Reference subtraction: input 0 is not a supported type.");
2615
2616 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2617 "Reference subtraction: input 1 is not a supported type.");
2618
2619 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2620 "Reference subtraction: output is not a supported type.");
2621
2622 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2623 "Reference subtraction: input 0 and Input 1 types are mismatched");
2624
2625 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2626 "Reference subtraction: input and output types are mismatched");
2627
2628 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2629 "Reference subtraction: shapes are not suitable for implicit broadcast.");
2630
2631 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002632}
2633
Matteo Martincighab9e5252019-06-13 17:27:46 +01002634bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
2635 const TensorInfo& alpha,
2636 const TensorInfo& output,
2637 Optional<std::string&> reasonIfUnsupported) const
2638{
2639 bool supported = true;
2640
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002641 std::array<DataType, 6> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01002642 {
2643 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002644 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002645 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002646 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002647 DataType::QSymmS16
Matteo Martincighab9e5252019-06-13 17:27:46 +01002648 };
2649
2650 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2651 "PReLU: input is not a supported type.");
2652
2653 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
2654 "PReLU: alpha is not a supported type.");
2655
2656 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2657 "PReLU: output is not a supported type.");
2658
2659 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
2660 "PReLU: input, alpha and output types are mismatched");
2661
2662 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
2663 "PReLU: shapes are not suitable for implicit broadcast");
2664
2665 return supported;
2666}
2667
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002668bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
2669 const TensorInfo& output,
2670 const TransposeConvolution2dDescriptor& descriptor,
2671 const TensorInfo& weights,
2672 const Optional<TensorInfo>& biases,
2673 Optional<std::string&> reasonIfUnsupported) const
2674{
Jan Eilers8eb25602020-03-09 12:13:48 +00002675 IgnoreUnused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002676 bool supported = true;
2677
Sadik Armagan303980c2020-04-17 12:45:14 +01002678 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002679 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002680 DataType::Float32,
2681 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002682 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002683 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002684 DataType::QSymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002685 DataType::QSymmS16
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002686 };
2687
2688 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2689 "Reference TransposeConvolution2d: input is not a supported type.");
2690
2691 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2692 "Reference TransposeConvolution2d: output is not a supported type.");
2693
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002694 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2695 "Reference TransposeConvolution2d: input and output types mismatched.");
2696
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002697
2698 const DataType inputType = input.GetDataType();
Sadik Armagan303980c2020-04-17 12:45:14 +01002699 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002700 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01002701 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002702 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002703 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002704 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01002705 DataType::QSymmS8
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002706 };
2707
2708 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
2709 "Reference TransposeConvolution2d: weights type not supported for "
2710 "quantized input.");
2711 }
2712 else
2713 {
2714 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
2715 "Reference TransposeConvolution2d: weights is not a supported type.");
2716
2717 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
2718 "Reference TransposeConvolution2d: input and weights types mismatched.");
2719 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002720
2721 if (biases.has_value())
2722 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002723 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01002724 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002725 DataType::Float32,
2726 DataType::Float16,
2727 DataType::Signed32
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002728 };
2729 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
2730 "Reference TransposeConvolution2d: biases is not a supported type.");
2731 }
2732
2733 return supported;
2734}
2735
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002736bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
2737 const TensorInfo& output,
2738 const TransposeDescriptor& descriptor,
2739 Optional<std::string&> reasonIfUnsupported) const
2740{
Jan Eilers8eb25602020-03-09 12:13:48 +00002741 IgnoreUnused(descriptor);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002742 bool supported = true;
2743
2744 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002745 std::array<DataType, 6> supportedTypes =
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002746 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002747 DataType::BFloat16,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002748 DataType::Float32,
2749 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002750 DataType::QAsymmS8,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002751 DataType::QAsymmU8,
2752 DataType::QSymmS16
2753 };
2754
2755 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2756 "Reference transpose: input is not a supported type.");
2757
2758 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2759 "Reference transpose: output is not a supported type.");
2760
2761 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2762 "Reference transpose: input and output types are mismatched.");
2763
2764 return supported;
2765}
2766
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002767bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
2768 const TensorInfo& input,
2769 const TensorInfo& outputStateIn,
2770 const TensorInfo& cellStateIn,
Mike Kelly12994962022-04-21 11:57:09 +01002771 const TensorInfo& outputStateOut,
2772 const TensorInfo& cellStateOut,
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002773 const TensorInfo& output,
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002774 const UnidirectionalSequenceLstmDescriptor& descriptor,
2775 const LstmInputParamsInfo& paramsInfo,
2776 Optional<std::string&> reasonIfUnsupported) const
2777{
2778 IgnoreUnused(descriptor);
2779 IgnoreUnused(paramsInfo);
2780 IgnoreUnused(outputStateIn);
2781 IgnoreUnused(cellStateIn);
Mike Kelly12994962022-04-21 11:57:09 +01002782 IgnoreUnused(outputStateOut);
2783 IgnoreUnused(cellStateOut);
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002784 bool supported = true;
2785
Mike Kelly12994962022-04-21 11:57:09 +01002786 std::array<DataType, 2> supportedTypes =
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002787 {
Mike Kelly12994962022-04-21 11:57:09 +01002788 DataType::Float32,
2789 DataType::QAsymmS8
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002790 };
2791
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002792 std::array<DataType, 2> supportedWeightTypes =
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002793 {
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002794 DataType::Float32,
2795 DataType::QAsymmS8
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002796 };
2797
Mike Kelly12994962022-04-21 11:57:09 +01002798 std::array<DataType, 3> supportedBiasTypes =
2799 {
2800 DataType::Float32,
2801 DataType::QAsymmS8,
2802 DataType::Signed32
2803 };
2804
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002805 // check inputs and outputs
2806 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2807 "Reference UnidirectionalSequenceLstm: input is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002808 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2809 "Reference UnidirectionalSequenceLstm: output is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002810
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002811 // check layer parameters
2812 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes),
2813 reasonIfUnsupported,
2814 "Reference UnidirectionalSequenceLstm: InputToForgetWeights "
2815 "is not a supported type.");
2816 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToCellWeights(), supportedWeightTypes),
2817 reasonIfUnsupported,
2818 "Reference UnidirectionalSequenceLstm: InputToCellWeights is not a supported type.");
2819 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToOutputWeights(), supportedWeightTypes),
2820 reasonIfUnsupported,
2821 "Reference UnidirectionalSequenceLstm: InputToOutputWeights "
2822 "is not a supported type.");
2823 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToForgetWeights(), supportedWeightTypes),
2824 reasonIfUnsupported,
2825 "Reference UnidirectionalSequenceLstm: RecurrentToForgetWeights "
2826 "is not a supported type.");
2827 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToCellWeights(), supportedWeightTypes),
2828 reasonIfUnsupported,
2829 "Reference UnidirectionalSequenceLstm: RecurrentToCellWeights "
2830 "is not a supported type.");
2831 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToOutputWeights(), supportedWeightTypes),
2832 reasonIfUnsupported,
2833 "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights "
2834 "is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002835
2836 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetGateBias(), supportedBiasTypes), reasonIfUnsupported,
2837 "Reference UnidirectionalSequenceLstm: ForgetGateBias is not a supported type.");
2838 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellBias(), supportedBiasTypes), reasonIfUnsupported,
2839 "Reference UnidirectionalSequenceLstm: CellBias is not a supported type.");
2840 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputGateBias(), supportedBiasTypes), reasonIfUnsupported,
2841 "Reference UnidirectionalSequenceLstm: OutputGateBias is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002842 if (!descriptor.m_CifgEnabled)
2843 {
2844 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes),
2845 reasonIfUnsupported,
2846 "Reference UnidirectionalSequenceLstm: InputToInputWeights "
2847 "is not a supported type.");
2848 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToInputWeights(), supportedWeightTypes),
2849 reasonIfUnsupported,
2850 "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights "
2851 "is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002852 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputGateBias(), supportedBiasTypes), reasonIfUnsupported,
2853 "Reference UnidirectionalSequenceLstm: InputGateBias is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002854 if (descriptor.m_PeepholeEnabled)
2855 {
2856 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes),
2857 reasonIfUnsupported,
2858 "Reference UnidirectionalSequenceLstm: CellToInputWeights "
2859 "is not a supported type.");
2860 }
2861 }
2862 if (descriptor.m_PeepholeEnabled)
2863 {
2864 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToForgetWeights(), supportedWeightTypes),
2865 reasonIfUnsupported,
2866 "Reference UnidirectionalSequenceLstm: CellToForgetWeights "
2867 "is not a supported type.");
2868 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToOutputWeights(), supportedWeightTypes),
2869 reasonIfUnsupported,
2870 "Reference UnidirectionalSequenceLstm: CellToOutputWeights "
2871 "is not a supported type.");
2872 }
2873 if (descriptor.m_ProjectionEnabled)
2874 {
2875 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetProjectionWeights(), supportedWeightTypes),
2876 reasonIfUnsupported,
2877 "Reference UnidirectionalSequenceLstm: ProjectionWeights "
2878 "is not a supported type.");
2879 if (paramsInfo.m_ProjectionBias != nullptr)
2880 {
2881 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
2882 "Reference UnidirectionalSequenceLstm: input and ProjectionBias types "
2883 "are mismatched");
2884 }
2885 }
2886 if (descriptor.m_LayerNormEnabled)
2887 {
2888 if (!descriptor.m_CifgEnabled)
2889 {
2890 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputLayerNormWeights(), supportedWeightTypes),
2891 reasonIfUnsupported,
2892 "Reference UnidirectionalSequenceLstm: InputLayerNormWeights "
2893 "is not a supported type.");
2894 }
2895 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetLayerNormWeights(), supportedWeightTypes),
2896 reasonIfUnsupported,
2897 "Reference UnidirectionalSequenceLstm: ForgetLayerNormWeights "
2898 "is not a supported type.");
2899 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellLayerNormWeights(), supportedWeightTypes),
2900 reasonIfUnsupported,
2901 "Reference UnidirectionalSequenceLstm: CellLayerNormWeights "
2902 "is not a supported type.");
2903 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputLayerNormWeights(), supportedWeightTypes),
2904 reasonIfUnsupported,
2905 "Reference UnidirectionalSequenceLstm: OutputLayerNormWeights "
2906 "is not a supported type.");
2907 }
2908
2909 return supported;
2910}
2911
arovir011c7c81b2018-10-08 11:34:28 +01002912} // namespace armnn