blob: f921383183521d9dade434407d0c433c1b50fcbe [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
Teresa Charlin52664732020-06-29 16:27:03 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01007
Keith Davis0c2eeac2020-02-11 16:51:50 +00008#include <armnn/TypesUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000010#include <armnn/utility/IgnoreUnused.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010011#include <armnn/utility/NumericCast.hpp>
Cathal Corbett34b429c2021-12-24 12:24:40 +000012#include <armnn/utility/PolymorphicDowncast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
Matteo Martincighe011d202019-11-28 11:35:47 +000014#include <LayerSupportCommon.hpp>
Derek Lambertif674aa02019-08-01 15:56:25 +010015#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000016
Derek Lamberti50db4e82019-03-13 14:16:15 +000017#include <vector>
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <array>
19
telsoa014fcda012018-03-09 14:13:49 +000020namespace armnn
21{
22
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010023namespace
24{
25
26template<typename Float32Func, typename Uint8Func, typename ... Params>
27bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
28 DataType dataType,
29 Float32Func floatFuncPtr,
30 Uint8Func uint8FuncPtr,
31 Params&&... params)
32{
33 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
34 dataType,
35 &FalseFunc<Params...>,
36 floatFuncPtr,
37 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000038 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000039 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010040 std::forward<Params>(params)...);
41}
42
43} // anonymous namespace
44
James Conroy4d1ff582019-06-10 17:06:39 +010045namespace
46{
47
48std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
49 unsigned int actual,
50 std::string& layerStr,
51 std::string& tensorName)
52{
53 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
54 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
55
56 return errorMsg;
57}
58
59} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000060
Cathal Corbett34b429c2021-12-24 12:24:40 +000061bool RefLayerSupport::IsLayerSupported(const LayerType& type,
62 const std::vector<TensorInfo>& infos,
63 const BaseDescriptor& descriptor,
64 const Optional<LstmInputParamsInfo>& lstmParamsInfo,
65 const Optional<QuantizedLstmInputParamsInfo>& quantizedLstmInputParamsInfo,
66 Optional<std::string&> reasonIfUnsupported) const
67{
68 switch (type)
69 {
70 case LayerType::Activation:
71 return IsActivationSupported(infos[0],
72 infos[1],
73 *(PolymorphicDowncast<const ActivationDescriptor*>(&descriptor)),
74 reasonIfUnsupported);
75 case LayerType::Addition:
76 return IsAdditionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
77 case LayerType::ArgMinMax:
78 return IsArgMinMaxSupported(infos[0],
79 infos[1],
80 *(PolymorphicDowncast<const ArgMinMaxDescriptor*>(&descriptor)),
81 reasonIfUnsupported);
82 case LayerType::BatchNormalization:
83 return IsBatchNormalizationSupported(infos[0],
84 infos[1],
85 infos[2],
86 infos[3],
87 infos[4],
88 infos[5],
89 *(PolymorphicDowncast<const BatchNormalizationDescriptor*>
90 (&descriptor)),
91 reasonIfUnsupported);
92 case LayerType::BatchToSpaceNd:
93 return IsBatchToSpaceNdSupported(infos[0],
94 infos[1],
95 *(PolymorphicDowncast<const BatchToSpaceNdDescriptor*>(&descriptor)),
96 reasonIfUnsupported);
97 case LayerType::Comparison:
98 return IsComparisonSupported(infos[0],
99 infos[1],
100 infos[2],
101 *(PolymorphicDowncast<const ComparisonDescriptor*>(&descriptor)),
102 reasonIfUnsupported);
103 case LayerType::Concat:
104 {
105 std::vector<const TensorInfo*> inputInfos;
106 for (uint32_t i = 0; i < (infos.size() - 1); i++)
107 {
108 inputInfos.push_back(&infos[i]);
109 }
110 return IsConcatSupported(inputInfos,
111 infos[infos.size() - 1],
112 *(PolymorphicDowncast<const OriginsDescriptor*>(&descriptor)),
113 reasonIfUnsupported);
114 }
115 case LayerType::Constant:
116 return IsConstantSupported(infos[0], reasonIfUnsupported);
117 case LayerType::ConvertBf16ToFp32:
118 return IsConvertBf16ToFp32Supported(infos[0], infos[1], reasonIfUnsupported);
119 case LayerType::ConvertFp16ToFp32:
120 return IsConvertFp16ToFp32Supported(infos[0], infos[1], reasonIfUnsupported);
121 case LayerType::ConvertFp32ToBf16:
122 return IsConvertFp32ToBf16Supported(infos[0], infos[1], reasonIfUnsupported);
123 case LayerType::ConvertFp32ToFp16:
124 return IsConvertFp32ToFp16Supported(infos[0], infos[1], reasonIfUnsupported);
125 case LayerType::Convolution2d:
126 {
127 if (infos.size() != 4)
128 {
129 throw InvalidArgumentException("Invalid number of Convolution2d TensorInfos. "
130 "TensorInfos should be of format: {input, output, weights, biases}.");
131 }
132
133 auto desc = *(PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor));
134 if (infos[3] == TensorInfo())
135 {
136 return IsConvolution2dSupported(infos[0],
137 infos[1],
138 desc,
139 infos[2],
140 EmptyOptional(),
141 reasonIfUnsupported);
142 }
143 else
144 {
145 return IsConvolution2dSupported(infos[0],
146 infos[1],
147 desc,
148 infos[2],
149 infos[3],
150 reasonIfUnsupported);
151 }
152 }
153 case LayerType::DepthToSpace:
154 return IsDepthToSpaceSupported(infos[0],
155 infos[1],
156 *(PolymorphicDowncast<const DepthToSpaceDescriptor*>(&descriptor)),
157 reasonIfUnsupported);
158 case LayerType::DepthwiseConvolution2d:
159 {
160 if (infos.size() != 4)
161 {
162 throw InvalidArgumentException("Invalid number of DepthwiseConvolution2d TensorInfos. "
163 "TensorInfos should be of format: {input, output, weights, biases}.");
164 }
165
166 auto desc = *(PolymorphicDowncast<const DepthwiseConvolution2dDescriptor*>(&descriptor));
167 if (infos[3] == TensorInfo())
168 {
169 return IsDepthwiseConvolutionSupported(infos[0],
170 infos[1],
171 desc,
172 infos[2],
173 EmptyOptional(),
174 reasonIfUnsupported);
175 }
176 else
177 {
178 return IsDepthwiseConvolutionSupported(infos[0],
179 infos[1],
180 desc,
181 infos[2],
182 infos[3],
183 reasonIfUnsupported);
184 }
185 }
186 case LayerType::Dequantize:
187 return IsDequantizeSupported(infos[0], infos[1], reasonIfUnsupported);
188 case LayerType::Division:
189 return IsDivisionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
190 case LayerType::ElementwiseUnary:
191 return IsElementwiseUnarySupported(infos[0],
192 infos[1],
193 *(PolymorphicDowncast<const ElementwiseUnaryDescriptor*>(&descriptor)),
194 reasonIfUnsupported);
195 case LayerType::Fill:
196 return IsFillSupported(infos[0],
197 infos[1],
198 *(PolymorphicDowncast<const FillDescriptor*>(&descriptor)),
199 reasonIfUnsupported);
200 case LayerType::Floor:
201 return IsFloorSupported(infos[0], infos[1], reasonIfUnsupported);
202 case LayerType::FullyConnected:
203 return IsFullyConnectedSupported(infos[0],
204 infos[1],
205 infos[2],
206 infos[3],
207 *(PolymorphicDowncast<const FullyConnectedDescriptor*>(&descriptor)),
208 reasonIfUnsupported);
209 case LayerType::Gather:
210 return IsGatherSupported(infos[0],
211 infos[1],
212 infos[2],
213 *(PolymorphicDowncast<const GatherDescriptor*>(&descriptor)),
214 reasonIfUnsupported);
Teresa Charlinb2d3ec52022-04-12 22:07:09 +0100215 case LayerType::GatherNd:
216 return IsGatherNdSupported(infos[0],
217 infos[1],
218 infos[2],
219 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000220 case LayerType::Input:
221 return IsInputSupported(infos[0], reasonIfUnsupported);
222 case LayerType::InstanceNormalization:
223 return IsInstanceNormalizationSupported(infos[0],
224 infos[1],
225 *(PolymorphicDowncast<const InstanceNormalizationDescriptor*>
226 (&descriptor)),
227 reasonIfUnsupported);
228 case LayerType::L2Normalization:
229 return IsL2NormalizationSupported(infos[0],
230 infos[1],
231 *(PolymorphicDowncast<const L2NormalizationDescriptor*>(&descriptor)),
232 reasonIfUnsupported);
233 case LayerType::LogicalBinary:
234 return IsLogicalBinarySupported(infos[0],
235 infos[1],
236 infos[2],
237 *(PolymorphicDowncast<const LogicalBinaryDescriptor*>(&descriptor)),
238 reasonIfUnsupported);
239 case LayerType::LogSoftmax:
240 return IsLogSoftmaxSupported(infos[0],
241 infos[1],
242 *(PolymorphicDowncast<const LogSoftmaxDescriptor*>(&descriptor)),
243 reasonIfUnsupported);
244 case LayerType::Lstm:
245 return IsLstmSupported(infos[0],
246 infos[1],
247 infos[2],
248 infos[3],
249 infos[4],
250 infos[5],
251 infos[6],
252 *(PolymorphicDowncast<const LstmDescriptor*>(&descriptor)),
253 lstmParamsInfo.value(),
254 reasonIfUnsupported);
255 case LayerType::QLstm:
256 return IsQLstmSupported(infos[0],
257 infos[1],
258 infos[2],
259 infos[3],
260 infos[4],
261 infos[5],
262 *(PolymorphicDowncast<const QLstmDescriptor*>(&descriptor)),
263 lstmParamsInfo.value(),
264 reasonIfUnsupported);
265 case LayerType::Maximum:
266 return IsMaximumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
267 case LayerType::Mean:
268 return IsMeanSupported(infos[0],
269 infos[1],
270 *(PolymorphicDowncast<const MeanDescriptor*>(&descriptor)),
271 reasonIfUnsupported);
272 case LayerType::Minimum:
273 return IsMinimumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
274 case LayerType::Multiplication:
275 return IsMultiplicationSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
276 case LayerType::Normalization:
277 return IsNormalizationSupported(infos[0],
278 infos[1],
279 *(PolymorphicDowncast<const NormalizationDescriptor*>(&descriptor)),
280 reasonIfUnsupported);
281 case LayerType::Output:
282 return IsOutputSupported(infos[0], reasonIfUnsupported);
283 case LayerType::Pad:
284 return IsPadSupported(infos[0],
285 infos[1],
286 *(PolymorphicDowncast<const PadDescriptor*>(&descriptor)),
287 reasonIfUnsupported);
288 case LayerType::Permute:
289 return IsPermuteSupported(infos[0],
290 infos[1],
291 *(PolymorphicDowncast<const PermuteDescriptor*>(&descriptor)),
292 reasonIfUnsupported);
293 case LayerType::Pooling2d:
294 return IsPooling2dSupported(infos[0],
295 infos[1],
296 *(PolymorphicDowncast<const Pooling2dDescriptor*>(&descriptor)),
297 reasonIfUnsupported);
298 case LayerType::Prelu:
299 return IsPreluSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
300 case LayerType::Quantize:
301 return IsQuantizeSupported(infos[0], infos[1], reasonIfUnsupported);
302 case LayerType::Reshape:
303 return IsReshapeSupported(infos[0],
304 infos[1],
305 *(PolymorphicDowncast<const ReshapeDescriptor*>(&descriptor)),
306 reasonIfUnsupported);
307 case LayerType::Resize:
308 return IsResizeSupported(infos[0],
309 infos[1],
310 *(PolymorphicDowncast<const ResizeDescriptor*>(&descriptor)),
311 reasonIfUnsupported);
312 case LayerType::Reduce:
313 return IsReduceSupported(infos[0],
314 infos[1],
315 *(PolymorphicDowncast<const ReduceDescriptor*>(&descriptor)),
316 reasonIfUnsupported);
317 case LayerType::Slice:
318 return IsSliceSupported(infos[0],
319 infos[1],
320 *(PolymorphicDowncast<const SliceDescriptor*>(&descriptor)),
321 reasonIfUnsupported);
322 case LayerType::Softmax:
323 return IsSoftmaxSupported(infos[0],
324 infos[1],
325 *(PolymorphicDowncast<const SoftmaxDescriptor*>(&descriptor)),
326 reasonIfUnsupported);
327 case LayerType::SpaceToBatchNd:
328 return IsSpaceToBatchNdSupported(infos[0],
329 infos[1],
330 *(PolymorphicDowncast<const SpaceToBatchNdDescriptor*>(&descriptor)),
331 reasonIfUnsupported);
332 case LayerType::SpaceToDepth:
333 return IsSpaceToDepthSupported(infos[0],
334 infos[1],
335 *(PolymorphicDowncast<const SpaceToDepthDescriptor*>(&descriptor)),
336 reasonIfUnsupported);
337 case LayerType::Splitter:
338 {
339 std::vector<TensorInfo> outputInfos;
340 for (uint32_t i = 1; i < infos.size(); i++)
341 {
342 outputInfos.push_back(infos[i]);
343 }
344 return IsSplitterSupported(infos[0],
345 {outputInfos.begin(), outputInfos.end()},
346 *(PolymorphicDowncast<const ViewsDescriptor*>(&descriptor)),
347 reasonIfUnsupported);
348 }
349 case LayerType::Stack:
350 {
351 std::vector<const TensorInfo*> inputInfos;
352 for (uint32_t i = 0; i < infos.size() - 1; i++)
353 {
354 inputInfos.push_back(&infos[i]);
355 }
356 return IsStackSupported(inputInfos,
357 infos[infos.size() - 1],
358 *(PolymorphicDowncast<const StackDescriptor*>(&descriptor)),
359 reasonIfUnsupported);
360 }
361 case LayerType::StridedSlice:
362 return IsStridedSliceSupported(infos[0],
363 infos[1],
364 *(PolymorphicDowncast<const StridedSliceDescriptor*>(&descriptor)),
365 reasonIfUnsupported);
366 case LayerType::Subtraction:
367 return IsSubtractionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
368 case LayerType::Transpose:
369 return IsTransposeSupported(infos[0],
370 infos[1],
371 *(PolymorphicDowncast<const TransposeDescriptor*>(&descriptor)),
372 reasonIfUnsupported);
373 case LayerType::TransposeConvolution2d:
374 {
375 if (infos.size() != 4)
376 {
377 throw InvalidArgumentException("Invalid number of TransposeConvolution2d TensorInfos. "
378 "TensorInfos should be of format: {input, output, weights, biases}.");
379 }
380
381 auto desc = *(PolymorphicDowncast<const TransposeConvolution2dDescriptor*>(&descriptor));
382 if (infos[3] == TensorInfo())
383 {
384 return IsTransposeConvolution2dSupported(infos[0],
385 infos[1],
386 desc,
387 infos[2],
388 EmptyOptional(),
389 reasonIfUnsupported);
390 }
391 else
392 {
393 return IsTransposeConvolution2dSupported(infos[0],
394 infos[1],
395 desc,
396 infos[2],
397 infos[3],
398 reasonIfUnsupported);
399 }
400 }
401 case LayerType::Cast:
402 return IsCastSupported(infos[0], infos[1], reasonIfUnsupported);
403 case LayerType::ChannelShuffle:
404 return IsChannelShuffleSupported(infos[0],
405 infos[1],
406 *(PolymorphicDowncast<const ChannelShuffleDescriptor*>(&descriptor)),
407 reasonIfUnsupported);
408 case LayerType::Convolution3d:
409 {
410 if (infos.size() != 4)
411 {
412 throw InvalidArgumentException("Invalid number of Convolution3d TensorInfos. "
413 "TensorInfos should be of format: {input, output, weights, biases}.");
414 }
415
416 auto desc = *(PolymorphicDowncast<const Convolution3dDescriptor*>(&descriptor));
417 if (infos[3] == TensorInfo())
418 {
419 return IsConvolution3dSupported(infos[0],
420 infos[1],
421 desc,
422 infos[2],
423 EmptyOptional(),
424 reasonIfUnsupported);
425 }
426 else
427 {
428 return IsConvolution3dSupported(infos[0],
429 infos[1],
430 desc,
431 infos[2],
432 infos[3],
433 reasonIfUnsupported);
434 }
435 }
436 case LayerType::Debug:
437 return IsDebugSupported(infos[0], infos[1], reasonIfUnsupported);
438 case LayerType::DetectionPostProcess:
439 return IsDetectionPostProcessSupported(infos[0],
440 infos[1],
441 infos[2],
442 infos[3],
443 infos[4],
444 infos[5],
445 infos[6],
446 *(PolymorphicDowncast<const DetectionPostProcessDescriptor*>
447 (&descriptor)),
448 reasonIfUnsupported);
449 case LayerType::FakeQuantization:
450 return IsFakeQuantizationSupported(infos[0],
451 *(PolymorphicDowncast<const FakeQuantizationDescriptor*>(&descriptor)),
452 reasonIfUnsupported);
453 case LayerType::MemCopy:
454 return IsMemCopySupported(infos[0], infos[1], reasonIfUnsupported);
455 case LayerType::Rank:
456 return IsRankSupported(infos[0], infos[1], reasonIfUnsupported);
457 case LayerType::Shape:
458 return IsShapeSupported(infos[0], infos[1], reasonIfUnsupported);
459 case LayerType::UnidirectionalSequenceLstm:
460 {
461 if (infos.size() != 6)
462 {
463 throw InvalidArgumentException("Invalid number of UnidirectionalSequenceLstm TensorInfos. TensorInfos "
464 "should be of format: {input, outputStateIn, cellStateIn, "
465 "hiddenStateOutputVal, cellStateOutputVal, output}");
466 }
467 auto desc = *(PolymorphicDowncast<const UnidirectionalSequenceLstmDescriptor*>(&descriptor));
Mike Kelly12994962022-04-21 11:57:09 +0100468 return IsUnidirectionalSequenceLstmSupported(infos[0],
469 infos[1],
470 infos[2],
471 infos[3],
472 infos[4],
473 infos[5],
474 desc,
475 lstmParamsInfo.value(),
476 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000477 }
478 case LayerType::Pooling3d:
479 return IsPooling3dSupported(infos[0],
480 infos[1],
481 *(PolymorphicDowncast<const Pooling3dDescriptor*>(&descriptor)),
482 reasonIfUnsupported);
483 case LayerType::Map:
484 return true;
485 case LayerType::Unmap:
486 return true;
487 case LayerType::MemImport:
488 return LayerSupportBase::IsMemImportSupported(infos[0], infos[1], reasonIfUnsupported);
489 case LayerType::Merge:
490 return LayerSupportBase::IsMergeSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
491 case LayerType::QuantizedLstm:
492 return LayerSupportBase::IsQuantizedLstmSupported(infos[0],
493 infos[1],
494 infos[2],
495 infos[3],
496 infos[4],
497 quantizedLstmInputParamsInfo.value(),
498 reasonIfUnsupported);
499 default:
500 // layers not supported in neon by default:
501 // precompiled, standin, switch
502 return false;
503 }
504}
505
arovir011c7c81b2018-10-08 11:34:28 +0100506bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
507 const TensorInfo& output,
508 const ActivationDescriptor& descriptor,
509 Optional<std::string&> reasonIfUnsupported) const
510{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000511 bool supported = true;
512
513 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +0000514 std::array<DataType,6> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000515 DataType::BFloat16,
Derek Lamberti50db4e82019-03-13 14:16:15 +0000516 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100517 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000518 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000519 DataType::QAsymmU8,
520 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000521 };
522
523 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
524 "Reference activation: input type not supported.");
525
526 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
527 "Reference activation: output type not supported.");
528
529 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
530 "Reference activation: input and output types mismatched.");
531
532 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
533 "Reference activation: input and output shapes are of different rank.");
534
535
536 struct ActivationFunctionSupported : public Rule
537 {
538 ActivationFunctionSupported(const ActivationDescriptor& desc)
539 {
540 switch(desc.m_Function)
541 {
542 case ActivationFunction::Abs:
543 case ActivationFunction::BoundedReLu:
David Monahan3b3c3812020-02-25 09:03:29 +0000544 case ActivationFunction::Elu:
Colm Donelan03fbeaf2020-02-26 15:39:23 +0000545 case ActivationFunction::HardSwish:
Derek Lamberti50db4e82019-03-13 14:16:15 +0000546 case ActivationFunction::LeakyReLu:
547 case ActivationFunction::Linear:
548 case ActivationFunction::ReLu:
549 case ActivationFunction::Sigmoid:
550 case ActivationFunction::SoftReLu:
551 case ActivationFunction::Sqrt:
552 case ActivationFunction::Square:
553 case ActivationFunction::TanH:
554 {
555 m_Res = true;
556 break;
557 }
558 default:
559 {
560 m_Res = false;
561 break;
562 }
563 }
564 }
565 };
566
567 // Function is supported
568 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
569 "Reference activation: function not supported.");
570
571 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100572}
573
574bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
575 const TensorInfo& input1,
576 const TensorInfo& output,
577 Optional<std::string&> reasonIfUnsupported) const
578{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000579 bool supported = true;
580
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100581 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000582 DataType::BFloat16,
Derek Lamberti50db4e82019-03-13 14:16:15 +0000583 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100584 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000585 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000586 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100587 DataType::QSymmS16,
588 DataType::Signed32
Derek Lamberti50db4e82019-03-13 14:16:15 +0000589 };
590
591 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
592 "Reference addition: input 0 is not a supported type.");
593
594 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
595 "Reference addition: input 1 is not a supported type.");
596
597 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
598 "Reference addition: output is not a supported type.");
599
600 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
601 "Reference addition: input 0 and Input 1 types are mismatched");
602
603 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
604 "Reference addition: input and output types are mismatched");
605
606 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
607 "Reference addition: shapes are not suitable for implicit broadcast.");
608
609 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100610}
611
Nikhil Raj68c2c902019-09-19 11:21:11 +0100612bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
613 const armnn::ArgMinMaxDescriptor &descriptor,
614 armnn::Optional<std::string &> reasonIfUnsupported) const
615{
Jan Eilers8eb25602020-03-09 12:13:48 +0000616 IgnoreUnused(descriptor);
Nikhil Raj68c2c902019-09-19 11:21:11 +0100617
Mike Kelly1f140f72021-04-06 12:25:55 +0100618 std::array<DataType, 8> supportedInputTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100619 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000620 DataType::BFloat16,
Teresa Charline300b362020-05-25 10:01:03 +0100621 DataType::Float16,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100622 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100623 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000624 DataType::QAsymmU8,
625 DataType::QSymmS16,
Mike Kelly1f140f72021-04-06 12:25:55 +0100626 DataType::Signed32,
627 DataType::Signed64
628 };
629
630 std::array<DataType,2> supportedOutputTypes = {
631 DataType::Signed32,
632 DataType::Signed64
Nikhil Raj68c2c902019-09-19 11:21:11 +0100633 };
634
635 bool supported = true;
636
Mike Kelly1f140f72021-04-06 12:25:55 +0100637 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100638 "Reference ArgMinMax: input is not a supported type.");
Mike Kelly1f140f72021-04-06 12:25:55 +0100639 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100640 "Reference ArgMinMax: output type not supported");
641
642 return supported;
643}
644
arovir011c7c81b2018-10-08 11:34:28 +0100645bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
646 const TensorInfo& output,
647 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100648 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100649 const TensorInfo& beta,
650 const TensorInfo& gamma,
651 const BatchNormalizationDescriptor& descriptor,
652 Optional<std::string&> reasonIfUnsupported) const
653{
Jan Eilers8eb25602020-03-09 12:13:48 +0000654 IgnoreUnused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100655
Sadik Armagan303980c2020-04-17 12:45:14 +0100656 std::array<DataType, 6> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100657 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000658 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100659 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100660 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100661 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000662 DataType::QAsymmU8,
663 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100664 };
665
666 bool supported = true;
667
668 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
669 "Reference batch normalization: input is not a supported type.");
670
671 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
672 "Reference batch normalization: output is not a supported type.");
673
674 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
675 "Reference batch normalization: input and output types are mismatched");
676
677 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
678 "Reference batch normalization: mean is not a supported type.");
679
680 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
681 "Reference batch normalization: variance is not a supported type.");
682
683 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
684 "Reference batch normalization: beta is not a supported type.");
685
686 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
687 "Reference batch normalization: gamma is not a supported type.");
688
689 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100690}
691
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000692bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
693 const TensorInfo& output,
694 const BatchToSpaceNdDescriptor& descriptor,
695 Optional<std::string&> reasonIfUnsupported) const
696{
Jan Eilers8eb25602020-03-09 12:13:48 +0000697 IgnoreUnused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100698
699 bool supported = true;
700
701 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
702 std::string inputTensorStr = "input";
703 std::string outputTensorStr = "output";
704
705 // Define supported types.
Sadik Armagan303980c2020-04-17 12:45:14 +0100706 std::array<DataType,6> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100707 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000708 DataType::BFloat16,
709 DataType::Float32,
710 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100711 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000712 DataType::QAsymmU8,
713 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100714 };
715
716 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
717 "Reference BatchToSpaceNd: input type not supported.");
718
719 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
720 "Reference BatchToSpaceNd: output type not supported.");
721
722 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
723 "Reference BatchToSpaceNd: input and output types mismatched.");
724
725 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
726 reasonIfUnsupported,
727 CreateIncorrectDimensionsErrorMsg(4,
728 output.GetNumDimensions(),
729 batchToSpaceNdLayerStr,
730 outputTensorStr).data());
731
732 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
733 reasonIfUnsupported,
734 CreateIncorrectDimensionsErrorMsg(4,
735 input.GetNumDimensions(),
736 batchToSpaceNdLayerStr,
737 inputTensorStr).data());
738
739 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000740}
741
mathad01b392e982021-04-07 12:07:30 +0100742bool RefLayerSupport::IsCastSupported(const TensorInfo& input,
743 const TensorInfo& output,
744 Optional<std::string&> reasonIfUnsupported) const
745{
746 std::array<DataType, 9> supportedInputTypes =
747 {
748 DataType::BFloat16,
749 DataType::Float32,
750 DataType::Float16,
751 DataType::QSymmS8,
752 DataType::QAsymmS8,
753 DataType::QAsymmU8,
754 DataType::QSymmS16,
755 DataType::Signed32
756 };
757
758 bool supported = true;
759 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
760 "Reference cast: input is not a supported type");
761
762
763 supported &= CheckSupportRule(TypeAnyOf(output, supportedInputTypes), reasonIfUnsupported,
764 "Reference cast: output is not a supported type");
765
766 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
767 "Reference cast: input and output shapes have different number of total elements");
768
769 return supported;
770}
771
Simon Obute51f67772021-09-03 15:50:13 +0100772bool RefLayerSupport::IsChannelShuffleSupported(const TensorInfo& input,
773 const TensorInfo& output,
774 const ChannelShuffleDescriptor& descriptor,
775 Optional<std::string&> reasonIfUnsupported) const
776{
777 IgnoreUnused(descriptor);
778 bool supported = true;
779
780 // Define supported output and inputs types.
781 std::array<DataType, 7> supportedTypes =
782 {
783 DataType::BFloat16,
784 DataType::Float32,
785 DataType::Float16,
786 DataType::QAsymmS8,
787 DataType::QAsymmU8,
788 DataType::QSymmS8,
789 DataType::QSymmS16
790 };
791
792 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
793 "Reference ChannelShuffle: input is not a supported type.");
794
795 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
796 "Reference ChannelShuffle: output is not a supported type.");
797
798 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
799 "Reference ChannelShuffle: input and output types are mismatched.");
800
801 return supported;
802}
803
804
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100805bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
806 const TensorInfo& input1,
807 const TensorInfo& output,
808 const ComparisonDescriptor& descriptor,
809 Optional<std::string&> reasonIfUnsupported) const
810{
Jan Eilers8eb25602020-03-09 12:13:48 +0000811 IgnoreUnused(descriptor);
Sadik Armagan303980c2020-04-17 12:45:14 +0100812 std::array<DataType, 8> supportedInputTypes =
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100813 {
Sadik Armaganb60dd242020-03-19 13:53:16 +0000814 DataType::Boolean,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000815 DataType::BFloat16,
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100816 DataType::Float32,
817 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100818 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000819 DataType::QAsymmU8,
Sadik Armaganb60dd242020-03-19 13:53:16 +0000820 DataType::QSymmS16,
821 DataType::Signed32
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100822 };
823
824 bool supported = true;
825 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
826 "Reference comparison: input 0 is not a supported type");
827
828 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
829 "Reference comparison: input 0 and Input 1 types are mismatched");
830
831 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
832 "Reference comparison: output is not of type Boolean");
833
834 return supported;
835}
836
Jim Flynn906f9462019-05-10 13:55:21 +0100837bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
838 const TensorInfo& output,
Cathal Corbett34b429c2021-12-24 12:24:40 +0000839 const OriginsDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100840 Optional<std::string&> reasonIfUnsupported) const
841{
Jan Eilers8eb25602020-03-09 12:13:48 +0000842 IgnoreUnused(descriptor);
Jim Flynne242f2d2019-05-22 14:24:13 +0100843
844 bool supported = true;
Teresa Charlin6abc7ee2022-02-22 17:32:27 +0000845 std::array<DataType,7> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100846 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000847 DataType::BFloat16,
848 DataType::Float32,
849 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000850 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100851 DataType::QAsymmU8,
Teresa Charlin6abc7ee2022-02-22 17:32:27 +0000852 DataType::QSymmS16,
853 DataType::Signed32
Jim Flynne242f2d2019-05-22 14:24:13 +0100854 };
855
856 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
857 "Reference concatenation: output type not supported");
858 for (const TensorInfo* input : inputs)
859 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100860 ARMNN_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100861 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
862 "Reference concatenation: input type not supported");
863
864 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
865 "Reference concatenation: input and output types mismatched.");
866 }
867
868 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100869}
870
arovir011c7c81b2018-10-08 11:34:28 +0100871bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
872 Optional<std::string&> reasonIfUnsupported) const
873{
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100874 std::array<DataType,8> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100875 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000876 DataType::BFloat16,
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100877 DataType::Float16,
Nina Drozd58ef2c62019-05-16 12:09:18 +0100878 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +0000879 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100880 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000881 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100882 DataType::QSymmS16,
883 DataType::Signed32
Nina Drozd58ef2c62019-05-16 12:09:18 +0100884 };
885
886 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
887 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100888}
889
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000890bool RefLayerSupport::IsConvertBf16ToFp32Supported(const TensorInfo& input,
891 const TensorInfo& output,
892 Optional<std::string&> reasonIfUnsupported) const
893{
894 bool supported = true;
895
896 supported &= CheckSupportRule(TypeIs(input, DataType::BFloat16), reasonIfUnsupported,
897 "Reference for ConvertBf16ToFp32 layer: input type not supported");
898
899 supported &= CheckSupportRule(TypeIs(output, DataType::Float32), reasonIfUnsupported,
900 "Reference for ConvertBf16ToFp32 layer: output type not supported");
901
902 return supported;
903}
904
arovir011c7c81b2018-10-08 11:34:28 +0100905bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
906 const TensorInfo& output,
907 Optional<std::string&> reasonIfUnsupported) const
908{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100909 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
910 input.GetDataType(),
911 &TrueFunc<>,
912 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000913 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000914 &FalseFuncI32<>,
915 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100916 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
917 output.GetDataType(),
918 &FalseOutputFuncF16<>,
919 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000920 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000921 &FalseFuncI32<>,
922 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100923}
924
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000925bool RefLayerSupport::IsConvertFp32ToBf16Supported(const TensorInfo& input,
926 const TensorInfo& output,
927 Optional<std::string&> reasonIfUnsupported) const
928{
929 bool supported = true;
930
931 supported &= CheckSupportRule(TypeIs(input, DataType::Float32), reasonIfUnsupported,
932 "Reference for ConvertFp32ToBf16 layer: input type not supported");
933
934 supported &= CheckSupportRule(TypeIs(output, DataType::BFloat16), reasonIfUnsupported,
935 "Reference for ConvertFp32ToBf16 layer: output type not supported");
936
937 return supported;
938}
939
arovir011c7c81b2018-10-08 11:34:28 +0100940bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
941 const TensorInfo& output,
942 Optional<std::string&> reasonIfUnsupported) const
943{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100944 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
945 input.GetDataType(),
946 &FalseInputFuncF16<>,
947 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000948 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000949 &FalseFuncI32<>,
950 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100951 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
952 output.GetDataType(),
953 &TrueFunc<>,
954 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000955 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000956 &FalseFuncI32<>,
957 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100958}
959
960bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
961 const TensorInfo& output,
962 const Convolution2dDescriptor& descriptor,
963 const TensorInfo& weights,
964 const Optional<TensorInfo>& biases,
965 Optional<std::string&> reasonIfUnsupported) const
966{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100967 bool supported = true;
968
969 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000970 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000971 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000972 DataType::BFloat16,
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000973 DataType::Float32,
974 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000975 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100976 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000977 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000978 DataType::QSymmS16
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100979 };
980
981 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000982 "Reference Convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100983
984 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000985 "Reference Convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100986
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +0000987 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
988 if (input.GetDataType() == DataType::BFloat16)
989 {
990 if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
991 {
992 reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
993 supported = false;
994 }
995 }
996 else
997 {
998 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000999 "Reference Convolution2d: input and output types mismatched.");
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001000 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001001
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001002 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +00001003 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001004 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001005 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001006 {
Sadik Armagan303980c2020-04-17 12:45:14 +01001007 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001008 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01001009 DataType::QSymmS8
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001010 };
1011
1012 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001013 "Reference Convolution2d: weights type not supported for quantized input.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001014 }
1015 else
1016 {
1017 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001018 "Reference Convolution2d: weights is not a supported type.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001019
1020 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001021 "Reference Convolution2d: input and weights types mismatched.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001022 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001023
1024 if (biases.has_value())
1025 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001026 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001027 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001028 DataType::BFloat16,
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001029 DataType::Float32,
1030 DataType::Float16,
1031 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001032 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001033
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001034 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001035 "Reference Convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001036 }
Jan Eilers8eb25602020-03-09 12:13:48 +00001037 IgnoreUnused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001038
1039 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001040}
1041
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001042bool RefLayerSupport::IsConvolution3dSupported(const TensorInfo& input,
1043 const TensorInfo& output,
1044 const Convolution3dDescriptor& descriptor,
1045 const TensorInfo& weights,
1046 const Optional<TensorInfo>& biases,
1047 Optional<std::string&> reasonIfUnsupported) const
1048{
1049 bool supported = true;
1050
1051 // Define supported types.
1052 std::array<DataType,7> supportedTypes =
1053 {
1054 DataType::BFloat16,
1055 DataType::Float32,
1056 DataType::Float16,
1057 DataType::QAsymmS8,
1058 DataType::QAsymmU8,
1059 DataType::QSymmS8,
1060 DataType::QSymmS16
1061 };
1062
1063 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1064 "Reference Convolution3d: input is not a supported type.");
1065
1066 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1067 "Reference Convolution3d: output is not a supported type.");
1068
1069 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1070 "Reference Convolution3d: input and output types mismatched.");
1071
1072 const DataType inputType = input.GetDataType();
1073 if (IsQuantized8BitType(inputType))
1074 {
1075 std::array<DataType, 3> supportedWeightTypes =
1076 {
1077 DataType::QAsymmS8,
1078 DataType::QAsymmU8,
1079 DataType::QSymmS8
1080 };
1081
1082 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1083 "Reference Convolution3d: weights type not supported for quantized input.");
1084 }
1085 else
1086 {
1087 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1088 "Reference Convolution3d: weights is not a supported type.");
1089
1090 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1091 "Reference Convolution3d: input and weights types mismatched.");
1092 }
1093
1094 if (biases.has_value())
1095 {
1096 std::array<DataType,4> biasesSupportedTypes =
1097 {
1098 DataType::BFloat16,
1099 DataType::Float32,
1100 DataType::Float16,
1101 DataType::Signed32
1102 };
1103
1104 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1105 "Reference Convolution3d: biases is not a supported type.");
1106 }
1107 IgnoreUnused(descriptor);
1108
1109 return supported;
1110}
1111
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001112bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
1113 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001114 Optional<std::string&> reasonIfUnsupported) const
1115{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001116 bool supported = true;
1117
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001118 std::array<DataType, 8> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001119 {
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001120 DataType::BFloat16,
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +00001121 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001122 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001123 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001124 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001125 DataType::QSymmS8,
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +00001126 DataType::QSymmS16,
1127 DataType::Signed32
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001128 };
1129
1130 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001131 "Reference for Debug layer: input type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001132
1133 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001134 "Reference for Debug layer: output type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001135
1136 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001137 "Reference for Debug layer: input and output types are mismatched");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001138
1139 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001140}
1141
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001142bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
1143 const TensorInfo& output,
1144 const DepthToSpaceDescriptor& descriptor,
1145 Optional<std::string&> reasonIfUnsupported) const
1146{
Jan Eilers8eb25602020-03-09 12:13:48 +00001147 IgnoreUnused(descriptor);
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001148 bool supported = true;
1149
Sadik Armagan303980c2020-04-17 12:45:14 +01001150 std::array<DataType,6> supportedTypes =
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001151 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001152 DataType::BFloat16,
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001153 DataType::Float32,
1154 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001155 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001156 DataType::QAsymmU8,
1157 DataType::QSymmS16
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001158 };
1159
1160 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1161 "Reference DepthToSpace: input type not supported");
1162
1163 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1164 "Reference DepthToSpace: output type not supported");
1165
1166 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1167 "Reference DepthToSpace: input and output types are mismatched");
1168
1169 return supported;
1170}
1171
arovir011c7c81b2018-10-08 11:34:28 +01001172bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
1173 const TensorInfo& output,
1174 const DepthwiseConvolution2dDescriptor& descriptor,
1175 const TensorInfo& weights,
1176 const Optional<TensorInfo>& biases,
1177 Optional<std::string&> reasonIfUnsupported) const
1178{
Sadik Armagan303980c2020-04-17 12:45:14 +01001179 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001180 bool supported = true;
1181
1182 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001183 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001184 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001185 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001186 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001187 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001188 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001189 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001190 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001191 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001192 };
1193
1194 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1195 "Reference DepthwiseConvolution2d: input is not a supported type.");
1196
1197 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1198 "Reference DepthwiseConvolution2d: output is not a supported type.");
1199
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001200 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1201 "Reference DepthwiseConvolution2d: input and output types mismatched.");
1202
Teresa Charlind8df0262019-11-11 12:28:15 +00001203 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +00001204 if (IsQuantized8BitType(inputType))
Teresa Charlind8df0262019-11-11 12:28:15 +00001205 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001206 std::array<DataType, 3> supportedWeightTypes =
Sadik Armagan303980c2020-04-17 12:45:14 +01001207 {
1208 DataType::QAsymmS8,
1209 DataType::QAsymmU8,
1210 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001211 };
Teresa Charlind8df0262019-11-11 12:28:15 +00001212
1213 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Sadik Armagan303980c2020-04-17 12:45:14 +01001214 "Reference DepthwiseConvolution2d: weights type not supported for "
1215 "quantized input.");
Teresa Charlind8df0262019-11-11 12:28:15 +00001216 }
1217 else
1218 {
1219 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1220 "Reference DepthwiseConvolution2d: weights is not a supported type.");
1221
1222 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1223 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
1224 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001225
1226 if (biases.has_value())
1227 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001228 std::array<DataType,4> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001229 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001230 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001231 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001232 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001233 DataType::Signed32
1234 };
1235 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1236 "Reference DepthwiseConvolution2d: biases is not a supported type.");
1237 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001238
1239 return supported;
1240
arovir011c7c81b2018-10-08 11:34:28 +01001241}
1242
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001243bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
1244 const TensorInfo& output,
1245 Optional<std::string&> reasonIfUnsupported) const
1246{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001247 bool supported = true;
1248
Ryan OShea9add1202020-02-07 10:06:33 +00001249 std::array<DataType,4> supportedInputTypes = {
1250 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001251 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00001252 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001253 DataType::QSymmS16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001254 };
1255
1256 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001257 "Reference for Dequantize layer: input type not supported.");
1258
Derek Lambertid466a542020-01-22 15:37:29 +00001259 supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
Teresa Charlin1b1950d2021-06-02 20:23:21 +01001260 "Reference for Dequantize layer: per-axis quantized input not supported.");
Derek Lambertid466a542020-01-22 15:37:29 +00001261
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001262 std::array<DataType,3> supportedOutputTypes = {
1263 DataType::BFloat16,
Jan Eilersf7107932019-11-01 11:09:36 +00001264 DataType::Float32,
1265 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001266 };
1267
1268 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001269 "Reference for Dequantize layer: output type not supported.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001270
1271 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001272 "Reference for Dequantize layer: input/output shapes have different num total "
1273 "elements.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001274
1275 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001276}
1277
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001278bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
1279 const TensorInfo& scores,
1280 const TensorInfo& anchors,
1281 const TensorInfo& detectionBoxes,
1282 const TensorInfo& detectionClasses,
1283 const TensorInfo& detectionScores,
1284 const TensorInfo& numDetections,
1285 const DetectionPostProcessDescriptor& descriptor,
1286 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001287{
Jan Eilers8eb25602020-03-09 12:13:48 +00001288 IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
Derek Lamberti901ea112019-12-10 22:07:09 +00001289
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001290 bool supported = true;
1291
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001292 std::array<DataType,6> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001293 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001294 DataType::BFloat16,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001295 DataType::Float32,
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001296 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001297 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001298 DataType::QAsymmU8,
1299 DataType::QSymmS16
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001300 };
1301
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001302 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001303 "Reference DetectionPostProcess: input 0 is not a supported type.");
1304
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001305 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001306 "Reference DetectionPostProcess: input 1 is not a supported type.");
1307
1308 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001309}
1310
Pablo Tellof0bd6832019-04-26 17:58:13 +01001311bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
1312 const TensorInfo& output,
1313 const DepthwiseConvolution2dDescriptor& descriptor,
1314 const TensorInfo& weights,
1315 const Optional<TensorInfo>& biases,
1316 Optional<std::string&> reasonIfUnsupported) const
1317{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001318 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +01001319}
1320
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001321bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +01001322 const TensorInfo& input1,
1323 const TensorInfo& output,
1324 Optional<std::string&> reasonIfUnsupported) const
1325{
Sadik Armagan2999a022019-04-09 14:20:12 +01001326 bool supported = true;
1327
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001328 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001329 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001330 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001331 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001332 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001333 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001334 DataType::QSymmS16,
1335 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001336 };
1337
1338 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1339 "Reference division: input 0 is not a supported type.");
1340
1341 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1342 "Reference division: input 1 is not a supported type.");
1343
1344 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1345 "Reference division: output is not a supported type.");
1346
1347 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1348 "Reference division: input 0 and Input 1 types are mismatched");
1349
1350 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1351 "Reference division: input and output types are mismatched");
1352
1353 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1354 "Reference division: shapes are not suitable for implicit broadcast.");
1355
1356 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001357}
1358
josh minor4a3c6102020-01-06 16:40:46 -06001359bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
1360 const TensorInfo& output,
1361 const ElementwiseUnaryDescriptor& descriptor,
1362 Optional<std::string&> reasonIfUnsupported) const
1363{
Jan Eilers8eb25602020-03-09 12:13:48 +00001364 IgnoreUnused(descriptor);
josh minor4a3c6102020-01-06 16:40:46 -06001365
Sadik Armagan303980c2020-04-17 12:45:14 +01001366 std::array<DataType, 7> supportedTypes =
josh minor4a3c6102020-01-06 16:40:46 -06001367 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001368 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -06001369 DataType::Float32,
1370 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001371 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06001372 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00001373 DataType::QSymmS16,
1374 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06001375 };
1376
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001377 std::array<DataType, 1> logicalSupportedTypes =
1378 {
1379 DataType::Boolean
1380 };
1381
josh minor4a3c6102020-01-06 16:40:46 -06001382 bool supported = true;
1383
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001384 if (descriptor.m_Operation == UnaryOperation::LogicalNot)
1385 {
1386 supported &= CheckSupportRule(TypeAnyOf(input, logicalSupportedTypes), reasonIfUnsupported,
1387 "Reference elementwise unary: input type not supported");
josh minor4a3c6102020-01-06 16:40:46 -06001388
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001389 supported &= CheckSupportRule(TypeAnyOf(output, logicalSupportedTypes), reasonIfUnsupported,
1390 "Reference elementwise unary: output type not supported");
1391 }
1392 else
1393 {
1394 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1395 "Reference elementwise unary: input type not supported");
1396
1397 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1398 "Reference elementwise unary: output type not supported");
1399 }
josh minor4a3c6102020-01-06 16:40:46 -06001400
1401 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1402 "Reference elementwise unary: input and output types not matching");
1403
1404 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1405 "Reference elementwise unary: input and output shapes"
1406 "have different number of total elements");
1407
1408 return supported;
1409}
1410
arovir011c7c81b2018-10-08 11:34:28 +01001411bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
1412 const FakeQuantizationDescriptor& descriptor,
1413 Optional<std::string&> reasonIfUnsupported) const
1414{
Jan Eilers8eb25602020-03-09 12:13:48 +00001415 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001416 bool supported = true;
1417
1418 std::array<DataType,1> supportedTypes =
1419 {
1420 DataType::Float32
1421 };
1422
1423 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1424 "Reference fake quantization: input type not supported.");
1425
1426 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001427}
1428
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001429bool RefLayerSupport::IsFillSupported(const TensorInfo& input,
1430 const TensorInfo& output,
1431 const FillDescriptor& descriptor,
1432 Optional<std::string&> reasonIfUnsupported) const
1433{
1434 IgnoreUnused(descriptor);
1435 IgnoreUnused(output);
1436
1437 bool supported = true;
1438
Sadik Armagana792a052020-06-23 16:22:23 +01001439 std::array<DataType,3> supportedTypes =
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001440 {
1441 DataType::Float32,
Sadik Armagana792a052020-06-23 16:22:23 +01001442 DataType::Float16,
1443 DataType::Signed32
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001444 };
1445
Teresa Charlin4b10fef2020-07-29 09:36:41 +01001446 supported &= CheckSupportRule(TypeIs(input, DataType::Signed32), reasonIfUnsupported,
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001447 "Reference Fill: input type not supported.");
1448
Teresa Charlin44088502020-07-27 11:27:19 +01001449 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1450 "Reference Fill: output type not supported.");
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001451 return supported;
1452}
1453
arovir011c7c81b2018-10-08 11:34:28 +01001454bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
1455 const TensorInfo& output,
1456 Optional<std::string&> reasonIfUnsupported) const
1457{
Jan Eilers8eb25602020-03-09 12:13:48 +00001458 IgnoreUnused(output);
James Conroy83735b12019-05-30 16:36:59 +01001459 bool supported = true;
1460
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01001461 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +01001462 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001463 DataType::BFloat16,
James Conroyb40d7102019-06-04 12:32:09 +01001464 DataType::Float32,
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01001465 DataType::Float16
James Conroy83735b12019-05-30 16:36:59 +01001466 };
1467
1468 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1469 "Reference Floor: input type not supported.");
1470
1471 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1472 "Reference Floor: output type not supported.");
1473
1474 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001475}
1476
1477bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
1478 const TensorInfo& output,
1479 const TensorInfo& weights,
1480 const TensorInfo& biases,
1481 const FullyConnectedDescriptor& descriptor,
1482 Optional<std::string&> reasonIfUnsupported) const
1483{
Francis Murtagh46c09d02019-05-28 08:15:28 +01001484 bool supported = true;
1485
1486 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001487 std::array<DataType,6> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +01001488 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001489 DataType::BFloat16,
1490 DataType::Float32,
1491 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001492 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001493 DataType::QAsymmU8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001494 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001495 };
1496
1497 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1498 "Reference Fully Connected: input type not supported.");
1499
1500 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1501 "Reference Fully Connected: output type not supported.");
1502
Francis Murtagh46c09d02019-05-28 08:15:28 +01001503 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1504 "Reference Fully Connected: weights type not supported.");
1505
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001506 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1507 if (input.GetDataType() == DataType::BFloat16)
1508 {
1509 if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
1510 {
1511 reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
1512 supported = false;
1513 }
1514 }
1515 else
1516 {
1517 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1518 "Reference Fully Connected: input and output types mismatched.");
1519 }
1520
Jan Eilers1f45dc32020-06-15 11:43:03 +01001521 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1522 "Reference Fully Connected: weights is not a supported type.");
Francis Murtaghddb1d062020-03-10 13:51:45 +00001523
Jan Eilers1f45dc32020-06-15 11:43:03 +01001524 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1525 "Reference Fully Connected: input and weights types mismatched.");
Francis Murtagh46c09d02019-05-28 08:15:28 +01001526
1527 if (descriptor.m_BiasEnabled)
1528 {
1529 // Defined supported types for bias
Sadik Armagandb73c982020-04-01 17:35:30 +01001530 std::array<DataType, 5>
Francis Murtagh46c09d02019-05-28 08:15:28 +01001531 supportedBiasTypes =
1532 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001533 DataType::BFloat16,
Francis Murtagh46c09d02019-05-28 08:15:28 +01001534 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001535 DataType::Float16,
Sadik Armagandb73c982020-04-01 17:35:30 +01001536 DataType::Signed32,
1537 DataType::QAsymmS8
Francis Murtagh46c09d02019-05-28 08:15:28 +01001538 };
1539
1540 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
1541 "Reference Fully Connected: bias type not supported.");
1542
1543 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
1544 "Reference Fully Connected: bias and weight types mismatch.");
1545
1546 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
1547 "Reference Fully Connected: bias type inferred from weights is incompatible.");
1548
Narumol Prangnawarat366d7232020-04-29 12:58:17 +01001549 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(biases, 1U), reasonIfUnsupported,
1550 "Reference Fully Connected: bias must have 1 dimension.");
1551
Francis Murtagh46c09d02019-05-28 08:15:28 +01001552 }
1553
1554 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001555}
1556
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01001557bool RefLayerSupport::IsGatherNdSupported(const armnn::TensorInfo& input0,
1558 const armnn::TensorInfo& input1,
1559 const armnn::TensorInfo& output,
1560 armnn::Optional<std::string&> reasonIfUnsupported) const
1561{
1562 bool supported = true;
1563 std::array<DataType,7> supportedTypes =
1564 {
1565 DataType::BFloat16,
1566 DataType::Float32,
1567 DataType::Float16,
1568 DataType::QAsymmS8,
1569 DataType::QAsymmU8,
1570 DataType::QSymmS16,
1571 DataType::Signed32
1572 };
1573
1574 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1575 "Reference GatherNd: input type not supported");
1576
1577 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1578 "Reference GatherNd: output type not supported");
1579
1580 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1581 "Reference GatherNd: indices (input1) type not supported");
1582
1583 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1584 "Reference GatherNd: input and output types not matching");
1585
1586 return supported;
1587}
1588
narpra014951d842019-01-18 16:53:53 +00001589bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
1590 const armnn::TensorInfo& input1,
1591 const armnn::TensorInfo& output,
Teresa Charlin52664732020-06-29 16:27:03 +01001592 const GatherDescriptor& descriptor,
narpra014951d842019-01-18 16:53:53 +00001593 armnn::Optional<std::string&> reasonIfUnsupported) const
1594{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001595 bool supported = true;
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001596 std::array<DataType,7> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001597 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001598 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001599 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001600 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001601 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001602 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001603 DataType::QSymmS16,
1604 DataType::Signed32
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001605 };
1606
Teresa Charlin52664732020-06-29 16:27:03 +01001607 if (descriptor.m_Axis != 0)
1608 {
1609 reasonIfUnsupported.value() += std::string("Reference Gather: axis not supported\n");
1610 supported &= false;
1611 }
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001612 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1613 "Reference Gather: input type not supported");
1614
1615 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1616 "Reference Gather: output type not supported");
1617
1618 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1619 "Reference Gather: indices (input1) type not supported");
1620
1621 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1622 "Reference Gather: input and output types not matching");
1623
1624 return supported;
narpra014951d842019-01-18 16:53:53 +00001625}
1626
Derek Lamberti901ea112019-12-10 22:07:09 +00001627bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
1628 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +01001629{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001630 return true;
arovir011c7c81b2018-10-08 11:34:28 +01001631}
1632
Kevin May09ca49c2019-10-09 12:37:34 +01001633bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
1634 const TensorInfo& output,
1635 const InstanceNormalizationDescriptor& descriptor,
1636 Optional<std::string&> reasonIfUnsupported) const
1637{
Jan Eilers8eb25602020-03-09 12:13:48 +00001638 IgnoreUnused(descriptor);
Kevin May09ca49c2019-10-09 12:37:34 +01001639 // Define supported types
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001640 std::array<DataType, 3> supportedTypes =
Kevin May09ca49c2019-10-09 12:37:34 +01001641 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001642 DataType::BFloat16,
Kevin May09ca49c2019-10-09 12:37:34 +01001643 DataType::Float32,
1644 DataType::Float16
1645 };
1646
1647 bool supported = true;
1648
1649 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1650 "Reference Instance Normalization: input type not supported.");
1651
1652 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1653 "Reference Instance Normalization: output type not supported.");
1654
1655 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1656 "Reference Instance Normalization: input and output types mismatched.");
1657
1658 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1659 "Reference Instance Normalization: input and output shapes have different "
1660 "num total elements.");
1661
1662 return supported;
1663}
1664
arovir011c7c81b2018-10-08 11:34:28 +01001665bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1666 const TensorInfo& output,
1667 const L2NormalizationDescriptor& descriptor,
1668 Optional<std::string&> reasonIfUnsupported) const
1669{
Jan Eilers8eb25602020-03-09 12:13:48 +00001670 IgnoreUnused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001671 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01001672 std::array<DataType, 6> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001673 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001674 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001675 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001676 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001677 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001678 DataType::QAsymmU8,
1679 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001680 };
1681
1682 bool supported = true;
1683
1684 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1685 "Reference L2normalization: input type not supported.");
1686
1687 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1688 "Reference L2normalization: output type not supported.");
1689
1690 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1691 "Reference L2normalization: input and output types mismatched.");
1692
1693 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1694 "Reference L2normalization: input and output shapes have different "
1695 "num total elements.");
1696
1697 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001698}
1699
James Conroyaba90cd2020-11-06 16:28:18 +00001700bool RefLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0,
1701 const TensorInfo& input1,
1702 const TensorInfo& output,
1703 const LogicalBinaryDescriptor& descriptor,
1704 Optional<std::string&> reasonIfUnsupported) const
1705{
1706 IgnoreUnused(descriptor);
1707
1708 std::array<DataType, 1> supportedTypes =
1709 {
1710 DataType::Boolean
1711 };
1712
1713 bool supported = true;
1714 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1715 "Reference LogicalBinary: input 0 type not supported");
1716 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1717 "Reference LogicalBinary: input 1 type not supported");
1718
1719 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1720 "Reference LogicalBinary: input and output types do not match");
1721
1722 return supported;
1723}
1724
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001725bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1726 const TensorInfo& output,
1727 const LogSoftmaxDescriptor& descriptor,
1728 Optional<std::string&> reasonIfUnsupported) const
1729{
Jan Eilers8eb25602020-03-09 12:13:48 +00001730 IgnoreUnused(descriptor);
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001731
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001732 std::array<DataType, 3> supportedTypes =
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001733 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001734 DataType::BFloat16,
1735 DataType::Float32,
1736 DataType::Float16
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001737 };
1738
1739 bool supported = true;
1740 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1741 "Reference LogSoftmax: input type not supported");
1742
1743 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1744 "Reference LogSoftmax: output type not supported");
1745
1746 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1747 "Reference LogSoftmax: input and output types do not match");
1748
1749 return supported;
1750}
1751
arovir011c7c81b2018-10-08 11:34:28 +01001752bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1753 const TensorInfo& outputStateIn,
1754 const TensorInfo& cellStateIn,
1755 const TensorInfo& scratchBuffer,
1756 const TensorInfo& outputStateOut,
1757 const TensorInfo& cellStateOut,
1758 const TensorInfo& output,
1759 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001760 const LstmInputParamsInfo& paramsInfo,
1761 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +01001762{
Jan Eilers8eb25602020-03-09 12:13:48 +00001763 IgnoreUnused(descriptor);
1764 IgnoreUnused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001765
1766 bool supported = true;
1767
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001768 std::array<DataType,3> supportedTypes = {
1769 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001770 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001771 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001772 };
1773
Jan Eilersd01a83c2019-07-03 18:20:40 +01001774 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001775 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1776 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001777 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1778 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001779 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1780 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001781 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1782 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001783 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1784 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001785 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1786 "Reference Lstm: input and cellStateOut types are mismatched");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01001787
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001788 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1789 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001790 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001791 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001792 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001793 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001794 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001795 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001796 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001797 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001798 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001799 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001800 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001801 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001802 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001803 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001804 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001805 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001806 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001807 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001808 "Reference Lstm: input and OutputGateBias types are mismatched");
1809 if (!descriptor.m_CifgEnabled)
1810 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001811 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001812 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001813 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001814 reasonIfUnsupported,
1815 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001816 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001817 "Reference Lstm: input and InputGateBias types are mismatched");
1818 if (descriptor.m_PeepholeEnabled)
1819 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001820 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001821 reasonIfUnsupported,
1822 "Reference Lstm: input and CellToInputWeights types are mismatched");
1823 }
1824 }
1825 if (descriptor.m_PeepholeEnabled)
1826 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001827 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001828 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001829 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001830 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1831 }
1832 if (descriptor.m_ProjectionEnabled)
1833 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001834 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001835 "Reference Lstm: input and mProjectionWeights types are mismatched");
1836 if (paramsInfo.m_ProjectionBias != nullptr)
1837 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001838 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001839 "Reference Lstm: input and ProjectionBias types are mismatched");
1840 }
1841 }
1842 if (descriptor.m_LayerNormEnabled)
1843 {
1844 if (!descriptor.m_CifgEnabled)
1845 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001846 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001847 reasonIfUnsupported,
1848 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1849 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001850 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001851 reasonIfUnsupported,
1852 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001853 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001854 reasonIfUnsupported,
1855 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001856 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001857 reasonIfUnsupported,
1858 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1859 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001860
1861 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001862}
1863
saoste012df12b32018-11-28 16:57:20 +00001864bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1865 const TensorInfo& input1,
1866 const TensorInfo& output,
1867 Optional<std::string&> reasonIfUnsupported) const
1868{
Sadik Armagan2999a022019-04-09 14:20:12 +01001869 bool supported = true;
1870
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001871 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001872 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001873 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001874 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001875 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001876 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001877 DataType::QSymmS16,
1878 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001879 };
1880
1881 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1882 "Reference maximum: input 0 is not a supported type.");
1883
1884 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1885 "Reference maximum: input 1 is not a supported type.");
1886
1887 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1888 "Reference maximum: output is not a supported type.");
1889
1890 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1891 "Reference maximum: input 0 and Input 1 types are mismatched");
1892
1893 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1894 "Reference maximum: input and output types are mismatched");
1895
1896 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1897 "Reference maximum: shapes are not suitable for implicit broadcast.");
1898
1899 return supported;
saoste012df12b32018-11-28 16:57:20 +00001900}
1901
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001902bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1903 const TensorInfo& output,
1904 const MeanDescriptor& descriptor,
1905 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001906{
James Conroy4d1ff582019-06-10 17:06:39 +01001907 bool supported = true;
1908 std::string meanLayerStr = "Mean";
1909 std::string outputTensorStr = "output";
1910
Sadik Armagan303980c2020-04-17 12:45:14 +01001911 std::array<DataType,6> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001912 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001913 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01001914 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001915 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001916 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001917 DataType::QAsymmU8,
1918 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01001919 };
1920
1921 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1922 "Reference Mean: input type not supported.");
1923
1924 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1925 "Reference Mean: input and output types are mismatched");
1926
1927 if (descriptor.m_KeepDims)
1928 {
1929 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1930 reasonIfUnsupported,
1931 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1932 output.GetNumDimensions(),
1933 meanLayerStr, outputTensorStr).data());
1934 }
1935 else if (descriptor.m_Axis.empty())
1936 {
1937 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1938 reasonIfUnsupported,
1939 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1940 meanLayerStr, outputTensorStr).data());
1941 }
1942 else
1943 {
Matthew Sloyan171214c2020-09-09 09:07:37 +01001944 auto outputDim = input.GetNumDimensions() - armnn::numeric_cast<unsigned int>(descriptor.m_Axis.size());
James Conroy4d1ff582019-06-10 17:06:39 +01001945
1946 if (outputDim > 0)
1947 {
1948 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1949 reasonIfUnsupported,
1950 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1951 meanLayerStr, outputTensorStr).data());
1952 }
1953 else
1954 {
1955 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1956 reasonIfUnsupported,
1957 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1958 meanLayerStr, outputTensorStr).data());
1959 }
1960 }
1961
1962 return supported;
narpra0132b90462018-09-13 11:07:48 +01001963}
1964
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001965bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1966 const TensorInfo &output,
1967 Optional<std::string &> reasonIfUnsupported) const
1968{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001969 bool supported = true;
1970
Sadik Armagan303980c2020-04-17 12:45:14 +01001971 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001972 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001973 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001974 DataType::Float32,
1975 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001976 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001977 DataType::QAsymmU8,
1978 DataType::QSymmS16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001979 DataType::Boolean
1980 };
1981
1982 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1983 "Reference MemCopy: input type not supported");
1984
1985 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1986 "Reference MemCopy: output type not supported");
1987
1988 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1989 "Reference MemCopy: input and output types are mismatched");
1990
1991 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001992}
1993
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001994bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1995 const TensorInfo& input1,
1996 const TensorInfo& output,
1997 Optional<std::string&> reasonIfUnsupported) const
1998{
Sadik Armagan2999a022019-04-09 14:20:12 +01001999 bool supported = true;
2000
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002001 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002002 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01002003 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002004 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002005 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002006 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002007 DataType::QSymmS16,
2008 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002009 };
2010
2011 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2012 "Reference minimum: input 0 is not a supported type.");
2013
2014 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2015 "Reference minimum: input 1 is not a supported type.");
2016
2017 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2018 "Reference minimum: output is not a supported type.");
2019
2020 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2021 "Reference minimum: input 0 and Input 1 types are mismatched");
2022
2023 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2024 "Reference minimum: input and output types are mismatched");
2025
2026 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2027 "Reference minimum: shapes are not suitable for implicit broadcast.");
2028
2029 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00002030}
2031
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002032bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
2033 const TensorInfo& input1,
2034 const TensorInfo& output,
2035 Optional<std::string&> reasonIfUnsupported) const
2036{
Sadik Armagan2999a022019-04-09 14:20:12 +01002037 bool supported = true;
2038
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002039 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002040 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01002041 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002042 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002043 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002044 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002045 DataType::QSymmS16,
2046 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002047 };
2048
2049 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2050 "Reference multiplication: input 0 is not a supported type.");
2051
2052 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2053 "Reference multiplication: input 1 is not a supported type.");
2054
2055 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2056 "Reference multiplication: output is not a supported type.");
2057
2058 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2059 "Reference multiplication: input 0 and Input 1 types are mismatched");
2060
2061 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2062 "Reference multiplication: input and output types are mismatched");
2063
2064 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2065 "Reference multiplication: shapes are not suitable for implicit broadcast.");
2066
2067 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002068}
2069
2070bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
2071 const TensorInfo& output,
2072 const NormalizationDescriptor& descriptor,
2073 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01002074{
Jan Eilers8eb25602020-03-09 12:13:48 +00002075 IgnoreUnused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002076
2077 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01002078 std::array<DataType, 6> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002079 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002080 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002081 DataType::Float16,
2082 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002083 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002084 DataType::QAsymmU8,
2085 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002086 };
2087
2088 bool supported = true;
2089
2090 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2091 "Reference normalization: input type not supported.");
2092
2093 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2094 "Reference normalization: output type not supported.");
2095
2096 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2097 "Reference normalization: input and output shapes have different "
2098 "num total elements.");
2099
2100 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002101}
2102
Derek Lamberti901ea112019-12-10 22:07:09 +00002103bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
2104 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002105{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01002106 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002107}
2108
2109bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
2110 const TensorInfo& output,
2111 const PadDescriptor& descriptor,
2112 Optional<std::string&> reasonIfUnsupported) const
2113{
Jan Eilers8eb25602020-03-09 12:13:48 +00002114 IgnoreUnused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002115 bool supported = true;
2116
2117 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002118 std::array<DataType,6> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002119 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002120 DataType::BFloat16,
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002121 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002122 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002123 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002124 DataType::QAsymmU8,
2125 DataType::QSymmS16
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002126 };
2127
2128 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2129 "Reference pad: input is not a supported type.");
2130
2131 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2132 "Reference pad: output is not a supported type.");
2133
2134 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2135 "Reference pad: input and output types are mismatched.");
2136
2137 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01002138}
2139
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002140bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
2141 const TensorInfo& output,
2142 const PermuteDescriptor& descriptor,
2143 Optional<std::string&> reasonIfUnsupported) const
2144{
Jan Eilers8eb25602020-03-09 12:13:48 +00002145 IgnoreUnused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002146 bool supported = true;
2147
2148 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002149 std::array<DataType, 6> supportedTypes =
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002150 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002151 DataType::BFloat16,
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002152 DataType::Float32,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002153 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002154 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002155 DataType::QAsymmU8,
2156 DataType::QSymmS16
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002157 };
2158
2159 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2160 "Reference permute: input is not a supported type.");
2161
2162 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2163 "Reference permute: output is not a supported type.");
2164
2165 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2166 "Reference permute: input and output types are mismatched.");
2167
2168 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002169}
2170
2171bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
2172 const TensorInfo& output,
2173 const Pooling2dDescriptor& descriptor,
2174 Optional<std::string&> reasonIfUnsupported) const
2175{
Jan Eilers8eb25602020-03-09 12:13:48 +00002176 IgnoreUnused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01002177 bool supported = true;
2178
2179 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002180 std::array<DataType,6> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01002181 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002182 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01002183 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002184 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002185 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002186 DataType::QAsymmU8,
2187 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01002188 };
2189
2190 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2191 "Reference poolind2d: input is not a supported type.");
2192
2193 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2194 "Reference poolind2d: output is not a supported type.");
2195
2196 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2197 "Reference poolind2d: input and output types are mismatched.");
2198
2199 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002200}
2201
Tamás Nyíri7b885b32021-10-26 14:47:57 +01002202bool RefLayerSupport::IsPooling3dSupported(const TensorInfo& input,
2203 const TensorInfo& output,
2204 const Pooling3dDescriptor& descriptor,
2205 Optional<std::string&> reasonIfUnsupported) const
2206{
2207 IgnoreUnused(descriptor);
2208 bool supported = true;
2209
2210 // Define supported output and inputs types.
2211 std::array<DataType,6> supportedTypes =
2212 {
2213 DataType::BFloat16,
2214 DataType::Float32,
2215 DataType::Float16,
2216 DataType::QAsymmS8,
2217 DataType::QAsymmU8,
2218 DataType::QSymmS16
2219 };
2220
2221 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2222 "Reference poolind3d: input is not a supported type.");
2223
2224 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2225 "Reference poolind3d: output is not a supported type.");
2226
2227 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2228 "Reference poolind3d: input and output types are mismatched.");
2229
2230 return supported;
2231}
2232
2233
James Conroy4f1f8992020-04-29 20:01:10 +01002234bool RefLayerSupport::IsQLstmSupported(const TensorInfo& input,
2235 const TensorInfo& previousOutputIn,
2236 const TensorInfo& previousCellStateIn,
2237 const TensorInfo& outputStateOut,
2238 const TensorInfo& cellStateOut,
2239 const TensorInfo& output,
2240 const QLstmDescriptor& descriptor,
2241 const LstmInputParamsInfo& paramsInfo,
2242 Optional<std::string&> reasonIfUnsupported) const
2243{
2244 IgnoreUnused(input);
2245 IgnoreUnused(previousOutputIn);
2246 IgnoreUnused(previousCellStateIn);
2247 IgnoreUnused(outputStateOut);
2248 IgnoreUnused(cellStateOut);
2249 IgnoreUnused(output);
2250 IgnoreUnused(descriptor);
2251 IgnoreUnused(paramsInfo);
2252
2253 IgnoreUnused(reasonIfUnsupported);
2254
2255 return true;
2256}
2257
Derek Lamberti5f400d62019-03-25 15:41:58 +00002258bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
2259 const TensorInfo& output,
2260 Optional<std::string&> reasonIfUnsupported) const
2261{
2262 bool supported = true;
2263
Finn Williamsfd271062019-12-04 14:27:27 +00002264 // Define supported input types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002265 std::array<DataType,7> supportedInputTypes = {
2266 DataType::BFloat16,
Keith Davis5e51cd82020-01-29 16:52:59 +00002267 DataType::Float32,
Keith Davis3d8bc972020-02-04 09:31:47 +00002268 DataType::Float16,
Ryan OShea9add1202020-02-07 10:06:33 +00002269 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002270 DataType::QAsymmU8,
2271 DataType::QSymmS8,
2272 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002273 };
2274
2275 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
2276 "Reference quantize: input type not supported.");
2277
2278 // Define supported output types.
Ryan OShea9add1202020-02-07 10:06:33 +00002279 std::array<DataType,4> supportedOutputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00002280 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002281 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00002282 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002283 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002284 };
2285 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2286 "Reference quantize: output type not supported.");
2287
2288 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2289 "Reference quantize: input and output shapes have different num total elements.");
2290
2291 return supported;
2292}
2293
Finn Williams2605b232020-06-10 15:53:46 +01002294bool RefLayerSupport::IsRankSupported(const TensorInfo& input,
2295 const TensorInfo& output,
2296 Optional<std::string&> reasonIfUnsupported) const
2297{
2298 IgnoreUnused(input);
2299 // Define supported output types.
2300 std::array<DataType,1> supportedOutputTypes =
2301 {
2302 DataType::Signed32,
2303 };
2304
2305 return CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2306 "Reference rank: input type not supported.");
2307}
2308
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00002309bool RefLayerSupport::IsReduceSupported(const TensorInfo& input,
2310 const TensorInfo& output,
2311 const ReduceDescriptor& descriptor,
2312 Optional<std::string&> reasonIfUnsupported) const
2313{
2314 IgnoreUnused(descriptor);
2315 bool supported = true;
2316 std::array<DataType,7> supportedTypes =
2317 {
2318 DataType::BFloat16,
2319 DataType::Float32,
2320 DataType::Float16,
2321 DataType::QAsymmS8,
2322 DataType::QAsymmU8,
2323 DataType::QSymmS16,
2324 DataType::Signed32
2325 };
2326
2327 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2328 "Reference Reduce: input type not supported");
2329
2330 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2331 "Reference Reduce: output type not supported");
2332
2333 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2334 "Reference Reduce: input and output types not matching");
2335
2336 return supported;
2337}
2338
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002339bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Kevin Maya023c402019-12-12 17:28:05 +00002340 const TensorInfo& output,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00002341 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002342 Optional<std::string&> reasonIfUnsupported) const
2343{
Jan Eilers8eb25602020-03-09 12:13:48 +00002344 IgnoreUnused(output);
2345 IgnoreUnused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01002346 // Define supported output types.
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002347 std::array<DataType,8> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01002348 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002349 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01002350 DataType::Float32,
2351 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01002352 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002353 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002354 DataType::QAsymmU8,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002355 DataType::QSymmS16,
2356 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01002357 };
Keith Davis0c2eeac2020-02-11 16:51:50 +00002358
Nina Drozd2f2778f2019-05-27 10:37:05 +01002359 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
2360 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002361}
2362
Teresa Charlin970f43b2019-07-01 13:51:07 +01002363bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
2364 const TensorInfo& output,
2365 const ResizeDescriptor& descriptor,
2366 Optional<std::string&> reasonIfUnsupported) const
2367{
Jan Eilers8eb25602020-03-09 12:13:48 +00002368 IgnoreUnused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01002369 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002370 std::array<DataType,6> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01002371 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002372 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01002373 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002374 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002375 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002376 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002377 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01002378 };
2379
2380 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2381 "Reference Resize: input type not supported");
2382
2383 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2384 "Reference Resize: output type not supported");
2385
2386 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2387 "Reference Resize: input and output types not matching");
2388
2389 return supported;
2390}
2391
Keith Davis3ae3f972021-05-21 16:33:48 +01002392bool RefLayerSupport::IsShapeSupported(const TensorInfo& input,
2393 const TensorInfo& output,
2394 Optional<std::string&> reasonIfUnsupported) const
2395{
2396 IgnoreUnused(input);
2397 bool supported = true;
2398
2399 std::array<DataType, 1> supportedTypes =
2400 {
2401 DataType::Signed32
2402 };
2403
2404 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2405 "Reference Shape: output type not supported");
2406
2407 return supported;
2408}
2409
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002410bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
2411 const TensorInfo& output,
2412 const SliceDescriptor& descriptor,
2413 Optional<std::string&> reasonIfUnsupported) const
2414{
Jan Eilers8eb25602020-03-09 12:13:48 +00002415 IgnoreUnused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002416 bool supported = true;
2417
Sadik Armagan303980c2020-04-17 12:45:14 +01002418 std::array<DataType, 5> supportedTypes =
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002419 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002420 DataType::BFloat16,
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002421 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002422 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002423 DataType::QAsymmU8,
2424 DataType::QSymmS16
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002425 };
2426
2427 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2428 "Reference Slice: input type not supported");
2429
2430 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2431 "Reference Slice: output type not supported");
2432
2433 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2434 "Reference Slice: input and output types are mismatched");
2435
2436 return supported;
2437}
2438
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002439bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
2440 const TensorInfo& output,
2441 const SoftmaxDescriptor& descriptor,
2442 Optional<std::string&> reasonIfUnsupported) const
2443{
Jan Eilers8eb25602020-03-09 12:13:48 +00002444 IgnoreUnused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01002445 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002446 std::array<DataType,7> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01002447 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002448 DataType::BFloat16,
2449 DataType::Float32,
2450 DataType::Float16,
2451 DataType::QSymmS8,
2452 DataType::QAsymmS8,
2453 DataType::QAsymmU8,
2454 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +01002455 };
2456
2457 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002458 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002459
2460 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002461 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002462
2463 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002464 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002465
2466 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002467}
2468
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002469bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
2470 const TensorInfo& output,
2471 const SpaceToBatchNdDescriptor& descriptor,
2472 Optional<std::string&> reasonIfUnsupported) const
2473{
Jan Eilers8eb25602020-03-09 12:13:48 +00002474 IgnoreUnused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01002475 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002476 std::array<DataType,6> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01002477 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002478 DataType::BFloat16,
2479 DataType::Float32,
2480 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002481 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002482 DataType::QAsymmU8,
2483 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01002484 };
2485
2486 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2487 "Reference SpaceToBatchNd: input type not supported");
2488
2489 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2490 "Reference SpaceToBatchNd: output type not supported");
2491
2492 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2493 "Reference SpaceToBatchNd: input and output types are mismatched");
2494
2495 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002496}
2497
Keith Davisa57eccb2019-06-14 17:33:22 +01002498bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01002499 const TensorInfo& output,
2500 const SpaceToDepthDescriptor& descriptor,
2501 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01002502{
2503
Jan Eilers8eb25602020-03-09 12:13:48 +00002504 IgnoreUnused(descriptor);
Keith Davisa57eccb2019-06-14 17:33:22 +01002505 bool supported = true;
2506
Sadik Armagan303980c2020-04-17 12:45:14 +01002507 std::array<DataType,6> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01002508 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002509 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01002510 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002511 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002512 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002513 DataType::QAsymmU8,
2514 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01002515 };
2516
2517 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2518 "Reference SpaceToDepth: input type not supported");
2519
2520 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2521 "Reference SpaceToDepth: output type not supported");
2522
2523 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2524 "Reference SpaceToDepth: input and output types are mismatched");
2525
2526 return supported;
2527}
2528
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002529bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002530 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
2531 const ViewsDescriptor& descriptor,
2532 Optional<std::string&> reasonIfUnsupported) const
2533{
Jan Eilers8eb25602020-03-09 12:13:48 +00002534 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002535 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002536 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002537 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002538 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002539 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002540 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002541 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002542 DataType::QAsymmU8,
2543 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002544 };
2545
2546 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2547 "Reference splitter: output type not supported");
Derek Lambertieac4adb2020-08-25 13:05:59 +01002548 for (const TensorInfo& output : outputs)
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002549 {
2550 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2551 "Reference splitter: input type not supported");
2552
2553 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2554 "Reference splitter: input and output types mismatched.");
2555 }
2556
2557 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002558}
2559
Matthew Jackson81e601c2019-07-11 12:07:09 +01002560bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
2561 const TensorInfo& output,
2562 const StackDescriptor& descriptor,
2563 Optional<std::string&> reasonIfUnsupported) const
2564{
Jan Eilers8eb25602020-03-09 12:13:48 +00002565 IgnoreUnused(descriptor);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002566
2567 bool supported = true;
Sadik Armagan529195f2022-01-14 12:56:35 +00002568 std::array<DataType,7> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01002569 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002570 DataType::BFloat16,
Matthew Jackson81e601c2019-07-11 12:07:09 +01002571 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01002572 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002573 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002574 DataType::QAsymmU8,
Sadik Armagan529195f2022-01-14 12:56:35 +00002575 DataType::QSymmS16,
2576 DataType::Signed32
Matthew Jackson81e601c2019-07-11 12:07:09 +01002577 };
2578
2579 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2580 "Reference stack: output type not supported");
2581 for (const TensorInfo* input : inputs)
2582 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01002583 ARMNN_ASSERT(input != nullptr);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002584 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
2585 "Reference stack: input type not supported");
2586
2587 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
2588 "Reference stack: input and output types mismatched.");
2589 }
2590
2591 return supported;
2592}
2593
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002594bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
2595 const TensorInfo& output,
2596 const StridedSliceDescriptor& descriptor,
2597 Optional<std::string&> reasonIfUnsupported) const
2598{
Jan Eilers8eb25602020-03-09 12:13:48 +00002599 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002600 bool supported = true;
2601
Sadik Armagan303980c2020-04-17 12:45:14 +01002602 std::array<DataType,5> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002603 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002604 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002605 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002606 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002607 DataType::QAsymmU8,
2608 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002609 };
2610
2611 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2612 "Reference StridedSlice: input type not supported");
2613
2614 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2615 "Reference StridedSlice: output type not supported");
2616
2617 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2618 "Reference StridedSlice: input and output types are mismatched");
2619
2620 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002621}
2622
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002623bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
2624 const TensorInfo& input1,
2625 const TensorInfo& output,
2626 Optional<std::string&> reasonIfUnsupported) const
2627{
Sadik Armagan2999a022019-04-09 14:20:12 +01002628 bool supported = true;
2629
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002630 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002631 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01002632 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002633 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002634 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002635 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002636 DataType::QSymmS16,
2637 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002638 };
2639
2640 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2641 "Reference subtraction: input 0 is not a supported type.");
2642
2643 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2644 "Reference subtraction: input 1 is not a supported type.");
2645
2646 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2647 "Reference subtraction: output is not a supported type.");
2648
2649 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2650 "Reference subtraction: input 0 and Input 1 types are mismatched");
2651
2652 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2653 "Reference subtraction: input and output types are mismatched");
2654
2655 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2656 "Reference subtraction: shapes are not suitable for implicit broadcast.");
2657
2658 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002659}
2660
Matteo Martincighab9e5252019-06-13 17:27:46 +01002661bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
2662 const TensorInfo& alpha,
2663 const TensorInfo& output,
2664 Optional<std::string&> reasonIfUnsupported) const
2665{
2666 bool supported = true;
2667
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002668 std::array<DataType, 6> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01002669 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002670 DataType::BFloat16,
Matteo Martincighab9e5252019-06-13 17:27:46 +01002671 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002672 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002673 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002674 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002675 DataType::QSymmS16
Matteo Martincighab9e5252019-06-13 17:27:46 +01002676 };
2677
2678 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2679 "PReLU: input is not a supported type.");
2680
2681 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
2682 "PReLU: alpha is not a supported type.");
2683
2684 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2685 "PReLU: output is not a supported type.");
2686
2687 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
2688 "PReLU: input, alpha and output types are mismatched");
2689
2690 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
2691 "PReLU: shapes are not suitable for implicit broadcast");
2692
2693 return supported;
2694}
2695
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002696bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
2697 const TensorInfo& output,
2698 const TransposeConvolution2dDescriptor& descriptor,
2699 const TensorInfo& weights,
2700 const Optional<TensorInfo>& biases,
2701 Optional<std::string&> reasonIfUnsupported) const
2702{
Jan Eilers8eb25602020-03-09 12:13:48 +00002703 IgnoreUnused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002704 bool supported = true;
2705
Sadik Armagan303980c2020-04-17 12:45:14 +01002706 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002707 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002708 DataType::BFloat16,
2709 DataType::Float32,
2710 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002711 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002712 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002713 DataType::QSymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002714 DataType::QSymmS16
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002715 };
2716
2717 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2718 "Reference TransposeConvolution2d: input is not a supported type.");
2719
2720 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2721 "Reference TransposeConvolution2d: output is not a supported type.");
2722
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002723 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2724 "Reference TransposeConvolution2d: input and output types mismatched.");
2725
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002726
2727 const DataType inputType = input.GetDataType();
Sadik Armagan303980c2020-04-17 12:45:14 +01002728 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002729 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01002730 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002731 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002732 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002733 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01002734 DataType::QSymmS8
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002735 };
2736
2737 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
2738 "Reference TransposeConvolution2d: weights type not supported for "
2739 "quantized input.");
2740 }
2741 else
2742 {
2743 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
2744 "Reference TransposeConvolution2d: weights is not a supported type.");
2745
2746 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
2747 "Reference TransposeConvolution2d: input and weights types mismatched.");
2748 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002749
2750 if (biases.has_value())
2751 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002752 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01002753 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002754 DataType::BFloat16,
2755 DataType::Float32,
2756 DataType::Float16,
2757 DataType::Signed32
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002758 };
2759 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
2760 "Reference TransposeConvolution2d: biases is not a supported type.");
2761 }
2762
2763 return supported;
2764}
2765
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002766bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
2767 const TensorInfo& output,
2768 const TransposeDescriptor& descriptor,
2769 Optional<std::string&> reasonIfUnsupported) const
2770{
Jan Eilers8eb25602020-03-09 12:13:48 +00002771 IgnoreUnused(descriptor);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002772 bool supported = true;
2773
2774 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002775 std::array<DataType, 6> supportedTypes =
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002776 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002777 DataType::BFloat16,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002778 DataType::Float32,
2779 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002780 DataType::QAsymmS8,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002781 DataType::QAsymmU8,
2782 DataType::QSymmS16
2783 };
2784
2785 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2786 "Reference transpose: input is not a supported type.");
2787
2788 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2789 "Reference transpose: output is not a supported type.");
2790
2791 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2792 "Reference transpose: input and output types are mismatched.");
2793
2794 return supported;
2795}
2796
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002797bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
2798 const TensorInfo& input,
2799 const TensorInfo& outputStateIn,
2800 const TensorInfo& cellStateIn,
Mike Kelly12994962022-04-21 11:57:09 +01002801 const TensorInfo& outputStateOut,
2802 const TensorInfo& cellStateOut,
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002803 const TensorInfo& output,
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002804 const UnidirectionalSequenceLstmDescriptor& descriptor,
2805 const LstmInputParamsInfo& paramsInfo,
2806 Optional<std::string&> reasonIfUnsupported) const
2807{
2808 IgnoreUnused(descriptor);
2809 IgnoreUnused(paramsInfo);
2810 IgnoreUnused(outputStateIn);
2811 IgnoreUnused(cellStateIn);
Mike Kelly12994962022-04-21 11:57:09 +01002812 IgnoreUnused(outputStateOut);
2813 IgnoreUnused(cellStateOut);
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002814 bool supported = true;
2815
Mike Kelly12994962022-04-21 11:57:09 +01002816 std::array<DataType, 2> supportedTypes =
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002817 {
Mike Kelly12994962022-04-21 11:57:09 +01002818 DataType::Float32,
2819 DataType::QAsymmS8
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002820 };
2821
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002822 std::array<DataType, 2> supportedWeightTypes =
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002823 {
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002824 DataType::Float32,
2825 DataType::QAsymmS8
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002826 };
2827
Mike Kelly12994962022-04-21 11:57:09 +01002828 std::array<DataType, 3> supportedBiasTypes =
2829 {
2830 DataType::Float32,
2831 DataType::QAsymmS8,
2832 DataType::Signed32
2833 };
2834
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002835 // check inputs and outputs
2836 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2837 "Reference UnidirectionalSequenceLstm: input is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002838 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2839 "Reference UnidirectionalSequenceLstm: output is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002840
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002841 // check layer parameters
2842 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes),
2843 reasonIfUnsupported,
2844 "Reference UnidirectionalSequenceLstm: InputToForgetWeights "
2845 "is not a supported type.");
2846 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToCellWeights(), supportedWeightTypes),
2847 reasonIfUnsupported,
2848 "Reference UnidirectionalSequenceLstm: InputToCellWeights is not a supported type.");
2849 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToOutputWeights(), supportedWeightTypes),
2850 reasonIfUnsupported,
2851 "Reference UnidirectionalSequenceLstm: InputToOutputWeights "
2852 "is not a supported type.");
2853 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToForgetWeights(), supportedWeightTypes),
2854 reasonIfUnsupported,
2855 "Reference UnidirectionalSequenceLstm: RecurrentToForgetWeights "
2856 "is not a supported type.");
2857 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToCellWeights(), supportedWeightTypes),
2858 reasonIfUnsupported,
2859 "Reference UnidirectionalSequenceLstm: RecurrentToCellWeights "
2860 "is not a supported type.");
2861 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToOutputWeights(), supportedWeightTypes),
2862 reasonIfUnsupported,
2863 "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights "
2864 "is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002865
2866 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetGateBias(), supportedBiasTypes), reasonIfUnsupported,
2867 "Reference UnidirectionalSequenceLstm: ForgetGateBias is not a supported type.");
2868 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellBias(), supportedBiasTypes), reasonIfUnsupported,
2869 "Reference UnidirectionalSequenceLstm: CellBias is not a supported type.");
2870 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputGateBias(), supportedBiasTypes), reasonIfUnsupported,
2871 "Reference UnidirectionalSequenceLstm: OutputGateBias is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002872 if (!descriptor.m_CifgEnabled)
2873 {
2874 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes),
2875 reasonIfUnsupported,
2876 "Reference UnidirectionalSequenceLstm: InputToInputWeights "
2877 "is not a supported type.");
2878 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToInputWeights(), supportedWeightTypes),
2879 reasonIfUnsupported,
2880 "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights "
2881 "is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002882 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputGateBias(), supportedBiasTypes), reasonIfUnsupported,
2883 "Reference UnidirectionalSequenceLstm: InputGateBias is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002884 if (descriptor.m_PeepholeEnabled)
2885 {
2886 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes),
2887 reasonIfUnsupported,
2888 "Reference UnidirectionalSequenceLstm: CellToInputWeights "
2889 "is not a supported type.");
2890 }
2891 }
2892 if (descriptor.m_PeepholeEnabled)
2893 {
2894 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToForgetWeights(), supportedWeightTypes),
2895 reasonIfUnsupported,
2896 "Reference UnidirectionalSequenceLstm: CellToForgetWeights "
2897 "is not a supported type.");
2898 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToOutputWeights(), supportedWeightTypes),
2899 reasonIfUnsupported,
2900 "Reference UnidirectionalSequenceLstm: CellToOutputWeights "
2901 "is not a supported type.");
2902 }
2903 if (descriptor.m_ProjectionEnabled)
2904 {
2905 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetProjectionWeights(), supportedWeightTypes),
2906 reasonIfUnsupported,
2907 "Reference UnidirectionalSequenceLstm: ProjectionWeights "
2908 "is not a supported type.");
2909 if (paramsInfo.m_ProjectionBias != nullptr)
2910 {
2911 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
2912 "Reference UnidirectionalSequenceLstm: input and ProjectionBias types "
2913 "are mismatched");
2914 }
2915 }
2916 if (descriptor.m_LayerNormEnabled)
2917 {
2918 if (!descriptor.m_CifgEnabled)
2919 {
2920 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputLayerNormWeights(), supportedWeightTypes),
2921 reasonIfUnsupported,
2922 "Reference UnidirectionalSequenceLstm: InputLayerNormWeights "
2923 "is not a supported type.");
2924 }
2925 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetLayerNormWeights(), supportedWeightTypes),
2926 reasonIfUnsupported,
2927 "Reference UnidirectionalSequenceLstm: ForgetLayerNormWeights "
2928 "is not a supported type.");
2929 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellLayerNormWeights(), supportedWeightTypes),
2930 reasonIfUnsupported,
2931 "Reference UnidirectionalSequenceLstm: CellLayerNormWeights "
2932 "is not a supported type.");
2933 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputLayerNormWeights(), supportedWeightTypes),
2934 reasonIfUnsupported,
2935 "Reference UnidirectionalSequenceLstm: OutputLayerNormWeights "
2936 "is not a supported type.");
2937 }
2938
2939 return supported;
2940}
2941
arovir011c7c81b2018-10-08 11:34:28 +01002942} // namespace armnn