blob: 1d5fab1adc966c48d89dae897c43d8d431b1d780 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
Mike Kelly3ec30772023-03-08 13:47:17 +00002// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01007
Keith Davis0c2eeac2020-02-11 16:51:50 +00008#include <armnn/TypesUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000010#include <armnn/utility/IgnoreUnused.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010011#include <armnn/utility/NumericCast.hpp>
Cathal Corbett34b429c2021-12-24 12:24:40 +000012#include <armnn/utility/PolymorphicDowncast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
Matteo Martincighe011d202019-11-28 11:35:47 +000014#include <LayerSupportCommon.hpp>
Derek Lambertif674aa02019-08-01 15:56:25 +010015#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000016
Derek Lamberti50db4e82019-03-13 14:16:15 +000017#include <vector>
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <array>
19
telsoa014fcda012018-03-09 14:13:49 +000020namespace armnn
21{
22
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010023namespace
24{
25
26template<typename Float32Func, typename Uint8Func, typename ... Params>
27bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
28 DataType dataType,
29 Float32Func floatFuncPtr,
30 Uint8Func uint8FuncPtr,
31 Params&&... params)
32{
33 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
34 dataType,
35 &FalseFunc<Params...>,
36 floatFuncPtr,
37 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000038 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000039 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010040 std::forward<Params>(params)...);
41}
42
43} // anonymous namespace
44
James Conroy4d1ff582019-06-10 17:06:39 +010045namespace
46{
47
48std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
49 unsigned int actual,
50 std::string& layerStr,
51 std::string& tensorName)
52{
53 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
54 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
55
56 return errorMsg;
57}
58
59} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000060
Cathal Corbett34b429c2021-12-24 12:24:40 +000061bool RefLayerSupport::IsLayerSupported(const LayerType& type,
62 const std::vector<TensorInfo>& infos,
63 const BaseDescriptor& descriptor,
64 const Optional<LstmInputParamsInfo>& lstmParamsInfo,
65 const Optional<QuantizedLstmInputParamsInfo>& quantizedLstmInputParamsInfo,
66 Optional<std::string&> reasonIfUnsupported) const
67{
68 switch (type)
69 {
70 case LayerType::Activation:
71 return IsActivationSupported(infos[0],
72 infos[1],
73 *(PolymorphicDowncast<const ActivationDescriptor*>(&descriptor)),
74 reasonIfUnsupported);
75 case LayerType::Addition:
76 return IsAdditionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
77 case LayerType::ArgMinMax:
78 return IsArgMinMaxSupported(infos[0],
79 infos[1],
80 *(PolymorphicDowncast<const ArgMinMaxDescriptor*>(&descriptor)),
81 reasonIfUnsupported);
Samuel Yap6b478092022-07-06 15:36:03 +010082 case LayerType::BatchMatMul:
83 return IsBatchMatMulSupported(infos[0],
84 infos[1],
85 infos[2],
86 *(PolymorphicDowncast<const BatchMatMulDescriptor*>(&descriptor)),
87 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +000088 case LayerType::BatchNormalization:
89 return IsBatchNormalizationSupported(infos[0],
90 infos[1],
91 infos[2],
92 infos[3],
93 infos[4],
94 infos[5],
95 *(PolymorphicDowncast<const BatchNormalizationDescriptor*>
96 (&descriptor)),
97 reasonIfUnsupported);
98 case LayerType::BatchToSpaceNd:
99 return IsBatchToSpaceNdSupported(infos[0],
100 infos[1],
101 *(PolymorphicDowncast<const BatchToSpaceNdDescriptor*>(&descriptor)),
102 reasonIfUnsupported);
103 case LayerType::Comparison:
104 return IsComparisonSupported(infos[0],
105 infos[1],
106 infos[2],
107 *(PolymorphicDowncast<const ComparisonDescriptor*>(&descriptor)),
108 reasonIfUnsupported);
109 case LayerType::Concat:
110 {
111 std::vector<const TensorInfo*> inputInfos;
112 for (uint32_t i = 0; i < (infos.size() - 1); i++)
113 {
114 inputInfos.push_back(&infos[i]);
115 }
116 return IsConcatSupported(inputInfos,
117 infos[infos.size() - 1],
118 *(PolymorphicDowncast<const OriginsDescriptor*>(&descriptor)),
119 reasonIfUnsupported);
120 }
121 case LayerType::Constant:
122 return IsConstantSupported(infos[0], reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000123 case LayerType::ConvertFp16ToFp32:
124 return IsConvertFp16ToFp32Supported(infos[0], infos[1], reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000125 case LayerType::ConvertFp32ToFp16:
126 return IsConvertFp32ToFp16Supported(infos[0], infos[1], reasonIfUnsupported);
127 case LayerType::Convolution2d:
128 {
129 if (infos.size() != 4)
130 {
131 throw InvalidArgumentException("Invalid number of Convolution2d TensorInfos. "
132 "TensorInfos should be of format: {input, output, weights, biases}.");
133 }
134
135 auto desc = *(PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor));
136 if (infos[3] == TensorInfo())
137 {
138 return IsConvolution2dSupported(infos[0],
139 infos[1],
140 desc,
141 infos[2],
142 EmptyOptional(),
143 reasonIfUnsupported);
144 }
145 else
146 {
147 return IsConvolution2dSupported(infos[0],
148 infos[1],
149 desc,
150 infos[2],
151 infos[3],
152 reasonIfUnsupported);
153 }
154 }
155 case LayerType::DepthToSpace:
156 return IsDepthToSpaceSupported(infos[0],
157 infos[1],
158 *(PolymorphicDowncast<const DepthToSpaceDescriptor*>(&descriptor)),
159 reasonIfUnsupported);
160 case LayerType::DepthwiseConvolution2d:
161 {
162 if (infos.size() != 4)
163 {
164 throw InvalidArgumentException("Invalid number of DepthwiseConvolution2d TensorInfos. "
165 "TensorInfos should be of format: {input, output, weights, biases}.");
166 }
167
168 auto desc = *(PolymorphicDowncast<const DepthwiseConvolution2dDescriptor*>(&descriptor));
169 if (infos[3] == TensorInfo())
170 {
171 return IsDepthwiseConvolutionSupported(infos[0],
172 infos[1],
173 desc,
174 infos[2],
175 EmptyOptional(),
176 reasonIfUnsupported);
177 }
178 else
179 {
180 return IsDepthwiseConvolutionSupported(infos[0],
181 infos[1],
182 desc,
183 infos[2],
184 infos[3],
185 reasonIfUnsupported);
186 }
187 }
188 case LayerType::Dequantize:
189 return IsDequantizeSupported(infos[0], infos[1], reasonIfUnsupported);
190 case LayerType::Division:
191 return IsDivisionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
Mike Kelly3ec30772023-03-08 13:47:17 +0000192 case LayerType::ElementwiseBinary:
193 {
194 std::array<DataType, 7> supportedTypes =
195 {
196 DataType::Float32,
197 DataType::Float16,
198 DataType::QAsymmS8,
199 DataType::QAsymmU8,
200 DataType::QSymmS16,
201 DataType::Signed32
202 };
203
204 bool supported = true;
205 supported &= CheckSupportRule(TypeAnyOf(infos[0], supportedTypes), reasonIfUnsupported,
206 "Reference elementwise unary: input type not supported");
207
208 supported &= CheckSupportRule(TypeAnyOf(infos[1], supportedTypes), reasonIfUnsupported,
209 "Reference elementwise unary: input type not supported");
210
211 supported &= CheckSupportRule(TypeAnyOf(infos[2], supportedTypes), reasonIfUnsupported,
212 "Reference elementwise unary: output type not supported");
213
214 supported &= CheckSupportRule(TypesAreEqual(infos[0], infos[1]), reasonIfUnsupported,
215 "Reference elementwise unary: input types not matching");
216
217 supported &= CheckSupportRule(TypesAreEqual(infos[0], infos[2]), reasonIfUnsupported,
218 "Reference elementwise unary: input and output types not matching");
219
220 return supported;
221 }
Cathal Corbett34b429c2021-12-24 12:24:40 +0000222 case LayerType::ElementwiseUnary:
223 return IsElementwiseUnarySupported(infos[0],
224 infos[1],
225 *(PolymorphicDowncast<const ElementwiseUnaryDescriptor*>(&descriptor)),
226 reasonIfUnsupported);
227 case LayerType::Fill:
228 return IsFillSupported(infos[0],
229 infos[1],
230 *(PolymorphicDowncast<const FillDescriptor*>(&descriptor)),
231 reasonIfUnsupported);
232 case LayerType::Floor:
233 return IsFloorSupported(infos[0], infos[1], reasonIfUnsupported);
234 case LayerType::FullyConnected:
235 return IsFullyConnectedSupported(infos[0],
236 infos[1],
237 infos[2],
238 infos[3],
239 *(PolymorphicDowncast<const FullyConnectedDescriptor*>(&descriptor)),
240 reasonIfUnsupported);
241 case LayerType::Gather:
242 return IsGatherSupported(infos[0],
243 infos[1],
244 infos[2],
245 *(PolymorphicDowncast<const GatherDescriptor*>(&descriptor)),
246 reasonIfUnsupported);
Teresa Charlinb2d3ec52022-04-12 22:07:09 +0100247 case LayerType::GatherNd:
248 return IsGatherNdSupported(infos[0],
249 infos[1],
250 infos[2],
251 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000252 case LayerType::Input:
253 return IsInputSupported(infos[0], reasonIfUnsupported);
254 case LayerType::InstanceNormalization:
255 return IsInstanceNormalizationSupported(infos[0],
256 infos[1],
257 *(PolymorphicDowncast<const InstanceNormalizationDescriptor*>
258 (&descriptor)),
259 reasonIfUnsupported);
260 case LayerType::L2Normalization:
261 return IsL2NormalizationSupported(infos[0],
262 infos[1],
263 *(PolymorphicDowncast<const L2NormalizationDescriptor*>(&descriptor)),
264 reasonIfUnsupported);
265 case LayerType::LogicalBinary:
266 return IsLogicalBinarySupported(infos[0],
267 infos[1],
268 infos[2],
269 *(PolymorphicDowncast<const LogicalBinaryDescriptor*>(&descriptor)),
270 reasonIfUnsupported);
271 case LayerType::LogSoftmax:
272 return IsLogSoftmaxSupported(infos[0],
273 infos[1],
274 *(PolymorphicDowncast<const LogSoftmaxDescriptor*>(&descriptor)),
275 reasonIfUnsupported);
276 case LayerType::Lstm:
277 return IsLstmSupported(infos[0],
278 infos[1],
279 infos[2],
280 infos[3],
281 infos[4],
282 infos[5],
283 infos[6],
284 *(PolymorphicDowncast<const LstmDescriptor*>(&descriptor)),
285 lstmParamsInfo.value(),
286 reasonIfUnsupported);
287 case LayerType::QLstm:
288 return IsQLstmSupported(infos[0],
289 infos[1],
290 infos[2],
291 infos[3],
292 infos[4],
293 infos[5],
294 *(PolymorphicDowncast<const QLstmDescriptor*>(&descriptor)),
295 lstmParamsInfo.value(),
296 reasonIfUnsupported);
297 case LayerType::Maximum:
298 return IsMaximumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
299 case LayerType::Mean:
300 return IsMeanSupported(infos[0],
301 infos[1],
302 *(PolymorphicDowncast<const MeanDescriptor*>(&descriptor)),
303 reasonIfUnsupported);
304 case LayerType::Minimum:
305 return IsMinimumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
306 case LayerType::Multiplication:
307 return IsMultiplicationSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
308 case LayerType::Normalization:
309 return IsNormalizationSupported(infos[0],
310 infos[1],
311 *(PolymorphicDowncast<const NormalizationDescriptor*>(&descriptor)),
312 reasonIfUnsupported);
313 case LayerType::Output:
314 return IsOutputSupported(infos[0], reasonIfUnsupported);
315 case LayerType::Pad:
316 return IsPadSupported(infos[0],
317 infos[1],
318 *(PolymorphicDowncast<const PadDescriptor*>(&descriptor)),
319 reasonIfUnsupported);
320 case LayerType::Permute:
321 return IsPermuteSupported(infos[0],
322 infos[1],
323 *(PolymorphicDowncast<const PermuteDescriptor*>(&descriptor)),
324 reasonIfUnsupported);
325 case LayerType::Pooling2d:
326 return IsPooling2dSupported(infos[0],
327 infos[1],
328 *(PolymorphicDowncast<const Pooling2dDescriptor*>(&descriptor)),
329 reasonIfUnsupported);
330 case LayerType::Prelu:
331 return IsPreluSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
332 case LayerType::Quantize:
333 return IsQuantizeSupported(infos[0], infos[1], reasonIfUnsupported);
334 case LayerType::Reshape:
335 return IsReshapeSupported(infos[0],
336 infos[1],
337 *(PolymorphicDowncast<const ReshapeDescriptor*>(&descriptor)),
338 reasonIfUnsupported);
339 case LayerType::Resize:
340 return IsResizeSupported(infos[0],
341 infos[1],
342 *(PolymorphicDowncast<const ResizeDescriptor*>(&descriptor)),
343 reasonIfUnsupported);
Tianle Cheng988354d2023-06-28 13:20:47 +0100344 case LayerType::ReverseV2:
345 return IsReverseV2Supported(infos[0],
346 infos[1],
347 *(PolymorphicDowncast<const ReverseV2Descriptor*>(&descriptor)),
348 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000349 case LayerType::Reduce:
350 return IsReduceSupported(infos[0],
351 infos[1],
352 *(PolymorphicDowncast<const ReduceDescriptor*>(&descriptor)),
353 reasonIfUnsupported);
354 case LayerType::Slice:
355 return IsSliceSupported(infos[0],
356 infos[1],
357 *(PolymorphicDowncast<const SliceDescriptor*>(&descriptor)),
358 reasonIfUnsupported);
359 case LayerType::Softmax:
360 return IsSoftmaxSupported(infos[0],
361 infos[1],
362 *(PolymorphicDowncast<const SoftmaxDescriptor*>(&descriptor)),
363 reasonIfUnsupported);
364 case LayerType::SpaceToBatchNd:
365 return IsSpaceToBatchNdSupported(infos[0],
366 infos[1],
367 *(PolymorphicDowncast<const SpaceToBatchNdDescriptor*>(&descriptor)),
368 reasonIfUnsupported);
369 case LayerType::SpaceToDepth:
370 return IsSpaceToDepthSupported(infos[0],
371 infos[1],
372 *(PolymorphicDowncast<const SpaceToDepthDescriptor*>(&descriptor)),
373 reasonIfUnsupported);
374 case LayerType::Splitter:
375 {
376 std::vector<TensorInfo> outputInfos;
377 for (uint32_t i = 1; i < infos.size(); i++)
378 {
379 outputInfos.push_back(infos[i]);
380 }
381 return IsSplitterSupported(infos[0],
382 {outputInfos.begin(), outputInfos.end()},
383 *(PolymorphicDowncast<const ViewsDescriptor*>(&descriptor)),
384 reasonIfUnsupported);
385 }
386 case LayerType::Stack:
387 {
388 std::vector<const TensorInfo*> inputInfos;
389 for (uint32_t i = 0; i < infos.size() - 1; i++)
390 {
391 inputInfos.push_back(&infos[i]);
392 }
393 return IsStackSupported(inputInfos,
394 infos[infos.size() - 1],
395 *(PolymorphicDowncast<const StackDescriptor*>(&descriptor)),
396 reasonIfUnsupported);
397 }
398 case LayerType::StridedSlice:
399 return IsStridedSliceSupported(infos[0],
400 infos[1],
401 *(PolymorphicDowncast<const StridedSliceDescriptor*>(&descriptor)),
402 reasonIfUnsupported);
403 case LayerType::Subtraction:
404 return IsSubtractionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
405 case LayerType::Transpose:
406 return IsTransposeSupported(infos[0],
407 infos[1],
408 *(PolymorphicDowncast<const TransposeDescriptor*>(&descriptor)),
409 reasonIfUnsupported);
410 case LayerType::TransposeConvolution2d:
411 {
412 if (infos.size() != 4)
413 {
414 throw InvalidArgumentException("Invalid number of TransposeConvolution2d TensorInfos. "
415 "TensorInfos should be of format: {input, output, weights, biases}.");
416 }
417
418 auto desc = *(PolymorphicDowncast<const TransposeConvolution2dDescriptor*>(&descriptor));
419 if (infos[3] == TensorInfo())
420 {
421 return IsTransposeConvolution2dSupported(infos[0],
422 infos[1],
423 desc,
424 infos[2],
425 EmptyOptional(),
426 reasonIfUnsupported);
427 }
428 else
429 {
430 return IsTransposeConvolution2dSupported(infos[0],
431 infos[1],
432 desc,
433 infos[2],
434 infos[3],
435 reasonIfUnsupported);
436 }
437 }
438 case LayerType::Cast:
439 return IsCastSupported(infos[0], infos[1], reasonIfUnsupported);
440 case LayerType::ChannelShuffle:
441 return IsChannelShuffleSupported(infos[0],
442 infos[1],
443 *(PolymorphicDowncast<const ChannelShuffleDescriptor*>(&descriptor)),
444 reasonIfUnsupported);
445 case LayerType::Convolution3d:
446 {
447 if (infos.size() != 4)
448 {
449 throw InvalidArgumentException("Invalid number of Convolution3d TensorInfos. "
450 "TensorInfos should be of format: {input, output, weights, biases}.");
451 }
452
453 auto desc = *(PolymorphicDowncast<const Convolution3dDescriptor*>(&descriptor));
454 if (infos[3] == TensorInfo())
455 {
456 return IsConvolution3dSupported(infos[0],
457 infos[1],
458 desc,
459 infos[2],
460 EmptyOptional(),
461 reasonIfUnsupported);
462 }
463 else
464 {
465 return IsConvolution3dSupported(infos[0],
466 infos[1],
467 desc,
468 infos[2],
469 infos[3],
470 reasonIfUnsupported);
471 }
472 }
473 case LayerType::Debug:
474 return IsDebugSupported(infos[0], infos[1], reasonIfUnsupported);
475 case LayerType::DetectionPostProcess:
476 return IsDetectionPostProcessSupported(infos[0],
477 infos[1],
478 infos[2],
479 infos[3],
480 infos[4],
481 infos[5],
482 infos[6],
483 *(PolymorphicDowncast<const DetectionPostProcessDescriptor*>
484 (&descriptor)),
485 reasonIfUnsupported);
486 case LayerType::FakeQuantization:
487 return IsFakeQuantizationSupported(infos[0],
488 *(PolymorphicDowncast<const FakeQuantizationDescriptor*>(&descriptor)),
489 reasonIfUnsupported);
490 case LayerType::MemCopy:
491 return IsMemCopySupported(infos[0], infos[1], reasonIfUnsupported);
492 case LayerType::Rank:
493 return IsRankSupported(infos[0], infos[1], reasonIfUnsupported);
494 case LayerType::Shape:
495 return IsShapeSupported(infos[0], infos[1], reasonIfUnsupported);
496 case LayerType::UnidirectionalSequenceLstm:
497 {
498 if (infos.size() != 6)
499 {
500 throw InvalidArgumentException("Invalid number of UnidirectionalSequenceLstm TensorInfos. TensorInfos "
501 "should be of format: {input, outputStateIn, cellStateIn, "
502 "hiddenStateOutputVal, cellStateOutputVal, output}");
503 }
504 auto desc = *(PolymorphicDowncast<const UnidirectionalSequenceLstmDescriptor*>(&descriptor));
Mike Kelly12994962022-04-21 11:57:09 +0100505 return IsUnidirectionalSequenceLstmSupported(infos[0],
506 infos[1],
507 infos[2],
508 infos[3],
509 infos[4],
510 infos[5],
511 desc,
512 lstmParamsInfo.value(),
513 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000514 }
515 case LayerType::Pooling3d:
516 return IsPooling3dSupported(infos[0],
517 infos[1],
518 *(PolymorphicDowncast<const Pooling3dDescriptor*>(&descriptor)),
519 reasonIfUnsupported);
520 case LayerType::Map:
521 return true;
522 case LayerType::Unmap:
523 return true;
524 case LayerType::MemImport:
525 return LayerSupportBase::IsMemImportSupported(infos[0], infos[1], reasonIfUnsupported);
526 case LayerType::Merge:
527 return LayerSupportBase::IsMergeSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
528 case LayerType::QuantizedLstm:
529 return LayerSupportBase::IsQuantizedLstmSupported(infos[0],
530 infos[1],
531 infos[2],
532 infos[3],
533 infos[4],
534 quantizedLstmInputParamsInfo.value(),
535 reasonIfUnsupported);
536 default:
537 // layers not supported in neon by default:
538 // precompiled, standin, switch
539 return false;
540 }
541}
542
arovir011c7c81b2018-10-08 11:34:28 +0100543bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
544 const TensorInfo& output,
545 const ActivationDescriptor& descriptor,
546 Optional<std::string&> reasonIfUnsupported) const
547{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000548 bool supported = true;
549
550 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +0000551 std::array<DataType,6> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000552 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100553 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000554 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000555 DataType::QAsymmU8,
556 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000557 };
558
559 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
560 "Reference activation: input type not supported.");
561
562 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
563 "Reference activation: output type not supported.");
564
565 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
566 "Reference activation: input and output types mismatched.");
567
568 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
569 "Reference activation: input and output shapes are of different rank.");
570
571
572 struct ActivationFunctionSupported : public Rule
573 {
574 ActivationFunctionSupported(const ActivationDescriptor& desc)
575 {
576 switch(desc.m_Function)
577 {
578 case ActivationFunction::Abs:
579 case ActivationFunction::BoundedReLu:
David Monahan3b3c3812020-02-25 09:03:29 +0000580 case ActivationFunction::Elu:
Colm Donelan03fbeaf2020-02-26 15:39:23 +0000581 case ActivationFunction::HardSwish:
Derek Lamberti50db4e82019-03-13 14:16:15 +0000582 case ActivationFunction::LeakyReLu:
583 case ActivationFunction::Linear:
584 case ActivationFunction::ReLu:
585 case ActivationFunction::Sigmoid:
586 case ActivationFunction::SoftReLu:
587 case ActivationFunction::Sqrt:
588 case ActivationFunction::Square:
589 case ActivationFunction::TanH:
590 {
591 m_Res = true;
592 break;
593 }
594 default:
595 {
596 m_Res = false;
597 break;
598 }
599 }
600 }
601 };
602
603 // Function is supported
604 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
605 "Reference activation: function not supported.");
606
607 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100608}
609
610bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
611 const TensorInfo& input1,
612 const TensorInfo& output,
613 Optional<std::string&> reasonIfUnsupported) const
614{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000615 bool supported = true;
616
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100617 std::array<DataType,7> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000618 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100619 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000620 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000621 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100622 DataType::QSymmS16,
623 DataType::Signed32
Derek Lamberti50db4e82019-03-13 14:16:15 +0000624 };
625
626 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
627 "Reference addition: input 0 is not a supported type.");
628
629 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
630 "Reference addition: input 1 is not a supported type.");
631
632 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
633 "Reference addition: output is not a supported type.");
634
635 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
636 "Reference addition: input 0 and Input 1 types are mismatched");
637
638 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
639 "Reference addition: input and output types are mismatched");
640
641 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
642 "Reference addition: shapes are not suitable for implicit broadcast.");
643
644 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100645}
646
Nikhil Raj68c2c902019-09-19 11:21:11 +0100647bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
648 const armnn::ArgMinMaxDescriptor &descriptor,
649 armnn::Optional<std::string &> reasonIfUnsupported) const
650{
Jan Eilers8eb25602020-03-09 12:13:48 +0000651 IgnoreUnused(descriptor);
Nikhil Raj68c2c902019-09-19 11:21:11 +0100652
Mike Kelly1f140f72021-04-06 12:25:55 +0100653 std::array<DataType, 8> supportedInputTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100654 {
Teresa Charline300b362020-05-25 10:01:03 +0100655 DataType::Float16,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100656 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100657 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000658 DataType::QAsymmU8,
659 DataType::QSymmS16,
Mike Kelly1f140f72021-04-06 12:25:55 +0100660 DataType::Signed32,
661 DataType::Signed64
662 };
663
664 std::array<DataType,2> supportedOutputTypes = {
665 DataType::Signed32,
666 DataType::Signed64
Nikhil Raj68c2c902019-09-19 11:21:11 +0100667 };
668
669 bool supported = true;
670
Mike Kelly1f140f72021-04-06 12:25:55 +0100671 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100672 "Reference ArgMinMax: input is not a supported type.");
Mike Kelly1f140f72021-04-06 12:25:55 +0100673 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100674 "Reference ArgMinMax: output type not supported");
675
676 return supported;
677}
678
Samuel Yap6b478092022-07-06 15:36:03 +0100679bool RefLayerSupport::IsBatchMatMulSupported(const TensorInfo& inputX,
680 const TensorInfo& inputY,
681 const TensorInfo& output,
682 const BatchMatMulDescriptor& descriptor,
683 Optional<std::string &> reasonIfUnsupported) const
684{
685 IgnoreUnused(descriptor);
686
687 std::array<DataType, 6> supportedTypes =
688 {
Samuel Yap6b478092022-07-06 15:36:03 +0100689 DataType::Float16,
690 DataType::Float32,
691 DataType::QAsymmS8,
692 DataType::QAsymmU8,
693 DataType::QSymmS16
694 };
695
696 bool supported = true;
697
698 supported &= CheckSupportRule(TypeAnyOf(inputX, supportedTypes), reasonIfUnsupported,
699 "Reference batch matrix multiplication: input X is not a supported type");
700
701 supported &= CheckSupportRule(TypeAnyOf(inputY, supportedTypes), reasonIfUnsupported,
702 "Reference batch matrix multiplication: input Y is not a supported type");
703
704 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
705 "Reference batch matrix multiplication: output is not a supported type");
706
707 supported &= CheckSupportRule(TypesAreEqual(inputX, inputY), reasonIfUnsupported,
708 "Reference batch matrix multiplication: input X and input Y types are mismatched");
709
710 supported &= CheckSupportRule(TypesAreEqual(inputX, output), reasonIfUnsupported,
711 "Reference batch matrix multiplication: inputs and output types are mismatched");
712
713 supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputX, 2),
714 reasonIfUnsupported,
715 "Reference batch matrix multiplication: input X is not of rank 2 or greater");
716
717 supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputY, 2),
718 reasonIfUnsupported,
719 "Reference batch matrix multiplication: input Y is not of rank 2 or greater");
720
721 return supported;
722}
723
arovir011c7c81b2018-10-08 11:34:28 +0100724bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
725 const TensorInfo& output,
726 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100727 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100728 const TensorInfo& beta,
729 const TensorInfo& gamma,
730 const BatchNormalizationDescriptor& descriptor,
731 Optional<std::string&> reasonIfUnsupported) const
732{
Jan Eilers8eb25602020-03-09 12:13:48 +0000733 IgnoreUnused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100734
Sadik Armagan303980c2020-04-17 12:45:14 +0100735 std::array<DataType, 6> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100736 {
737 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100738 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100739 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000740 DataType::QAsymmU8,
741 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100742 };
743
744 bool supported = true;
745
746 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
747 "Reference batch normalization: input is not a supported type.");
748
749 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
750 "Reference batch normalization: output is not a supported type.");
751
752 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
753 "Reference batch normalization: input and output types are mismatched");
754
755 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
756 "Reference batch normalization: mean is not a supported type.");
757
758 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
759 "Reference batch normalization: variance is not a supported type.");
760
761 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
762 "Reference batch normalization: beta is not a supported type.");
763
764 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
765 "Reference batch normalization: gamma is not a supported type.");
766
767 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100768}
769
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000770bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
771 const TensorInfo& output,
772 const BatchToSpaceNdDescriptor& descriptor,
773 Optional<std::string&> reasonIfUnsupported) const
774{
Jan Eilers8eb25602020-03-09 12:13:48 +0000775 IgnoreUnused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100776
777 bool supported = true;
778
779 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
780 std::string inputTensorStr = "input";
781 std::string outputTensorStr = "output";
782
783 // Define supported types.
Sadik Armagan303980c2020-04-17 12:45:14 +0100784 std::array<DataType,6> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100785 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000786 DataType::Float32,
787 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100788 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000789 DataType::QAsymmU8,
790 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100791 };
792
793 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
794 "Reference BatchToSpaceNd: input type not supported.");
795
796 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
797 "Reference BatchToSpaceNd: output type not supported.");
798
799 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
800 "Reference BatchToSpaceNd: input and output types mismatched.");
801
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100802 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000803}
804
mathad01b392e982021-04-07 12:07:30 +0100805bool RefLayerSupport::IsCastSupported(const TensorInfo& input,
806 const TensorInfo& output,
807 Optional<std::string&> reasonIfUnsupported) const
808{
809 std::array<DataType, 9> supportedInputTypes =
810 {
mathad01b392e982021-04-07 12:07:30 +0100811 DataType::Float32,
812 DataType::Float16,
813 DataType::QSymmS8,
814 DataType::QAsymmS8,
815 DataType::QAsymmU8,
816 DataType::QSymmS16,
817 DataType::Signed32
818 };
819
820 bool supported = true;
821 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
822 "Reference cast: input is not a supported type");
823
824
825 supported &= CheckSupportRule(TypeAnyOf(output, supportedInputTypes), reasonIfUnsupported,
826 "Reference cast: output is not a supported type");
827
828 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
829 "Reference cast: input and output shapes have different number of total elements");
830
831 return supported;
832}
833
Simon Obute51f67772021-09-03 15:50:13 +0100834bool RefLayerSupport::IsChannelShuffleSupported(const TensorInfo& input,
835 const TensorInfo& output,
836 const ChannelShuffleDescriptor& descriptor,
837 Optional<std::string&> reasonIfUnsupported) const
838{
839 IgnoreUnused(descriptor);
840 bool supported = true;
841
842 // Define supported output and inputs types.
843 std::array<DataType, 7> supportedTypes =
844 {
Simon Obute51f67772021-09-03 15:50:13 +0100845 DataType::Float32,
846 DataType::Float16,
847 DataType::QAsymmS8,
848 DataType::QAsymmU8,
849 DataType::QSymmS8,
850 DataType::QSymmS16
851 };
852
853 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
854 "Reference ChannelShuffle: input is not a supported type.");
855
856 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
857 "Reference ChannelShuffle: output is not a supported type.");
858
859 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
860 "Reference ChannelShuffle: input and output types are mismatched.");
861
862 return supported;
863}
864
865
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100866bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
867 const TensorInfo& input1,
868 const TensorInfo& output,
869 const ComparisonDescriptor& descriptor,
870 Optional<std::string&> reasonIfUnsupported) const
871{
Jan Eilers8eb25602020-03-09 12:13:48 +0000872 IgnoreUnused(descriptor);
Sadik Armagan303980c2020-04-17 12:45:14 +0100873 std::array<DataType, 8> supportedInputTypes =
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100874 {
Sadik Armaganb60dd242020-03-19 13:53:16 +0000875 DataType::Boolean,
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100876 DataType::Float32,
877 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100878 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000879 DataType::QAsymmU8,
Sadik Armaganb60dd242020-03-19 13:53:16 +0000880 DataType::QSymmS16,
881 DataType::Signed32
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100882 };
883
884 bool supported = true;
885 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
886 "Reference comparison: input 0 is not a supported type");
887
888 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
889 "Reference comparison: input 0 and Input 1 types are mismatched");
890
891 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
892 "Reference comparison: output is not of type Boolean");
893
894 return supported;
895}
896
Jim Flynn906f9462019-05-10 13:55:21 +0100897bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
898 const TensorInfo& output,
Cathal Corbett34b429c2021-12-24 12:24:40 +0000899 const OriginsDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100900 Optional<std::string&> reasonIfUnsupported) const
901{
Jan Eilers8eb25602020-03-09 12:13:48 +0000902 IgnoreUnused(descriptor);
Jim Flynne242f2d2019-05-22 14:24:13 +0100903
904 bool supported = true;
Teresa Charlin6abc7ee2022-02-22 17:32:27 +0000905 std::array<DataType,7> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100906 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000907 DataType::Float32,
908 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000909 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100910 DataType::QAsymmU8,
Teresa Charlin6abc7ee2022-02-22 17:32:27 +0000911 DataType::QSymmS16,
912 DataType::Signed32
Jim Flynne242f2d2019-05-22 14:24:13 +0100913 };
914
915 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
916 "Reference concatenation: output type not supported");
917 for (const TensorInfo* input : inputs)
918 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100919 ARMNN_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100920 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
921 "Reference concatenation: input type not supported");
922
923 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
924 "Reference concatenation: input and output types mismatched.");
925 }
926
927 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100928}
929
arovir011c7c81b2018-10-08 11:34:28 +0100930bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
931 Optional<std::string&> reasonIfUnsupported) const
932{
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100933 std::array<DataType,8> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100934 {
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100935 DataType::Float16,
Nina Drozd58ef2c62019-05-16 12:09:18 +0100936 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +0000937 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100938 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000939 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100940 DataType::QSymmS16,
941 DataType::Signed32
Nina Drozd58ef2c62019-05-16 12:09:18 +0100942 };
943
944 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
945 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100946}
947
948bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
949 const TensorInfo& output,
950 Optional<std::string&> reasonIfUnsupported) const
951{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100952 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
953 input.GetDataType(),
954 &TrueFunc<>,
955 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000956 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000957 &FalseFuncI32<>,
958 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100959 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
960 output.GetDataType(),
961 &FalseOutputFuncF16<>,
962 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000963 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000964 &FalseFuncI32<>,
965 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100966}
967
968bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
969 const TensorInfo& output,
970 Optional<std::string&> reasonIfUnsupported) const
971{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100972 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
973 input.GetDataType(),
974 &FalseInputFuncF16<>,
975 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000976 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000977 &FalseFuncI32<>,
978 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100979 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
980 output.GetDataType(),
981 &TrueFunc<>,
982 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000983 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000984 &FalseFuncI32<>,
985 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100986}
987
988bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
989 const TensorInfo& output,
990 const Convolution2dDescriptor& descriptor,
991 const TensorInfo& weights,
992 const Optional<TensorInfo>& biases,
993 Optional<std::string&> reasonIfUnsupported) const
994{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100995 bool supported = true;
996
997 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000998 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000999 {
1000 DataType::Float32,
1001 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001002 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001003 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001004 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001005 DataType::QSymmS16
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001006 };
1007
1008 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001009 "Reference Convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001010
1011 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001012 "Reference Convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001013
Ryan OShea31441592022-11-07 16:20:48 +00001014 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1015 "Reference Convolution2d: input and output types mismatched.");
1016
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001017
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001018 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +00001019 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001020 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001021 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001022 {
Sadik Armagan303980c2020-04-17 12:45:14 +01001023 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001024 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01001025 DataType::QSymmS8
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001026 };
1027
1028 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001029 "Reference Convolution2d: weights type not supported for quantized input.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001030 }
1031 else
1032 {
1033 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001034 "Reference Convolution2d: weights is not a supported type.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001035
1036 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001037 "Reference Convolution2d: input and weights types mismatched.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001038 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001039
1040 if (biases.has_value())
1041 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001042 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001043 {
1044 DataType::Float32,
1045 DataType::Float16,
1046 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001047 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001048
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001049 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001050 "Reference Convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001051 }
Jan Eilers8eb25602020-03-09 12:13:48 +00001052 IgnoreUnused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001053
1054 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001055}
1056
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001057bool RefLayerSupport::IsConvolution3dSupported(const TensorInfo& input,
1058 const TensorInfo& output,
1059 const Convolution3dDescriptor& descriptor,
1060 const TensorInfo& weights,
1061 const Optional<TensorInfo>& biases,
1062 Optional<std::string&> reasonIfUnsupported) const
1063{
1064 bool supported = true;
1065
1066 // Define supported types.
1067 std::array<DataType,7> supportedTypes =
1068 {
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001069 DataType::Float32,
1070 DataType::Float16,
1071 DataType::QAsymmS8,
1072 DataType::QAsymmU8,
1073 DataType::QSymmS8,
1074 DataType::QSymmS16
1075 };
1076
1077 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1078 "Reference Convolution3d: input is not a supported type.");
1079
1080 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1081 "Reference Convolution3d: output is not a supported type.");
1082
1083 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1084 "Reference Convolution3d: input and output types mismatched.");
1085
1086 const DataType inputType = input.GetDataType();
1087 if (IsQuantized8BitType(inputType))
1088 {
1089 std::array<DataType, 3> supportedWeightTypes =
1090 {
1091 DataType::QAsymmS8,
1092 DataType::QAsymmU8,
1093 DataType::QSymmS8
1094 };
1095
1096 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1097 "Reference Convolution3d: weights type not supported for quantized input.");
1098 }
1099 else
1100 {
1101 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1102 "Reference Convolution3d: weights is not a supported type.");
1103
1104 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1105 "Reference Convolution3d: input and weights types mismatched.");
1106 }
1107
1108 if (biases.has_value())
1109 {
1110 std::array<DataType,4> biasesSupportedTypes =
1111 {
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001112 DataType::Float32,
1113 DataType::Float16,
1114 DataType::Signed32
1115 };
1116
1117 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1118 "Reference Convolution3d: biases is not a supported type.");
1119 }
1120 IgnoreUnused(descriptor);
1121
1122 return supported;
1123}
1124
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001125bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
1126 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001127 Optional<std::string&> reasonIfUnsupported) const
1128{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001129 bool supported = true;
1130
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001131 std::array<DataType, 8> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001132 {
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001133 DataType::BFloat16,
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +00001134 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001135 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001136 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001137 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001138 DataType::QSymmS8,
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +00001139 DataType::QSymmS16,
1140 DataType::Signed32
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001141 };
1142
1143 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001144 "Reference for Debug layer: input type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001145
1146 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001147 "Reference for Debug layer: output type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001148
1149 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001150 "Reference for Debug layer: input and output types are mismatched");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001151
1152 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001153}
1154
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001155bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
1156 const TensorInfo& output,
1157 const DepthToSpaceDescriptor& descriptor,
1158 Optional<std::string&> reasonIfUnsupported) const
1159{
Jan Eilers8eb25602020-03-09 12:13:48 +00001160 IgnoreUnused(descriptor);
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001161 bool supported = true;
1162
Sadik Armagan303980c2020-04-17 12:45:14 +01001163 std::array<DataType,6> supportedTypes =
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001164 {
1165 DataType::Float32,
1166 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001167 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001168 DataType::QAsymmU8,
1169 DataType::QSymmS16
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001170 };
1171
1172 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1173 "Reference DepthToSpace: input type not supported");
1174
1175 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1176 "Reference DepthToSpace: output type not supported");
1177
1178 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1179 "Reference DepthToSpace: input and output types are mismatched");
1180
1181 return supported;
1182}
1183
arovir011c7c81b2018-10-08 11:34:28 +01001184bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
1185 const TensorInfo& output,
1186 const DepthwiseConvolution2dDescriptor& descriptor,
1187 const TensorInfo& weights,
1188 const Optional<TensorInfo>& biases,
1189 Optional<std::string&> reasonIfUnsupported) const
1190{
Sadik Armagan303980c2020-04-17 12:45:14 +01001191 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001192 bool supported = true;
1193
1194 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001195 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001196 {
1197 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001198 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001199 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001200 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001201 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001202 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001203 };
1204
1205 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1206 "Reference DepthwiseConvolution2d: input is not a supported type.");
1207
1208 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1209 "Reference DepthwiseConvolution2d: output is not a supported type.");
1210
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001211 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1212 "Reference DepthwiseConvolution2d: input and output types mismatched.");
1213
Teresa Charlind8df0262019-11-11 12:28:15 +00001214 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +00001215 if (IsQuantized8BitType(inputType))
Teresa Charlind8df0262019-11-11 12:28:15 +00001216 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001217 std::array<DataType, 3> supportedWeightTypes =
Sadik Armagan303980c2020-04-17 12:45:14 +01001218 {
1219 DataType::QAsymmS8,
1220 DataType::QAsymmU8,
1221 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001222 };
Teresa Charlind8df0262019-11-11 12:28:15 +00001223
1224 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Sadik Armagan303980c2020-04-17 12:45:14 +01001225 "Reference DepthwiseConvolution2d: weights type not supported for "
1226 "quantized input.");
Teresa Charlind8df0262019-11-11 12:28:15 +00001227 }
1228 else
1229 {
1230 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1231 "Reference DepthwiseConvolution2d: weights is not a supported type.");
1232
1233 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1234 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
1235 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001236
1237 if (biases.has_value())
1238 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001239 std::array<DataType,4> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001240 {
1241 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001242 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001243 DataType::Signed32
1244 };
1245 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1246 "Reference DepthwiseConvolution2d: biases is not a supported type.");
1247 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001248
1249 return supported;
1250
arovir011c7c81b2018-10-08 11:34:28 +01001251}
1252
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001253bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
1254 const TensorInfo& output,
1255 Optional<std::string&> reasonIfUnsupported) const
1256{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001257 bool supported = true;
1258
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001259 std::array<DataType,5> supportedInputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00001260 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001261 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00001262 DataType::QSymmS8,
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001263 DataType::QSymmS16,
1264 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001265 };
1266
1267 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001268 "Reference for Dequantize layer: input type not supported.");
1269
Derek Lambertid466a542020-01-22 15:37:29 +00001270 supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
Teresa Charlin1b1950d2021-06-02 20:23:21 +01001271 "Reference for Dequantize layer: per-axis quantized input not supported.");
Derek Lambertid466a542020-01-22 15:37:29 +00001272
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001273 std::array<DataType,3> supportedOutputTypes = {
Jan Eilersf7107932019-11-01 11:09:36 +00001274 DataType::Float32,
1275 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001276 };
1277
1278 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001279 "Reference for Dequantize layer: output type not supported.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001280
1281 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001282 "Reference for Dequantize layer: input/output shapes have different num total "
1283 "elements.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001284
1285 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001286}
1287
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001288bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
1289 const TensorInfo& scores,
1290 const TensorInfo& anchors,
1291 const TensorInfo& detectionBoxes,
1292 const TensorInfo& detectionClasses,
1293 const TensorInfo& detectionScores,
1294 const TensorInfo& numDetections,
1295 const DetectionPostProcessDescriptor& descriptor,
1296 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001297{
Jan Eilers8eb25602020-03-09 12:13:48 +00001298 IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
Derek Lamberti901ea112019-12-10 22:07:09 +00001299
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001300 bool supported = true;
1301
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001302 std::array<DataType,6> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001303 {
1304 DataType::Float32,
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001305 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001306 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001307 DataType::QAsymmU8,
1308 DataType::QSymmS16
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001309 };
1310
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001311 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001312 "Reference DetectionPostProcess: input 0 is not a supported type.");
1313
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001314 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001315 "Reference DetectionPostProcess: input 1 is not a supported type.");
1316
1317 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001318}
1319
Pablo Tellof0bd6832019-04-26 17:58:13 +01001320bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
1321 const TensorInfo& output,
1322 const DepthwiseConvolution2dDescriptor& descriptor,
1323 const TensorInfo& weights,
1324 const Optional<TensorInfo>& biases,
1325 Optional<std::string&> reasonIfUnsupported) const
1326{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001327 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +01001328}
1329
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001330bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +01001331 const TensorInfo& input1,
1332 const TensorInfo& output,
1333 Optional<std::string&> reasonIfUnsupported) const
1334{
Sadik Armagan2999a022019-04-09 14:20:12 +01001335 bool supported = true;
1336
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001337 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001338 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001339 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001340 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001341 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001342 DataType::QSymmS16,
1343 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001344 };
1345
1346 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1347 "Reference division: input 0 is not a supported type.");
1348
1349 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1350 "Reference division: input 1 is not a supported type.");
1351
1352 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1353 "Reference division: output is not a supported type.");
1354
1355 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1356 "Reference division: input 0 and Input 1 types are mismatched");
1357
1358 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1359 "Reference division: input and output types are mismatched");
1360
1361 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1362 "Reference division: shapes are not suitable for implicit broadcast.");
1363
1364 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001365}
1366
josh minor4a3c6102020-01-06 16:40:46 -06001367bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
1368 const TensorInfo& output,
1369 const ElementwiseUnaryDescriptor& descriptor,
1370 Optional<std::string&> reasonIfUnsupported) const
1371{
Jan Eilers8eb25602020-03-09 12:13:48 +00001372 IgnoreUnused(descriptor);
josh minor4a3c6102020-01-06 16:40:46 -06001373
Sadik Armagan303980c2020-04-17 12:45:14 +01001374 std::array<DataType, 7> supportedTypes =
josh minor4a3c6102020-01-06 16:40:46 -06001375 {
1376 DataType::Float32,
1377 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001378 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06001379 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00001380 DataType::QSymmS16,
1381 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06001382 };
1383
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001384 std::array<DataType, 1> logicalSupportedTypes =
1385 {
1386 DataType::Boolean
1387 };
1388
josh minor4a3c6102020-01-06 16:40:46 -06001389 bool supported = true;
1390
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001391 if (descriptor.m_Operation == UnaryOperation::LogicalNot)
1392 {
1393 supported &= CheckSupportRule(TypeAnyOf(input, logicalSupportedTypes), reasonIfUnsupported,
1394 "Reference elementwise unary: input type not supported");
josh minor4a3c6102020-01-06 16:40:46 -06001395
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001396 supported &= CheckSupportRule(TypeAnyOf(output, logicalSupportedTypes), reasonIfUnsupported,
1397 "Reference elementwise unary: output type not supported");
1398 }
1399 else
1400 {
1401 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1402 "Reference elementwise unary: input type not supported");
1403
1404 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1405 "Reference elementwise unary: output type not supported");
1406 }
josh minor4a3c6102020-01-06 16:40:46 -06001407
1408 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1409 "Reference elementwise unary: input and output types not matching");
1410
1411 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1412 "Reference elementwise unary: input and output shapes"
1413 "have different number of total elements");
1414
1415 return supported;
1416}
1417
arovir011c7c81b2018-10-08 11:34:28 +01001418bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
1419 const FakeQuantizationDescriptor& descriptor,
1420 Optional<std::string&> reasonIfUnsupported) const
1421{
Jan Eilers8eb25602020-03-09 12:13:48 +00001422 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001423 bool supported = true;
1424
1425 std::array<DataType,1> supportedTypes =
1426 {
1427 DataType::Float32
1428 };
1429
1430 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1431 "Reference fake quantization: input type not supported.");
1432
1433 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001434}
1435
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001436bool RefLayerSupport::IsFillSupported(const TensorInfo& input,
1437 const TensorInfo& output,
1438 const FillDescriptor& descriptor,
1439 Optional<std::string&> reasonIfUnsupported) const
1440{
1441 IgnoreUnused(descriptor);
1442 IgnoreUnused(output);
1443
1444 bool supported = true;
1445
Sadik Armagana792a052020-06-23 16:22:23 +01001446 std::array<DataType,3> supportedTypes =
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001447 {
1448 DataType::Float32,
Sadik Armagana792a052020-06-23 16:22:23 +01001449 DataType::Float16,
1450 DataType::Signed32
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001451 };
1452
Teresa Charlin4b10fef2020-07-29 09:36:41 +01001453 supported &= CheckSupportRule(TypeIs(input, DataType::Signed32), reasonIfUnsupported,
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001454 "Reference Fill: input type not supported.");
1455
Teresa Charlin44088502020-07-27 11:27:19 +01001456 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1457 "Reference Fill: output type not supported.");
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001458 return supported;
1459}
1460
arovir011c7c81b2018-10-08 11:34:28 +01001461bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
1462 const TensorInfo& output,
1463 Optional<std::string&> reasonIfUnsupported) const
1464{
Jan Eilers8eb25602020-03-09 12:13:48 +00001465 IgnoreUnused(output);
James Conroy83735b12019-05-30 16:36:59 +01001466 bool supported = true;
1467
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01001468 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +01001469 {
James Conroyb40d7102019-06-04 12:32:09 +01001470 DataType::Float32,
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01001471 DataType::Float16
James Conroy83735b12019-05-30 16:36:59 +01001472 };
1473
1474 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1475 "Reference Floor: input type not supported.");
1476
1477 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1478 "Reference Floor: output type not supported.");
1479
1480 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001481}
1482
1483bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
1484 const TensorInfo& output,
1485 const TensorInfo& weights,
1486 const TensorInfo& biases,
1487 const FullyConnectedDescriptor& descriptor,
1488 Optional<std::string&> reasonIfUnsupported) const
1489{
Francis Murtagh46c09d02019-05-28 08:15:28 +01001490 bool supported = true;
1491
1492 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001493 std::array<DataType,6> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +01001494 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001495 DataType::Float32,
1496 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001497 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001498 DataType::QAsymmU8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001499 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001500 };
1501
1502 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1503 "Reference Fully Connected: input type not supported.");
1504
1505 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1506 "Reference Fully Connected: output type not supported.");
1507
Francis Murtagh46c09d02019-05-28 08:15:28 +01001508 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1509 "Reference Fully Connected: weights type not supported.");
1510
Ryan OShea31441592022-11-07 16:20:48 +00001511 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1512 "Reference Fully Connected: input and output types mismatched.");
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001513
Jan Eilers1f45dc32020-06-15 11:43:03 +01001514 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1515 "Reference Fully Connected: weights is not a supported type.");
Francis Murtaghddb1d062020-03-10 13:51:45 +00001516
Jan Eilers1f45dc32020-06-15 11:43:03 +01001517 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1518 "Reference Fully Connected: input and weights types mismatched.");
Francis Murtagh46c09d02019-05-28 08:15:28 +01001519
1520 if (descriptor.m_BiasEnabled)
1521 {
1522 // Defined supported types for bias
Sadik Armagandb73c982020-04-01 17:35:30 +01001523 std::array<DataType, 5>
Francis Murtagh46c09d02019-05-28 08:15:28 +01001524 supportedBiasTypes =
1525 {
1526 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001527 DataType::Float16,
Sadik Armagandb73c982020-04-01 17:35:30 +01001528 DataType::Signed32,
1529 DataType::QAsymmS8
Francis Murtagh46c09d02019-05-28 08:15:28 +01001530 };
1531
1532 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
1533 "Reference Fully Connected: bias type not supported.");
1534
1535 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
1536 "Reference Fully Connected: bias and weight types mismatch.");
1537
1538 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
1539 "Reference Fully Connected: bias type inferred from weights is incompatible.");
1540
Narumol Prangnawarat366d7232020-04-29 12:58:17 +01001541 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(biases, 1U), reasonIfUnsupported,
1542 "Reference Fully Connected: bias must have 1 dimension.");
1543
Francis Murtagh46c09d02019-05-28 08:15:28 +01001544 }
1545
1546 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001547}
1548
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01001549bool RefLayerSupport::IsGatherNdSupported(const armnn::TensorInfo& input0,
1550 const armnn::TensorInfo& input1,
1551 const armnn::TensorInfo& output,
1552 armnn::Optional<std::string&> reasonIfUnsupported) const
1553{
1554 bool supported = true;
1555 std::array<DataType,7> supportedTypes =
1556 {
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01001557 DataType::Float32,
1558 DataType::Float16,
1559 DataType::QAsymmS8,
1560 DataType::QAsymmU8,
1561 DataType::QSymmS16,
1562 DataType::Signed32
1563 };
1564
1565 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1566 "Reference GatherNd: input type not supported");
1567
1568 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1569 "Reference GatherNd: output type not supported");
1570
1571 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1572 "Reference GatherNd: indices (input1) type not supported");
1573
1574 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1575 "Reference GatherNd: input and output types not matching");
1576
1577 return supported;
1578}
1579
narpra014951d842019-01-18 16:53:53 +00001580bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
1581 const armnn::TensorInfo& input1,
1582 const armnn::TensorInfo& output,
Teresa Charlin52664732020-06-29 16:27:03 +01001583 const GatherDescriptor& descriptor,
narpra014951d842019-01-18 16:53:53 +00001584 armnn::Optional<std::string&> reasonIfUnsupported) const
1585{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001586 bool supported = true;
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001587 std::array<DataType,7> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001588 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001589 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001590 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001591 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001592 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001593 DataType::QSymmS16,
1594 DataType::Signed32
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001595 };
1596
Nikhil Raj369d8fc2022-11-24 13:12:36 +00001597 IgnoreUnused(descriptor);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001598 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1599 "Reference Gather: input type not supported");
1600
1601 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1602 "Reference Gather: output type not supported");
1603
1604 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1605 "Reference Gather: indices (input1) type not supported");
1606
1607 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1608 "Reference Gather: input and output types not matching");
1609
1610 return supported;
narpra014951d842019-01-18 16:53:53 +00001611}
1612
Derek Lamberti901ea112019-12-10 22:07:09 +00001613bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
1614 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +01001615{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001616 return true;
arovir011c7c81b2018-10-08 11:34:28 +01001617}
1618
Kevin May09ca49c2019-10-09 12:37:34 +01001619bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
1620 const TensorInfo& output,
1621 const InstanceNormalizationDescriptor& descriptor,
1622 Optional<std::string&> reasonIfUnsupported) const
1623{
Jan Eilers8eb25602020-03-09 12:13:48 +00001624 IgnoreUnused(descriptor);
Kevin May09ca49c2019-10-09 12:37:34 +01001625 // Define supported types
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001626 std::array<DataType, 3> supportedTypes =
Kevin May09ca49c2019-10-09 12:37:34 +01001627 {
1628 DataType::Float32,
1629 DataType::Float16
1630 };
1631
1632 bool supported = true;
1633
1634 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1635 "Reference Instance Normalization: input type not supported.");
1636
1637 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1638 "Reference Instance Normalization: output type not supported.");
1639
1640 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1641 "Reference Instance Normalization: input and output types mismatched.");
1642
1643 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1644 "Reference Instance Normalization: input and output shapes have different "
1645 "num total elements.");
1646
1647 return supported;
1648}
1649
arovir011c7c81b2018-10-08 11:34:28 +01001650bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1651 const TensorInfo& output,
1652 const L2NormalizationDescriptor& descriptor,
1653 Optional<std::string&> reasonIfUnsupported) const
1654{
Jan Eilers8eb25602020-03-09 12:13:48 +00001655 IgnoreUnused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001656 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01001657 std::array<DataType, 6> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001658 {
1659 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001660 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001661 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001662 DataType::QAsymmU8,
1663 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001664 };
1665
1666 bool supported = true;
1667
1668 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1669 "Reference L2normalization: input type not supported.");
1670
1671 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1672 "Reference L2normalization: output type not supported.");
1673
1674 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1675 "Reference L2normalization: input and output types mismatched.");
1676
1677 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1678 "Reference L2normalization: input and output shapes have different "
1679 "num total elements.");
1680
1681 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001682}
1683
James Conroyaba90cd2020-11-06 16:28:18 +00001684bool RefLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0,
1685 const TensorInfo& input1,
1686 const TensorInfo& output,
1687 const LogicalBinaryDescriptor& descriptor,
1688 Optional<std::string&> reasonIfUnsupported) const
1689{
1690 IgnoreUnused(descriptor);
1691
1692 std::array<DataType, 1> supportedTypes =
1693 {
1694 DataType::Boolean
1695 };
1696
1697 bool supported = true;
1698 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1699 "Reference LogicalBinary: input 0 type not supported");
1700 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1701 "Reference LogicalBinary: input 1 type not supported");
1702
1703 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1704 "Reference LogicalBinary: input and output types do not match");
1705
1706 return supported;
1707}
1708
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001709bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1710 const TensorInfo& output,
1711 const LogSoftmaxDescriptor& descriptor,
1712 Optional<std::string&> reasonIfUnsupported) const
1713{
Jan Eilers8eb25602020-03-09 12:13:48 +00001714 IgnoreUnused(descriptor);
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001715
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001716 std::array<DataType, 3> supportedTypes =
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001717 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001718 DataType::Float32,
1719 DataType::Float16
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001720 };
1721
1722 bool supported = true;
1723 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1724 "Reference LogSoftmax: input type not supported");
1725
1726 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1727 "Reference LogSoftmax: output type not supported");
1728
1729 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1730 "Reference LogSoftmax: input and output types do not match");
1731
1732 return supported;
1733}
1734
arovir011c7c81b2018-10-08 11:34:28 +01001735bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1736 const TensorInfo& outputStateIn,
1737 const TensorInfo& cellStateIn,
1738 const TensorInfo& scratchBuffer,
1739 const TensorInfo& outputStateOut,
1740 const TensorInfo& cellStateOut,
1741 const TensorInfo& output,
1742 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001743 const LstmInputParamsInfo& paramsInfo,
1744 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +01001745{
Jan Eilers8eb25602020-03-09 12:13:48 +00001746 IgnoreUnused(descriptor);
1747 IgnoreUnused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001748
1749 bool supported = true;
1750
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001751 std::array<DataType,3> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +01001752 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001753 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001754 };
1755
Jan Eilersd01a83c2019-07-03 18:20:40 +01001756 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001757 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1758 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001759 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1760 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001761 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1762 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001763 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1764 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001765 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1766 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001767 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1768 "Reference Lstm: input and cellStateOut types are mismatched");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01001769
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001770 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1771 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001772 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001773 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001774 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001775 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001776 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001777 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001778 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001779 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001780 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001781 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001782 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001783 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001784 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001785 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001786 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001787 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001788 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001789 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001790 "Reference Lstm: input and OutputGateBias types are mismatched");
1791 if (!descriptor.m_CifgEnabled)
1792 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001793 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001794 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001795 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001796 reasonIfUnsupported,
1797 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001798 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001799 "Reference Lstm: input and InputGateBias types are mismatched");
1800 if (descriptor.m_PeepholeEnabled)
1801 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001802 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001803 reasonIfUnsupported,
1804 "Reference Lstm: input and CellToInputWeights types are mismatched");
1805 }
1806 }
1807 if (descriptor.m_PeepholeEnabled)
1808 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001809 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001810 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001811 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001812 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1813 }
1814 if (descriptor.m_ProjectionEnabled)
1815 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001816 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001817 "Reference Lstm: input and mProjectionWeights types are mismatched");
1818 if (paramsInfo.m_ProjectionBias != nullptr)
1819 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001820 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001821 "Reference Lstm: input and ProjectionBias types are mismatched");
1822 }
1823 }
1824 if (descriptor.m_LayerNormEnabled)
1825 {
1826 if (!descriptor.m_CifgEnabled)
1827 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001828 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001829 reasonIfUnsupported,
1830 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1831 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001832 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001833 reasonIfUnsupported,
1834 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001835 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001836 reasonIfUnsupported,
1837 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001838 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001839 reasonIfUnsupported,
1840 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1841 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001842
1843 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001844}
1845
saoste012df12b32018-11-28 16:57:20 +00001846bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1847 const TensorInfo& input1,
1848 const TensorInfo& output,
1849 Optional<std::string&> reasonIfUnsupported) const
1850{
Sadik Armagan2999a022019-04-09 14:20:12 +01001851 bool supported = true;
1852
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001853 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001854 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001855 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001856 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001857 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001858 DataType::QSymmS16,
1859 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001860 };
1861
1862 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1863 "Reference maximum: input 0 is not a supported type.");
1864
1865 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1866 "Reference maximum: input 1 is not a supported type.");
1867
1868 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1869 "Reference maximum: output is not a supported type.");
1870
1871 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1872 "Reference maximum: input 0 and Input 1 types are mismatched");
1873
1874 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1875 "Reference maximum: input and output types are mismatched");
1876
1877 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1878 "Reference maximum: shapes are not suitable for implicit broadcast.");
1879
1880 return supported;
saoste012df12b32018-11-28 16:57:20 +00001881}
1882
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001883bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1884 const TensorInfo& output,
1885 const MeanDescriptor& descriptor,
1886 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001887{
James Conroy4d1ff582019-06-10 17:06:39 +01001888 bool supported = true;
1889 std::string meanLayerStr = "Mean";
1890 std::string outputTensorStr = "output";
1891
Sadik Armagan303980c2020-04-17 12:45:14 +01001892 std::array<DataType,6> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001893 {
1894 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001895 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001896 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001897 DataType::QAsymmU8,
1898 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01001899 };
1900
1901 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1902 "Reference Mean: input type not supported.");
1903
1904 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1905 "Reference Mean: input and output types are mismatched");
1906
1907 if (descriptor.m_KeepDims)
1908 {
1909 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1910 reasonIfUnsupported,
1911 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1912 output.GetNumDimensions(),
1913 meanLayerStr, outputTensorStr).data());
1914 }
1915 else if (descriptor.m_Axis.empty())
1916 {
1917 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1918 reasonIfUnsupported,
1919 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1920 meanLayerStr, outputTensorStr).data());
1921 }
1922 else
1923 {
Matthew Sloyan171214c2020-09-09 09:07:37 +01001924 auto outputDim = input.GetNumDimensions() - armnn::numeric_cast<unsigned int>(descriptor.m_Axis.size());
James Conroy4d1ff582019-06-10 17:06:39 +01001925
1926 if (outputDim > 0)
1927 {
1928 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1929 reasonIfUnsupported,
1930 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1931 meanLayerStr, outputTensorStr).data());
1932 }
1933 else
1934 {
1935 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1936 reasonIfUnsupported,
1937 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1938 meanLayerStr, outputTensorStr).data());
1939 }
1940 }
1941
1942 return supported;
narpra0132b90462018-09-13 11:07:48 +01001943}
1944
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001945bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1946 const TensorInfo &output,
1947 Optional<std::string &> reasonIfUnsupported) const
1948{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001949 bool supported = true;
1950
Sadik Armagan303980c2020-04-17 12:45:14 +01001951 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001952 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001953 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001954 DataType::Float32,
1955 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001956 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001957 DataType::QAsymmU8,
1958 DataType::QSymmS16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001959 DataType::Boolean
1960 };
1961
1962 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1963 "Reference MemCopy: input type not supported");
1964
1965 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1966 "Reference MemCopy: output type not supported");
1967
1968 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1969 "Reference MemCopy: input and output types are mismatched");
1970
1971 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001972}
1973
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001974bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1975 const TensorInfo& input1,
1976 const TensorInfo& output,
1977 Optional<std::string&> reasonIfUnsupported) const
1978{
Sadik Armagan2999a022019-04-09 14:20:12 +01001979 bool supported = true;
1980
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001981 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001982 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001983 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001984 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001985 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001986 DataType::QSymmS16,
1987 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001988 };
1989
1990 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1991 "Reference minimum: input 0 is not a supported type.");
1992
1993 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1994 "Reference minimum: input 1 is not a supported type.");
1995
1996 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1997 "Reference minimum: output is not a supported type.");
1998
1999 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2000 "Reference minimum: input 0 and Input 1 types are mismatched");
2001
2002 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2003 "Reference minimum: input and output types are mismatched");
2004
2005 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2006 "Reference minimum: shapes are not suitable for implicit broadcast.");
2007
2008 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00002009}
2010
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002011bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
2012 const TensorInfo& input1,
2013 const TensorInfo& output,
2014 Optional<std::string&> reasonIfUnsupported) const
2015{
Sadik Armagan2999a022019-04-09 14:20:12 +01002016 bool supported = true;
2017
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002018 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01002019 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002020 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002021 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002022 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002023 DataType::QSymmS16,
2024 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002025 };
2026
2027 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2028 "Reference multiplication: input 0 is not a supported type.");
2029
2030 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2031 "Reference multiplication: input 1 is not a supported type.");
2032
2033 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2034 "Reference multiplication: output is not a supported type.");
2035
2036 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2037 "Reference multiplication: input 0 and Input 1 types are mismatched");
2038
2039 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2040 "Reference multiplication: input and output types are mismatched");
2041
2042 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2043 "Reference multiplication: shapes are not suitable for implicit broadcast.");
2044
2045 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002046}
2047
2048bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
2049 const TensorInfo& output,
2050 const NormalizationDescriptor& descriptor,
2051 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01002052{
Jan Eilers8eb25602020-03-09 12:13:48 +00002053 IgnoreUnused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002054
2055 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01002056 std::array<DataType, 6> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002057 {
2058 DataType::Float16,
2059 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002060 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002061 DataType::QAsymmU8,
2062 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002063 };
2064
2065 bool supported = true;
2066
2067 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2068 "Reference normalization: input type not supported.");
2069
2070 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2071 "Reference normalization: output type not supported.");
2072
2073 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2074 "Reference normalization: input and output shapes have different "
2075 "num total elements.");
2076
2077 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002078}
2079
Derek Lamberti901ea112019-12-10 22:07:09 +00002080bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
2081 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002082{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01002083 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002084}
2085
2086bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
2087 const TensorInfo& output,
2088 const PadDescriptor& descriptor,
2089 Optional<std::string&> reasonIfUnsupported) const
2090{
Jan Eilers8eb25602020-03-09 12:13:48 +00002091 IgnoreUnused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002092 bool supported = true;
2093
2094 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002095 std::array<DataType,6> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002096 {
2097 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002098 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002099 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002100 DataType::QAsymmU8,
2101 DataType::QSymmS16
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002102 };
2103
2104 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2105 "Reference pad: input is not a supported type.");
2106
2107 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2108 "Reference pad: output is not a supported type.");
2109
2110 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2111 "Reference pad: input and output types are mismatched.");
2112
2113 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01002114}
2115
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002116bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
2117 const TensorInfo& output,
2118 const PermuteDescriptor& descriptor,
2119 Optional<std::string&> reasonIfUnsupported) const
2120{
Jan Eilers8eb25602020-03-09 12:13:48 +00002121 IgnoreUnused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002122 bool supported = true;
2123
2124 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002125 std::array<DataType, 6> supportedTypes =
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002126 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002127 DataType::BFloat16,
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002128 DataType::Float32,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002129 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002130 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002131 DataType::QAsymmU8,
2132 DataType::QSymmS16
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002133 };
2134
2135 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2136 "Reference permute: input is not a supported type.");
2137
2138 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2139 "Reference permute: output is not a supported type.");
2140
2141 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2142 "Reference permute: input and output types are mismatched.");
2143
2144 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002145}
2146
2147bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
2148 const TensorInfo& output,
2149 const Pooling2dDescriptor& descriptor,
2150 Optional<std::string&> reasonIfUnsupported) const
2151{
Jan Eilers8eb25602020-03-09 12:13:48 +00002152 IgnoreUnused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01002153 bool supported = true;
2154
2155 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002156 std::array<DataType,6> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01002157 {
2158 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002159 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002160 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002161 DataType::QAsymmU8,
2162 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01002163 };
2164
2165 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2166 "Reference poolind2d: input is not a supported type.");
2167
2168 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2169 "Reference poolind2d: output is not a supported type.");
2170
2171 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2172 "Reference poolind2d: input and output types are mismatched.");
2173
2174 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002175}
2176
Tamás Nyíri7b885b32021-10-26 14:47:57 +01002177bool RefLayerSupport::IsPooling3dSupported(const TensorInfo& input,
2178 const TensorInfo& output,
2179 const Pooling3dDescriptor& descriptor,
2180 Optional<std::string&> reasonIfUnsupported) const
2181{
2182 IgnoreUnused(descriptor);
2183 bool supported = true;
2184
2185 // Define supported output and inputs types.
2186 std::array<DataType,6> supportedTypes =
2187 {
Tamás Nyíri7b885b32021-10-26 14:47:57 +01002188 DataType::Float32,
2189 DataType::Float16,
2190 DataType::QAsymmS8,
2191 DataType::QAsymmU8,
2192 DataType::QSymmS16
2193 };
2194
2195 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2196 "Reference poolind3d: input is not a supported type.");
2197
2198 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2199 "Reference poolind3d: output is not a supported type.");
2200
2201 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2202 "Reference poolind3d: input and output types are mismatched.");
2203
2204 return supported;
2205}
2206
2207
James Conroy4f1f8992020-04-29 20:01:10 +01002208bool RefLayerSupport::IsQLstmSupported(const TensorInfo& input,
2209 const TensorInfo& previousOutputIn,
2210 const TensorInfo& previousCellStateIn,
2211 const TensorInfo& outputStateOut,
2212 const TensorInfo& cellStateOut,
2213 const TensorInfo& output,
2214 const QLstmDescriptor& descriptor,
2215 const LstmInputParamsInfo& paramsInfo,
2216 Optional<std::string&> reasonIfUnsupported) const
2217{
2218 IgnoreUnused(input);
2219 IgnoreUnused(previousOutputIn);
2220 IgnoreUnused(previousCellStateIn);
2221 IgnoreUnused(outputStateOut);
2222 IgnoreUnused(cellStateOut);
2223 IgnoreUnused(output);
2224 IgnoreUnused(descriptor);
2225 IgnoreUnused(paramsInfo);
2226
2227 IgnoreUnused(reasonIfUnsupported);
2228
2229 return true;
2230}
2231
Derek Lamberti5f400d62019-03-25 15:41:58 +00002232bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
2233 const TensorInfo& output,
2234 Optional<std::string&> reasonIfUnsupported) const
2235{
2236 bool supported = true;
2237
Finn Williamsfd271062019-12-04 14:27:27 +00002238 // Define supported input types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002239 std::array<DataType,7> supportedInputTypes = {
Keith Davis5e51cd82020-01-29 16:52:59 +00002240 DataType::Float32,
Keith Davis3d8bc972020-02-04 09:31:47 +00002241 DataType::Float16,
Ryan OShea9add1202020-02-07 10:06:33 +00002242 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002243 DataType::QAsymmU8,
2244 DataType::QSymmS8,
2245 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002246 };
2247
2248 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
2249 "Reference quantize: input type not supported.");
2250
2251 // Define supported output types.
Ryan OShea9add1202020-02-07 10:06:33 +00002252 std::array<DataType,4> supportedOutputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00002253 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002254 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00002255 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002256 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002257 };
2258 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2259 "Reference quantize: output type not supported.");
2260
2261 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2262 "Reference quantize: input and output shapes have different num total elements.");
2263
2264 return supported;
2265}
2266
Finn Williams2605b232020-06-10 15:53:46 +01002267bool RefLayerSupport::IsRankSupported(const TensorInfo& input,
2268 const TensorInfo& output,
2269 Optional<std::string&> reasonIfUnsupported) const
2270{
2271 IgnoreUnused(input);
2272 // Define supported output types.
2273 std::array<DataType,1> supportedOutputTypes =
2274 {
2275 DataType::Signed32,
2276 };
2277
2278 return CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2279 "Reference rank: input type not supported.");
2280}
2281
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00002282bool RefLayerSupport::IsReduceSupported(const TensorInfo& input,
2283 const TensorInfo& output,
2284 const ReduceDescriptor& descriptor,
2285 Optional<std::string&> reasonIfUnsupported) const
2286{
2287 IgnoreUnused(descriptor);
2288 bool supported = true;
2289 std::array<DataType,7> supportedTypes =
2290 {
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00002291 DataType::Float32,
2292 DataType::Float16,
2293 DataType::QAsymmS8,
2294 DataType::QAsymmU8,
2295 DataType::QSymmS16,
2296 DataType::Signed32
2297 };
2298
2299 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2300 "Reference Reduce: input type not supported");
2301
2302 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2303 "Reference Reduce: output type not supported");
2304
2305 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2306 "Reference Reduce: input and output types not matching");
2307
2308 return supported;
2309}
2310
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002311bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Kevin Maya023c402019-12-12 17:28:05 +00002312 const TensorInfo& output,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00002313 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002314 Optional<std::string&> reasonIfUnsupported) const
2315{
Jan Eilers8eb25602020-03-09 12:13:48 +00002316 IgnoreUnused(output);
2317 IgnoreUnused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01002318 // Define supported output types.
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002319 std::array<DataType,8> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01002320 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002321 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01002322 DataType::Float32,
2323 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01002324 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002325 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002326 DataType::QAsymmU8,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002327 DataType::QSymmS16,
2328 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01002329 };
Keith Davis0c2eeac2020-02-11 16:51:50 +00002330
Nina Drozd2f2778f2019-05-27 10:37:05 +01002331 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
2332 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002333}
2334
Teresa Charlin970f43b2019-07-01 13:51:07 +01002335bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
2336 const TensorInfo& output,
2337 const ResizeDescriptor& descriptor,
2338 Optional<std::string&> reasonIfUnsupported) const
2339{
Jan Eilers8eb25602020-03-09 12:13:48 +00002340 IgnoreUnused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01002341 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002342 std::array<DataType,6> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01002343 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002344 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01002345 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002346 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002347 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002348 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002349 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01002350 };
2351
2352 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2353 "Reference Resize: input type not supported");
2354
2355 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2356 "Reference Resize: output type not supported");
2357
2358 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2359 "Reference Resize: input and output types not matching");
2360
2361 return supported;
2362}
2363
Tianle Cheng988354d2023-06-28 13:20:47 +01002364bool RefLayerSupport::IsReverseV2Supported(const TensorInfo& input,
2365 const TensorInfo& output,
2366 const ReverseV2Descriptor& descriptor,
2367 Optional<std::string&> reasonIfUnsupported) const
2368{
2369 IgnoreUnused(descriptor);
2370 bool supported = true;
2371 // ReverseV2 is data type agnostic so it can support all the types in the Reference backend
2372 std::array<DataType,6> supportedTypes =
2373 {
2374 DataType::BFloat16,
2375 DataType::Float32,
2376 DataType::Float16,
2377 DataType::QAsymmS8,
2378 DataType::QAsymmU8,
2379 DataType::QSymmS16
2380 };
2381
2382 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2383 "Reference ReverseV2: input type not supported");
2384
2385 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2386 "Reference ReverseV2: output type not supported");
2387
2388 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2389 "Reference ReverseV2: input and output types not matching");
2390
2391 return supported;
2392}
2393
Keith Davis3ae3f972021-05-21 16:33:48 +01002394bool RefLayerSupport::IsShapeSupported(const TensorInfo& input,
2395 const TensorInfo& output,
2396 Optional<std::string&> reasonIfUnsupported) const
2397{
2398 IgnoreUnused(input);
2399 bool supported = true;
2400
2401 std::array<DataType, 1> supportedTypes =
2402 {
2403 DataType::Signed32
2404 };
2405
2406 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2407 "Reference Shape: output type not supported");
2408
2409 return supported;
2410}
2411
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002412bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
2413 const TensorInfo& output,
2414 const SliceDescriptor& descriptor,
2415 Optional<std::string&> reasonIfUnsupported) const
2416{
Jan Eilers8eb25602020-03-09 12:13:48 +00002417 IgnoreUnused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002418 bool supported = true;
2419
Sadik Armagan303980c2020-04-17 12:45:14 +01002420 std::array<DataType, 5> supportedTypes =
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002421 {
2422 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002423 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002424 DataType::QAsymmU8,
2425 DataType::QSymmS16
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002426 };
2427
2428 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2429 "Reference Slice: input type not supported");
2430
2431 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2432 "Reference Slice: output type not supported");
2433
2434 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2435 "Reference Slice: input and output types are mismatched");
2436
2437 return supported;
2438}
2439
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002440bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
2441 const TensorInfo& output,
2442 const SoftmaxDescriptor& descriptor,
2443 Optional<std::string&> reasonIfUnsupported) const
2444{
Jan Eilers8eb25602020-03-09 12:13:48 +00002445 IgnoreUnused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01002446 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002447 std::array<DataType,7> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01002448 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002449 DataType::Float32,
2450 DataType::Float16,
2451 DataType::QSymmS8,
2452 DataType::QAsymmS8,
2453 DataType::QAsymmU8,
2454 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +01002455 };
2456
2457 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002458 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002459
2460 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002461 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002462
2463 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002464 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002465
2466 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002467}
2468
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002469bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
2470 const TensorInfo& output,
2471 const SpaceToBatchNdDescriptor& descriptor,
2472 Optional<std::string&> reasonIfUnsupported) const
2473{
Jan Eilers8eb25602020-03-09 12:13:48 +00002474 IgnoreUnused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01002475 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002476 std::array<DataType,6> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01002477 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002478 DataType::Float32,
2479 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002480 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002481 DataType::QAsymmU8,
2482 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01002483 };
2484
2485 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2486 "Reference SpaceToBatchNd: input type not supported");
2487
2488 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2489 "Reference SpaceToBatchNd: output type not supported");
2490
2491 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2492 "Reference SpaceToBatchNd: input and output types are mismatched");
2493
2494 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002495}
2496
Keith Davisa57eccb2019-06-14 17:33:22 +01002497bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01002498 const TensorInfo& output,
2499 const SpaceToDepthDescriptor& descriptor,
2500 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01002501{
2502
Jan Eilers8eb25602020-03-09 12:13:48 +00002503 IgnoreUnused(descriptor);
Keith Davisa57eccb2019-06-14 17:33:22 +01002504 bool supported = true;
2505
Sadik Armagan303980c2020-04-17 12:45:14 +01002506 std::array<DataType,6> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01002507 {
2508 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002509 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002510 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002511 DataType::QAsymmU8,
2512 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01002513 };
2514
2515 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2516 "Reference SpaceToDepth: input type not supported");
2517
2518 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2519 "Reference SpaceToDepth: output type not supported");
2520
2521 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2522 "Reference SpaceToDepth: input and output types are mismatched");
2523
2524 return supported;
2525}
2526
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002527bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002528 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
2529 const ViewsDescriptor& descriptor,
2530 Optional<std::string&> reasonIfUnsupported) const
2531{
Jan Eilers8eb25602020-03-09 12:13:48 +00002532 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002533 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002534 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002535 {
2536 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002537 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002538 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002539 DataType::QAsymmU8,
2540 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002541 };
2542
2543 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2544 "Reference splitter: output type not supported");
Derek Lambertieac4adb2020-08-25 13:05:59 +01002545 for (const TensorInfo& output : outputs)
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002546 {
2547 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2548 "Reference splitter: input type not supported");
2549
2550 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2551 "Reference splitter: input and output types mismatched.");
2552 }
2553
2554 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002555}
2556
Matthew Jackson81e601c2019-07-11 12:07:09 +01002557bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
2558 const TensorInfo& output,
2559 const StackDescriptor& descriptor,
2560 Optional<std::string&> reasonIfUnsupported) const
2561{
Jan Eilers8eb25602020-03-09 12:13:48 +00002562 IgnoreUnused(descriptor);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002563
2564 bool supported = true;
Sadik Armagan529195f2022-01-14 12:56:35 +00002565 std::array<DataType,7> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01002566 {
2567 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01002568 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002569 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002570 DataType::QAsymmU8,
Sadik Armagan529195f2022-01-14 12:56:35 +00002571 DataType::QSymmS16,
2572 DataType::Signed32
Matthew Jackson81e601c2019-07-11 12:07:09 +01002573 };
2574
2575 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2576 "Reference stack: output type not supported");
2577 for (const TensorInfo* input : inputs)
2578 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01002579 ARMNN_ASSERT(input != nullptr);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002580 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
2581 "Reference stack: input type not supported");
2582
2583 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
2584 "Reference stack: input and output types mismatched.");
2585 }
2586
2587 return supported;
2588}
2589
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002590bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
2591 const TensorInfo& output,
2592 const StridedSliceDescriptor& descriptor,
2593 Optional<std::string&> reasonIfUnsupported) const
2594{
Jan Eilers8eb25602020-03-09 12:13:48 +00002595 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002596 bool supported = true;
2597
Sadik Armagan303980c2020-04-17 12:45:14 +01002598 std::array<DataType,5> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002599 {
2600 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002601 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002602 DataType::QAsymmU8,
2603 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002604 };
2605
2606 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2607 "Reference StridedSlice: input type not supported");
2608
2609 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2610 "Reference StridedSlice: output type not supported");
2611
2612 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2613 "Reference StridedSlice: input and output types are mismatched");
2614
2615 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002616}
2617
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002618bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
2619 const TensorInfo& input1,
2620 const TensorInfo& output,
2621 Optional<std::string&> reasonIfUnsupported) const
2622{
Sadik Armagan2999a022019-04-09 14:20:12 +01002623 bool supported = true;
2624
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002625 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01002626 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002627 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002628 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002629 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002630 DataType::QSymmS16,
2631 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002632 };
2633
2634 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2635 "Reference subtraction: input 0 is not a supported type.");
2636
2637 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2638 "Reference subtraction: input 1 is not a supported type.");
2639
2640 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2641 "Reference subtraction: output is not a supported type.");
2642
2643 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2644 "Reference subtraction: input 0 and Input 1 types are mismatched");
2645
2646 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2647 "Reference subtraction: input and output types are mismatched");
2648
2649 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2650 "Reference subtraction: shapes are not suitable for implicit broadcast.");
2651
2652 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002653}
2654
Matteo Martincighab9e5252019-06-13 17:27:46 +01002655bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
2656 const TensorInfo& alpha,
2657 const TensorInfo& output,
2658 Optional<std::string&> reasonIfUnsupported) const
2659{
2660 bool supported = true;
2661
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002662 std::array<DataType, 6> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01002663 {
2664 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002665 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002666 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002667 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002668 DataType::QSymmS16
Matteo Martincighab9e5252019-06-13 17:27:46 +01002669 };
2670
2671 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2672 "PReLU: input is not a supported type.");
2673
2674 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
2675 "PReLU: alpha is not a supported type.");
2676
2677 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2678 "PReLU: output is not a supported type.");
2679
2680 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
2681 "PReLU: input, alpha and output types are mismatched");
2682
2683 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
2684 "PReLU: shapes are not suitable for implicit broadcast");
2685
2686 return supported;
2687}
2688
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002689bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
2690 const TensorInfo& output,
2691 const TransposeConvolution2dDescriptor& descriptor,
2692 const TensorInfo& weights,
2693 const Optional<TensorInfo>& biases,
2694 Optional<std::string&> reasonIfUnsupported) const
2695{
Jan Eilers8eb25602020-03-09 12:13:48 +00002696 IgnoreUnused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002697 bool supported = true;
2698
Sadik Armagan303980c2020-04-17 12:45:14 +01002699 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002700 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002701 DataType::Float32,
2702 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002703 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002704 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002705 DataType::QSymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002706 DataType::QSymmS16
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002707 };
2708
2709 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2710 "Reference TransposeConvolution2d: input is not a supported type.");
2711
2712 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2713 "Reference TransposeConvolution2d: output is not a supported type.");
2714
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002715 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2716 "Reference TransposeConvolution2d: input and output types mismatched.");
2717
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002718
2719 const DataType inputType = input.GetDataType();
Sadik Armagan303980c2020-04-17 12:45:14 +01002720 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002721 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01002722 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002723 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002724 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002725 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01002726 DataType::QSymmS8
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002727 };
2728
2729 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
2730 "Reference TransposeConvolution2d: weights type not supported for "
2731 "quantized input.");
2732 }
2733 else
2734 {
2735 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
2736 "Reference TransposeConvolution2d: weights is not a supported type.");
2737
2738 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
2739 "Reference TransposeConvolution2d: input and weights types mismatched.");
2740 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002741
2742 if (biases.has_value())
2743 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002744 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01002745 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002746 DataType::Float32,
2747 DataType::Float16,
2748 DataType::Signed32
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002749 };
2750 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
2751 "Reference TransposeConvolution2d: biases is not a supported type.");
2752 }
2753
2754 return supported;
2755}
2756
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002757bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
2758 const TensorInfo& output,
2759 const TransposeDescriptor& descriptor,
2760 Optional<std::string&> reasonIfUnsupported) const
2761{
Jan Eilers8eb25602020-03-09 12:13:48 +00002762 IgnoreUnused(descriptor);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002763 bool supported = true;
2764
2765 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002766 std::array<DataType, 6> supportedTypes =
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002767 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002768 DataType::BFloat16,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002769 DataType::Float32,
2770 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002771 DataType::QAsymmS8,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002772 DataType::QAsymmU8,
2773 DataType::QSymmS16
2774 };
2775
2776 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2777 "Reference transpose: input is not a supported type.");
2778
2779 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2780 "Reference transpose: output is not a supported type.");
2781
2782 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2783 "Reference transpose: input and output types are mismatched.");
2784
2785 return supported;
2786}
2787
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002788bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
2789 const TensorInfo& input,
2790 const TensorInfo& outputStateIn,
2791 const TensorInfo& cellStateIn,
Mike Kelly12994962022-04-21 11:57:09 +01002792 const TensorInfo& outputStateOut,
2793 const TensorInfo& cellStateOut,
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002794 const TensorInfo& output,
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002795 const UnidirectionalSequenceLstmDescriptor& descriptor,
2796 const LstmInputParamsInfo& paramsInfo,
2797 Optional<std::string&> reasonIfUnsupported) const
2798{
2799 IgnoreUnused(descriptor);
2800 IgnoreUnused(paramsInfo);
2801 IgnoreUnused(outputStateIn);
2802 IgnoreUnused(cellStateIn);
Mike Kelly12994962022-04-21 11:57:09 +01002803 IgnoreUnused(outputStateOut);
2804 IgnoreUnused(cellStateOut);
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002805 bool supported = true;
2806
Mike Kelly12994962022-04-21 11:57:09 +01002807 std::array<DataType, 2> supportedTypes =
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002808 {
Mike Kelly12994962022-04-21 11:57:09 +01002809 DataType::Float32,
2810 DataType::QAsymmS8
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002811 };
2812
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002813 std::array<DataType, 2> supportedWeightTypes =
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002814 {
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002815 DataType::Float32,
2816 DataType::QAsymmS8
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002817 };
2818
Mike Kelly12994962022-04-21 11:57:09 +01002819 std::array<DataType, 3> supportedBiasTypes =
2820 {
2821 DataType::Float32,
2822 DataType::QAsymmS8,
2823 DataType::Signed32
2824 };
2825
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002826 // check inputs and outputs
2827 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2828 "Reference UnidirectionalSequenceLstm: input is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002829 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2830 "Reference UnidirectionalSequenceLstm: output is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002831
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002832 // check layer parameters
2833 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes),
2834 reasonIfUnsupported,
2835 "Reference UnidirectionalSequenceLstm: InputToForgetWeights "
2836 "is not a supported type.");
2837 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToCellWeights(), supportedWeightTypes),
2838 reasonIfUnsupported,
2839 "Reference UnidirectionalSequenceLstm: InputToCellWeights is not a supported type.");
2840 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToOutputWeights(), supportedWeightTypes),
2841 reasonIfUnsupported,
2842 "Reference UnidirectionalSequenceLstm: InputToOutputWeights "
2843 "is not a supported type.");
2844 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToForgetWeights(), supportedWeightTypes),
2845 reasonIfUnsupported,
2846 "Reference UnidirectionalSequenceLstm: RecurrentToForgetWeights "
2847 "is not a supported type.");
2848 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToCellWeights(), supportedWeightTypes),
2849 reasonIfUnsupported,
2850 "Reference UnidirectionalSequenceLstm: RecurrentToCellWeights "
2851 "is not a supported type.");
2852 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToOutputWeights(), supportedWeightTypes),
2853 reasonIfUnsupported,
2854 "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights "
2855 "is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002856
2857 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetGateBias(), supportedBiasTypes), reasonIfUnsupported,
2858 "Reference UnidirectionalSequenceLstm: ForgetGateBias is not a supported type.");
2859 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellBias(), supportedBiasTypes), reasonIfUnsupported,
2860 "Reference UnidirectionalSequenceLstm: CellBias is not a supported type.");
2861 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputGateBias(), supportedBiasTypes), reasonIfUnsupported,
2862 "Reference UnidirectionalSequenceLstm: OutputGateBias is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002863 if (!descriptor.m_CifgEnabled)
2864 {
2865 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes),
2866 reasonIfUnsupported,
2867 "Reference UnidirectionalSequenceLstm: InputToInputWeights "
2868 "is not a supported type.");
2869 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToInputWeights(), supportedWeightTypes),
2870 reasonIfUnsupported,
2871 "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights "
2872 "is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002873 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputGateBias(), supportedBiasTypes), reasonIfUnsupported,
2874 "Reference UnidirectionalSequenceLstm: InputGateBias is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002875 if (descriptor.m_PeepholeEnabled)
2876 {
2877 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes),
2878 reasonIfUnsupported,
2879 "Reference UnidirectionalSequenceLstm: CellToInputWeights "
2880 "is not a supported type.");
2881 }
2882 }
2883 if (descriptor.m_PeepholeEnabled)
2884 {
2885 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToForgetWeights(), supportedWeightTypes),
2886 reasonIfUnsupported,
2887 "Reference UnidirectionalSequenceLstm: CellToForgetWeights "
2888 "is not a supported type.");
2889 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToOutputWeights(), supportedWeightTypes),
2890 reasonIfUnsupported,
2891 "Reference UnidirectionalSequenceLstm: CellToOutputWeights "
2892 "is not a supported type.");
2893 }
2894 if (descriptor.m_ProjectionEnabled)
2895 {
2896 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetProjectionWeights(), supportedWeightTypes),
2897 reasonIfUnsupported,
2898 "Reference UnidirectionalSequenceLstm: ProjectionWeights "
2899 "is not a supported type.");
2900 if (paramsInfo.m_ProjectionBias != nullptr)
2901 {
2902 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
2903 "Reference UnidirectionalSequenceLstm: input and ProjectionBias types "
2904 "are mismatched");
2905 }
2906 }
2907 if (descriptor.m_LayerNormEnabled)
2908 {
2909 if (!descriptor.m_CifgEnabled)
2910 {
2911 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputLayerNormWeights(), supportedWeightTypes),
2912 reasonIfUnsupported,
2913 "Reference UnidirectionalSequenceLstm: InputLayerNormWeights "
2914 "is not a supported type.");
2915 }
2916 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetLayerNormWeights(), supportedWeightTypes),
2917 reasonIfUnsupported,
2918 "Reference UnidirectionalSequenceLstm: ForgetLayerNormWeights "
2919 "is not a supported type.");
2920 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellLayerNormWeights(), supportedWeightTypes),
2921 reasonIfUnsupported,
2922 "Reference UnidirectionalSequenceLstm: CellLayerNormWeights "
2923 "is not a supported type.");
2924 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputLayerNormWeights(), supportedWeightTypes),
2925 reasonIfUnsupported,
2926 "Reference UnidirectionalSequenceLstm: OutputLayerNormWeights "
2927 "is not a supported type.");
2928 }
2929
2930 return supported;
2931}
2932
arovir011c7c81b2018-10-08 11:34:28 +01002933} // namespace armnn