blob: a5015a737678d317377fb41e08aa9b5ffab9487d [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
Nikhil Raj369d8fc2022-11-24 13:12:36 +00002// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01007
Keith Davis0c2eeac2020-02-11 16:51:50 +00008#include <armnn/TypesUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000010#include <armnn/utility/IgnoreUnused.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010011#include <armnn/utility/NumericCast.hpp>
Cathal Corbett34b429c2021-12-24 12:24:40 +000012#include <armnn/utility/PolymorphicDowncast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
Matteo Martincighe011d202019-11-28 11:35:47 +000014#include <LayerSupportCommon.hpp>
Derek Lambertif674aa02019-08-01 15:56:25 +010015#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000016
Derek Lamberti50db4e82019-03-13 14:16:15 +000017#include <vector>
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <array>
19
telsoa014fcda012018-03-09 14:13:49 +000020namespace armnn
21{
22
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010023namespace
24{
25
26template<typename Float32Func, typename Uint8Func, typename ... Params>
27bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
28 DataType dataType,
29 Float32Func floatFuncPtr,
30 Uint8Func uint8FuncPtr,
31 Params&&... params)
32{
33 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
34 dataType,
35 &FalseFunc<Params...>,
36 floatFuncPtr,
37 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000038 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000039 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010040 std::forward<Params>(params)...);
41}
42
43} // anonymous namespace
44
James Conroy4d1ff582019-06-10 17:06:39 +010045namespace
46{
47
48std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
49 unsigned int actual,
50 std::string& layerStr,
51 std::string& tensorName)
52{
53 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
54 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
55
56 return errorMsg;
57}
58
59} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000060
Cathal Corbett34b429c2021-12-24 12:24:40 +000061bool RefLayerSupport::IsLayerSupported(const LayerType& type,
62 const std::vector<TensorInfo>& infos,
63 const BaseDescriptor& descriptor,
64 const Optional<LstmInputParamsInfo>& lstmParamsInfo,
65 const Optional<QuantizedLstmInputParamsInfo>& quantizedLstmInputParamsInfo,
66 Optional<std::string&> reasonIfUnsupported) const
67{
68 switch (type)
69 {
70 case LayerType::Activation:
71 return IsActivationSupported(infos[0],
72 infos[1],
73 *(PolymorphicDowncast<const ActivationDescriptor*>(&descriptor)),
74 reasonIfUnsupported);
75 case LayerType::Addition:
76 return IsAdditionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
77 case LayerType::ArgMinMax:
78 return IsArgMinMaxSupported(infos[0],
79 infos[1],
80 *(PolymorphicDowncast<const ArgMinMaxDescriptor*>(&descriptor)),
81 reasonIfUnsupported);
Samuel Yap6b478092022-07-06 15:36:03 +010082 case LayerType::BatchMatMul:
83 return IsBatchMatMulSupported(infos[0],
84 infos[1],
85 infos[2],
86 *(PolymorphicDowncast<const BatchMatMulDescriptor*>(&descriptor)),
87 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +000088 case LayerType::BatchNormalization:
89 return IsBatchNormalizationSupported(infos[0],
90 infos[1],
91 infos[2],
92 infos[3],
93 infos[4],
94 infos[5],
95 *(PolymorphicDowncast<const BatchNormalizationDescriptor*>
96 (&descriptor)),
97 reasonIfUnsupported);
98 case LayerType::BatchToSpaceNd:
99 return IsBatchToSpaceNdSupported(infos[0],
100 infos[1],
101 *(PolymorphicDowncast<const BatchToSpaceNdDescriptor*>(&descriptor)),
102 reasonIfUnsupported);
103 case LayerType::Comparison:
104 return IsComparisonSupported(infos[0],
105 infos[1],
106 infos[2],
107 *(PolymorphicDowncast<const ComparisonDescriptor*>(&descriptor)),
108 reasonIfUnsupported);
109 case LayerType::Concat:
110 {
111 std::vector<const TensorInfo*> inputInfos;
112 for (uint32_t i = 0; i < (infos.size() - 1); i++)
113 {
114 inputInfos.push_back(&infos[i]);
115 }
116 return IsConcatSupported(inputInfos,
117 infos[infos.size() - 1],
118 *(PolymorphicDowncast<const OriginsDescriptor*>(&descriptor)),
119 reasonIfUnsupported);
120 }
121 case LayerType::Constant:
122 return IsConstantSupported(infos[0], reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000123 case LayerType::ConvertFp16ToFp32:
124 return IsConvertFp16ToFp32Supported(infos[0], infos[1], reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000125 case LayerType::ConvertFp32ToFp16:
126 return IsConvertFp32ToFp16Supported(infos[0], infos[1], reasonIfUnsupported);
127 case LayerType::Convolution2d:
128 {
129 if (infos.size() != 4)
130 {
131 throw InvalidArgumentException("Invalid number of Convolution2d TensorInfos. "
132 "TensorInfos should be of format: {input, output, weights, biases}.");
133 }
134
135 auto desc = *(PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor));
136 if (infos[3] == TensorInfo())
137 {
138 return IsConvolution2dSupported(infos[0],
139 infos[1],
140 desc,
141 infos[2],
142 EmptyOptional(),
143 reasonIfUnsupported);
144 }
145 else
146 {
147 return IsConvolution2dSupported(infos[0],
148 infos[1],
149 desc,
150 infos[2],
151 infos[3],
152 reasonIfUnsupported);
153 }
154 }
155 case LayerType::DepthToSpace:
156 return IsDepthToSpaceSupported(infos[0],
157 infos[1],
158 *(PolymorphicDowncast<const DepthToSpaceDescriptor*>(&descriptor)),
159 reasonIfUnsupported);
160 case LayerType::DepthwiseConvolution2d:
161 {
162 if (infos.size() != 4)
163 {
164 throw InvalidArgumentException("Invalid number of DepthwiseConvolution2d TensorInfos. "
165 "TensorInfos should be of format: {input, output, weights, biases}.");
166 }
167
168 auto desc = *(PolymorphicDowncast<const DepthwiseConvolution2dDescriptor*>(&descriptor));
169 if (infos[3] == TensorInfo())
170 {
171 return IsDepthwiseConvolutionSupported(infos[0],
172 infos[1],
173 desc,
174 infos[2],
175 EmptyOptional(),
176 reasonIfUnsupported);
177 }
178 else
179 {
180 return IsDepthwiseConvolutionSupported(infos[0],
181 infos[1],
182 desc,
183 infos[2],
184 infos[3],
185 reasonIfUnsupported);
186 }
187 }
188 case LayerType::Dequantize:
189 return IsDequantizeSupported(infos[0], infos[1], reasonIfUnsupported);
190 case LayerType::Division:
191 return IsDivisionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
192 case LayerType::ElementwiseUnary:
193 return IsElementwiseUnarySupported(infos[0],
194 infos[1],
195 *(PolymorphicDowncast<const ElementwiseUnaryDescriptor*>(&descriptor)),
196 reasonIfUnsupported);
197 case LayerType::Fill:
198 return IsFillSupported(infos[0],
199 infos[1],
200 *(PolymorphicDowncast<const FillDescriptor*>(&descriptor)),
201 reasonIfUnsupported);
202 case LayerType::Floor:
203 return IsFloorSupported(infos[0], infos[1], reasonIfUnsupported);
204 case LayerType::FullyConnected:
205 return IsFullyConnectedSupported(infos[0],
206 infos[1],
207 infos[2],
208 infos[3],
209 *(PolymorphicDowncast<const FullyConnectedDescriptor*>(&descriptor)),
210 reasonIfUnsupported);
211 case LayerType::Gather:
212 return IsGatherSupported(infos[0],
213 infos[1],
214 infos[2],
215 *(PolymorphicDowncast<const GatherDescriptor*>(&descriptor)),
216 reasonIfUnsupported);
Teresa Charlinb2d3ec52022-04-12 22:07:09 +0100217 case LayerType::GatherNd:
218 return IsGatherNdSupported(infos[0],
219 infos[1],
220 infos[2],
221 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000222 case LayerType::Input:
223 return IsInputSupported(infos[0], reasonIfUnsupported);
224 case LayerType::InstanceNormalization:
225 return IsInstanceNormalizationSupported(infos[0],
226 infos[1],
227 *(PolymorphicDowncast<const InstanceNormalizationDescriptor*>
228 (&descriptor)),
229 reasonIfUnsupported);
230 case LayerType::L2Normalization:
231 return IsL2NormalizationSupported(infos[0],
232 infos[1],
233 *(PolymorphicDowncast<const L2NormalizationDescriptor*>(&descriptor)),
234 reasonIfUnsupported);
235 case LayerType::LogicalBinary:
236 return IsLogicalBinarySupported(infos[0],
237 infos[1],
238 infos[2],
239 *(PolymorphicDowncast<const LogicalBinaryDescriptor*>(&descriptor)),
240 reasonIfUnsupported);
241 case LayerType::LogSoftmax:
242 return IsLogSoftmaxSupported(infos[0],
243 infos[1],
244 *(PolymorphicDowncast<const LogSoftmaxDescriptor*>(&descriptor)),
245 reasonIfUnsupported);
246 case LayerType::Lstm:
247 return IsLstmSupported(infos[0],
248 infos[1],
249 infos[2],
250 infos[3],
251 infos[4],
252 infos[5],
253 infos[6],
254 *(PolymorphicDowncast<const LstmDescriptor*>(&descriptor)),
255 lstmParamsInfo.value(),
256 reasonIfUnsupported);
257 case LayerType::QLstm:
258 return IsQLstmSupported(infos[0],
259 infos[1],
260 infos[2],
261 infos[3],
262 infos[4],
263 infos[5],
264 *(PolymorphicDowncast<const QLstmDescriptor*>(&descriptor)),
265 lstmParamsInfo.value(),
266 reasonIfUnsupported);
267 case LayerType::Maximum:
268 return IsMaximumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
269 case LayerType::Mean:
270 return IsMeanSupported(infos[0],
271 infos[1],
272 *(PolymorphicDowncast<const MeanDescriptor*>(&descriptor)),
273 reasonIfUnsupported);
274 case LayerType::Minimum:
275 return IsMinimumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
276 case LayerType::Multiplication:
277 return IsMultiplicationSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
278 case LayerType::Normalization:
279 return IsNormalizationSupported(infos[0],
280 infos[1],
281 *(PolymorphicDowncast<const NormalizationDescriptor*>(&descriptor)),
282 reasonIfUnsupported);
283 case LayerType::Output:
284 return IsOutputSupported(infos[0], reasonIfUnsupported);
285 case LayerType::Pad:
286 return IsPadSupported(infos[0],
287 infos[1],
288 *(PolymorphicDowncast<const PadDescriptor*>(&descriptor)),
289 reasonIfUnsupported);
290 case LayerType::Permute:
291 return IsPermuteSupported(infos[0],
292 infos[1],
293 *(PolymorphicDowncast<const PermuteDescriptor*>(&descriptor)),
294 reasonIfUnsupported);
295 case LayerType::Pooling2d:
296 return IsPooling2dSupported(infos[0],
297 infos[1],
298 *(PolymorphicDowncast<const Pooling2dDescriptor*>(&descriptor)),
299 reasonIfUnsupported);
300 case LayerType::Prelu:
301 return IsPreluSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
302 case LayerType::Quantize:
303 return IsQuantizeSupported(infos[0], infos[1], reasonIfUnsupported);
304 case LayerType::Reshape:
305 return IsReshapeSupported(infos[0],
306 infos[1],
307 *(PolymorphicDowncast<const ReshapeDescriptor*>(&descriptor)),
308 reasonIfUnsupported);
309 case LayerType::Resize:
310 return IsResizeSupported(infos[0],
311 infos[1],
312 *(PolymorphicDowncast<const ResizeDescriptor*>(&descriptor)),
313 reasonIfUnsupported);
314 case LayerType::Reduce:
315 return IsReduceSupported(infos[0],
316 infos[1],
317 *(PolymorphicDowncast<const ReduceDescriptor*>(&descriptor)),
318 reasonIfUnsupported);
319 case LayerType::Slice:
320 return IsSliceSupported(infos[0],
321 infos[1],
322 *(PolymorphicDowncast<const SliceDescriptor*>(&descriptor)),
323 reasonIfUnsupported);
324 case LayerType::Softmax:
325 return IsSoftmaxSupported(infos[0],
326 infos[1],
327 *(PolymorphicDowncast<const SoftmaxDescriptor*>(&descriptor)),
328 reasonIfUnsupported);
329 case LayerType::SpaceToBatchNd:
330 return IsSpaceToBatchNdSupported(infos[0],
331 infos[1],
332 *(PolymorphicDowncast<const SpaceToBatchNdDescriptor*>(&descriptor)),
333 reasonIfUnsupported);
334 case LayerType::SpaceToDepth:
335 return IsSpaceToDepthSupported(infos[0],
336 infos[1],
337 *(PolymorphicDowncast<const SpaceToDepthDescriptor*>(&descriptor)),
338 reasonIfUnsupported);
339 case LayerType::Splitter:
340 {
341 std::vector<TensorInfo> outputInfos;
342 for (uint32_t i = 1; i < infos.size(); i++)
343 {
344 outputInfos.push_back(infos[i]);
345 }
346 return IsSplitterSupported(infos[0],
347 {outputInfos.begin(), outputInfos.end()},
348 *(PolymorphicDowncast<const ViewsDescriptor*>(&descriptor)),
349 reasonIfUnsupported);
350 }
351 case LayerType::Stack:
352 {
353 std::vector<const TensorInfo*> inputInfos;
354 for (uint32_t i = 0; i < infos.size() - 1; i++)
355 {
356 inputInfos.push_back(&infos[i]);
357 }
358 return IsStackSupported(inputInfos,
359 infos[infos.size() - 1],
360 *(PolymorphicDowncast<const StackDescriptor*>(&descriptor)),
361 reasonIfUnsupported);
362 }
363 case LayerType::StridedSlice:
364 return IsStridedSliceSupported(infos[0],
365 infos[1],
366 *(PolymorphicDowncast<const StridedSliceDescriptor*>(&descriptor)),
367 reasonIfUnsupported);
368 case LayerType::Subtraction:
369 return IsSubtractionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
370 case LayerType::Transpose:
371 return IsTransposeSupported(infos[0],
372 infos[1],
373 *(PolymorphicDowncast<const TransposeDescriptor*>(&descriptor)),
374 reasonIfUnsupported);
375 case LayerType::TransposeConvolution2d:
376 {
377 if (infos.size() != 4)
378 {
379 throw InvalidArgumentException("Invalid number of TransposeConvolution2d TensorInfos. "
380 "TensorInfos should be of format: {input, output, weights, biases}.");
381 }
382
383 auto desc = *(PolymorphicDowncast<const TransposeConvolution2dDescriptor*>(&descriptor));
384 if (infos[3] == TensorInfo())
385 {
386 return IsTransposeConvolution2dSupported(infos[0],
387 infos[1],
388 desc,
389 infos[2],
390 EmptyOptional(),
391 reasonIfUnsupported);
392 }
393 else
394 {
395 return IsTransposeConvolution2dSupported(infos[0],
396 infos[1],
397 desc,
398 infos[2],
399 infos[3],
400 reasonIfUnsupported);
401 }
402 }
403 case LayerType::Cast:
404 return IsCastSupported(infos[0], infos[1], reasonIfUnsupported);
405 case LayerType::ChannelShuffle:
406 return IsChannelShuffleSupported(infos[0],
407 infos[1],
408 *(PolymorphicDowncast<const ChannelShuffleDescriptor*>(&descriptor)),
409 reasonIfUnsupported);
410 case LayerType::Convolution3d:
411 {
412 if (infos.size() != 4)
413 {
414 throw InvalidArgumentException("Invalid number of Convolution3d TensorInfos. "
415 "TensorInfos should be of format: {input, output, weights, biases}.");
416 }
417
418 auto desc = *(PolymorphicDowncast<const Convolution3dDescriptor*>(&descriptor));
419 if (infos[3] == TensorInfo())
420 {
421 return IsConvolution3dSupported(infos[0],
422 infos[1],
423 desc,
424 infos[2],
425 EmptyOptional(),
426 reasonIfUnsupported);
427 }
428 else
429 {
430 return IsConvolution3dSupported(infos[0],
431 infos[1],
432 desc,
433 infos[2],
434 infos[3],
435 reasonIfUnsupported);
436 }
437 }
438 case LayerType::Debug:
439 return IsDebugSupported(infos[0], infos[1], reasonIfUnsupported);
440 case LayerType::DetectionPostProcess:
441 return IsDetectionPostProcessSupported(infos[0],
442 infos[1],
443 infos[2],
444 infos[3],
445 infos[4],
446 infos[5],
447 infos[6],
448 *(PolymorphicDowncast<const DetectionPostProcessDescriptor*>
449 (&descriptor)),
450 reasonIfUnsupported);
451 case LayerType::FakeQuantization:
452 return IsFakeQuantizationSupported(infos[0],
453 *(PolymorphicDowncast<const FakeQuantizationDescriptor*>(&descriptor)),
454 reasonIfUnsupported);
455 case LayerType::MemCopy:
456 return IsMemCopySupported(infos[0], infos[1], reasonIfUnsupported);
457 case LayerType::Rank:
458 return IsRankSupported(infos[0], infos[1], reasonIfUnsupported);
459 case LayerType::Shape:
460 return IsShapeSupported(infos[0], infos[1], reasonIfUnsupported);
461 case LayerType::UnidirectionalSequenceLstm:
462 {
463 if (infos.size() != 6)
464 {
465 throw InvalidArgumentException("Invalid number of UnidirectionalSequenceLstm TensorInfos. TensorInfos "
466 "should be of format: {input, outputStateIn, cellStateIn, "
467 "hiddenStateOutputVal, cellStateOutputVal, output}");
468 }
469 auto desc = *(PolymorphicDowncast<const UnidirectionalSequenceLstmDescriptor*>(&descriptor));
Mike Kelly12994962022-04-21 11:57:09 +0100470 return IsUnidirectionalSequenceLstmSupported(infos[0],
471 infos[1],
472 infos[2],
473 infos[3],
474 infos[4],
475 infos[5],
476 desc,
477 lstmParamsInfo.value(),
478 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000479 }
480 case LayerType::Pooling3d:
481 return IsPooling3dSupported(infos[0],
482 infos[1],
483 *(PolymorphicDowncast<const Pooling3dDescriptor*>(&descriptor)),
484 reasonIfUnsupported);
485 case LayerType::Map:
486 return true;
487 case LayerType::Unmap:
488 return true;
489 case LayerType::MemImport:
490 return LayerSupportBase::IsMemImportSupported(infos[0], infos[1], reasonIfUnsupported);
491 case LayerType::Merge:
492 return LayerSupportBase::IsMergeSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
493 case LayerType::QuantizedLstm:
494 return LayerSupportBase::IsQuantizedLstmSupported(infos[0],
495 infos[1],
496 infos[2],
497 infos[3],
498 infos[4],
499 quantizedLstmInputParamsInfo.value(),
500 reasonIfUnsupported);
501 default:
502 // layers not supported in neon by default:
503 // precompiled, standin, switch
504 return false;
505 }
506}
507
arovir011c7c81b2018-10-08 11:34:28 +0100508bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
509 const TensorInfo& output,
510 const ActivationDescriptor& descriptor,
511 Optional<std::string&> reasonIfUnsupported) const
512{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000513 bool supported = true;
514
515 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +0000516 std::array<DataType,6> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000517 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100518 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000519 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000520 DataType::QAsymmU8,
521 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000522 };
523
524 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
525 "Reference activation: input type not supported.");
526
527 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
528 "Reference activation: output type not supported.");
529
530 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
531 "Reference activation: input and output types mismatched.");
532
533 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
534 "Reference activation: input and output shapes are of different rank.");
535
536
537 struct ActivationFunctionSupported : public Rule
538 {
539 ActivationFunctionSupported(const ActivationDescriptor& desc)
540 {
541 switch(desc.m_Function)
542 {
543 case ActivationFunction::Abs:
544 case ActivationFunction::BoundedReLu:
David Monahan3b3c3812020-02-25 09:03:29 +0000545 case ActivationFunction::Elu:
Colm Donelan03fbeaf2020-02-26 15:39:23 +0000546 case ActivationFunction::HardSwish:
Derek Lamberti50db4e82019-03-13 14:16:15 +0000547 case ActivationFunction::LeakyReLu:
548 case ActivationFunction::Linear:
549 case ActivationFunction::ReLu:
550 case ActivationFunction::Sigmoid:
551 case ActivationFunction::SoftReLu:
552 case ActivationFunction::Sqrt:
553 case ActivationFunction::Square:
554 case ActivationFunction::TanH:
555 {
556 m_Res = true;
557 break;
558 }
559 default:
560 {
561 m_Res = false;
562 break;
563 }
564 }
565 }
566 };
567
568 // Function is supported
569 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
570 "Reference activation: function not supported.");
571
572 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100573}
574
575bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
576 const TensorInfo& input1,
577 const TensorInfo& output,
578 Optional<std::string&> reasonIfUnsupported) const
579{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000580 bool supported = true;
581
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100582 std::array<DataType,7> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000583 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100584 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000585 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000586 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100587 DataType::QSymmS16,
588 DataType::Signed32
Derek Lamberti50db4e82019-03-13 14:16:15 +0000589 };
590
591 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
592 "Reference addition: input 0 is not a supported type.");
593
594 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
595 "Reference addition: input 1 is not a supported type.");
596
597 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
598 "Reference addition: output is not a supported type.");
599
600 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
601 "Reference addition: input 0 and Input 1 types are mismatched");
602
603 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
604 "Reference addition: input and output types are mismatched");
605
606 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
607 "Reference addition: shapes are not suitable for implicit broadcast.");
608
609 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100610}
611
Nikhil Raj68c2c902019-09-19 11:21:11 +0100612bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
613 const armnn::ArgMinMaxDescriptor &descriptor,
614 armnn::Optional<std::string &> reasonIfUnsupported) const
615{
Jan Eilers8eb25602020-03-09 12:13:48 +0000616 IgnoreUnused(descriptor);
Nikhil Raj68c2c902019-09-19 11:21:11 +0100617
Mike Kelly1f140f72021-04-06 12:25:55 +0100618 std::array<DataType, 8> supportedInputTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100619 {
Teresa Charline300b362020-05-25 10:01:03 +0100620 DataType::Float16,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100621 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100622 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000623 DataType::QAsymmU8,
624 DataType::QSymmS16,
Mike Kelly1f140f72021-04-06 12:25:55 +0100625 DataType::Signed32,
626 DataType::Signed64
627 };
628
629 std::array<DataType,2> supportedOutputTypes = {
630 DataType::Signed32,
631 DataType::Signed64
Nikhil Raj68c2c902019-09-19 11:21:11 +0100632 };
633
634 bool supported = true;
635
Mike Kelly1f140f72021-04-06 12:25:55 +0100636 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100637 "Reference ArgMinMax: input is not a supported type.");
Mike Kelly1f140f72021-04-06 12:25:55 +0100638 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100639 "Reference ArgMinMax: output type not supported");
640
641 return supported;
642}
643
Samuel Yap6b478092022-07-06 15:36:03 +0100644bool RefLayerSupport::IsBatchMatMulSupported(const TensorInfo& inputX,
645 const TensorInfo& inputY,
646 const TensorInfo& output,
647 const BatchMatMulDescriptor& descriptor,
648 Optional<std::string &> reasonIfUnsupported) const
649{
650 IgnoreUnused(descriptor);
651
652 std::array<DataType, 6> supportedTypes =
653 {
Samuel Yap6b478092022-07-06 15:36:03 +0100654 DataType::Float16,
655 DataType::Float32,
656 DataType::QAsymmS8,
657 DataType::QAsymmU8,
658 DataType::QSymmS16
659 };
660
661 bool supported = true;
662
663 supported &= CheckSupportRule(TypeAnyOf(inputX, supportedTypes), reasonIfUnsupported,
664 "Reference batch matrix multiplication: input X is not a supported type");
665
666 supported &= CheckSupportRule(TypeAnyOf(inputY, supportedTypes), reasonIfUnsupported,
667 "Reference batch matrix multiplication: input Y is not a supported type");
668
669 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
670 "Reference batch matrix multiplication: output is not a supported type");
671
672 supported &= CheckSupportRule(TypesAreEqual(inputX, inputY), reasonIfUnsupported,
673 "Reference batch matrix multiplication: input X and input Y types are mismatched");
674
675 supported &= CheckSupportRule(TypesAreEqual(inputX, output), reasonIfUnsupported,
676 "Reference batch matrix multiplication: inputs and output types are mismatched");
677
678 supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputX, 2),
679 reasonIfUnsupported,
680 "Reference batch matrix multiplication: input X is not of rank 2 or greater");
681
682 supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputY, 2),
683 reasonIfUnsupported,
684 "Reference batch matrix multiplication: input Y is not of rank 2 or greater");
685
686 return supported;
687}
688
arovir011c7c81b2018-10-08 11:34:28 +0100689bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
690 const TensorInfo& output,
691 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100692 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100693 const TensorInfo& beta,
694 const TensorInfo& gamma,
695 const BatchNormalizationDescriptor& descriptor,
696 Optional<std::string&> reasonIfUnsupported) const
697{
Jan Eilers8eb25602020-03-09 12:13:48 +0000698 IgnoreUnused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100699
Sadik Armagan303980c2020-04-17 12:45:14 +0100700 std::array<DataType, 6> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100701 {
702 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100703 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100704 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000705 DataType::QAsymmU8,
706 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100707 };
708
709 bool supported = true;
710
711 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
712 "Reference batch normalization: input is not a supported type.");
713
714 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
715 "Reference batch normalization: output is not a supported type.");
716
717 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
718 "Reference batch normalization: input and output types are mismatched");
719
720 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
721 "Reference batch normalization: mean is not a supported type.");
722
723 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
724 "Reference batch normalization: variance is not a supported type.");
725
726 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
727 "Reference batch normalization: beta is not a supported type.");
728
729 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
730 "Reference batch normalization: gamma is not a supported type.");
731
732 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100733}
734
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000735bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
736 const TensorInfo& output,
737 const BatchToSpaceNdDescriptor& descriptor,
738 Optional<std::string&> reasonIfUnsupported) const
739{
Jan Eilers8eb25602020-03-09 12:13:48 +0000740 IgnoreUnused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100741
742 bool supported = true;
743
744 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
745 std::string inputTensorStr = "input";
746 std::string outputTensorStr = "output";
747
748 // Define supported types.
Sadik Armagan303980c2020-04-17 12:45:14 +0100749 std::array<DataType,6> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100750 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000751 DataType::Float32,
752 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100753 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000754 DataType::QAsymmU8,
755 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100756 };
757
758 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
759 "Reference BatchToSpaceNd: input type not supported.");
760
761 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
762 "Reference BatchToSpaceNd: output type not supported.");
763
764 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
765 "Reference BatchToSpaceNd: input and output types mismatched.");
766
767 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
768 reasonIfUnsupported,
769 CreateIncorrectDimensionsErrorMsg(4,
770 output.GetNumDimensions(),
771 batchToSpaceNdLayerStr,
772 outputTensorStr).data());
773
774 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
775 reasonIfUnsupported,
776 CreateIncorrectDimensionsErrorMsg(4,
777 input.GetNumDimensions(),
778 batchToSpaceNdLayerStr,
779 inputTensorStr).data());
780
781 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000782}
783
mathad01b392e982021-04-07 12:07:30 +0100784bool RefLayerSupport::IsCastSupported(const TensorInfo& input,
785 const TensorInfo& output,
786 Optional<std::string&> reasonIfUnsupported) const
787{
788 std::array<DataType, 9> supportedInputTypes =
789 {
mathad01b392e982021-04-07 12:07:30 +0100790 DataType::Float32,
791 DataType::Float16,
792 DataType::QSymmS8,
793 DataType::QAsymmS8,
794 DataType::QAsymmU8,
795 DataType::QSymmS16,
796 DataType::Signed32
797 };
798
799 bool supported = true;
800 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
801 "Reference cast: input is not a supported type");
802
803
804 supported &= CheckSupportRule(TypeAnyOf(output, supportedInputTypes), reasonIfUnsupported,
805 "Reference cast: output is not a supported type");
806
807 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
808 "Reference cast: input and output shapes have different number of total elements");
809
810 return supported;
811}
812
Simon Obute51f67772021-09-03 15:50:13 +0100813bool RefLayerSupport::IsChannelShuffleSupported(const TensorInfo& input,
814 const TensorInfo& output,
815 const ChannelShuffleDescriptor& descriptor,
816 Optional<std::string&> reasonIfUnsupported) const
817{
818 IgnoreUnused(descriptor);
819 bool supported = true;
820
821 // Define supported output and inputs types.
822 std::array<DataType, 7> supportedTypes =
823 {
Simon Obute51f67772021-09-03 15:50:13 +0100824 DataType::Float32,
825 DataType::Float16,
826 DataType::QAsymmS8,
827 DataType::QAsymmU8,
828 DataType::QSymmS8,
829 DataType::QSymmS16
830 };
831
832 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
833 "Reference ChannelShuffle: input is not a supported type.");
834
835 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
836 "Reference ChannelShuffle: output is not a supported type.");
837
838 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
839 "Reference ChannelShuffle: input and output types are mismatched.");
840
841 return supported;
842}
843
844
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100845bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
846 const TensorInfo& input1,
847 const TensorInfo& output,
848 const ComparisonDescriptor& descriptor,
849 Optional<std::string&> reasonIfUnsupported) const
850{
Jan Eilers8eb25602020-03-09 12:13:48 +0000851 IgnoreUnused(descriptor);
Sadik Armagan303980c2020-04-17 12:45:14 +0100852 std::array<DataType, 8> supportedInputTypes =
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100853 {
Sadik Armaganb60dd242020-03-19 13:53:16 +0000854 DataType::Boolean,
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100855 DataType::Float32,
856 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100857 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000858 DataType::QAsymmU8,
Sadik Armaganb60dd242020-03-19 13:53:16 +0000859 DataType::QSymmS16,
860 DataType::Signed32
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100861 };
862
863 bool supported = true;
864 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
865 "Reference comparison: input 0 is not a supported type");
866
867 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
868 "Reference comparison: input 0 and Input 1 types are mismatched");
869
870 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
871 "Reference comparison: output is not of type Boolean");
872
873 return supported;
874}
875
Jim Flynn906f9462019-05-10 13:55:21 +0100876bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
877 const TensorInfo& output,
Cathal Corbett34b429c2021-12-24 12:24:40 +0000878 const OriginsDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100879 Optional<std::string&> reasonIfUnsupported) const
880{
Jan Eilers8eb25602020-03-09 12:13:48 +0000881 IgnoreUnused(descriptor);
Jim Flynne242f2d2019-05-22 14:24:13 +0100882
883 bool supported = true;
Teresa Charlin6abc7ee2022-02-22 17:32:27 +0000884 std::array<DataType,7> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100885 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000886 DataType::Float32,
887 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000888 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100889 DataType::QAsymmU8,
Teresa Charlin6abc7ee2022-02-22 17:32:27 +0000890 DataType::QSymmS16,
891 DataType::Signed32
Jim Flynne242f2d2019-05-22 14:24:13 +0100892 };
893
894 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
895 "Reference concatenation: output type not supported");
896 for (const TensorInfo* input : inputs)
897 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100898 ARMNN_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100899 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
900 "Reference concatenation: input type not supported");
901
902 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
903 "Reference concatenation: input and output types mismatched.");
904 }
905
906 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100907}
908
arovir011c7c81b2018-10-08 11:34:28 +0100909bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
910 Optional<std::string&> reasonIfUnsupported) const
911{
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100912 std::array<DataType,8> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100913 {
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100914 DataType::Float16,
Nina Drozd58ef2c62019-05-16 12:09:18 +0100915 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +0000916 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100917 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000918 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100919 DataType::QSymmS16,
920 DataType::Signed32
Nina Drozd58ef2c62019-05-16 12:09:18 +0100921 };
922
923 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
924 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100925}
926
927bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
928 const TensorInfo& output,
929 Optional<std::string&> reasonIfUnsupported) const
930{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100931 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
932 input.GetDataType(),
933 &TrueFunc<>,
934 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000935 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000936 &FalseFuncI32<>,
937 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100938 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
939 output.GetDataType(),
940 &FalseOutputFuncF16<>,
941 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000942 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000943 &FalseFuncI32<>,
944 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100945}
946
947bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
948 const TensorInfo& output,
949 Optional<std::string&> reasonIfUnsupported) const
950{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100951 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
952 input.GetDataType(),
953 &FalseInputFuncF16<>,
954 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000955 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000956 &FalseFuncI32<>,
957 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100958 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
959 output.GetDataType(),
960 &TrueFunc<>,
961 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000962 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000963 &FalseFuncI32<>,
964 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100965}
966
967bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
968 const TensorInfo& output,
969 const Convolution2dDescriptor& descriptor,
970 const TensorInfo& weights,
971 const Optional<TensorInfo>& biases,
972 Optional<std::string&> reasonIfUnsupported) const
973{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100974 bool supported = true;
975
976 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000977 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000978 {
979 DataType::Float32,
980 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000981 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100982 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000983 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000984 DataType::QSymmS16
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100985 };
986
987 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000988 "Reference Convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100989
990 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000991 "Reference Convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100992
Ryan OShea31441592022-11-07 16:20:48 +0000993 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
994 "Reference Convolution2d: input and output types mismatched.");
995
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100996
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000997 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000998 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000999 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001000 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001001 {
Sadik Armagan303980c2020-04-17 12:45:14 +01001002 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001003 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01001004 DataType::QSymmS8
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001005 };
1006
1007 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001008 "Reference Convolution2d: weights type not supported for quantized input.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001009 }
1010 else
1011 {
1012 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001013 "Reference Convolution2d: weights is not a supported type.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001014
1015 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001016 "Reference Convolution2d: input and weights types mismatched.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001017 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001018
1019 if (biases.has_value())
1020 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001021 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001022 {
1023 DataType::Float32,
1024 DataType::Float16,
1025 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001026 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001027
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001028 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001029 "Reference Convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001030 }
Jan Eilers8eb25602020-03-09 12:13:48 +00001031 IgnoreUnused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001032
1033 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001034}
1035
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001036bool RefLayerSupport::IsConvolution3dSupported(const TensorInfo& input,
1037 const TensorInfo& output,
1038 const Convolution3dDescriptor& descriptor,
1039 const TensorInfo& weights,
1040 const Optional<TensorInfo>& biases,
1041 Optional<std::string&> reasonIfUnsupported) const
1042{
1043 bool supported = true;
1044
1045 // Define supported types.
1046 std::array<DataType,7> supportedTypes =
1047 {
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001048 DataType::Float32,
1049 DataType::Float16,
1050 DataType::QAsymmS8,
1051 DataType::QAsymmU8,
1052 DataType::QSymmS8,
1053 DataType::QSymmS16
1054 };
1055
1056 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1057 "Reference Convolution3d: input is not a supported type.");
1058
1059 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1060 "Reference Convolution3d: output is not a supported type.");
1061
1062 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1063 "Reference Convolution3d: input and output types mismatched.");
1064
1065 const DataType inputType = input.GetDataType();
1066 if (IsQuantized8BitType(inputType))
1067 {
1068 std::array<DataType, 3> supportedWeightTypes =
1069 {
1070 DataType::QAsymmS8,
1071 DataType::QAsymmU8,
1072 DataType::QSymmS8
1073 };
1074
1075 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1076 "Reference Convolution3d: weights type not supported for quantized input.");
1077 }
1078 else
1079 {
1080 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1081 "Reference Convolution3d: weights is not a supported type.");
1082
1083 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1084 "Reference Convolution3d: input and weights types mismatched.");
1085 }
1086
1087 if (biases.has_value())
1088 {
1089 std::array<DataType,4> biasesSupportedTypes =
1090 {
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001091 DataType::Float32,
1092 DataType::Float16,
1093 DataType::Signed32
1094 };
1095
1096 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1097 "Reference Convolution3d: biases is not a supported type.");
1098 }
1099 IgnoreUnused(descriptor);
1100
1101 return supported;
1102}
1103
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001104bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
1105 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001106 Optional<std::string&> reasonIfUnsupported) const
1107{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001108 bool supported = true;
1109
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001110 std::array<DataType, 8> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001111 {
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001112 DataType::BFloat16,
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +00001113 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001114 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001115 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001116 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001117 DataType::QSymmS8,
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +00001118 DataType::QSymmS16,
1119 DataType::Signed32
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001120 };
1121
1122 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001123 "Reference for Debug layer: input type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001124
1125 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001126 "Reference for Debug layer: output type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001127
1128 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001129 "Reference for Debug layer: input and output types are mismatched");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001130
1131 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001132}
1133
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001134bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
1135 const TensorInfo& output,
1136 const DepthToSpaceDescriptor& descriptor,
1137 Optional<std::string&> reasonIfUnsupported) const
1138{
Jan Eilers8eb25602020-03-09 12:13:48 +00001139 IgnoreUnused(descriptor);
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001140 bool supported = true;
1141
Sadik Armagan303980c2020-04-17 12:45:14 +01001142 std::array<DataType,6> supportedTypes =
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001143 {
1144 DataType::Float32,
1145 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001146 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001147 DataType::QAsymmU8,
1148 DataType::QSymmS16
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001149 };
1150
1151 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1152 "Reference DepthToSpace: input type not supported");
1153
1154 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1155 "Reference DepthToSpace: output type not supported");
1156
1157 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1158 "Reference DepthToSpace: input and output types are mismatched");
1159
1160 return supported;
1161}
1162
arovir011c7c81b2018-10-08 11:34:28 +01001163bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
1164 const TensorInfo& output,
1165 const DepthwiseConvolution2dDescriptor& descriptor,
1166 const TensorInfo& weights,
1167 const Optional<TensorInfo>& biases,
1168 Optional<std::string&> reasonIfUnsupported) const
1169{
Sadik Armagan303980c2020-04-17 12:45:14 +01001170 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001171 bool supported = true;
1172
1173 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001174 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001175 {
1176 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001177 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001178 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001179 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001180 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001181 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001182 };
1183
1184 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1185 "Reference DepthwiseConvolution2d: input is not a supported type.");
1186
1187 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1188 "Reference DepthwiseConvolution2d: output is not a supported type.");
1189
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001190 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1191 "Reference DepthwiseConvolution2d: input and output types mismatched.");
1192
Teresa Charlind8df0262019-11-11 12:28:15 +00001193 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +00001194 if (IsQuantized8BitType(inputType))
Teresa Charlind8df0262019-11-11 12:28:15 +00001195 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001196 std::array<DataType, 3> supportedWeightTypes =
Sadik Armagan303980c2020-04-17 12:45:14 +01001197 {
1198 DataType::QAsymmS8,
1199 DataType::QAsymmU8,
1200 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001201 };
Teresa Charlind8df0262019-11-11 12:28:15 +00001202
1203 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Sadik Armagan303980c2020-04-17 12:45:14 +01001204 "Reference DepthwiseConvolution2d: weights type not supported for "
1205 "quantized input.");
Teresa Charlind8df0262019-11-11 12:28:15 +00001206 }
1207 else
1208 {
1209 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1210 "Reference DepthwiseConvolution2d: weights is not a supported type.");
1211
1212 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1213 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
1214 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001215
1216 if (biases.has_value())
1217 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001218 std::array<DataType,4> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001219 {
1220 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001221 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001222 DataType::Signed32
1223 };
1224 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1225 "Reference DepthwiseConvolution2d: biases is not a supported type.");
1226 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001227
1228 return supported;
1229
arovir011c7c81b2018-10-08 11:34:28 +01001230}
1231
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001232bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
1233 const TensorInfo& output,
1234 Optional<std::string&> reasonIfUnsupported) const
1235{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001236 bool supported = true;
1237
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001238 std::array<DataType,5> supportedInputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00001239 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001240 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00001241 DataType::QSymmS8,
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001242 DataType::QSymmS16,
1243 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001244 };
1245
1246 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001247 "Reference for Dequantize layer: input type not supported.");
1248
Derek Lambertid466a542020-01-22 15:37:29 +00001249 supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
Teresa Charlin1b1950d2021-06-02 20:23:21 +01001250 "Reference for Dequantize layer: per-axis quantized input not supported.");
Derek Lambertid466a542020-01-22 15:37:29 +00001251
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001252 std::array<DataType,3> supportedOutputTypes = {
Jan Eilersf7107932019-11-01 11:09:36 +00001253 DataType::Float32,
1254 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001255 };
1256
1257 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001258 "Reference for Dequantize layer: output type not supported.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001259
1260 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001261 "Reference for Dequantize layer: input/output shapes have different num total "
1262 "elements.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001263
1264 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001265}
1266
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001267bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
1268 const TensorInfo& scores,
1269 const TensorInfo& anchors,
1270 const TensorInfo& detectionBoxes,
1271 const TensorInfo& detectionClasses,
1272 const TensorInfo& detectionScores,
1273 const TensorInfo& numDetections,
1274 const DetectionPostProcessDescriptor& descriptor,
1275 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001276{
Jan Eilers8eb25602020-03-09 12:13:48 +00001277 IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
Derek Lamberti901ea112019-12-10 22:07:09 +00001278
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001279 bool supported = true;
1280
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001281 std::array<DataType,6> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001282 {
1283 DataType::Float32,
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001284 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001285 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001286 DataType::QAsymmU8,
1287 DataType::QSymmS16
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001288 };
1289
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001290 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001291 "Reference DetectionPostProcess: input 0 is not a supported type.");
1292
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001293 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001294 "Reference DetectionPostProcess: input 1 is not a supported type.");
1295
1296 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001297}
1298
Pablo Tellof0bd6832019-04-26 17:58:13 +01001299bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
1300 const TensorInfo& output,
1301 const DepthwiseConvolution2dDescriptor& descriptor,
1302 const TensorInfo& weights,
1303 const Optional<TensorInfo>& biases,
1304 Optional<std::string&> reasonIfUnsupported) const
1305{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001306 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +01001307}
1308
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001309bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +01001310 const TensorInfo& input1,
1311 const TensorInfo& output,
1312 Optional<std::string&> reasonIfUnsupported) const
1313{
Sadik Armagan2999a022019-04-09 14:20:12 +01001314 bool supported = true;
1315
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001316 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001317 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001318 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001319 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001320 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001321 DataType::QSymmS16,
1322 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001323 };
1324
1325 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1326 "Reference division: input 0 is not a supported type.");
1327
1328 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1329 "Reference division: input 1 is not a supported type.");
1330
1331 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1332 "Reference division: output is not a supported type.");
1333
1334 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1335 "Reference division: input 0 and Input 1 types are mismatched");
1336
1337 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1338 "Reference division: input and output types are mismatched");
1339
1340 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1341 "Reference division: shapes are not suitable for implicit broadcast.");
1342
1343 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001344}
1345
josh minor4a3c6102020-01-06 16:40:46 -06001346bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
1347 const TensorInfo& output,
1348 const ElementwiseUnaryDescriptor& descriptor,
1349 Optional<std::string&> reasonIfUnsupported) const
1350{
Jan Eilers8eb25602020-03-09 12:13:48 +00001351 IgnoreUnused(descriptor);
josh minor4a3c6102020-01-06 16:40:46 -06001352
Sadik Armagan303980c2020-04-17 12:45:14 +01001353 std::array<DataType, 7> supportedTypes =
josh minor4a3c6102020-01-06 16:40:46 -06001354 {
1355 DataType::Float32,
1356 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001357 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06001358 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00001359 DataType::QSymmS16,
1360 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06001361 };
1362
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001363 std::array<DataType, 1> logicalSupportedTypes =
1364 {
1365 DataType::Boolean
1366 };
1367
josh minor4a3c6102020-01-06 16:40:46 -06001368 bool supported = true;
1369
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001370 if (descriptor.m_Operation == UnaryOperation::LogicalNot)
1371 {
1372 supported &= CheckSupportRule(TypeAnyOf(input, logicalSupportedTypes), reasonIfUnsupported,
1373 "Reference elementwise unary: input type not supported");
josh minor4a3c6102020-01-06 16:40:46 -06001374
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001375 supported &= CheckSupportRule(TypeAnyOf(output, logicalSupportedTypes), reasonIfUnsupported,
1376 "Reference elementwise unary: output type not supported");
1377 }
1378 else
1379 {
1380 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1381 "Reference elementwise unary: input type not supported");
1382
1383 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1384 "Reference elementwise unary: output type not supported");
1385 }
josh minor4a3c6102020-01-06 16:40:46 -06001386
1387 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1388 "Reference elementwise unary: input and output types not matching");
1389
1390 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1391 "Reference elementwise unary: input and output shapes"
1392 "have different number of total elements");
1393
1394 return supported;
1395}
1396
arovir011c7c81b2018-10-08 11:34:28 +01001397bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
1398 const FakeQuantizationDescriptor& descriptor,
1399 Optional<std::string&> reasonIfUnsupported) const
1400{
Jan Eilers8eb25602020-03-09 12:13:48 +00001401 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001402 bool supported = true;
1403
1404 std::array<DataType,1> supportedTypes =
1405 {
1406 DataType::Float32
1407 };
1408
1409 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1410 "Reference fake quantization: input type not supported.");
1411
1412 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001413}
1414
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001415bool RefLayerSupport::IsFillSupported(const TensorInfo& input,
1416 const TensorInfo& output,
1417 const FillDescriptor& descriptor,
1418 Optional<std::string&> reasonIfUnsupported) const
1419{
1420 IgnoreUnused(descriptor);
1421 IgnoreUnused(output);
1422
1423 bool supported = true;
1424
Sadik Armagana792a052020-06-23 16:22:23 +01001425 std::array<DataType,3> supportedTypes =
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001426 {
1427 DataType::Float32,
Sadik Armagana792a052020-06-23 16:22:23 +01001428 DataType::Float16,
1429 DataType::Signed32
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001430 };
1431
Teresa Charlin4b10fef2020-07-29 09:36:41 +01001432 supported &= CheckSupportRule(TypeIs(input, DataType::Signed32), reasonIfUnsupported,
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001433 "Reference Fill: input type not supported.");
1434
Teresa Charlin44088502020-07-27 11:27:19 +01001435 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1436 "Reference Fill: output type not supported.");
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001437 return supported;
1438}
1439
arovir011c7c81b2018-10-08 11:34:28 +01001440bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
1441 const TensorInfo& output,
1442 Optional<std::string&> reasonIfUnsupported) const
1443{
Jan Eilers8eb25602020-03-09 12:13:48 +00001444 IgnoreUnused(output);
James Conroy83735b12019-05-30 16:36:59 +01001445 bool supported = true;
1446
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01001447 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +01001448 {
James Conroyb40d7102019-06-04 12:32:09 +01001449 DataType::Float32,
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01001450 DataType::Float16
James Conroy83735b12019-05-30 16:36:59 +01001451 };
1452
1453 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1454 "Reference Floor: input type not supported.");
1455
1456 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1457 "Reference Floor: output type not supported.");
1458
1459 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001460}
1461
1462bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
1463 const TensorInfo& output,
1464 const TensorInfo& weights,
1465 const TensorInfo& biases,
1466 const FullyConnectedDescriptor& descriptor,
1467 Optional<std::string&> reasonIfUnsupported) const
1468{
Francis Murtagh46c09d02019-05-28 08:15:28 +01001469 bool supported = true;
1470
1471 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001472 std::array<DataType,6> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +01001473 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001474 DataType::Float32,
1475 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001476 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001477 DataType::QAsymmU8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001478 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001479 };
1480
1481 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1482 "Reference Fully Connected: input type not supported.");
1483
1484 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1485 "Reference Fully Connected: output type not supported.");
1486
Francis Murtagh46c09d02019-05-28 08:15:28 +01001487 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1488 "Reference Fully Connected: weights type not supported.");
1489
Ryan OShea31441592022-11-07 16:20:48 +00001490 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1491 "Reference Fully Connected: input and output types mismatched.");
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001492
Jan Eilers1f45dc32020-06-15 11:43:03 +01001493 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1494 "Reference Fully Connected: weights is not a supported type.");
Francis Murtaghddb1d062020-03-10 13:51:45 +00001495
Jan Eilers1f45dc32020-06-15 11:43:03 +01001496 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1497 "Reference Fully Connected: input and weights types mismatched.");
Francis Murtagh46c09d02019-05-28 08:15:28 +01001498
1499 if (descriptor.m_BiasEnabled)
1500 {
1501 // Defined supported types for bias
Sadik Armagandb73c982020-04-01 17:35:30 +01001502 std::array<DataType, 5>
Francis Murtagh46c09d02019-05-28 08:15:28 +01001503 supportedBiasTypes =
1504 {
1505 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001506 DataType::Float16,
Sadik Armagandb73c982020-04-01 17:35:30 +01001507 DataType::Signed32,
1508 DataType::QAsymmS8
Francis Murtagh46c09d02019-05-28 08:15:28 +01001509 };
1510
1511 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
1512 "Reference Fully Connected: bias type not supported.");
1513
1514 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
1515 "Reference Fully Connected: bias and weight types mismatch.");
1516
1517 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
1518 "Reference Fully Connected: bias type inferred from weights is incompatible.");
1519
Narumol Prangnawarat366d7232020-04-29 12:58:17 +01001520 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(biases, 1U), reasonIfUnsupported,
1521 "Reference Fully Connected: bias must have 1 dimension.");
1522
Francis Murtagh46c09d02019-05-28 08:15:28 +01001523 }
1524
1525 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001526}
1527
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01001528bool RefLayerSupport::IsGatherNdSupported(const armnn::TensorInfo& input0,
1529 const armnn::TensorInfo& input1,
1530 const armnn::TensorInfo& output,
1531 armnn::Optional<std::string&> reasonIfUnsupported) const
1532{
1533 bool supported = true;
1534 std::array<DataType,7> supportedTypes =
1535 {
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01001536 DataType::Float32,
1537 DataType::Float16,
1538 DataType::QAsymmS8,
1539 DataType::QAsymmU8,
1540 DataType::QSymmS16,
1541 DataType::Signed32
1542 };
1543
1544 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1545 "Reference GatherNd: input type not supported");
1546
1547 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1548 "Reference GatherNd: output type not supported");
1549
1550 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1551 "Reference GatherNd: indices (input1) type not supported");
1552
1553 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1554 "Reference GatherNd: input and output types not matching");
1555
1556 return supported;
1557}
1558
narpra014951d842019-01-18 16:53:53 +00001559bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
1560 const armnn::TensorInfo& input1,
1561 const armnn::TensorInfo& output,
Teresa Charlin52664732020-06-29 16:27:03 +01001562 const GatherDescriptor& descriptor,
narpra014951d842019-01-18 16:53:53 +00001563 armnn::Optional<std::string&> reasonIfUnsupported) const
1564{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001565 bool supported = true;
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001566 std::array<DataType,7> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001567 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001568 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001569 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001570 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001571 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001572 DataType::QSymmS16,
1573 DataType::Signed32
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001574 };
1575
Nikhil Raj369d8fc2022-11-24 13:12:36 +00001576 IgnoreUnused(descriptor);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001577 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1578 "Reference Gather: input type not supported");
1579
1580 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1581 "Reference Gather: output type not supported");
1582
1583 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1584 "Reference Gather: indices (input1) type not supported");
1585
1586 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1587 "Reference Gather: input and output types not matching");
1588
1589 return supported;
narpra014951d842019-01-18 16:53:53 +00001590}
1591
Derek Lamberti901ea112019-12-10 22:07:09 +00001592bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
1593 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +01001594{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001595 return true;
arovir011c7c81b2018-10-08 11:34:28 +01001596}
1597
Kevin May09ca49c2019-10-09 12:37:34 +01001598bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
1599 const TensorInfo& output,
1600 const InstanceNormalizationDescriptor& descriptor,
1601 Optional<std::string&> reasonIfUnsupported) const
1602{
Jan Eilers8eb25602020-03-09 12:13:48 +00001603 IgnoreUnused(descriptor);
Kevin May09ca49c2019-10-09 12:37:34 +01001604 // Define supported types
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001605 std::array<DataType, 3> supportedTypes =
Kevin May09ca49c2019-10-09 12:37:34 +01001606 {
1607 DataType::Float32,
1608 DataType::Float16
1609 };
1610
1611 bool supported = true;
1612
1613 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1614 "Reference Instance Normalization: input type not supported.");
1615
1616 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1617 "Reference Instance Normalization: output type not supported.");
1618
1619 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1620 "Reference Instance Normalization: input and output types mismatched.");
1621
1622 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1623 "Reference Instance Normalization: input and output shapes have different "
1624 "num total elements.");
1625
1626 return supported;
1627}
1628
arovir011c7c81b2018-10-08 11:34:28 +01001629bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1630 const TensorInfo& output,
1631 const L2NormalizationDescriptor& descriptor,
1632 Optional<std::string&> reasonIfUnsupported) const
1633{
Jan Eilers8eb25602020-03-09 12:13:48 +00001634 IgnoreUnused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001635 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01001636 std::array<DataType, 6> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001637 {
1638 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001639 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001640 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001641 DataType::QAsymmU8,
1642 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001643 };
1644
1645 bool supported = true;
1646
1647 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1648 "Reference L2normalization: input type not supported.");
1649
1650 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1651 "Reference L2normalization: output type not supported.");
1652
1653 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1654 "Reference L2normalization: input and output types mismatched.");
1655
1656 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1657 "Reference L2normalization: input and output shapes have different "
1658 "num total elements.");
1659
1660 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001661}
1662
James Conroyaba90cd2020-11-06 16:28:18 +00001663bool RefLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0,
1664 const TensorInfo& input1,
1665 const TensorInfo& output,
1666 const LogicalBinaryDescriptor& descriptor,
1667 Optional<std::string&> reasonIfUnsupported) const
1668{
1669 IgnoreUnused(descriptor);
1670
1671 std::array<DataType, 1> supportedTypes =
1672 {
1673 DataType::Boolean
1674 };
1675
1676 bool supported = true;
1677 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1678 "Reference LogicalBinary: input 0 type not supported");
1679 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1680 "Reference LogicalBinary: input 1 type not supported");
1681
1682 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1683 "Reference LogicalBinary: input and output types do not match");
1684
1685 return supported;
1686}
1687
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001688bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1689 const TensorInfo& output,
1690 const LogSoftmaxDescriptor& descriptor,
1691 Optional<std::string&> reasonIfUnsupported) const
1692{
Jan Eilers8eb25602020-03-09 12:13:48 +00001693 IgnoreUnused(descriptor);
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001694
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001695 std::array<DataType, 3> supportedTypes =
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001696 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001697 DataType::Float32,
1698 DataType::Float16
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001699 };
1700
1701 bool supported = true;
1702 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1703 "Reference LogSoftmax: input type not supported");
1704
1705 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1706 "Reference LogSoftmax: output type not supported");
1707
1708 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1709 "Reference LogSoftmax: input and output types do not match");
1710
1711 return supported;
1712}
1713
arovir011c7c81b2018-10-08 11:34:28 +01001714bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1715 const TensorInfo& outputStateIn,
1716 const TensorInfo& cellStateIn,
1717 const TensorInfo& scratchBuffer,
1718 const TensorInfo& outputStateOut,
1719 const TensorInfo& cellStateOut,
1720 const TensorInfo& output,
1721 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001722 const LstmInputParamsInfo& paramsInfo,
1723 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +01001724{
Jan Eilers8eb25602020-03-09 12:13:48 +00001725 IgnoreUnused(descriptor);
1726 IgnoreUnused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001727
1728 bool supported = true;
1729
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001730 std::array<DataType,3> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +01001731 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001732 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001733 };
1734
Jan Eilersd01a83c2019-07-03 18:20:40 +01001735 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001736 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1737 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001738 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1739 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001740 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1741 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001742 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1743 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001744 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1745 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001746 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1747 "Reference Lstm: input and cellStateOut types are mismatched");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01001748
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001749 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1750 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001751 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001752 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001753 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001754 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001755 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001756 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001757 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001758 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001759 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001760 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001761 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001762 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001763 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001764 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001765 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001766 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001767 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001768 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001769 "Reference Lstm: input and OutputGateBias types are mismatched");
1770 if (!descriptor.m_CifgEnabled)
1771 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001772 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001773 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001774 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001775 reasonIfUnsupported,
1776 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001777 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001778 "Reference Lstm: input and InputGateBias types are mismatched");
1779 if (descriptor.m_PeepholeEnabled)
1780 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001781 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001782 reasonIfUnsupported,
1783 "Reference Lstm: input and CellToInputWeights types are mismatched");
1784 }
1785 }
1786 if (descriptor.m_PeepholeEnabled)
1787 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001788 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001789 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001790 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001791 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1792 }
1793 if (descriptor.m_ProjectionEnabled)
1794 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001795 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001796 "Reference Lstm: input and mProjectionWeights types are mismatched");
1797 if (paramsInfo.m_ProjectionBias != nullptr)
1798 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001799 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001800 "Reference Lstm: input and ProjectionBias types are mismatched");
1801 }
1802 }
1803 if (descriptor.m_LayerNormEnabled)
1804 {
1805 if (!descriptor.m_CifgEnabled)
1806 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001807 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001808 reasonIfUnsupported,
1809 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1810 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001811 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001812 reasonIfUnsupported,
1813 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001814 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001815 reasonIfUnsupported,
1816 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001817 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001818 reasonIfUnsupported,
1819 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1820 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001821
1822 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001823}
1824
saoste012df12b32018-11-28 16:57:20 +00001825bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1826 const TensorInfo& input1,
1827 const TensorInfo& output,
1828 Optional<std::string&> reasonIfUnsupported) const
1829{
Sadik Armagan2999a022019-04-09 14:20:12 +01001830 bool supported = true;
1831
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001832 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001833 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001834 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001835 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001836 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001837 DataType::QSymmS16,
1838 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001839 };
1840
1841 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1842 "Reference maximum: input 0 is not a supported type.");
1843
1844 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1845 "Reference maximum: input 1 is not a supported type.");
1846
1847 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1848 "Reference maximum: output is not a supported type.");
1849
1850 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1851 "Reference maximum: input 0 and Input 1 types are mismatched");
1852
1853 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1854 "Reference maximum: input and output types are mismatched");
1855
1856 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1857 "Reference maximum: shapes are not suitable for implicit broadcast.");
1858
1859 return supported;
saoste012df12b32018-11-28 16:57:20 +00001860}
1861
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001862bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1863 const TensorInfo& output,
1864 const MeanDescriptor& descriptor,
1865 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001866{
James Conroy4d1ff582019-06-10 17:06:39 +01001867 bool supported = true;
1868 std::string meanLayerStr = "Mean";
1869 std::string outputTensorStr = "output";
1870
Sadik Armagan303980c2020-04-17 12:45:14 +01001871 std::array<DataType,6> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001872 {
1873 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001874 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001875 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001876 DataType::QAsymmU8,
1877 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01001878 };
1879
1880 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1881 "Reference Mean: input type not supported.");
1882
1883 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1884 "Reference Mean: input and output types are mismatched");
1885
1886 if (descriptor.m_KeepDims)
1887 {
1888 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1889 reasonIfUnsupported,
1890 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1891 output.GetNumDimensions(),
1892 meanLayerStr, outputTensorStr).data());
1893 }
1894 else if (descriptor.m_Axis.empty())
1895 {
1896 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1897 reasonIfUnsupported,
1898 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1899 meanLayerStr, outputTensorStr).data());
1900 }
1901 else
1902 {
Matthew Sloyan171214c2020-09-09 09:07:37 +01001903 auto outputDim = input.GetNumDimensions() - armnn::numeric_cast<unsigned int>(descriptor.m_Axis.size());
James Conroy4d1ff582019-06-10 17:06:39 +01001904
1905 if (outputDim > 0)
1906 {
1907 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1908 reasonIfUnsupported,
1909 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1910 meanLayerStr, outputTensorStr).data());
1911 }
1912 else
1913 {
1914 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1915 reasonIfUnsupported,
1916 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1917 meanLayerStr, outputTensorStr).data());
1918 }
1919 }
1920
1921 return supported;
narpra0132b90462018-09-13 11:07:48 +01001922}
1923
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001924bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1925 const TensorInfo &output,
1926 Optional<std::string &> reasonIfUnsupported) const
1927{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001928 bool supported = true;
1929
Sadik Armagan303980c2020-04-17 12:45:14 +01001930 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001931 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001932 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001933 DataType::Float32,
1934 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001935 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001936 DataType::QAsymmU8,
1937 DataType::QSymmS16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001938 DataType::Boolean
1939 };
1940
1941 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1942 "Reference MemCopy: input type not supported");
1943
1944 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1945 "Reference MemCopy: output type not supported");
1946
1947 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1948 "Reference MemCopy: input and output types are mismatched");
1949
1950 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001951}
1952
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001953bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1954 const TensorInfo& input1,
1955 const TensorInfo& output,
1956 Optional<std::string&> reasonIfUnsupported) const
1957{
Sadik Armagan2999a022019-04-09 14:20:12 +01001958 bool supported = true;
1959
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001960 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001961 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001962 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001963 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001964 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001965 DataType::QSymmS16,
1966 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001967 };
1968
1969 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1970 "Reference minimum: input 0 is not a supported type.");
1971
1972 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1973 "Reference minimum: input 1 is not a supported type.");
1974
1975 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1976 "Reference minimum: output is not a supported type.");
1977
1978 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1979 "Reference minimum: input 0 and Input 1 types are mismatched");
1980
1981 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1982 "Reference minimum: input and output types are mismatched");
1983
1984 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1985 "Reference minimum: shapes are not suitable for implicit broadcast.");
1986
1987 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001988}
1989
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001990bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1991 const TensorInfo& input1,
1992 const TensorInfo& output,
1993 Optional<std::string&> reasonIfUnsupported) const
1994{
Sadik Armagan2999a022019-04-09 14:20:12 +01001995 bool supported = true;
1996
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001997 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001998 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001999 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002000 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002001 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002002 DataType::QSymmS16,
2003 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002004 };
2005
2006 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2007 "Reference multiplication: input 0 is not a supported type.");
2008
2009 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2010 "Reference multiplication: input 1 is not a supported type.");
2011
2012 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2013 "Reference multiplication: output is not a supported type.");
2014
2015 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2016 "Reference multiplication: input 0 and Input 1 types are mismatched");
2017
2018 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2019 "Reference multiplication: input and output types are mismatched");
2020
2021 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2022 "Reference multiplication: shapes are not suitable for implicit broadcast.");
2023
2024 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002025}
2026
2027bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
2028 const TensorInfo& output,
2029 const NormalizationDescriptor& descriptor,
2030 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01002031{
Jan Eilers8eb25602020-03-09 12:13:48 +00002032 IgnoreUnused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002033
2034 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01002035 std::array<DataType, 6> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002036 {
2037 DataType::Float16,
2038 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002039 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002040 DataType::QAsymmU8,
2041 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002042 };
2043
2044 bool supported = true;
2045
2046 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2047 "Reference normalization: input type not supported.");
2048
2049 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2050 "Reference normalization: output type not supported.");
2051
2052 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2053 "Reference normalization: input and output shapes have different "
2054 "num total elements.");
2055
2056 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002057}
2058
Derek Lamberti901ea112019-12-10 22:07:09 +00002059bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
2060 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002061{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01002062 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002063}
2064
2065bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
2066 const TensorInfo& output,
2067 const PadDescriptor& descriptor,
2068 Optional<std::string&> reasonIfUnsupported) const
2069{
Jan Eilers8eb25602020-03-09 12:13:48 +00002070 IgnoreUnused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002071 bool supported = true;
2072
2073 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002074 std::array<DataType,6> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002075 {
2076 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002077 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002078 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002079 DataType::QAsymmU8,
2080 DataType::QSymmS16
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002081 };
2082
2083 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2084 "Reference pad: input is not a supported type.");
2085
2086 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2087 "Reference pad: output is not a supported type.");
2088
2089 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2090 "Reference pad: input and output types are mismatched.");
2091
2092 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01002093}
2094
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002095bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
2096 const TensorInfo& output,
2097 const PermuteDescriptor& descriptor,
2098 Optional<std::string&> reasonIfUnsupported) const
2099{
Jan Eilers8eb25602020-03-09 12:13:48 +00002100 IgnoreUnused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002101 bool supported = true;
2102
2103 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002104 std::array<DataType, 6> supportedTypes =
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002105 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002106 DataType::BFloat16,
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002107 DataType::Float32,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002108 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002109 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002110 DataType::QAsymmU8,
2111 DataType::QSymmS16
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002112 };
2113
2114 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2115 "Reference permute: input is not a supported type.");
2116
2117 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2118 "Reference permute: output is not a supported type.");
2119
2120 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2121 "Reference permute: input and output types are mismatched.");
2122
2123 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002124}
2125
2126bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
2127 const TensorInfo& output,
2128 const Pooling2dDescriptor& descriptor,
2129 Optional<std::string&> reasonIfUnsupported) const
2130{
Jan Eilers8eb25602020-03-09 12:13:48 +00002131 IgnoreUnused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01002132 bool supported = true;
2133
2134 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002135 std::array<DataType,6> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01002136 {
2137 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002138 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002139 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002140 DataType::QAsymmU8,
2141 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01002142 };
2143
2144 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2145 "Reference poolind2d: input is not a supported type.");
2146
2147 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2148 "Reference poolind2d: output is not a supported type.");
2149
2150 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2151 "Reference poolind2d: input and output types are mismatched.");
2152
2153 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002154}
2155
Tamás Nyíri7b885b32021-10-26 14:47:57 +01002156bool RefLayerSupport::IsPooling3dSupported(const TensorInfo& input,
2157 const TensorInfo& output,
2158 const Pooling3dDescriptor& descriptor,
2159 Optional<std::string&> reasonIfUnsupported) const
2160{
2161 IgnoreUnused(descriptor);
2162 bool supported = true;
2163
2164 // Define supported output and inputs types.
2165 std::array<DataType,6> supportedTypes =
2166 {
Tamás Nyíri7b885b32021-10-26 14:47:57 +01002167 DataType::Float32,
2168 DataType::Float16,
2169 DataType::QAsymmS8,
2170 DataType::QAsymmU8,
2171 DataType::QSymmS16
2172 };
2173
2174 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2175 "Reference poolind3d: input is not a supported type.");
2176
2177 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2178 "Reference poolind3d: output is not a supported type.");
2179
2180 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2181 "Reference poolind3d: input and output types are mismatched.");
2182
2183 return supported;
2184}
2185
2186
James Conroy4f1f8992020-04-29 20:01:10 +01002187bool RefLayerSupport::IsQLstmSupported(const TensorInfo& input,
2188 const TensorInfo& previousOutputIn,
2189 const TensorInfo& previousCellStateIn,
2190 const TensorInfo& outputStateOut,
2191 const TensorInfo& cellStateOut,
2192 const TensorInfo& output,
2193 const QLstmDescriptor& descriptor,
2194 const LstmInputParamsInfo& paramsInfo,
2195 Optional<std::string&> reasonIfUnsupported) const
2196{
2197 IgnoreUnused(input);
2198 IgnoreUnused(previousOutputIn);
2199 IgnoreUnused(previousCellStateIn);
2200 IgnoreUnused(outputStateOut);
2201 IgnoreUnused(cellStateOut);
2202 IgnoreUnused(output);
2203 IgnoreUnused(descriptor);
2204 IgnoreUnused(paramsInfo);
2205
2206 IgnoreUnused(reasonIfUnsupported);
2207
2208 return true;
2209}
2210
Derek Lamberti5f400d62019-03-25 15:41:58 +00002211bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
2212 const TensorInfo& output,
2213 Optional<std::string&> reasonIfUnsupported) const
2214{
2215 bool supported = true;
2216
Finn Williamsfd271062019-12-04 14:27:27 +00002217 // Define supported input types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002218 std::array<DataType,7> supportedInputTypes = {
Keith Davis5e51cd82020-01-29 16:52:59 +00002219 DataType::Float32,
Keith Davis3d8bc972020-02-04 09:31:47 +00002220 DataType::Float16,
Ryan OShea9add1202020-02-07 10:06:33 +00002221 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002222 DataType::QAsymmU8,
2223 DataType::QSymmS8,
2224 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002225 };
2226
2227 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
2228 "Reference quantize: input type not supported.");
2229
2230 // Define supported output types.
Ryan OShea9add1202020-02-07 10:06:33 +00002231 std::array<DataType,4> supportedOutputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00002232 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002233 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00002234 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002235 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002236 };
2237 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2238 "Reference quantize: output type not supported.");
2239
2240 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2241 "Reference quantize: input and output shapes have different num total elements.");
2242
2243 return supported;
2244}
2245
Finn Williams2605b232020-06-10 15:53:46 +01002246bool RefLayerSupport::IsRankSupported(const TensorInfo& input,
2247 const TensorInfo& output,
2248 Optional<std::string&> reasonIfUnsupported) const
2249{
2250 IgnoreUnused(input);
2251 // Define supported output types.
2252 std::array<DataType,1> supportedOutputTypes =
2253 {
2254 DataType::Signed32,
2255 };
2256
2257 return CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2258 "Reference rank: input type not supported.");
2259}
2260
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00002261bool RefLayerSupport::IsReduceSupported(const TensorInfo& input,
2262 const TensorInfo& output,
2263 const ReduceDescriptor& descriptor,
2264 Optional<std::string&> reasonIfUnsupported) const
2265{
2266 IgnoreUnused(descriptor);
2267 bool supported = true;
2268 std::array<DataType,7> supportedTypes =
2269 {
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00002270 DataType::Float32,
2271 DataType::Float16,
2272 DataType::QAsymmS8,
2273 DataType::QAsymmU8,
2274 DataType::QSymmS16,
2275 DataType::Signed32
2276 };
2277
2278 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2279 "Reference Reduce: input type not supported");
2280
2281 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2282 "Reference Reduce: output type not supported");
2283
2284 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2285 "Reference Reduce: input and output types not matching");
2286
2287 return supported;
2288}
2289
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002290bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Kevin Maya023c402019-12-12 17:28:05 +00002291 const TensorInfo& output,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00002292 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002293 Optional<std::string&> reasonIfUnsupported) const
2294{
Jan Eilers8eb25602020-03-09 12:13:48 +00002295 IgnoreUnused(output);
2296 IgnoreUnused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01002297 // Define supported output types.
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002298 std::array<DataType,8> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01002299 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002300 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01002301 DataType::Float32,
2302 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01002303 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002304 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002305 DataType::QAsymmU8,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002306 DataType::QSymmS16,
2307 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01002308 };
Keith Davis0c2eeac2020-02-11 16:51:50 +00002309
Nina Drozd2f2778f2019-05-27 10:37:05 +01002310 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
2311 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002312}
2313
Teresa Charlin970f43b2019-07-01 13:51:07 +01002314bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
2315 const TensorInfo& output,
2316 const ResizeDescriptor& descriptor,
2317 Optional<std::string&> reasonIfUnsupported) const
2318{
Jan Eilers8eb25602020-03-09 12:13:48 +00002319 IgnoreUnused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01002320 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002321 std::array<DataType,6> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01002322 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002323 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01002324 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002325 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002326 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002327 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002328 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01002329 };
2330
2331 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2332 "Reference Resize: input type not supported");
2333
2334 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2335 "Reference Resize: output type not supported");
2336
2337 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2338 "Reference Resize: input and output types not matching");
2339
2340 return supported;
2341}
2342
Keith Davis3ae3f972021-05-21 16:33:48 +01002343bool RefLayerSupport::IsShapeSupported(const TensorInfo& input,
2344 const TensorInfo& output,
2345 Optional<std::string&> reasonIfUnsupported) const
2346{
2347 IgnoreUnused(input);
2348 bool supported = true;
2349
2350 std::array<DataType, 1> supportedTypes =
2351 {
2352 DataType::Signed32
2353 };
2354
2355 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2356 "Reference Shape: output type not supported");
2357
2358 return supported;
2359}
2360
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002361bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
2362 const TensorInfo& output,
2363 const SliceDescriptor& descriptor,
2364 Optional<std::string&> reasonIfUnsupported) const
2365{
Jan Eilers8eb25602020-03-09 12:13:48 +00002366 IgnoreUnused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002367 bool supported = true;
2368
Sadik Armagan303980c2020-04-17 12:45:14 +01002369 std::array<DataType, 5> supportedTypes =
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002370 {
2371 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002372 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002373 DataType::QAsymmU8,
2374 DataType::QSymmS16
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002375 };
2376
2377 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2378 "Reference Slice: input type not supported");
2379
2380 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2381 "Reference Slice: output type not supported");
2382
2383 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2384 "Reference Slice: input and output types are mismatched");
2385
2386 return supported;
2387}
2388
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002389bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
2390 const TensorInfo& output,
2391 const SoftmaxDescriptor& descriptor,
2392 Optional<std::string&> reasonIfUnsupported) const
2393{
Jan Eilers8eb25602020-03-09 12:13:48 +00002394 IgnoreUnused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01002395 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002396 std::array<DataType,7> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01002397 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002398 DataType::Float32,
2399 DataType::Float16,
2400 DataType::QSymmS8,
2401 DataType::QAsymmS8,
2402 DataType::QAsymmU8,
2403 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +01002404 };
2405
2406 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002407 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002408
2409 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002410 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002411
2412 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002413 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002414
2415 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002416}
2417
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002418bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
2419 const TensorInfo& output,
2420 const SpaceToBatchNdDescriptor& descriptor,
2421 Optional<std::string&> reasonIfUnsupported) const
2422{
Jan Eilers8eb25602020-03-09 12:13:48 +00002423 IgnoreUnused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01002424 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002425 std::array<DataType,6> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01002426 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002427 DataType::Float32,
2428 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002429 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002430 DataType::QAsymmU8,
2431 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01002432 };
2433
2434 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2435 "Reference SpaceToBatchNd: input type not supported");
2436
2437 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2438 "Reference SpaceToBatchNd: output type not supported");
2439
2440 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2441 "Reference SpaceToBatchNd: input and output types are mismatched");
2442
2443 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002444}
2445
Keith Davisa57eccb2019-06-14 17:33:22 +01002446bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01002447 const TensorInfo& output,
2448 const SpaceToDepthDescriptor& descriptor,
2449 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01002450{
2451
Jan Eilers8eb25602020-03-09 12:13:48 +00002452 IgnoreUnused(descriptor);
Keith Davisa57eccb2019-06-14 17:33:22 +01002453 bool supported = true;
2454
Sadik Armagan303980c2020-04-17 12:45:14 +01002455 std::array<DataType,6> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01002456 {
2457 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002458 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002459 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002460 DataType::QAsymmU8,
2461 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01002462 };
2463
2464 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2465 "Reference SpaceToDepth: input type not supported");
2466
2467 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2468 "Reference SpaceToDepth: output type not supported");
2469
2470 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2471 "Reference SpaceToDepth: input and output types are mismatched");
2472
2473 return supported;
2474}
2475
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002476bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002477 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
2478 const ViewsDescriptor& descriptor,
2479 Optional<std::string&> reasonIfUnsupported) const
2480{
Jan Eilers8eb25602020-03-09 12:13:48 +00002481 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002482 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002483 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002484 {
2485 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002486 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002487 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002488 DataType::QAsymmU8,
2489 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002490 };
2491
2492 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2493 "Reference splitter: output type not supported");
Derek Lambertieac4adb2020-08-25 13:05:59 +01002494 for (const TensorInfo& output : outputs)
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002495 {
2496 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2497 "Reference splitter: input type not supported");
2498
2499 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2500 "Reference splitter: input and output types mismatched.");
2501 }
2502
2503 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002504}
2505
Matthew Jackson81e601c2019-07-11 12:07:09 +01002506bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
2507 const TensorInfo& output,
2508 const StackDescriptor& descriptor,
2509 Optional<std::string&> reasonIfUnsupported) const
2510{
Jan Eilers8eb25602020-03-09 12:13:48 +00002511 IgnoreUnused(descriptor);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002512
2513 bool supported = true;
Sadik Armagan529195f2022-01-14 12:56:35 +00002514 std::array<DataType,7> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01002515 {
2516 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01002517 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002518 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002519 DataType::QAsymmU8,
Sadik Armagan529195f2022-01-14 12:56:35 +00002520 DataType::QSymmS16,
2521 DataType::Signed32
Matthew Jackson81e601c2019-07-11 12:07:09 +01002522 };
2523
2524 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2525 "Reference stack: output type not supported");
2526 for (const TensorInfo* input : inputs)
2527 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01002528 ARMNN_ASSERT(input != nullptr);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002529 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
2530 "Reference stack: input type not supported");
2531
2532 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
2533 "Reference stack: input and output types mismatched.");
2534 }
2535
2536 return supported;
2537}
2538
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002539bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
2540 const TensorInfo& output,
2541 const StridedSliceDescriptor& descriptor,
2542 Optional<std::string&> reasonIfUnsupported) const
2543{
Jan Eilers8eb25602020-03-09 12:13:48 +00002544 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002545 bool supported = true;
2546
Sadik Armagan303980c2020-04-17 12:45:14 +01002547 std::array<DataType,5> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002548 {
2549 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002550 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002551 DataType::QAsymmU8,
2552 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002553 };
2554
2555 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2556 "Reference StridedSlice: input type not supported");
2557
2558 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2559 "Reference StridedSlice: output type not supported");
2560
2561 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2562 "Reference StridedSlice: input and output types are mismatched");
2563
2564 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002565}
2566
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002567bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
2568 const TensorInfo& input1,
2569 const TensorInfo& output,
2570 Optional<std::string&> reasonIfUnsupported) const
2571{
Sadik Armagan2999a022019-04-09 14:20:12 +01002572 bool supported = true;
2573
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002574 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01002575 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002576 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002577 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002578 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002579 DataType::QSymmS16,
2580 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002581 };
2582
2583 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2584 "Reference subtraction: input 0 is not a supported type.");
2585
2586 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2587 "Reference subtraction: input 1 is not a supported type.");
2588
2589 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2590 "Reference subtraction: output is not a supported type.");
2591
2592 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2593 "Reference subtraction: input 0 and Input 1 types are mismatched");
2594
2595 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2596 "Reference subtraction: input and output types are mismatched");
2597
2598 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2599 "Reference subtraction: shapes are not suitable for implicit broadcast.");
2600
2601 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002602}
2603
Matteo Martincighab9e5252019-06-13 17:27:46 +01002604bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
2605 const TensorInfo& alpha,
2606 const TensorInfo& output,
2607 Optional<std::string&> reasonIfUnsupported) const
2608{
2609 bool supported = true;
2610
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002611 std::array<DataType, 6> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01002612 {
2613 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002614 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002615 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002616 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002617 DataType::QSymmS16
Matteo Martincighab9e5252019-06-13 17:27:46 +01002618 };
2619
2620 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2621 "PReLU: input is not a supported type.");
2622
2623 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
2624 "PReLU: alpha is not a supported type.");
2625
2626 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2627 "PReLU: output is not a supported type.");
2628
2629 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
2630 "PReLU: input, alpha and output types are mismatched");
2631
2632 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
2633 "PReLU: shapes are not suitable for implicit broadcast");
2634
2635 return supported;
2636}
2637
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002638bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
2639 const TensorInfo& output,
2640 const TransposeConvolution2dDescriptor& descriptor,
2641 const TensorInfo& weights,
2642 const Optional<TensorInfo>& biases,
2643 Optional<std::string&> reasonIfUnsupported) const
2644{
Jan Eilers8eb25602020-03-09 12:13:48 +00002645 IgnoreUnused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002646 bool supported = true;
2647
Sadik Armagan303980c2020-04-17 12:45:14 +01002648 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002649 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002650 DataType::Float32,
2651 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002652 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002653 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002654 DataType::QSymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002655 DataType::QSymmS16
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002656 };
2657
2658 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2659 "Reference TransposeConvolution2d: input is not a supported type.");
2660
2661 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2662 "Reference TransposeConvolution2d: output is not a supported type.");
2663
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002664 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2665 "Reference TransposeConvolution2d: input and output types mismatched.");
2666
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002667
2668 const DataType inputType = input.GetDataType();
Sadik Armagan303980c2020-04-17 12:45:14 +01002669 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002670 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01002671 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002672 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002673 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002674 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01002675 DataType::QSymmS8
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002676 };
2677
2678 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
2679 "Reference TransposeConvolution2d: weights type not supported for "
2680 "quantized input.");
2681 }
2682 else
2683 {
2684 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
2685 "Reference TransposeConvolution2d: weights is not a supported type.");
2686
2687 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
2688 "Reference TransposeConvolution2d: input and weights types mismatched.");
2689 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002690
2691 if (biases.has_value())
2692 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002693 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01002694 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002695 DataType::Float32,
2696 DataType::Float16,
2697 DataType::Signed32
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002698 };
2699 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
2700 "Reference TransposeConvolution2d: biases is not a supported type.");
2701 }
2702
2703 return supported;
2704}
2705
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002706bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
2707 const TensorInfo& output,
2708 const TransposeDescriptor& descriptor,
2709 Optional<std::string&> reasonIfUnsupported) const
2710{
Jan Eilers8eb25602020-03-09 12:13:48 +00002711 IgnoreUnused(descriptor);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002712 bool supported = true;
2713
2714 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002715 std::array<DataType, 6> supportedTypes =
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002716 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002717 DataType::BFloat16,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002718 DataType::Float32,
2719 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002720 DataType::QAsymmS8,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002721 DataType::QAsymmU8,
2722 DataType::QSymmS16
2723 };
2724
2725 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2726 "Reference transpose: input is not a supported type.");
2727
2728 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2729 "Reference transpose: output is not a supported type.");
2730
2731 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2732 "Reference transpose: input and output types are mismatched.");
2733
2734 return supported;
2735}
2736
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002737bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
2738 const TensorInfo& input,
2739 const TensorInfo& outputStateIn,
2740 const TensorInfo& cellStateIn,
Mike Kelly12994962022-04-21 11:57:09 +01002741 const TensorInfo& outputStateOut,
2742 const TensorInfo& cellStateOut,
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002743 const TensorInfo& output,
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002744 const UnidirectionalSequenceLstmDescriptor& descriptor,
2745 const LstmInputParamsInfo& paramsInfo,
2746 Optional<std::string&> reasonIfUnsupported) const
2747{
2748 IgnoreUnused(descriptor);
2749 IgnoreUnused(paramsInfo);
2750 IgnoreUnused(outputStateIn);
2751 IgnoreUnused(cellStateIn);
Mike Kelly12994962022-04-21 11:57:09 +01002752 IgnoreUnused(outputStateOut);
2753 IgnoreUnused(cellStateOut);
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002754 bool supported = true;
2755
Mike Kelly12994962022-04-21 11:57:09 +01002756 std::array<DataType, 2> supportedTypes =
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002757 {
Mike Kelly12994962022-04-21 11:57:09 +01002758 DataType::Float32,
2759 DataType::QAsymmS8
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002760 };
2761
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002762 std::array<DataType, 2> supportedWeightTypes =
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002763 {
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002764 DataType::Float32,
2765 DataType::QAsymmS8
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002766 };
2767
Mike Kelly12994962022-04-21 11:57:09 +01002768 std::array<DataType, 3> supportedBiasTypes =
2769 {
2770 DataType::Float32,
2771 DataType::QAsymmS8,
2772 DataType::Signed32
2773 };
2774
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002775 // check inputs and outputs
2776 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2777 "Reference UnidirectionalSequenceLstm: input is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002778 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2779 "Reference UnidirectionalSequenceLstm: output is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002780
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002781 // check layer parameters
2782 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes),
2783 reasonIfUnsupported,
2784 "Reference UnidirectionalSequenceLstm: InputToForgetWeights "
2785 "is not a supported type.");
2786 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToCellWeights(), supportedWeightTypes),
2787 reasonIfUnsupported,
2788 "Reference UnidirectionalSequenceLstm: InputToCellWeights is not a supported type.");
2789 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToOutputWeights(), supportedWeightTypes),
2790 reasonIfUnsupported,
2791 "Reference UnidirectionalSequenceLstm: InputToOutputWeights "
2792 "is not a supported type.");
2793 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToForgetWeights(), supportedWeightTypes),
2794 reasonIfUnsupported,
2795 "Reference UnidirectionalSequenceLstm: RecurrentToForgetWeights "
2796 "is not a supported type.");
2797 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToCellWeights(), supportedWeightTypes),
2798 reasonIfUnsupported,
2799 "Reference UnidirectionalSequenceLstm: RecurrentToCellWeights "
2800 "is not a supported type.");
2801 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToOutputWeights(), supportedWeightTypes),
2802 reasonIfUnsupported,
2803 "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights "
2804 "is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002805
2806 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetGateBias(), supportedBiasTypes), reasonIfUnsupported,
2807 "Reference UnidirectionalSequenceLstm: ForgetGateBias is not a supported type.");
2808 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellBias(), supportedBiasTypes), reasonIfUnsupported,
2809 "Reference UnidirectionalSequenceLstm: CellBias is not a supported type.");
2810 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputGateBias(), supportedBiasTypes), reasonIfUnsupported,
2811 "Reference UnidirectionalSequenceLstm: OutputGateBias is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002812 if (!descriptor.m_CifgEnabled)
2813 {
2814 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes),
2815 reasonIfUnsupported,
2816 "Reference UnidirectionalSequenceLstm: InputToInputWeights "
2817 "is not a supported type.");
2818 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToInputWeights(), supportedWeightTypes),
2819 reasonIfUnsupported,
2820 "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights "
2821 "is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002822 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputGateBias(), supportedBiasTypes), reasonIfUnsupported,
2823 "Reference UnidirectionalSequenceLstm: InputGateBias is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002824 if (descriptor.m_PeepholeEnabled)
2825 {
2826 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes),
2827 reasonIfUnsupported,
2828 "Reference UnidirectionalSequenceLstm: CellToInputWeights "
2829 "is not a supported type.");
2830 }
2831 }
2832 if (descriptor.m_PeepholeEnabled)
2833 {
2834 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToForgetWeights(), supportedWeightTypes),
2835 reasonIfUnsupported,
2836 "Reference UnidirectionalSequenceLstm: CellToForgetWeights "
2837 "is not a supported type.");
2838 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToOutputWeights(), supportedWeightTypes),
2839 reasonIfUnsupported,
2840 "Reference UnidirectionalSequenceLstm: CellToOutputWeights "
2841 "is not a supported type.");
2842 }
2843 if (descriptor.m_ProjectionEnabled)
2844 {
2845 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetProjectionWeights(), supportedWeightTypes),
2846 reasonIfUnsupported,
2847 "Reference UnidirectionalSequenceLstm: ProjectionWeights "
2848 "is not a supported type.");
2849 if (paramsInfo.m_ProjectionBias != nullptr)
2850 {
2851 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
2852 "Reference UnidirectionalSequenceLstm: input and ProjectionBias types "
2853 "are mismatched");
2854 }
2855 }
2856 if (descriptor.m_LayerNormEnabled)
2857 {
2858 if (!descriptor.m_CifgEnabled)
2859 {
2860 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputLayerNormWeights(), supportedWeightTypes),
2861 reasonIfUnsupported,
2862 "Reference UnidirectionalSequenceLstm: InputLayerNormWeights "
2863 "is not a supported type.");
2864 }
2865 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetLayerNormWeights(), supportedWeightTypes),
2866 reasonIfUnsupported,
2867 "Reference UnidirectionalSequenceLstm: ForgetLayerNormWeights "
2868 "is not a supported type.");
2869 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellLayerNormWeights(), supportedWeightTypes),
2870 reasonIfUnsupported,
2871 "Reference UnidirectionalSequenceLstm: CellLayerNormWeights "
2872 "is not a supported type.");
2873 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputLayerNormWeights(), supportedWeightTypes),
2874 reasonIfUnsupported,
2875 "Reference UnidirectionalSequenceLstm: OutputLayerNormWeights "
2876 "is not a supported type.");
2877 }
2878
2879 return supported;
2880}
2881
arovir011c7c81b2018-10-08 11:34:28 +01002882} // namespace armnn