blob: 40909019ba04b238c692a037b6f5ef37dd6c3911 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
Teresa Charlin52664732020-06-29 16:27:03 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01007
Keith Davis0c2eeac2020-02-11 16:51:50 +00008#include <armnn/TypesUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000010#include <armnn/utility/IgnoreUnused.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010011#include <armnn/utility/NumericCast.hpp>
Cathal Corbett34b429c2021-12-24 12:24:40 +000012#include <armnn/utility/PolymorphicDowncast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
Matteo Martincighe011d202019-11-28 11:35:47 +000014#include <LayerSupportCommon.hpp>
Derek Lambertif674aa02019-08-01 15:56:25 +010015#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000016
Derek Lamberti50db4e82019-03-13 14:16:15 +000017#include <vector>
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <array>
19
telsoa014fcda012018-03-09 14:13:49 +000020namespace armnn
21{
22
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010023namespace
24{
25
26template<typename Float32Func, typename Uint8Func, typename ... Params>
27bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
28 DataType dataType,
29 Float32Func floatFuncPtr,
30 Uint8Func uint8FuncPtr,
31 Params&&... params)
32{
33 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
34 dataType,
35 &FalseFunc<Params...>,
36 floatFuncPtr,
37 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000038 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000039 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010040 std::forward<Params>(params)...);
41}
42
43} // anonymous namespace
44
James Conroy4d1ff582019-06-10 17:06:39 +010045namespace
46{
47
48std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
49 unsigned int actual,
50 std::string& layerStr,
51 std::string& tensorName)
52{
53 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
54 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
55
56 return errorMsg;
57}
58
59} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000060
Cathal Corbett34b429c2021-12-24 12:24:40 +000061bool RefLayerSupport::IsLayerSupported(const LayerType& type,
62 const std::vector<TensorInfo>& infos,
63 const BaseDescriptor& descriptor,
64 const Optional<LstmInputParamsInfo>& lstmParamsInfo,
65 const Optional<QuantizedLstmInputParamsInfo>& quantizedLstmInputParamsInfo,
66 Optional<std::string&> reasonIfUnsupported) const
67{
68 switch (type)
69 {
70 case LayerType::Activation:
71 return IsActivationSupported(infos[0],
72 infos[1],
73 *(PolymorphicDowncast<const ActivationDescriptor*>(&descriptor)),
74 reasonIfUnsupported);
75 case LayerType::Addition:
76 return IsAdditionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
77 case LayerType::ArgMinMax:
78 return IsArgMinMaxSupported(infos[0],
79 infos[1],
80 *(PolymorphicDowncast<const ArgMinMaxDescriptor*>(&descriptor)),
81 reasonIfUnsupported);
Samuel Yap6b478092022-07-06 15:36:03 +010082 case LayerType::BatchMatMul:
83 return IsBatchMatMulSupported(infos[0],
84 infos[1],
85 infos[2],
86 *(PolymorphicDowncast<const BatchMatMulDescriptor*>(&descriptor)),
87 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +000088 case LayerType::BatchNormalization:
89 return IsBatchNormalizationSupported(infos[0],
90 infos[1],
91 infos[2],
92 infos[3],
93 infos[4],
94 infos[5],
95 *(PolymorphicDowncast<const BatchNormalizationDescriptor*>
96 (&descriptor)),
97 reasonIfUnsupported);
98 case LayerType::BatchToSpaceNd:
99 return IsBatchToSpaceNdSupported(infos[0],
100 infos[1],
101 *(PolymorphicDowncast<const BatchToSpaceNdDescriptor*>(&descriptor)),
102 reasonIfUnsupported);
103 case LayerType::Comparison:
104 return IsComparisonSupported(infos[0],
105 infos[1],
106 infos[2],
107 *(PolymorphicDowncast<const ComparisonDescriptor*>(&descriptor)),
108 reasonIfUnsupported);
109 case LayerType::Concat:
110 {
111 std::vector<const TensorInfo*> inputInfos;
112 for (uint32_t i = 0; i < (infos.size() - 1); i++)
113 {
114 inputInfos.push_back(&infos[i]);
115 }
116 return IsConcatSupported(inputInfos,
117 infos[infos.size() - 1],
118 *(PolymorphicDowncast<const OriginsDescriptor*>(&descriptor)),
119 reasonIfUnsupported);
120 }
121 case LayerType::Constant:
122 return IsConstantSupported(infos[0], reasonIfUnsupported);
123 case LayerType::ConvertBf16ToFp32:
124 return IsConvertBf16ToFp32Supported(infos[0], infos[1], reasonIfUnsupported);
125 case LayerType::ConvertFp16ToFp32:
126 return IsConvertFp16ToFp32Supported(infos[0], infos[1], reasonIfUnsupported);
127 case LayerType::ConvertFp32ToBf16:
128 return IsConvertFp32ToBf16Supported(infos[0], infos[1], reasonIfUnsupported);
129 case LayerType::ConvertFp32ToFp16:
130 return IsConvertFp32ToFp16Supported(infos[0], infos[1], reasonIfUnsupported);
131 case LayerType::Convolution2d:
132 {
133 if (infos.size() != 4)
134 {
135 throw InvalidArgumentException("Invalid number of Convolution2d TensorInfos. "
136 "TensorInfos should be of format: {input, output, weights, biases}.");
137 }
138
139 auto desc = *(PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor));
140 if (infos[3] == TensorInfo())
141 {
142 return IsConvolution2dSupported(infos[0],
143 infos[1],
144 desc,
145 infos[2],
146 EmptyOptional(),
147 reasonIfUnsupported);
148 }
149 else
150 {
151 return IsConvolution2dSupported(infos[0],
152 infos[1],
153 desc,
154 infos[2],
155 infos[3],
156 reasonIfUnsupported);
157 }
158 }
159 case LayerType::DepthToSpace:
160 return IsDepthToSpaceSupported(infos[0],
161 infos[1],
162 *(PolymorphicDowncast<const DepthToSpaceDescriptor*>(&descriptor)),
163 reasonIfUnsupported);
164 case LayerType::DepthwiseConvolution2d:
165 {
166 if (infos.size() != 4)
167 {
168 throw InvalidArgumentException("Invalid number of DepthwiseConvolution2d TensorInfos. "
169 "TensorInfos should be of format: {input, output, weights, biases}.");
170 }
171
172 auto desc = *(PolymorphicDowncast<const DepthwiseConvolution2dDescriptor*>(&descriptor));
173 if (infos[3] == TensorInfo())
174 {
175 return IsDepthwiseConvolutionSupported(infos[0],
176 infos[1],
177 desc,
178 infos[2],
179 EmptyOptional(),
180 reasonIfUnsupported);
181 }
182 else
183 {
184 return IsDepthwiseConvolutionSupported(infos[0],
185 infos[1],
186 desc,
187 infos[2],
188 infos[3],
189 reasonIfUnsupported);
190 }
191 }
192 case LayerType::Dequantize:
193 return IsDequantizeSupported(infos[0], infos[1], reasonIfUnsupported);
194 case LayerType::Division:
195 return IsDivisionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
196 case LayerType::ElementwiseUnary:
197 return IsElementwiseUnarySupported(infos[0],
198 infos[1],
199 *(PolymorphicDowncast<const ElementwiseUnaryDescriptor*>(&descriptor)),
200 reasonIfUnsupported);
201 case LayerType::Fill:
202 return IsFillSupported(infos[0],
203 infos[1],
204 *(PolymorphicDowncast<const FillDescriptor*>(&descriptor)),
205 reasonIfUnsupported);
206 case LayerType::Floor:
207 return IsFloorSupported(infos[0], infos[1], reasonIfUnsupported);
208 case LayerType::FullyConnected:
209 return IsFullyConnectedSupported(infos[0],
210 infos[1],
211 infos[2],
212 infos[3],
213 *(PolymorphicDowncast<const FullyConnectedDescriptor*>(&descriptor)),
214 reasonIfUnsupported);
215 case LayerType::Gather:
216 return IsGatherSupported(infos[0],
217 infos[1],
218 infos[2],
219 *(PolymorphicDowncast<const GatherDescriptor*>(&descriptor)),
220 reasonIfUnsupported);
Teresa Charlinb2d3ec52022-04-12 22:07:09 +0100221 case LayerType::GatherNd:
222 return IsGatherNdSupported(infos[0],
223 infos[1],
224 infos[2],
225 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000226 case LayerType::Input:
227 return IsInputSupported(infos[0], reasonIfUnsupported);
228 case LayerType::InstanceNormalization:
229 return IsInstanceNormalizationSupported(infos[0],
230 infos[1],
231 *(PolymorphicDowncast<const InstanceNormalizationDescriptor*>
232 (&descriptor)),
233 reasonIfUnsupported);
234 case LayerType::L2Normalization:
235 return IsL2NormalizationSupported(infos[0],
236 infos[1],
237 *(PolymorphicDowncast<const L2NormalizationDescriptor*>(&descriptor)),
238 reasonIfUnsupported);
239 case LayerType::LogicalBinary:
240 return IsLogicalBinarySupported(infos[0],
241 infos[1],
242 infos[2],
243 *(PolymorphicDowncast<const LogicalBinaryDescriptor*>(&descriptor)),
244 reasonIfUnsupported);
245 case LayerType::LogSoftmax:
246 return IsLogSoftmaxSupported(infos[0],
247 infos[1],
248 *(PolymorphicDowncast<const LogSoftmaxDescriptor*>(&descriptor)),
249 reasonIfUnsupported);
250 case LayerType::Lstm:
251 return IsLstmSupported(infos[0],
252 infos[1],
253 infos[2],
254 infos[3],
255 infos[4],
256 infos[5],
257 infos[6],
258 *(PolymorphicDowncast<const LstmDescriptor*>(&descriptor)),
259 lstmParamsInfo.value(),
260 reasonIfUnsupported);
261 case LayerType::QLstm:
262 return IsQLstmSupported(infos[0],
263 infos[1],
264 infos[2],
265 infos[3],
266 infos[4],
267 infos[5],
268 *(PolymorphicDowncast<const QLstmDescriptor*>(&descriptor)),
269 lstmParamsInfo.value(),
270 reasonIfUnsupported);
271 case LayerType::Maximum:
272 return IsMaximumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
273 case LayerType::Mean:
274 return IsMeanSupported(infos[0],
275 infos[1],
276 *(PolymorphicDowncast<const MeanDescriptor*>(&descriptor)),
277 reasonIfUnsupported);
278 case LayerType::Minimum:
279 return IsMinimumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
280 case LayerType::Multiplication:
281 return IsMultiplicationSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
282 case LayerType::Normalization:
283 return IsNormalizationSupported(infos[0],
284 infos[1],
285 *(PolymorphicDowncast<const NormalizationDescriptor*>(&descriptor)),
286 reasonIfUnsupported);
287 case LayerType::Output:
288 return IsOutputSupported(infos[0], reasonIfUnsupported);
289 case LayerType::Pad:
290 return IsPadSupported(infos[0],
291 infos[1],
292 *(PolymorphicDowncast<const PadDescriptor*>(&descriptor)),
293 reasonIfUnsupported);
294 case LayerType::Permute:
295 return IsPermuteSupported(infos[0],
296 infos[1],
297 *(PolymorphicDowncast<const PermuteDescriptor*>(&descriptor)),
298 reasonIfUnsupported);
299 case LayerType::Pooling2d:
300 return IsPooling2dSupported(infos[0],
301 infos[1],
302 *(PolymorphicDowncast<const Pooling2dDescriptor*>(&descriptor)),
303 reasonIfUnsupported);
304 case LayerType::Prelu:
305 return IsPreluSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
306 case LayerType::Quantize:
307 return IsQuantizeSupported(infos[0], infos[1], reasonIfUnsupported);
308 case LayerType::Reshape:
309 return IsReshapeSupported(infos[0],
310 infos[1],
311 *(PolymorphicDowncast<const ReshapeDescriptor*>(&descriptor)),
312 reasonIfUnsupported);
313 case LayerType::Resize:
314 return IsResizeSupported(infos[0],
315 infos[1],
316 *(PolymorphicDowncast<const ResizeDescriptor*>(&descriptor)),
317 reasonIfUnsupported);
318 case LayerType::Reduce:
319 return IsReduceSupported(infos[0],
320 infos[1],
321 *(PolymorphicDowncast<const ReduceDescriptor*>(&descriptor)),
322 reasonIfUnsupported);
323 case LayerType::Slice:
324 return IsSliceSupported(infos[0],
325 infos[1],
326 *(PolymorphicDowncast<const SliceDescriptor*>(&descriptor)),
327 reasonIfUnsupported);
328 case LayerType::Softmax:
329 return IsSoftmaxSupported(infos[0],
330 infos[1],
331 *(PolymorphicDowncast<const SoftmaxDescriptor*>(&descriptor)),
332 reasonIfUnsupported);
333 case LayerType::SpaceToBatchNd:
334 return IsSpaceToBatchNdSupported(infos[0],
335 infos[1],
336 *(PolymorphicDowncast<const SpaceToBatchNdDescriptor*>(&descriptor)),
337 reasonIfUnsupported);
338 case LayerType::SpaceToDepth:
339 return IsSpaceToDepthSupported(infos[0],
340 infos[1],
341 *(PolymorphicDowncast<const SpaceToDepthDescriptor*>(&descriptor)),
342 reasonIfUnsupported);
343 case LayerType::Splitter:
344 {
345 std::vector<TensorInfo> outputInfos;
346 for (uint32_t i = 1; i < infos.size(); i++)
347 {
348 outputInfos.push_back(infos[i]);
349 }
350 return IsSplitterSupported(infos[0],
351 {outputInfos.begin(), outputInfos.end()},
352 *(PolymorphicDowncast<const ViewsDescriptor*>(&descriptor)),
353 reasonIfUnsupported);
354 }
355 case LayerType::Stack:
356 {
357 std::vector<const TensorInfo*> inputInfos;
358 for (uint32_t i = 0; i < infos.size() - 1; i++)
359 {
360 inputInfos.push_back(&infos[i]);
361 }
362 return IsStackSupported(inputInfos,
363 infos[infos.size() - 1],
364 *(PolymorphicDowncast<const StackDescriptor*>(&descriptor)),
365 reasonIfUnsupported);
366 }
367 case LayerType::StridedSlice:
368 return IsStridedSliceSupported(infos[0],
369 infos[1],
370 *(PolymorphicDowncast<const StridedSliceDescriptor*>(&descriptor)),
371 reasonIfUnsupported);
372 case LayerType::Subtraction:
373 return IsSubtractionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
374 case LayerType::Transpose:
375 return IsTransposeSupported(infos[0],
376 infos[1],
377 *(PolymorphicDowncast<const TransposeDescriptor*>(&descriptor)),
378 reasonIfUnsupported);
379 case LayerType::TransposeConvolution2d:
380 {
381 if (infos.size() != 4)
382 {
383 throw InvalidArgumentException("Invalid number of TransposeConvolution2d TensorInfos. "
384 "TensorInfos should be of format: {input, output, weights, biases}.");
385 }
386
387 auto desc = *(PolymorphicDowncast<const TransposeConvolution2dDescriptor*>(&descriptor));
388 if (infos[3] == TensorInfo())
389 {
390 return IsTransposeConvolution2dSupported(infos[0],
391 infos[1],
392 desc,
393 infos[2],
394 EmptyOptional(),
395 reasonIfUnsupported);
396 }
397 else
398 {
399 return IsTransposeConvolution2dSupported(infos[0],
400 infos[1],
401 desc,
402 infos[2],
403 infos[3],
404 reasonIfUnsupported);
405 }
406 }
407 case LayerType::Cast:
408 return IsCastSupported(infos[0], infos[1], reasonIfUnsupported);
409 case LayerType::ChannelShuffle:
410 return IsChannelShuffleSupported(infos[0],
411 infos[1],
412 *(PolymorphicDowncast<const ChannelShuffleDescriptor*>(&descriptor)),
413 reasonIfUnsupported);
414 case LayerType::Convolution3d:
415 {
416 if (infos.size() != 4)
417 {
418 throw InvalidArgumentException("Invalid number of Convolution3d TensorInfos. "
419 "TensorInfos should be of format: {input, output, weights, biases}.");
420 }
421
422 auto desc = *(PolymorphicDowncast<const Convolution3dDescriptor*>(&descriptor));
423 if (infos[3] == TensorInfo())
424 {
425 return IsConvolution3dSupported(infos[0],
426 infos[1],
427 desc,
428 infos[2],
429 EmptyOptional(),
430 reasonIfUnsupported);
431 }
432 else
433 {
434 return IsConvolution3dSupported(infos[0],
435 infos[1],
436 desc,
437 infos[2],
438 infos[3],
439 reasonIfUnsupported);
440 }
441 }
442 case LayerType::Debug:
443 return IsDebugSupported(infos[0], infos[1], reasonIfUnsupported);
444 case LayerType::DetectionPostProcess:
445 return IsDetectionPostProcessSupported(infos[0],
446 infos[1],
447 infos[2],
448 infos[3],
449 infos[4],
450 infos[5],
451 infos[6],
452 *(PolymorphicDowncast<const DetectionPostProcessDescriptor*>
453 (&descriptor)),
454 reasonIfUnsupported);
455 case LayerType::FakeQuantization:
456 return IsFakeQuantizationSupported(infos[0],
457 *(PolymorphicDowncast<const FakeQuantizationDescriptor*>(&descriptor)),
458 reasonIfUnsupported);
459 case LayerType::MemCopy:
460 return IsMemCopySupported(infos[0], infos[1], reasonIfUnsupported);
461 case LayerType::Rank:
462 return IsRankSupported(infos[0], infos[1], reasonIfUnsupported);
463 case LayerType::Shape:
464 return IsShapeSupported(infos[0], infos[1], reasonIfUnsupported);
465 case LayerType::UnidirectionalSequenceLstm:
466 {
467 if (infos.size() != 6)
468 {
469 throw InvalidArgumentException("Invalid number of UnidirectionalSequenceLstm TensorInfos. TensorInfos "
470 "should be of format: {input, outputStateIn, cellStateIn, "
471 "hiddenStateOutputVal, cellStateOutputVal, output}");
472 }
473 auto desc = *(PolymorphicDowncast<const UnidirectionalSequenceLstmDescriptor*>(&descriptor));
Mike Kelly12994962022-04-21 11:57:09 +0100474 return IsUnidirectionalSequenceLstmSupported(infos[0],
475 infos[1],
476 infos[2],
477 infos[3],
478 infos[4],
479 infos[5],
480 desc,
481 lstmParamsInfo.value(),
482 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000483 }
484 case LayerType::Pooling3d:
485 return IsPooling3dSupported(infos[0],
486 infos[1],
487 *(PolymorphicDowncast<const Pooling3dDescriptor*>(&descriptor)),
488 reasonIfUnsupported);
489 case LayerType::Map:
490 return true;
491 case LayerType::Unmap:
492 return true;
493 case LayerType::MemImport:
494 return LayerSupportBase::IsMemImportSupported(infos[0], infos[1], reasonIfUnsupported);
495 case LayerType::Merge:
496 return LayerSupportBase::IsMergeSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
497 case LayerType::QuantizedLstm:
498 return LayerSupportBase::IsQuantizedLstmSupported(infos[0],
499 infos[1],
500 infos[2],
501 infos[3],
502 infos[4],
503 quantizedLstmInputParamsInfo.value(),
504 reasonIfUnsupported);
505 default:
506 // layers not supported in neon by default:
507 // precompiled, standin, switch
508 return false;
509 }
510}
511
arovir011c7c81b2018-10-08 11:34:28 +0100512bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
513 const TensorInfo& output,
514 const ActivationDescriptor& descriptor,
515 Optional<std::string&> reasonIfUnsupported) const
516{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000517 bool supported = true;
518
519 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +0000520 std::array<DataType,6> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000521 DataType::BFloat16,
Derek Lamberti50db4e82019-03-13 14:16:15 +0000522 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100523 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000524 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000525 DataType::QAsymmU8,
526 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000527 };
528
529 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
530 "Reference activation: input type not supported.");
531
532 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
533 "Reference activation: output type not supported.");
534
535 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
536 "Reference activation: input and output types mismatched.");
537
538 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
539 "Reference activation: input and output shapes are of different rank.");
540
541
542 struct ActivationFunctionSupported : public Rule
543 {
544 ActivationFunctionSupported(const ActivationDescriptor& desc)
545 {
546 switch(desc.m_Function)
547 {
548 case ActivationFunction::Abs:
549 case ActivationFunction::BoundedReLu:
David Monahan3b3c3812020-02-25 09:03:29 +0000550 case ActivationFunction::Elu:
Colm Donelan03fbeaf2020-02-26 15:39:23 +0000551 case ActivationFunction::HardSwish:
Derek Lamberti50db4e82019-03-13 14:16:15 +0000552 case ActivationFunction::LeakyReLu:
553 case ActivationFunction::Linear:
554 case ActivationFunction::ReLu:
555 case ActivationFunction::Sigmoid:
556 case ActivationFunction::SoftReLu:
557 case ActivationFunction::Sqrt:
558 case ActivationFunction::Square:
559 case ActivationFunction::TanH:
560 {
561 m_Res = true;
562 break;
563 }
564 default:
565 {
566 m_Res = false;
567 break;
568 }
569 }
570 }
571 };
572
573 // Function is supported
574 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
575 "Reference activation: function not supported.");
576
577 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100578}
579
580bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
581 const TensorInfo& input1,
582 const TensorInfo& output,
583 Optional<std::string&> reasonIfUnsupported) const
584{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000585 bool supported = true;
586
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100587 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000588 DataType::BFloat16,
Derek Lamberti50db4e82019-03-13 14:16:15 +0000589 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100590 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000591 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000592 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100593 DataType::QSymmS16,
594 DataType::Signed32
Derek Lamberti50db4e82019-03-13 14:16:15 +0000595 };
596
597 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
598 "Reference addition: input 0 is not a supported type.");
599
600 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
601 "Reference addition: input 1 is not a supported type.");
602
603 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
604 "Reference addition: output is not a supported type.");
605
606 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
607 "Reference addition: input 0 and Input 1 types are mismatched");
608
609 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
610 "Reference addition: input and output types are mismatched");
611
612 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
613 "Reference addition: shapes are not suitable for implicit broadcast.");
614
615 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100616}
617
Nikhil Raj68c2c902019-09-19 11:21:11 +0100618bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
619 const armnn::ArgMinMaxDescriptor &descriptor,
620 armnn::Optional<std::string &> reasonIfUnsupported) const
621{
Jan Eilers8eb25602020-03-09 12:13:48 +0000622 IgnoreUnused(descriptor);
Nikhil Raj68c2c902019-09-19 11:21:11 +0100623
Mike Kelly1f140f72021-04-06 12:25:55 +0100624 std::array<DataType, 8> supportedInputTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100625 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000626 DataType::BFloat16,
Teresa Charline300b362020-05-25 10:01:03 +0100627 DataType::Float16,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100628 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100629 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000630 DataType::QAsymmU8,
631 DataType::QSymmS16,
Mike Kelly1f140f72021-04-06 12:25:55 +0100632 DataType::Signed32,
633 DataType::Signed64
634 };
635
636 std::array<DataType,2> supportedOutputTypes = {
637 DataType::Signed32,
638 DataType::Signed64
Nikhil Raj68c2c902019-09-19 11:21:11 +0100639 };
640
641 bool supported = true;
642
Mike Kelly1f140f72021-04-06 12:25:55 +0100643 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100644 "Reference ArgMinMax: input is not a supported type.");
Mike Kelly1f140f72021-04-06 12:25:55 +0100645 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100646 "Reference ArgMinMax: output type not supported");
647
648 return supported;
649}
650
Samuel Yap6b478092022-07-06 15:36:03 +0100651bool RefLayerSupport::IsBatchMatMulSupported(const TensorInfo& inputX,
652 const TensorInfo& inputY,
653 const TensorInfo& output,
654 const BatchMatMulDescriptor& descriptor,
655 Optional<std::string &> reasonIfUnsupported) const
656{
657 IgnoreUnused(descriptor);
658
659 std::array<DataType, 6> supportedTypes =
660 {
661 DataType::BFloat16,
662 DataType::Float16,
663 DataType::Float32,
664 DataType::QAsymmS8,
665 DataType::QAsymmU8,
666 DataType::QSymmS16
667 };
668
669 bool supported = true;
670
671 supported &= CheckSupportRule(TypeAnyOf(inputX, supportedTypes), reasonIfUnsupported,
672 "Reference batch matrix multiplication: input X is not a supported type");
673
674 supported &= CheckSupportRule(TypeAnyOf(inputY, supportedTypes), reasonIfUnsupported,
675 "Reference batch matrix multiplication: input Y is not a supported type");
676
677 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
678 "Reference batch matrix multiplication: output is not a supported type");
679
680 supported &= CheckSupportRule(TypesAreEqual(inputX, inputY), reasonIfUnsupported,
681 "Reference batch matrix multiplication: input X and input Y types are mismatched");
682
683 supported &= CheckSupportRule(TypesAreEqual(inputX, output), reasonIfUnsupported,
684 "Reference batch matrix multiplication: inputs and output types are mismatched");
685
686 supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputX, 2),
687 reasonIfUnsupported,
688 "Reference batch matrix multiplication: input X is not of rank 2 or greater");
689
690 supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputY, 2),
691 reasonIfUnsupported,
692 "Reference batch matrix multiplication: input Y is not of rank 2 or greater");
693
694 return supported;
695}
696
arovir011c7c81b2018-10-08 11:34:28 +0100697bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
698 const TensorInfo& output,
699 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100700 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100701 const TensorInfo& beta,
702 const TensorInfo& gamma,
703 const BatchNormalizationDescriptor& descriptor,
704 Optional<std::string&> reasonIfUnsupported) const
705{
Jan Eilers8eb25602020-03-09 12:13:48 +0000706 IgnoreUnused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100707
Sadik Armagan303980c2020-04-17 12:45:14 +0100708 std::array<DataType, 6> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100709 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000710 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100711 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100712 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100713 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000714 DataType::QAsymmU8,
715 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100716 };
717
718 bool supported = true;
719
720 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
721 "Reference batch normalization: input is not a supported type.");
722
723 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
724 "Reference batch normalization: output is not a supported type.");
725
726 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
727 "Reference batch normalization: input and output types are mismatched");
728
729 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
730 "Reference batch normalization: mean is not a supported type.");
731
732 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
733 "Reference batch normalization: variance is not a supported type.");
734
735 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
736 "Reference batch normalization: beta is not a supported type.");
737
738 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
739 "Reference batch normalization: gamma is not a supported type.");
740
741 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100742}
743
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000744bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
745 const TensorInfo& output,
746 const BatchToSpaceNdDescriptor& descriptor,
747 Optional<std::string&> reasonIfUnsupported) const
748{
Jan Eilers8eb25602020-03-09 12:13:48 +0000749 IgnoreUnused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100750
751 bool supported = true;
752
753 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
754 std::string inputTensorStr = "input";
755 std::string outputTensorStr = "output";
756
757 // Define supported types.
Sadik Armagan303980c2020-04-17 12:45:14 +0100758 std::array<DataType,6> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100759 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000760 DataType::BFloat16,
761 DataType::Float32,
762 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100763 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000764 DataType::QAsymmU8,
765 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100766 };
767
768 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
769 "Reference BatchToSpaceNd: input type not supported.");
770
771 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
772 "Reference BatchToSpaceNd: output type not supported.");
773
774 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
775 "Reference BatchToSpaceNd: input and output types mismatched.");
776
777 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
778 reasonIfUnsupported,
779 CreateIncorrectDimensionsErrorMsg(4,
780 output.GetNumDimensions(),
781 batchToSpaceNdLayerStr,
782 outputTensorStr).data());
783
784 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
785 reasonIfUnsupported,
786 CreateIncorrectDimensionsErrorMsg(4,
787 input.GetNumDimensions(),
788 batchToSpaceNdLayerStr,
789 inputTensorStr).data());
790
791 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000792}
793
mathad01b392e982021-04-07 12:07:30 +0100794bool RefLayerSupport::IsCastSupported(const TensorInfo& input,
795 const TensorInfo& output,
796 Optional<std::string&> reasonIfUnsupported) const
797{
798 std::array<DataType, 9> supportedInputTypes =
799 {
800 DataType::BFloat16,
801 DataType::Float32,
802 DataType::Float16,
803 DataType::QSymmS8,
804 DataType::QAsymmS8,
805 DataType::QAsymmU8,
806 DataType::QSymmS16,
807 DataType::Signed32
808 };
809
810 bool supported = true;
811 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
812 "Reference cast: input is not a supported type");
813
814
815 supported &= CheckSupportRule(TypeAnyOf(output, supportedInputTypes), reasonIfUnsupported,
816 "Reference cast: output is not a supported type");
817
818 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
819 "Reference cast: input and output shapes have different number of total elements");
820
821 return supported;
822}
823
Simon Obute51f67772021-09-03 15:50:13 +0100824bool RefLayerSupport::IsChannelShuffleSupported(const TensorInfo& input,
825 const TensorInfo& output,
826 const ChannelShuffleDescriptor& descriptor,
827 Optional<std::string&> reasonIfUnsupported) const
828{
829 IgnoreUnused(descriptor);
830 bool supported = true;
831
832 // Define supported output and inputs types.
833 std::array<DataType, 7> supportedTypes =
834 {
835 DataType::BFloat16,
836 DataType::Float32,
837 DataType::Float16,
838 DataType::QAsymmS8,
839 DataType::QAsymmU8,
840 DataType::QSymmS8,
841 DataType::QSymmS16
842 };
843
844 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
845 "Reference ChannelShuffle: input is not a supported type.");
846
847 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
848 "Reference ChannelShuffle: output is not a supported type.");
849
850 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
851 "Reference ChannelShuffle: input and output types are mismatched.");
852
853 return supported;
854}
855
856
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100857bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
858 const TensorInfo& input1,
859 const TensorInfo& output,
860 const ComparisonDescriptor& descriptor,
861 Optional<std::string&> reasonIfUnsupported) const
862{
Jan Eilers8eb25602020-03-09 12:13:48 +0000863 IgnoreUnused(descriptor);
Sadik Armagan303980c2020-04-17 12:45:14 +0100864 std::array<DataType, 8> supportedInputTypes =
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100865 {
Sadik Armaganb60dd242020-03-19 13:53:16 +0000866 DataType::Boolean,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000867 DataType::BFloat16,
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100868 DataType::Float32,
869 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100870 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000871 DataType::QAsymmU8,
Sadik Armaganb60dd242020-03-19 13:53:16 +0000872 DataType::QSymmS16,
873 DataType::Signed32
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100874 };
875
876 bool supported = true;
877 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
878 "Reference comparison: input 0 is not a supported type");
879
880 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
881 "Reference comparison: input 0 and Input 1 types are mismatched");
882
883 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
884 "Reference comparison: output is not of type Boolean");
885
886 return supported;
887}
888
Jim Flynn906f9462019-05-10 13:55:21 +0100889bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
890 const TensorInfo& output,
Cathal Corbett34b429c2021-12-24 12:24:40 +0000891 const OriginsDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100892 Optional<std::string&> reasonIfUnsupported) const
893{
Jan Eilers8eb25602020-03-09 12:13:48 +0000894 IgnoreUnused(descriptor);
Jim Flynne242f2d2019-05-22 14:24:13 +0100895
896 bool supported = true;
Teresa Charlin6abc7ee2022-02-22 17:32:27 +0000897 std::array<DataType,7> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100898 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000899 DataType::BFloat16,
900 DataType::Float32,
901 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000902 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100903 DataType::QAsymmU8,
Teresa Charlin6abc7ee2022-02-22 17:32:27 +0000904 DataType::QSymmS16,
905 DataType::Signed32
Jim Flynne242f2d2019-05-22 14:24:13 +0100906 };
907
908 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
909 "Reference concatenation: output type not supported");
910 for (const TensorInfo* input : inputs)
911 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100912 ARMNN_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100913 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
914 "Reference concatenation: input type not supported");
915
916 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
917 "Reference concatenation: input and output types mismatched.");
918 }
919
920 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100921}
922
arovir011c7c81b2018-10-08 11:34:28 +0100923bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
924 Optional<std::string&> reasonIfUnsupported) const
925{
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100926 std::array<DataType,8> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100927 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000928 DataType::BFloat16,
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100929 DataType::Float16,
Nina Drozd58ef2c62019-05-16 12:09:18 +0100930 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +0000931 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100932 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000933 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100934 DataType::QSymmS16,
935 DataType::Signed32
Nina Drozd58ef2c62019-05-16 12:09:18 +0100936 };
937
938 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
939 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100940}
941
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000942bool RefLayerSupport::IsConvertBf16ToFp32Supported(const TensorInfo& input,
943 const TensorInfo& output,
944 Optional<std::string&> reasonIfUnsupported) const
945{
946 bool supported = true;
947
948 supported &= CheckSupportRule(TypeIs(input, DataType::BFloat16), reasonIfUnsupported,
949 "Reference for ConvertBf16ToFp32 layer: input type not supported");
950
951 supported &= CheckSupportRule(TypeIs(output, DataType::Float32), reasonIfUnsupported,
952 "Reference for ConvertBf16ToFp32 layer: output type not supported");
953
954 return supported;
955}
956
arovir011c7c81b2018-10-08 11:34:28 +0100957bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
958 const TensorInfo& output,
959 Optional<std::string&> reasonIfUnsupported) const
960{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100961 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
962 input.GetDataType(),
963 &TrueFunc<>,
964 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000965 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000966 &FalseFuncI32<>,
967 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100968 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
969 output.GetDataType(),
970 &FalseOutputFuncF16<>,
971 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000972 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000973 &FalseFuncI32<>,
974 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100975}
976
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000977bool RefLayerSupport::IsConvertFp32ToBf16Supported(const TensorInfo& input,
978 const TensorInfo& output,
979 Optional<std::string&> reasonIfUnsupported) const
980{
981 bool supported = true;
982
983 supported &= CheckSupportRule(TypeIs(input, DataType::Float32), reasonIfUnsupported,
984 "Reference for ConvertFp32ToBf16 layer: input type not supported");
985
986 supported &= CheckSupportRule(TypeIs(output, DataType::BFloat16), reasonIfUnsupported,
987 "Reference for ConvertFp32ToBf16 layer: output type not supported");
988
989 return supported;
990}
991
arovir011c7c81b2018-10-08 11:34:28 +0100992bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
993 const TensorInfo& output,
994 Optional<std::string&> reasonIfUnsupported) const
995{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100996 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
997 input.GetDataType(),
998 &FalseInputFuncF16<>,
999 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +00001000 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +00001001 &FalseFuncI32<>,
1002 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001003 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
1004 output.GetDataType(),
1005 &TrueFunc<>,
1006 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +00001007 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +00001008 &FalseFuncI32<>,
1009 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +01001010}
1011
1012bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
1013 const TensorInfo& output,
1014 const Convolution2dDescriptor& descriptor,
1015 const TensorInfo& weights,
1016 const Optional<TensorInfo>& biases,
1017 Optional<std::string&> reasonIfUnsupported) const
1018{
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001019 bool supported = true;
1020
1021 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001022 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001023 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001024 DataType::BFloat16,
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001025 DataType::Float32,
1026 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001027 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001028 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001029 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001030 DataType::QSymmS16
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001031 };
1032
1033 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001034 "Reference Convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001035
1036 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001037 "Reference Convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001038
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001039 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
1040 if (input.GetDataType() == DataType::BFloat16)
1041 {
1042 if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
1043 {
1044 reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
1045 supported = false;
1046 }
1047 }
1048 else
1049 {
1050 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001051 "Reference Convolution2d: input and output types mismatched.");
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001052 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001053
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001054 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +00001055 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001056 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001057 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001058 {
Sadik Armagan303980c2020-04-17 12:45:14 +01001059 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001060 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01001061 DataType::QSymmS8
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001062 };
1063
1064 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001065 "Reference Convolution2d: weights type not supported for quantized input.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001066 }
1067 else
1068 {
1069 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001070 "Reference Convolution2d: weights is not a supported type.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001071
1072 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001073 "Reference Convolution2d: input and weights types mismatched.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001074 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001075
1076 if (biases.has_value())
1077 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001078 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001079 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001080 DataType::BFloat16,
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001081 DataType::Float32,
1082 DataType::Float16,
1083 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001084 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001085
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001086 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001087 "Reference Convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001088 }
Jan Eilers8eb25602020-03-09 12:13:48 +00001089 IgnoreUnused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001090
1091 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001092}
1093
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001094bool RefLayerSupport::IsConvolution3dSupported(const TensorInfo& input,
1095 const TensorInfo& output,
1096 const Convolution3dDescriptor& descriptor,
1097 const TensorInfo& weights,
1098 const Optional<TensorInfo>& biases,
1099 Optional<std::string&> reasonIfUnsupported) const
1100{
1101 bool supported = true;
1102
1103 // Define supported types.
1104 std::array<DataType,7> supportedTypes =
1105 {
1106 DataType::BFloat16,
1107 DataType::Float32,
1108 DataType::Float16,
1109 DataType::QAsymmS8,
1110 DataType::QAsymmU8,
1111 DataType::QSymmS8,
1112 DataType::QSymmS16
1113 };
1114
1115 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1116 "Reference Convolution3d: input is not a supported type.");
1117
1118 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1119 "Reference Convolution3d: output is not a supported type.");
1120
1121 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1122 "Reference Convolution3d: input and output types mismatched.");
1123
1124 const DataType inputType = input.GetDataType();
1125 if (IsQuantized8BitType(inputType))
1126 {
1127 std::array<DataType, 3> supportedWeightTypes =
1128 {
1129 DataType::QAsymmS8,
1130 DataType::QAsymmU8,
1131 DataType::QSymmS8
1132 };
1133
1134 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1135 "Reference Convolution3d: weights type not supported for quantized input.");
1136 }
1137 else
1138 {
1139 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1140 "Reference Convolution3d: weights is not a supported type.");
1141
1142 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1143 "Reference Convolution3d: input and weights types mismatched.");
1144 }
1145
1146 if (biases.has_value())
1147 {
1148 std::array<DataType,4> biasesSupportedTypes =
1149 {
1150 DataType::BFloat16,
1151 DataType::Float32,
1152 DataType::Float16,
1153 DataType::Signed32
1154 };
1155
1156 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1157 "Reference Convolution3d: biases is not a supported type.");
1158 }
1159 IgnoreUnused(descriptor);
1160
1161 return supported;
1162}
1163
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001164bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
1165 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001166 Optional<std::string&> reasonIfUnsupported) const
1167{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001168 bool supported = true;
1169
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001170 std::array<DataType, 8> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001171 {
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001172 DataType::BFloat16,
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +00001173 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001174 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001175 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001176 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001177 DataType::QSymmS8,
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +00001178 DataType::QSymmS16,
1179 DataType::Signed32
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001180 };
1181
1182 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001183 "Reference for Debug layer: input type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001184
1185 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001186 "Reference for Debug layer: output type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001187
1188 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001189 "Reference for Debug layer: input and output types are mismatched");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001190
1191 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001192}
1193
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001194bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
1195 const TensorInfo& output,
1196 const DepthToSpaceDescriptor& descriptor,
1197 Optional<std::string&> reasonIfUnsupported) const
1198{
Jan Eilers8eb25602020-03-09 12:13:48 +00001199 IgnoreUnused(descriptor);
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001200 bool supported = true;
1201
Sadik Armagan303980c2020-04-17 12:45:14 +01001202 std::array<DataType,6> supportedTypes =
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001203 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001204 DataType::BFloat16,
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001205 DataType::Float32,
1206 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001207 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001208 DataType::QAsymmU8,
1209 DataType::QSymmS16
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001210 };
1211
1212 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1213 "Reference DepthToSpace: input type not supported");
1214
1215 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1216 "Reference DepthToSpace: output type not supported");
1217
1218 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1219 "Reference DepthToSpace: input and output types are mismatched");
1220
1221 return supported;
1222}
1223
arovir011c7c81b2018-10-08 11:34:28 +01001224bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
1225 const TensorInfo& output,
1226 const DepthwiseConvolution2dDescriptor& descriptor,
1227 const TensorInfo& weights,
1228 const Optional<TensorInfo>& biases,
1229 Optional<std::string&> reasonIfUnsupported) const
1230{
Sadik Armagan303980c2020-04-17 12:45:14 +01001231 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001232 bool supported = true;
1233
1234 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001235 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001236 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001237 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001238 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001239 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001240 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001241 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001242 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001243 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001244 };
1245
1246 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1247 "Reference DepthwiseConvolution2d: input is not a supported type.");
1248
1249 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1250 "Reference DepthwiseConvolution2d: output is not a supported type.");
1251
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001252 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1253 "Reference DepthwiseConvolution2d: input and output types mismatched.");
1254
Teresa Charlind8df0262019-11-11 12:28:15 +00001255 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +00001256 if (IsQuantized8BitType(inputType))
Teresa Charlind8df0262019-11-11 12:28:15 +00001257 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001258 std::array<DataType, 3> supportedWeightTypes =
Sadik Armagan303980c2020-04-17 12:45:14 +01001259 {
1260 DataType::QAsymmS8,
1261 DataType::QAsymmU8,
1262 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001263 };
Teresa Charlind8df0262019-11-11 12:28:15 +00001264
1265 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Sadik Armagan303980c2020-04-17 12:45:14 +01001266 "Reference DepthwiseConvolution2d: weights type not supported for "
1267 "quantized input.");
Teresa Charlind8df0262019-11-11 12:28:15 +00001268 }
1269 else
1270 {
1271 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1272 "Reference DepthwiseConvolution2d: weights is not a supported type.");
1273
1274 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1275 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
1276 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001277
1278 if (biases.has_value())
1279 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001280 std::array<DataType,4> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001281 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001282 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001283 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001284 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001285 DataType::Signed32
1286 };
1287 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1288 "Reference DepthwiseConvolution2d: biases is not a supported type.");
1289 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001290
1291 return supported;
1292
arovir011c7c81b2018-10-08 11:34:28 +01001293}
1294
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001295bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
1296 const TensorInfo& output,
1297 Optional<std::string&> reasonIfUnsupported) const
1298{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001299 bool supported = true;
1300
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001301 std::array<DataType,5> supportedInputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00001302 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001303 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00001304 DataType::QSymmS8,
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001305 DataType::QSymmS16,
1306 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001307 };
1308
1309 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001310 "Reference for Dequantize layer: input type not supported.");
1311
Derek Lambertid466a542020-01-22 15:37:29 +00001312 supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
Teresa Charlin1b1950d2021-06-02 20:23:21 +01001313 "Reference for Dequantize layer: per-axis quantized input not supported.");
Derek Lambertid466a542020-01-22 15:37:29 +00001314
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001315 std::array<DataType,3> supportedOutputTypes = {
1316 DataType::BFloat16,
Jan Eilersf7107932019-11-01 11:09:36 +00001317 DataType::Float32,
1318 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001319 };
1320
1321 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001322 "Reference for Dequantize layer: output type not supported.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001323
1324 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001325 "Reference for Dequantize layer: input/output shapes have different num total "
1326 "elements.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001327
1328 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001329}
1330
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001331bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
1332 const TensorInfo& scores,
1333 const TensorInfo& anchors,
1334 const TensorInfo& detectionBoxes,
1335 const TensorInfo& detectionClasses,
1336 const TensorInfo& detectionScores,
1337 const TensorInfo& numDetections,
1338 const DetectionPostProcessDescriptor& descriptor,
1339 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001340{
Jan Eilers8eb25602020-03-09 12:13:48 +00001341 IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
Derek Lamberti901ea112019-12-10 22:07:09 +00001342
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001343 bool supported = true;
1344
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001345 std::array<DataType,6> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001346 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001347 DataType::BFloat16,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001348 DataType::Float32,
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001349 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001350 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001351 DataType::QAsymmU8,
1352 DataType::QSymmS16
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001353 };
1354
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001355 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001356 "Reference DetectionPostProcess: input 0 is not a supported type.");
1357
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001358 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001359 "Reference DetectionPostProcess: input 1 is not a supported type.");
1360
1361 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001362}
1363
Pablo Tellof0bd6832019-04-26 17:58:13 +01001364bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
1365 const TensorInfo& output,
1366 const DepthwiseConvolution2dDescriptor& descriptor,
1367 const TensorInfo& weights,
1368 const Optional<TensorInfo>& biases,
1369 Optional<std::string&> reasonIfUnsupported) const
1370{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001371 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +01001372}
1373
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001374bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +01001375 const TensorInfo& input1,
1376 const TensorInfo& output,
1377 Optional<std::string&> reasonIfUnsupported) const
1378{
Sadik Armagan2999a022019-04-09 14:20:12 +01001379 bool supported = true;
1380
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001381 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001382 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001383 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001384 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001385 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001386 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001387 DataType::QSymmS16,
1388 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001389 };
1390
1391 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1392 "Reference division: input 0 is not a supported type.");
1393
1394 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1395 "Reference division: input 1 is not a supported type.");
1396
1397 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1398 "Reference division: output is not a supported type.");
1399
1400 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1401 "Reference division: input 0 and Input 1 types are mismatched");
1402
1403 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1404 "Reference division: input and output types are mismatched");
1405
1406 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1407 "Reference division: shapes are not suitable for implicit broadcast.");
1408
1409 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001410}
1411
josh minor4a3c6102020-01-06 16:40:46 -06001412bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
1413 const TensorInfo& output,
1414 const ElementwiseUnaryDescriptor& descriptor,
1415 Optional<std::string&> reasonIfUnsupported) const
1416{
Jan Eilers8eb25602020-03-09 12:13:48 +00001417 IgnoreUnused(descriptor);
josh minor4a3c6102020-01-06 16:40:46 -06001418
Sadik Armagan303980c2020-04-17 12:45:14 +01001419 std::array<DataType, 7> supportedTypes =
josh minor4a3c6102020-01-06 16:40:46 -06001420 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001421 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -06001422 DataType::Float32,
1423 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001424 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06001425 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00001426 DataType::QSymmS16,
1427 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06001428 };
1429
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001430 std::array<DataType, 1> logicalSupportedTypes =
1431 {
1432 DataType::Boolean
1433 };
1434
josh minor4a3c6102020-01-06 16:40:46 -06001435 bool supported = true;
1436
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001437 if (descriptor.m_Operation == UnaryOperation::LogicalNot)
1438 {
1439 supported &= CheckSupportRule(TypeAnyOf(input, logicalSupportedTypes), reasonIfUnsupported,
1440 "Reference elementwise unary: input type not supported");
josh minor4a3c6102020-01-06 16:40:46 -06001441
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001442 supported &= CheckSupportRule(TypeAnyOf(output, logicalSupportedTypes), reasonIfUnsupported,
1443 "Reference elementwise unary: output type not supported");
1444 }
1445 else
1446 {
1447 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1448 "Reference elementwise unary: input type not supported");
1449
1450 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1451 "Reference elementwise unary: output type not supported");
1452 }
josh minor4a3c6102020-01-06 16:40:46 -06001453
1454 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1455 "Reference elementwise unary: input and output types not matching");
1456
1457 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1458 "Reference elementwise unary: input and output shapes"
1459 "have different number of total elements");
1460
1461 return supported;
1462}
1463
arovir011c7c81b2018-10-08 11:34:28 +01001464bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
1465 const FakeQuantizationDescriptor& descriptor,
1466 Optional<std::string&> reasonIfUnsupported) const
1467{
Jan Eilers8eb25602020-03-09 12:13:48 +00001468 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001469 bool supported = true;
1470
1471 std::array<DataType,1> supportedTypes =
1472 {
1473 DataType::Float32
1474 };
1475
1476 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1477 "Reference fake quantization: input type not supported.");
1478
1479 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001480}
1481
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001482bool RefLayerSupport::IsFillSupported(const TensorInfo& input,
1483 const TensorInfo& output,
1484 const FillDescriptor& descriptor,
1485 Optional<std::string&> reasonIfUnsupported) const
1486{
1487 IgnoreUnused(descriptor);
1488 IgnoreUnused(output);
1489
1490 bool supported = true;
1491
Sadik Armagana792a052020-06-23 16:22:23 +01001492 std::array<DataType,3> supportedTypes =
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001493 {
1494 DataType::Float32,
Sadik Armagana792a052020-06-23 16:22:23 +01001495 DataType::Float16,
1496 DataType::Signed32
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001497 };
1498
Teresa Charlin4b10fef2020-07-29 09:36:41 +01001499 supported &= CheckSupportRule(TypeIs(input, DataType::Signed32), reasonIfUnsupported,
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001500 "Reference Fill: input type not supported.");
1501
Teresa Charlin44088502020-07-27 11:27:19 +01001502 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1503 "Reference Fill: output type not supported.");
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001504 return supported;
1505}
1506
arovir011c7c81b2018-10-08 11:34:28 +01001507bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
1508 const TensorInfo& output,
1509 Optional<std::string&> reasonIfUnsupported) const
1510{
Jan Eilers8eb25602020-03-09 12:13:48 +00001511 IgnoreUnused(output);
James Conroy83735b12019-05-30 16:36:59 +01001512 bool supported = true;
1513
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01001514 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +01001515 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001516 DataType::BFloat16,
James Conroyb40d7102019-06-04 12:32:09 +01001517 DataType::Float32,
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01001518 DataType::Float16
James Conroy83735b12019-05-30 16:36:59 +01001519 };
1520
1521 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1522 "Reference Floor: input type not supported.");
1523
1524 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1525 "Reference Floor: output type not supported.");
1526
1527 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001528}
1529
1530bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
1531 const TensorInfo& output,
1532 const TensorInfo& weights,
1533 const TensorInfo& biases,
1534 const FullyConnectedDescriptor& descriptor,
1535 Optional<std::string&> reasonIfUnsupported) const
1536{
Francis Murtagh46c09d02019-05-28 08:15:28 +01001537 bool supported = true;
1538
1539 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001540 std::array<DataType,6> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +01001541 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001542 DataType::BFloat16,
1543 DataType::Float32,
1544 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001545 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001546 DataType::QAsymmU8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001547 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001548 };
1549
1550 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1551 "Reference Fully Connected: input type not supported.");
1552
1553 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1554 "Reference Fully Connected: output type not supported.");
1555
Francis Murtagh46c09d02019-05-28 08:15:28 +01001556 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1557 "Reference Fully Connected: weights type not supported.");
1558
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001559 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1560 if (input.GetDataType() == DataType::BFloat16)
1561 {
1562 if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
1563 {
1564 reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
1565 supported = false;
1566 }
1567 }
1568 else
1569 {
1570 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1571 "Reference Fully Connected: input and output types mismatched.");
1572 }
1573
Jan Eilers1f45dc32020-06-15 11:43:03 +01001574 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1575 "Reference Fully Connected: weights is not a supported type.");
Francis Murtaghddb1d062020-03-10 13:51:45 +00001576
Jan Eilers1f45dc32020-06-15 11:43:03 +01001577 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1578 "Reference Fully Connected: input and weights types mismatched.");
Francis Murtagh46c09d02019-05-28 08:15:28 +01001579
1580 if (descriptor.m_BiasEnabled)
1581 {
1582 // Defined supported types for bias
Sadik Armagandb73c982020-04-01 17:35:30 +01001583 std::array<DataType, 5>
Francis Murtagh46c09d02019-05-28 08:15:28 +01001584 supportedBiasTypes =
1585 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001586 DataType::BFloat16,
Francis Murtagh46c09d02019-05-28 08:15:28 +01001587 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001588 DataType::Float16,
Sadik Armagandb73c982020-04-01 17:35:30 +01001589 DataType::Signed32,
1590 DataType::QAsymmS8
Francis Murtagh46c09d02019-05-28 08:15:28 +01001591 };
1592
1593 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
1594 "Reference Fully Connected: bias type not supported.");
1595
1596 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
1597 "Reference Fully Connected: bias and weight types mismatch.");
1598
1599 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
1600 "Reference Fully Connected: bias type inferred from weights is incompatible.");
1601
Narumol Prangnawarat366d7232020-04-29 12:58:17 +01001602 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(biases, 1U), reasonIfUnsupported,
1603 "Reference Fully Connected: bias must have 1 dimension.");
1604
Francis Murtagh46c09d02019-05-28 08:15:28 +01001605 }
1606
1607 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001608}
1609
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01001610bool RefLayerSupport::IsGatherNdSupported(const armnn::TensorInfo& input0,
1611 const armnn::TensorInfo& input1,
1612 const armnn::TensorInfo& output,
1613 armnn::Optional<std::string&> reasonIfUnsupported) const
1614{
1615 bool supported = true;
1616 std::array<DataType,7> supportedTypes =
1617 {
1618 DataType::BFloat16,
1619 DataType::Float32,
1620 DataType::Float16,
1621 DataType::QAsymmS8,
1622 DataType::QAsymmU8,
1623 DataType::QSymmS16,
1624 DataType::Signed32
1625 };
1626
1627 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1628 "Reference GatherNd: input type not supported");
1629
1630 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1631 "Reference GatherNd: output type not supported");
1632
1633 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1634 "Reference GatherNd: indices (input1) type not supported");
1635
1636 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1637 "Reference GatherNd: input and output types not matching");
1638
1639 return supported;
1640}
1641
narpra014951d842019-01-18 16:53:53 +00001642bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
1643 const armnn::TensorInfo& input1,
1644 const armnn::TensorInfo& output,
Teresa Charlin52664732020-06-29 16:27:03 +01001645 const GatherDescriptor& descriptor,
narpra014951d842019-01-18 16:53:53 +00001646 armnn::Optional<std::string&> reasonIfUnsupported) const
1647{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001648 bool supported = true;
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001649 std::array<DataType,7> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001650 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001651 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001652 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001653 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001654 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001655 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001656 DataType::QSymmS16,
1657 DataType::Signed32
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001658 };
1659
Teresa Charlin52664732020-06-29 16:27:03 +01001660 if (descriptor.m_Axis != 0)
1661 {
1662 reasonIfUnsupported.value() += std::string("Reference Gather: axis not supported\n");
1663 supported &= false;
1664 }
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001665 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1666 "Reference Gather: input type not supported");
1667
1668 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1669 "Reference Gather: output type not supported");
1670
1671 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1672 "Reference Gather: indices (input1) type not supported");
1673
1674 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1675 "Reference Gather: input and output types not matching");
1676
1677 return supported;
narpra014951d842019-01-18 16:53:53 +00001678}
1679
Derek Lamberti901ea112019-12-10 22:07:09 +00001680bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
1681 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +01001682{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001683 return true;
arovir011c7c81b2018-10-08 11:34:28 +01001684}
1685
Kevin May09ca49c2019-10-09 12:37:34 +01001686bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
1687 const TensorInfo& output,
1688 const InstanceNormalizationDescriptor& descriptor,
1689 Optional<std::string&> reasonIfUnsupported) const
1690{
Jan Eilers8eb25602020-03-09 12:13:48 +00001691 IgnoreUnused(descriptor);
Kevin May09ca49c2019-10-09 12:37:34 +01001692 // Define supported types
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001693 std::array<DataType, 3> supportedTypes =
Kevin May09ca49c2019-10-09 12:37:34 +01001694 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001695 DataType::BFloat16,
Kevin May09ca49c2019-10-09 12:37:34 +01001696 DataType::Float32,
1697 DataType::Float16
1698 };
1699
1700 bool supported = true;
1701
1702 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1703 "Reference Instance Normalization: input type not supported.");
1704
1705 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1706 "Reference Instance Normalization: output type not supported.");
1707
1708 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1709 "Reference Instance Normalization: input and output types mismatched.");
1710
1711 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1712 "Reference Instance Normalization: input and output shapes have different "
1713 "num total elements.");
1714
1715 return supported;
1716}
1717
arovir011c7c81b2018-10-08 11:34:28 +01001718bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1719 const TensorInfo& output,
1720 const L2NormalizationDescriptor& descriptor,
1721 Optional<std::string&> reasonIfUnsupported) const
1722{
Jan Eilers8eb25602020-03-09 12:13:48 +00001723 IgnoreUnused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001724 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01001725 std::array<DataType, 6> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001726 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001727 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001728 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001729 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001730 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001731 DataType::QAsymmU8,
1732 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001733 };
1734
1735 bool supported = true;
1736
1737 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1738 "Reference L2normalization: input type not supported.");
1739
1740 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1741 "Reference L2normalization: output type not supported.");
1742
1743 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1744 "Reference L2normalization: input and output types mismatched.");
1745
1746 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1747 "Reference L2normalization: input and output shapes have different "
1748 "num total elements.");
1749
1750 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001751}
1752
James Conroyaba90cd2020-11-06 16:28:18 +00001753bool RefLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0,
1754 const TensorInfo& input1,
1755 const TensorInfo& output,
1756 const LogicalBinaryDescriptor& descriptor,
1757 Optional<std::string&> reasonIfUnsupported) const
1758{
1759 IgnoreUnused(descriptor);
1760
1761 std::array<DataType, 1> supportedTypes =
1762 {
1763 DataType::Boolean
1764 };
1765
1766 bool supported = true;
1767 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1768 "Reference LogicalBinary: input 0 type not supported");
1769 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1770 "Reference LogicalBinary: input 1 type not supported");
1771
1772 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1773 "Reference LogicalBinary: input and output types do not match");
1774
1775 return supported;
1776}
1777
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001778bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1779 const TensorInfo& output,
1780 const LogSoftmaxDescriptor& descriptor,
1781 Optional<std::string&> reasonIfUnsupported) const
1782{
Jan Eilers8eb25602020-03-09 12:13:48 +00001783 IgnoreUnused(descriptor);
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001784
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001785 std::array<DataType, 3> supportedTypes =
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001786 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001787 DataType::BFloat16,
1788 DataType::Float32,
1789 DataType::Float16
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001790 };
1791
1792 bool supported = true;
1793 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1794 "Reference LogSoftmax: input type not supported");
1795
1796 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1797 "Reference LogSoftmax: output type not supported");
1798
1799 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1800 "Reference LogSoftmax: input and output types do not match");
1801
1802 return supported;
1803}
1804
arovir011c7c81b2018-10-08 11:34:28 +01001805bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1806 const TensorInfo& outputStateIn,
1807 const TensorInfo& cellStateIn,
1808 const TensorInfo& scratchBuffer,
1809 const TensorInfo& outputStateOut,
1810 const TensorInfo& cellStateOut,
1811 const TensorInfo& output,
1812 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001813 const LstmInputParamsInfo& paramsInfo,
1814 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +01001815{
Jan Eilers8eb25602020-03-09 12:13:48 +00001816 IgnoreUnused(descriptor);
1817 IgnoreUnused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001818
1819 bool supported = true;
1820
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001821 std::array<DataType,3> supportedTypes = {
1822 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001823 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001824 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001825 };
1826
Jan Eilersd01a83c2019-07-03 18:20:40 +01001827 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001828 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1829 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001830 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1831 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001832 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1833 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001834 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1835 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001836 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1837 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001838 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1839 "Reference Lstm: input and cellStateOut types are mismatched");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01001840
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001841 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1842 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001843 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001844 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001845 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001846 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001847 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001848 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001849 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001850 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001851 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001852 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001853 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001854 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001855 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001856 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001857 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001858 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001859 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001860 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001861 "Reference Lstm: input and OutputGateBias types are mismatched");
1862 if (!descriptor.m_CifgEnabled)
1863 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001864 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001865 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001866 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001867 reasonIfUnsupported,
1868 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001869 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001870 "Reference Lstm: input and InputGateBias types are mismatched");
1871 if (descriptor.m_PeepholeEnabled)
1872 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001873 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001874 reasonIfUnsupported,
1875 "Reference Lstm: input and CellToInputWeights types are mismatched");
1876 }
1877 }
1878 if (descriptor.m_PeepholeEnabled)
1879 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001880 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001881 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001882 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001883 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1884 }
1885 if (descriptor.m_ProjectionEnabled)
1886 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001887 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001888 "Reference Lstm: input and mProjectionWeights types are mismatched");
1889 if (paramsInfo.m_ProjectionBias != nullptr)
1890 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001891 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001892 "Reference Lstm: input and ProjectionBias types are mismatched");
1893 }
1894 }
1895 if (descriptor.m_LayerNormEnabled)
1896 {
1897 if (!descriptor.m_CifgEnabled)
1898 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001899 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001900 reasonIfUnsupported,
1901 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1902 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001903 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001904 reasonIfUnsupported,
1905 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001906 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001907 reasonIfUnsupported,
1908 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001909 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001910 reasonIfUnsupported,
1911 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1912 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001913
1914 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001915}
1916
saoste012df12b32018-11-28 16:57:20 +00001917bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1918 const TensorInfo& input1,
1919 const TensorInfo& output,
1920 Optional<std::string&> reasonIfUnsupported) const
1921{
Sadik Armagan2999a022019-04-09 14:20:12 +01001922 bool supported = true;
1923
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001924 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001925 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001926 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001927 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001928 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001929 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001930 DataType::QSymmS16,
1931 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001932 };
1933
1934 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1935 "Reference maximum: input 0 is not a supported type.");
1936
1937 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1938 "Reference maximum: input 1 is not a supported type.");
1939
1940 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1941 "Reference maximum: output is not a supported type.");
1942
1943 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1944 "Reference maximum: input 0 and Input 1 types are mismatched");
1945
1946 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1947 "Reference maximum: input and output types are mismatched");
1948
1949 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1950 "Reference maximum: shapes are not suitable for implicit broadcast.");
1951
1952 return supported;
saoste012df12b32018-11-28 16:57:20 +00001953}
1954
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001955bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1956 const TensorInfo& output,
1957 const MeanDescriptor& descriptor,
1958 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001959{
James Conroy4d1ff582019-06-10 17:06:39 +01001960 bool supported = true;
1961 std::string meanLayerStr = "Mean";
1962 std::string outputTensorStr = "output";
1963
Sadik Armagan303980c2020-04-17 12:45:14 +01001964 std::array<DataType,6> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001965 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001966 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01001967 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001968 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001969 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001970 DataType::QAsymmU8,
1971 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01001972 };
1973
1974 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1975 "Reference Mean: input type not supported.");
1976
1977 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1978 "Reference Mean: input and output types are mismatched");
1979
1980 if (descriptor.m_KeepDims)
1981 {
1982 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1983 reasonIfUnsupported,
1984 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1985 output.GetNumDimensions(),
1986 meanLayerStr, outputTensorStr).data());
1987 }
1988 else if (descriptor.m_Axis.empty())
1989 {
1990 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1991 reasonIfUnsupported,
1992 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1993 meanLayerStr, outputTensorStr).data());
1994 }
1995 else
1996 {
Matthew Sloyan171214c2020-09-09 09:07:37 +01001997 auto outputDim = input.GetNumDimensions() - armnn::numeric_cast<unsigned int>(descriptor.m_Axis.size());
James Conroy4d1ff582019-06-10 17:06:39 +01001998
1999 if (outputDim > 0)
2000 {
2001 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
2002 reasonIfUnsupported,
2003 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
2004 meanLayerStr, outputTensorStr).data());
2005 }
2006 else
2007 {
2008 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
2009 reasonIfUnsupported,
2010 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
2011 meanLayerStr, outputTensorStr).data());
2012 }
2013 }
2014
2015 return supported;
narpra0132b90462018-09-13 11:07:48 +01002016}
2017
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00002018bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
2019 const TensorInfo &output,
2020 Optional<std::string &> reasonIfUnsupported) const
2021{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002022 bool supported = true;
2023
Sadik Armagan303980c2020-04-17 12:45:14 +01002024 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002025 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002026 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002027 DataType::Float32,
2028 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002029 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002030 DataType::QAsymmU8,
2031 DataType::QSymmS16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002032 DataType::Boolean
2033 };
2034
2035 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2036 "Reference MemCopy: input type not supported");
2037
2038 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2039 "Reference MemCopy: output type not supported");
2040
2041 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2042 "Reference MemCopy: input and output types are mismatched");
2043
2044 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00002045}
2046
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00002047bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
2048 const TensorInfo& input1,
2049 const TensorInfo& output,
2050 Optional<std::string&> reasonIfUnsupported) const
2051{
Sadik Armagan2999a022019-04-09 14:20:12 +01002052 bool supported = true;
2053
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002054 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002055 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01002056 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002057 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002058 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002059 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002060 DataType::QSymmS16,
2061 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002062 };
2063
2064 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2065 "Reference minimum: input 0 is not a supported type.");
2066
2067 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2068 "Reference minimum: input 1 is not a supported type.");
2069
2070 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2071 "Reference minimum: output is not a supported type.");
2072
2073 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2074 "Reference minimum: input 0 and Input 1 types are mismatched");
2075
2076 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2077 "Reference minimum: input and output types are mismatched");
2078
2079 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2080 "Reference minimum: shapes are not suitable for implicit broadcast.");
2081
2082 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00002083}
2084
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002085bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
2086 const TensorInfo& input1,
2087 const TensorInfo& output,
2088 Optional<std::string&> reasonIfUnsupported) const
2089{
Sadik Armagan2999a022019-04-09 14:20:12 +01002090 bool supported = true;
2091
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002092 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002093 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01002094 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002095 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002096 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002097 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002098 DataType::QSymmS16,
2099 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002100 };
2101
2102 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2103 "Reference multiplication: input 0 is not a supported type.");
2104
2105 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2106 "Reference multiplication: input 1 is not a supported type.");
2107
2108 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2109 "Reference multiplication: output is not a supported type.");
2110
2111 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2112 "Reference multiplication: input 0 and Input 1 types are mismatched");
2113
2114 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2115 "Reference multiplication: input and output types are mismatched");
2116
2117 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2118 "Reference multiplication: shapes are not suitable for implicit broadcast.");
2119
2120 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002121}
2122
2123bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
2124 const TensorInfo& output,
2125 const NormalizationDescriptor& descriptor,
2126 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01002127{
Jan Eilers8eb25602020-03-09 12:13:48 +00002128 IgnoreUnused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002129
2130 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01002131 std::array<DataType, 6> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002132 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002133 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002134 DataType::Float16,
2135 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002136 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002137 DataType::QAsymmU8,
2138 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002139 };
2140
2141 bool supported = true;
2142
2143 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2144 "Reference normalization: input type not supported.");
2145
2146 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2147 "Reference normalization: output type not supported.");
2148
2149 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2150 "Reference normalization: input and output shapes have different "
2151 "num total elements.");
2152
2153 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002154}
2155
Derek Lamberti901ea112019-12-10 22:07:09 +00002156bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
2157 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002158{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01002159 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002160}
2161
2162bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
2163 const TensorInfo& output,
2164 const PadDescriptor& descriptor,
2165 Optional<std::string&> reasonIfUnsupported) const
2166{
Jan Eilers8eb25602020-03-09 12:13:48 +00002167 IgnoreUnused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002168 bool supported = true;
2169
2170 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002171 std::array<DataType,6> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002172 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002173 DataType::BFloat16,
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002174 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002175 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002176 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002177 DataType::QAsymmU8,
2178 DataType::QSymmS16
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002179 };
2180
2181 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2182 "Reference pad: input is not a supported type.");
2183
2184 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2185 "Reference pad: output is not a supported type.");
2186
2187 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2188 "Reference pad: input and output types are mismatched.");
2189
2190 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01002191}
2192
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002193bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
2194 const TensorInfo& output,
2195 const PermuteDescriptor& descriptor,
2196 Optional<std::string&> reasonIfUnsupported) const
2197{
Jan Eilers8eb25602020-03-09 12:13:48 +00002198 IgnoreUnused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002199 bool supported = true;
2200
2201 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002202 std::array<DataType, 6> supportedTypes =
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002203 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002204 DataType::BFloat16,
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002205 DataType::Float32,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002206 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002207 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002208 DataType::QAsymmU8,
2209 DataType::QSymmS16
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002210 };
2211
2212 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2213 "Reference permute: input is not a supported type.");
2214
2215 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2216 "Reference permute: output is not a supported type.");
2217
2218 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2219 "Reference permute: input and output types are mismatched.");
2220
2221 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002222}
2223
2224bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
2225 const TensorInfo& output,
2226 const Pooling2dDescriptor& descriptor,
2227 Optional<std::string&> reasonIfUnsupported) const
2228{
Jan Eilers8eb25602020-03-09 12:13:48 +00002229 IgnoreUnused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01002230 bool supported = true;
2231
2232 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002233 std::array<DataType,6> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01002234 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002235 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01002236 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002237 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002238 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002239 DataType::QAsymmU8,
2240 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01002241 };
2242
2243 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2244 "Reference poolind2d: input is not a supported type.");
2245
2246 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2247 "Reference poolind2d: output is not a supported type.");
2248
2249 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2250 "Reference poolind2d: input and output types are mismatched.");
2251
2252 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002253}
2254
Tamás Nyíri7b885b32021-10-26 14:47:57 +01002255bool RefLayerSupport::IsPooling3dSupported(const TensorInfo& input,
2256 const TensorInfo& output,
2257 const Pooling3dDescriptor& descriptor,
2258 Optional<std::string&> reasonIfUnsupported) const
2259{
2260 IgnoreUnused(descriptor);
2261 bool supported = true;
2262
2263 // Define supported output and inputs types.
2264 std::array<DataType,6> supportedTypes =
2265 {
2266 DataType::BFloat16,
2267 DataType::Float32,
2268 DataType::Float16,
2269 DataType::QAsymmS8,
2270 DataType::QAsymmU8,
2271 DataType::QSymmS16
2272 };
2273
2274 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2275 "Reference poolind3d: input is not a supported type.");
2276
2277 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2278 "Reference poolind3d: output is not a supported type.");
2279
2280 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2281 "Reference poolind3d: input and output types are mismatched.");
2282
2283 return supported;
2284}
2285
2286
James Conroy4f1f8992020-04-29 20:01:10 +01002287bool RefLayerSupport::IsQLstmSupported(const TensorInfo& input,
2288 const TensorInfo& previousOutputIn,
2289 const TensorInfo& previousCellStateIn,
2290 const TensorInfo& outputStateOut,
2291 const TensorInfo& cellStateOut,
2292 const TensorInfo& output,
2293 const QLstmDescriptor& descriptor,
2294 const LstmInputParamsInfo& paramsInfo,
2295 Optional<std::string&> reasonIfUnsupported) const
2296{
2297 IgnoreUnused(input);
2298 IgnoreUnused(previousOutputIn);
2299 IgnoreUnused(previousCellStateIn);
2300 IgnoreUnused(outputStateOut);
2301 IgnoreUnused(cellStateOut);
2302 IgnoreUnused(output);
2303 IgnoreUnused(descriptor);
2304 IgnoreUnused(paramsInfo);
2305
2306 IgnoreUnused(reasonIfUnsupported);
2307
2308 return true;
2309}
2310
Derek Lamberti5f400d62019-03-25 15:41:58 +00002311bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
2312 const TensorInfo& output,
2313 Optional<std::string&> reasonIfUnsupported) const
2314{
2315 bool supported = true;
2316
Finn Williamsfd271062019-12-04 14:27:27 +00002317 // Define supported input types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002318 std::array<DataType,7> supportedInputTypes = {
2319 DataType::BFloat16,
Keith Davis5e51cd82020-01-29 16:52:59 +00002320 DataType::Float32,
Keith Davis3d8bc972020-02-04 09:31:47 +00002321 DataType::Float16,
Ryan OShea9add1202020-02-07 10:06:33 +00002322 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002323 DataType::QAsymmU8,
2324 DataType::QSymmS8,
2325 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002326 };
2327
2328 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
2329 "Reference quantize: input type not supported.");
2330
2331 // Define supported output types.
Ryan OShea9add1202020-02-07 10:06:33 +00002332 std::array<DataType,4> supportedOutputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00002333 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002334 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00002335 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002336 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002337 };
2338 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2339 "Reference quantize: output type not supported.");
2340
2341 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2342 "Reference quantize: input and output shapes have different num total elements.");
2343
2344 return supported;
2345}
2346
Finn Williams2605b232020-06-10 15:53:46 +01002347bool RefLayerSupport::IsRankSupported(const TensorInfo& input,
2348 const TensorInfo& output,
2349 Optional<std::string&> reasonIfUnsupported) const
2350{
2351 IgnoreUnused(input);
2352 // Define supported output types.
2353 std::array<DataType,1> supportedOutputTypes =
2354 {
2355 DataType::Signed32,
2356 };
2357
2358 return CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2359 "Reference rank: input type not supported.");
2360}
2361
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00002362bool RefLayerSupport::IsReduceSupported(const TensorInfo& input,
2363 const TensorInfo& output,
2364 const ReduceDescriptor& descriptor,
2365 Optional<std::string&> reasonIfUnsupported) const
2366{
2367 IgnoreUnused(descriptor);
2368 bool supported = true;
2369 std::array<DataType,7> supportedTypes =
2370 {
2371 DataType::BFloat16,
2372 DataType::Float32,
2373 DataType::Float16,
2374 DataType::QAsymmS8,
2375 DataType::QAsymmU8,
2376 DataType::QSymmS16,
2377 DataType::Signed32
2378 };
2379
2380 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2381 "Reference Reduce: input type not supported");
2382
2383 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2384 "Reference Reduce: output type not supported");
2385
2386 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2387 "Reference Reduce: input and output types not matching");
2388
2389 return supported;
2390}
2391
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002392bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Kevin Maya023c402019-12-12 17:28:05 +00002393 const TensorInfo& output,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00002394 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002395 Optional<std::string&> reasonIfUnsupported) const
2396{
Jan Eilers8eb25602020-03-09 12:13:48 +00002397 IgnoreUnused(output);
2398 IgnoreUnused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01002399 // Define supported output types.
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002400 std::array<DataType,8> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01002401 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002402 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01002403 DataType::Float32,
2404 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01002405 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002406 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002407 DataType::QAsymmU8,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002408 DataType::QSymmS16,
2409 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01002410 };
Keith Davis0c2eeac2020-02-11 16:51:50 +00002411
Nina Drozd2f2778f2019-05-27 10:37:05 +01002412 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
2413 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002414}
2415
Teresa Charlin970f43b2019-07-01 13:51:07 +01002416bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
2417 const TensorInfo& output,
2418 const ResizeDescriptor& descriptor,
2419 Optional<std::string&> reasonIfUnsupported) const
2420{
Jan Eilers8eb25602020-03-09 12:13:48 +00002421 IgnoreUnused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01002422 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002423 std::array<DataType,6> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01002424 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002425 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01002426 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002427 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002428 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002429 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002430 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01002431 };
2432
2433 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2434 "Reference Resize: input type not supported");
2435
2436 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2437 "Reference Resize: output type not supported");
2438
2439 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2440 "Reference Resize: input and output types not matching");
2441
2442 return supported;
2443}
2444
Keith Davis3ae3f972021-05-21 16:33:48 +01002445bool RefLayerSupport::IsShapeSupported(const TensorInfo& input,
2446 const TensorInfo& output,
2447 Optional<std::string&> reasonIfUnsupported) const
2448{
2449 IgnoreUnused(input);
2450 bool supported = true;
2451
2452 std::array<DataType, 1> supportedTypes =
2453 {
2454 DataType::Signed32
2455 };
2456
2457 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2458 "Reference Shape: output type not supported");
2459
2460 return supported;
2461}
2462
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002463bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
2464 const TensorInfo& output,
2465 const SliceDescriptor& descriptor,
2466 Optional<std::string&> reasonIfUnsupported) const
2467{
Jan Eilers8eb25602020-03-09 12:13:48 +00002468 IgnoreUnused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002469 bool supported = true;
2470
Sadik Armagan303980c2020-04-17 12:45:14 +01002471 std::array<DataType, 5> supportedTypes =
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002472 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002473 DataType::BFloat16,
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002474 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002475 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002476 DataType::QAsymmU8,
2477 DataType::QSymmS16
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002478 };
2479
2480 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2481 "Reference Slice: input type not supported");
2482
2483 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2484 "Reference Slice: output type not supported");
2485
2486 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2487 "Reference Slice: input and output types are mismatched");
2488
2489 return supported;
2490}
2491
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002492bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
2493 const TensorInfo& output,
2494 const SoftmaxDescriptor& descriptor,
2495 Optional<std::string&> reasonIfUnsupported) const
2496{
Jan Eilers8eb25602020-03-09 12:13:48 +00002497 IgnoreUnused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01002498 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002499 std::array<DataType,7> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01002500 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002501 DataType::BFloat16,
2502 DataType::Float32,
2503 DataType::Float16,
2504 DataType::QSymmS8,
2505 DataType::QAsymmS8,
2506 DataType::QAsymmU8,
2507 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +01002508 };
2509
2510 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002511 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002512
2513 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002514 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002515
2516 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002517 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002518
2519 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002520}
2521
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002522bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
2523 const TensorInfo& output,
2524 const SpaceToBatchNdDescriptor& descriptor,
2525 Optional<std::string&> reasonIfUnsupported) const
2526{
Jan Eilers8eb25602020-03-09 12:13:48 +00002527 IgnoreUnused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01002528 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002529 std::array<DataType,6> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01002530 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002531 DataType::BFloat16,
2532 DataType::Float32,
2533 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002534 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002535 DataType::QAsymmU8,
2536 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01002537 };
2538
2539 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2540 "Reference SpaceToBatchNd: input type not supported");
2541
2542 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2543 "Reference SpaceToBatchNd: output type not supported");
2544
2545 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2546 "Reference SpaceToBatchNd: input and output types are mismatched");
2547
2548 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002549}
2550
Keith Davisa57eccb2019-06-14 17:33:22 +01002551bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01002552 const TensorInfo& output,
2553 const SpaceToDepthDescriptor& descriptor,
2554 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01002555{
2556
Jan Eilers8eb25602020-03-09 12:13:48 +00002557 IgnoreUnused(descriptor);
Keith Davisa57eccb2019-06-14 17:33:22 +01002558 bool supported = true;
2559
Sadik Armagan303980c2020-04-17 12:45:14 +01002560 std::array<DataType,6> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01002561 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002562 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01002563 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002564 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002565 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002566 DataType::QAsymmU8,
2567 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01002568 };
2569
2570 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2571 "Reference SpaceToDepth: input type not supported");
2572
2573 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2574 "Reference SpaceToDepth: output type not supported");
2575
2576 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2577 "Reference SpaceToDepth: input and output types are mismatched");
2578
2579 return supported;
2580}
2581
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002582bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002583 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
2584 const ViewsDescriptor& descriptor,
2585 Optional<std::string&> reasonIfUnsupported) const
2586{
Jan Eilers8eb25602020-03-09 12:13:48 +00002587 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002588 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002589 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002590 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002591 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002592 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002593 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002594 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002595 DataType::QAsymmU8,
2596 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002597 };
2598
2599 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2600 "Reference splitter: output type not supported");
Derek Lambertieac4adb2020-08-25 13:05:59 +01002601 for (const TensorInfo& output : outputs)
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002602 {
2603 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2604 "Reference splitter: input type not supported");
2605
2606 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2607 "Reference splitter: input and output types mismatched.");
2608 }
2609
2610 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002611}
2612
Matthew Jackson81e601c2019-07-11 12:07:09 +01002613bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
2614 const TensorInfo& output,
2615 const StackDescriptor& descriptor,
2616 Optional<std::string&> reasonIfUnsupported) const
2617{
Jan Eilers8eb25602020-03-09 12:13:48 +00002618 IgnoreUnused(descriptor);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002619
2620 bool supported = true;
Sadik Armagan529195f2022-01-14 12:56:35 +00002621 std::array<DataType,7> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01002622 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002623 DataType::BFloat16,
Matthew Jackson81e601c2019-07-11 12:07:09 +01002624 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01002625 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002626 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002627 DataType::QAsymmU8,
Sadik Armagan529195f2022-01-14 12:56:35 +00002628 DataType::QSymmS16,
2629 DataType::Signed32
Matthew Jackson81e601c2019-07-11 12:07:09 +01002630 };
2631
2632 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2633 "Reference stack: output type not supported");
2634 for (const TensorInfo* input : inputs)
2635 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01002636 ARMNN_ASSERT(input != nullptr);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002637 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
2638 "Reference stack: input type not supported");
2639
2640 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
2641 "Reference stack: input and output types mismatched.");
2642 }
2643
2644 return supported;
2645}
2646
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002647bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
2648 const TensorInfo& output,
2649 const StridedSliceDescriptor& descriptor,
2650 Optional<std::string&> reasonIfUnsupported) const
2651{
Jan Eilers8eb25602020-03-09 12:13:48 +00002652 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002653 bool supported = true;
2654
Sadik Armagan303980c2020-04-17 12:45:14 +01002655 std::array<DataType,5> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002656 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002657 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002658 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002659 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002660 DataType::QAsymmU8,
2661 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002662 };
2663
2664 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2665 "Reference StridedSlice: input type not supported");
2666
2667 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2668 "Reference StridedSlice: output type not supported");
2669
2670 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2671 "Reference StridedSlice: input and output types are mismatched");
2672
2673 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002674}
2675
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002676bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
2677 const TensorInfo& input1,
2678 const TensorInfo& output,
2679 Optional<std::string&> reasonIfUnsupported) const
2680{
Sadik Armagan2999a022019-04-09 14:20:12 +01002681 bool supported = true;
2682
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002683 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002684 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01002685 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002686 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002687 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002688 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002689 DataType::QSymmS16,
2690 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002691 };
2692
2693 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2694 "Reference subtraction: input 0 is not a supported type.");
2695
2696 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2697 "Reference subtraction: input 1 is not a supported type.");
2698
2699 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2700 "Reference subtraction: output is not a supported type.");
2701
2702 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2703 "Reference subtraction: input 0 and Input 1 types are mismatched");
2704
2705 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2706 "Reference subtraction: input and output types are mismatched");
2707
2708 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2709 "Reference subtraction: shapes are not suitable for implicit broadcast.");
2710
2711 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002712}
2713
Matteo Martincighab9e5252019-06-13 17:27:46 +01002714bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
2715 const TensorInfo& alpha,
2716 const TensorInfo& output,
2717 Optional<std::string&> reasonIfUnsupported) const
2718{
2719 bool supported = true;
2720
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002721 std::array<DataType, 6> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01002722 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002723 DataType::BFloat16,
Matteo Martincighab9e5252019-06-13 17:27:46 +01002724 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002725 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002726 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002727 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002728 DataType::QSymmS16
Matteo Martincighab9e5252019-06-13 17:27:46 +01002729 };
2730
2731 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2732 "PReLU: input is not a supported type.");
2733
2734 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
2735 "PReLU: alpha is not a supported type.");
2736
2737 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2738 "PReLU: output is not a supported type.");
2739
2740 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
2741 "PReLU: input, alpha and output types are mismatched");
2742
2743 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
2744 "PReLU: shapes are not suitable for implicit broadcast");
2745
2746 return supported;
2747}
2748
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002749bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
2750 const TensorInfo& output,
2751 const TransposeConvolution2dDescriptor& descriptor,
2752 const TensorInfo& weights,
2753 const Optional<TensorInfo>& biases,
2754 Optional<std::string&> reasonIfUnsupported) const
2755{
Jan Eilers8eb25602020-03-09 12:13:48 +00002756 IgnoreUnused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002757 bool supported = true;
2758
Sadik Armagan303980c2020-04-17 12:45:14 +01002759 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002760 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002761 DataType::BFloat16,
2762 DataType::Float32,
2763 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002764 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002765 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002766 DataType::QSymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002767 DataType::QSymmS16
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002768 };
2769
2770 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2771 "Reference TransposeConvolution2d: input is not a supported type.");
2772
2773 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2774 "Reference TransposeConvolution2d: output is not a supported type.");
2775
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002776 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2777 "Reference TransposeConvolution2d: input and output types mismatched.");
2778
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002779
2780 const DataType inputType = input.GetDataType();
Sadik Armagan303980c2020-04-17 12:45:14 +01002781 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002782 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01002783 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002784 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002785 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002786 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01002787 DataType::QSymmS8
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002788 };
2789
2790 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
2791 "Reference TransposeConvolution2d: weights type not supported for "
2792 "quantized input.");
2793 }
2794 else
2795 {
2796 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
2797 "Reference TransposeConvolution2d: weights is not a supported type.");
2798
2799 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
2800 "Reference TransposeConvolution2d: input and weights types mismatched.");
2801 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002802
2803 if (biases.has_value())
2804 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002805 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01002806 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002807 DataType::BFloat16,
2808 DataType::Float32,
2809 DataType::Float16,
2810 DataType::Signed32
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002811 };
2812 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
2813 "Reference TransposeConvolution2d: biases is not a supported type.");
2814 }
2815
2816 return supported;
2817}
2818
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002819bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
2820 const TensorInfo& output,
2821 const TransposeDescriptor& descriptor,
2822 Optional<std::string&> reasonIfUnsupported) const
2823{
Jan Eilers8eb25602020-03-09 12:13:48 +00002824 IgnoreUnused(descriptor);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002825 bool supported = true;
2826
2827 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002828 std::array<DataType, 6> supportedTypes =
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002829 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002830 DataType::BFloat16,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002831 DataType::Float32,
2832 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002833 DataType::QAsymmS8,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002834 DataType::QAsymmU8,
2835 DataType::QSymmS16
2836 };
2837
2838 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2839 "Reference transpose: input is not a supported type.");
2840
2841 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2842 "Reference transpose: output is not a supported type.");
2843
2844 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2845 "Reference transpose: input and output types are mismatched.");
2846
2847 return supported;
2848}
2849
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002850bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
2851 const TensorInfo& input,
2852 const TensorInfo& outputStateIn,
2853 const TensorInfo& cellStateIn,
Mike Kelly12994962022-04-21 11:57:09 +01002854 const TensorInfo& outputStateOut,
2855 const TensorInfo& cellStateOut,
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002856 const TensorInfo& output,
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002857 const UnidirectionalSequenceLstmDescriptor& descriptor,
2858 const LstmInputParamsInfo& paramsInfo,
2859 Optional<std::string&> reasonIfUnsupported) const
2860{
2861 IgnoreUnused(descriptor);
2862 IgnoreUnused(paramsInfo);
2863 IgnoreUnused(outputStateIn);
2864 IgnoreUnused(cellStateIn);
Mike Kelly12994962022-04-21 11:57:09 +01002865 IgnoreUnused(outputStateOut);
2866 IgnoreUnused(cellStateOut);
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002867 bool supported = true;
2868
Mike Kelly12994962022-04-21 11:57:09 +01002869 std::array<DataType, 2> supportedTypes =
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002870 {
Mike Kelly12994962022-04-21 11:57:09 +01002871 DataType::Float32,
2872 DataType::QAsymmS8
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002873 };
2874
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002875 std::array<DataType, 2> supportedWeightTypes =
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002876 {
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002877 DataType::Float32,
2878 DataType::QAsymmS8
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002879 };
2880
Mike Kelly12994962022-04-21 11:57:09 +01002881 std::array<DataType, 3> supportedBiasTypes =
2882 {
2883 DataType::Float32,
2884 DataType::QAsymmS8,
2885 DataType::Signed32
2886 };
2887
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002888 // check inputs and outputs
2889 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2890 "Reference UnidirectionalSequenceLstm: input is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002891 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2892 "Reference UnidirectionalSequenceLstm: output is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002893
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002894 // check layer parameters
2895 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes),
2896 reasonIfUnsupported,
2897 "Reference UnidirectionalSequenceLstm: InputToForgetWeights "
2898 "is not a supported type.");
2899 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToCellWeights(), supportedWeightTypes),
2900 reasonIfUnsupported,
2901 "Reference UnidirectionalSequenceLstm: InputToCellWeights is not a supported type.");
2902 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToOutputWeights(), supportedWeightTypes),
2903 reasonIfUnsupported,
2904 "Reference UnidirectionalSequenceLstm: InputToOutputWeights "
2905 "is not a supported type.");
2906 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToForgetWeights(), supportedWeightTypes),
2907 reasonIfUnsupported,
2908 "Reference UnidirectionalSequenceLstm: RecurrentToForgetWeights "
2909 "is not a supported type.");
2910 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToCellWeights(), supportedWeightTypes),
2911 reasonIfUnsupported,
2912 "Reference UnidirectionalSequenceLstm: RecurrentToCellWeights "
2913 "is not a supported type.");
2914 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToOutputWeights(), supportedWeightTypes),
2915 reasonIfUnsupported,
2916 "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights "
2917 "is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002918
2919 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetGateBias(), supportedBiasTypes), reasonIfUnsupported,
2920 "Reference UnidirectionalSequenceLstm: ForgetGateBias is not a supported type.");
2921 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellBias(), supportedBiasTypes), reasonIfUnsupported,
2922 "Reference UnidirectionalSequenceLstm: CellBias is not a supported type.");
2923 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputGateBias(), supportedBiasTypes), reasonIfUnsupported,
2924 "Reference UnidirectionalSequenceLstm: OutputGateBias is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002925 if (!descriptor.m_CifgEnabled)
2926 {
2927 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes),
2928 reasonIfUnsupported,
2929 "Reference UnidirectionalSequenceLstm: InputToInputWeights "
2930 "is not a supported type.");
2931 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToInputWeights(), supportedWeightTypes),
2932 reasonIfUnsupported,
2933 "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights "
2934 "is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002935 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputGateBias(), supportedBiasTypes), reasonIfUnsupported,
2936 "Reference UnidirectionalSequenceLstm: InputGateBias is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002937 if (descriptor.m_PeepholeEnabled)
2938 {
2939 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes),
2940 reasonIfUnsupported,
2941 "Reference UnidirectionalSequenceLstm: CellToInputWeights "
2942 "is not a supported type.");
2943 }
2944 }
2945 if (descriptor.m_PeepholeEnabled)
2946 {
2947 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToForgetWeights(), supportedWeightTypes),
2948 reasonIfUnsupported,
2949 "Reference UnidirectionalSequenceLstm: CellToForgetWeights "
2950 "is not a supported type.");
2951 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToOutputWeights(), supportedWeightTypes),
2952 reasonIfUnsupported,
2953 "Reference UnidirectionalSequenceLstm: CellToOutputWeights "
2954 "is not a supported type.");
2955 }
2956 if (descriptor.m_ProjectionEnabled)
2957 {
2958 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetProjectionWeights(), supportedWeightTypes),
2959 reasonIfUnsupported,
2960 "Reference UnidirectionalSequenceLstm: ProjectionWeights "
2961 "is not a supported type.");
2962 if (paramsInfo.m_ProjectionBias != nullptr)
2963 {
2964 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
2965 "Reference UnidirectionalSequenceLstm: input and ProjectionBias types "
2966 "are mismatched");
2967 }
2968 }
2969 if (descriptor.m_LayerNormEnabled)
2970 {
2971 if (!descriptor.m_CifgEnabled)
2972 {
2973 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputLayerNormWeights(), supportedWeightTypes),
2974 reasonIfUnsupported,
2975 "Reference UnidirectionalSequenceLstm: InputLayerNormWeights "
2976 "is not a supported type.");
2977 }
2978 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetLayerNormWeights(), supportedWeightTypes),
2979 reasonIfUnsupported,
2980 "Reference UnidirectionalSequenceLstm: ForgetLayerNormWeights "
2981 "is not a supported type.");
2982 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellLayerNormWeights(), supportedWeightTypes),
2983 reasonIfUnsupported,
2984 "Reference UnidirectionalSequenceLstm: CellLayerNormWeights "
2985 "is not a supported type.");
2986 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputLayerNormWeights(), supportedWeightTypes),
2987 reasonIfUnsupported,
2988 "Reference UnidirectionalSequenceLstm: OutputLayerNormWeights "
2989 "is not a supported type.");
2990 }
2991
2992 return supported;
2993}
2994
arovir011c7c81b2018-10-08 11:34:28 +01002995} // namespace armnn