blob: 919c6db6ffe321a563b9c0872d4c7d0a9b55a935 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
Teresa Charlin52664732020-06-29 16:27:03 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01007
Keith Davis0c2eeac2020-02-11 16:51:50 +00008#include <armnn/TypesUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000010#include <armnn/utility/IgnoreUnused.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010011#include <armnn/utility/NumericCast.hpp>
Cathal Corbett34b429c2021-12-24 12:24:40 +000012#include <armnn/utility/PolymorphicDowncast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
Matteo Martincighe011d202019-11-28 11:35:47 +000014#include <LayerSupportCommon.hpp>
Derek Lambertif674aa02019-08-01 15:56:25 +010015#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000016
Derek Lamberti50db4e82019-03-13 14:16:15 +000017#include <vector>
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <array>
19
telsoa014fcda012018-03-09 14:13:49 +000020namespace armnn
21{
22
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010023namespace
24{
25
26template<typename Float32Func, typename Uint8Func, typename ... Params>
27bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
28 DataType dataType,
29 Float32Func floatFuncPtr,
30 Uint8Func uint8FuncPtr,
31 Params&&... params)
32{
33 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
34 dataType,
35 &FalseFunc<Params...>,
36 floatFuncPtr,
37 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000038 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000039 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010040 std::forward<Params>(params)...);
41}
42
43} // anonymous namespace
44
James Conroy4d1ff582019-06-10 17:06:39 +010045namespace
46{
47
48std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
49 unsigned int actual,
50 std::string& layerStr,
51 std::string& tensorName)
52{
53 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
54 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
55
56 return errorMsg;
57}
58
59} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000060
Cathal Corbett34b429c2021-12-24 12:24:40 +000061bool RefLayerSupport::IsLayerSupported(const LayerType& type,
62 const std::vector<TensorInfo>& infos,
63 const BaseDescriptor& descriptor,
64 const Optional<LstmInputParamsInfo>& lstmParamsInfo,
65 const Optional<QuantizedLstmInputParamsInfo>& quantizedLstmInputParamsInfo,
66 Optional<std::string&> reasonIfUnsupported) const
67{
68 switch (type)
69 {
70 case LayerType::Activation:
71 return IsActivationSupported(infos[0],
72 infos[1],
73 *(PolymorphicDowncast<const ActivationDescriptor*>(&descriptor)),
74 reasonIfUnsupported);
75 case LayerType::Addition:
76 return IsAdditionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
77 case LayerType::ArgMinMax:
78 return IsArgMinMaxSupported(infos[0],
79 infos[1],
80 *(PolymorphicDowncast<const ArgMinMaxDescriptor*>(&descriptor)),
81 reasonIfUnsupported);
82 case LayerType::BatchNormalization:
83 return IsBatchNormalizationSupported(infos[0],
84 infos[1],
85 infos[2],
86 infos[3],
87 infos[4],
88 infos[5],
89 *(PolymorphicDowncast<const BatchNormalizationDescriptor*>
90 (&descriptor)),
91 reasonIfUnsupported);
92 case LayerType::BatchToSpaceNd:
93 return IsBatchToSpaceNdSupported(infos[0],
94 infos[1],
95 *(PolymorphicDowncast<const BatchToSpaceNdDescriptor*>(&descriptor)),
96 reasonIfUnsupported);
97 case LayerType::Comparison:
98 return IsComparisonSupported(infos[0],
99 infos[1],
100 infos[2],
101 *(PolymorphicDowncast<const ComparisonDescriptor*>(&descriptor)),
102 reasonIfUnsupported);
103 case LayerType::Concat:
104 {
105 std::vector<const TensorInfo*> inputInfos;
106 for (uint32_t i = 0; i < (infos.size() - 1); i++)
107 {
108 inputInfos.push_back(&infos[i]);
109 }
110 return IsConcatSupported(inputInfos,
111 infos[infos.size() - 1],
112 *(PolymorphicDowncast<const OriginsDescriptor*>(&descriptor)),
113 reasonIfUnsupported);
114 }
115 case LayerType::Constant:
116 return IsConstantSupported(infos[0], reasonIfUnsupported);
117 case LayerType::ConvertBf16ToFp32:
118 return IsConvertBf16ToFp32Supported(infos[0], infos[1], reasonIfUnsupported);
119 case LayerType::ConvertFp16ToFp32:
120 return IsConvertFp16ToFp32Supported(infos[0], infos[1], reasonIfUnsupported);
121 case LayerType::ConvertFp32ToBf16:
122 return IsConvertFp32ToBf16Supported(infos[0], infos[1], reasonIfUnsupported);
123 case LayerType::ConvertFp32ToFp16:
124 return IsConvertFp32ToFp16Supported(infos[0], infos[1], reasonIfUnsupported);
125 case LayerType::Convolution2d:
126 {
127 if (infos.size() != 4)
128 {
129 throw InvalidArgumentException("Invalid number of Convolution2d TensorInfos. "
130 "TensorInfos should be of format: {input, output, weights, biases}.");
131 }
132
133 auto desc = *(PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor));
134 if (infos[3] == TensorInfo())
135 {
136 return IsConvolution2dSupported(infos[0],
137 infos[1],
138 desc,
139 infos[2],
140 EmptyOptional(),
141 reasonIfUnsupported);
142 }
143 else
144 {
145 return IsConvolution2dSupported(infos[0],
146 infos[1],
147 desc,
148 infos[2],
149 infos[3],
150 reasonIfUnsupported);
151 }
152 }
153 case LayerType::DepthToSpace:
154 return IsDepthToSpaceSupported(infos[0],
155 infos[1],
156 *(PolymorphicDowncast<const DepthToSpaceDescriptor*>(&descriptor)),
157 reasonIfUnsupported);
158 case LayerType::DepthwiseConvolution2d:
159 {
160 if (infos.size() != 4)
161 {
162 throw InvalidArgumentException("Invalid number of DepthwiseConvolution2d TensorInfos. "
163 "TensorInfos should be of format: {input, output, weights, biases}.");
164 }
165
166 auto desc = *(PolymorphicDowncast<const DepthwiseConvolution2dDescriptor*>(&descriptor));
167 if (infos[3] == TensorInfo())
168 {
169 return IsDepthwiseConvolutionSupported(infos[0],
170 infos[1],
171 desc,
172 infos[2],
173 EmptyOptional(),
174 reasonIfUnsupported);
175 }
176 else
177 {
178 return IsDepthwiseConvolutionSupported(infos[0],
179 infos[1],
180 desc,
181 infos[2],
182 infos[3],
183 reasonIfUnsupported);
184 }
185 }
186 case LayerType::Dequantize:
187 return IsDequantizeSupported(infos[0], infos[1], reasonIfUnsupported);
188 case LayerType::Division:
189 return IsDivisionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
190 case LayerType::ElementwiseUnary:
191 return IsElementwiseUnarySupported(infos[0],
192 infos[1],
193 *(PolymorphicDowncast<const ElementwiseUnaryDescriptor*>(&descriptor)),
194 reasonIfUnsupported);
195 case LayerType::Fill:
196 return IsFillSupported(infos[0],
197 infos[1],
198 *(PolymorphicDowncast<const FillDescriptor*>(&descriptor)),
199 reasonIfUnsupported);
200 case LayerType::Floor:
201 return IsFloorSupported(infos[0], infos[1], reasonIfUnsupported);
202 case LayerType::FullyConnected:
203 return IsFullyConnectedSupported(infos[0],
204 infos[1],
205 infos[2],
206 infos[3],
207 *(PolymorphicDowncast<const FullyConnectedDescriptor*>(&descriptor)),
208 reasonIfUnsupported);
209 case LayerType::Gather:
210 return IsGatherSupported(infos[0],
211 infos[1],
212 infos[2],
213 *(PolymorphicDowncast<const GatherDescriptor*>(&descriptor)),
214 reasonIfUnsupported);
Teresa Charlinb2d3ec52022-04-12 22:07:09 +0100215 case LayerType::GatherNd:
216 return IsGatherNdSupported(infos[0],
217 infos[1],
218 infos[2],
219 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000220 case LayerType::Input:
221 return IsInputSupported(infos[0], reasonIfUnsupported);
222 case LayerType::InstanceNormalization:
223 return IsInstanceNormalizationSupported(infos[0],
224 infos[1],
225 *(PolymorphicDowncast<const InstanceNormalizationDescriptor*>
226 (&descriptor)),
227 reasonIfUnsupported);
228 case LayerType::L2Normalization:
229 return IsL2NormalizationSupported(infos[0],
230 infos[1],
231 *(PolymorphicDowncast<const L2NormalizationDescriptor*>(&descriptor)),
232 reasonIfUnsupported);
233 case LayerType::LogicalBinary:
234 return IsLogicalBinarySupported(infos[0],
235 infos[1],
236 infos[2],
237 *(PolymorphicDowncast<const LogicalBinaryDescriptor*>(&descriptor)),
238 reasonIfUnsupported);
239 case LayerType::LogSoftmax:
240 return IsLogSoftmaxSupported(infos[0],
241 infos[1],
242 *(PolymorphicDowncast<const LogSoftmaxDescriptor*>(&descriptor)),
243 reasonIfUnsupported);
244 case LayerType::Lstm:
245 return IsLstmSupported(infos[0],
246 infos[1],
247 infos[2],
248 infos[3],
249 infos[4],
250 infos[5],
251 infos[6],
252 *(PolymorphicDowncast<const LstmDescriptor*>(&descriptor)),
253 lstmParamsInfo.value(),
254 reasonIfUnsupported);
255 case LayerType::QLstm:
256 return IsQLstmSupported(infos[0],
257 infos[1],
258 infos[2],
259 infos[3],
260 infos[4],
261 infos[5],
262 *(PolymorphicDowncast<const QLstmDescriptor*>(&descriptor)),
263 lstmParamsInfo.value(),
264 reasonIfUnsupported);
265 case LayerType::Maximum:
266 return IsMaximumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
267 case LayerType::Mean:
268 return IsMeanSupported(infos[0],
269 infos[1],
270 *(PolymorphicDowncast<const MeanDescriptor*>(&descriptor)),
271 reasonIfUnsupported);
272 case LayerType::Minimum:
273 return IsMinimumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
274 case LayerType::Multiplication:
275 return IsMultiplicationSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
276 case LayerType::Normalization:
277 return IsNormalizationSupported(infos[0],
278 infos[1],
279 *(PolymorphicDowncast<const NormalizationDescriptor*>(&descriptor)),
280 reasonIfUnsupported);
281 case LayerType::Output:
282 return IsOutputSupported(infos[0], reasonIfUnsupported);
283 case LayerType::Pad:
284 return IsPadSupported(infos[0],
285 infos[1],
286 *(PolymorphicDowncast<const PadDescriptor*>(&descriptor)),
287 reasonIfUnsupported);
288 case LayerType::Permute:
289 return IsPermuteSupported(infos[0],
290 infos[1],
291 *(PolymorphicDowncast<const PermuteDescriptor*>(&descriptor)),
292 reasonIfUnsupported);
293 case LayerType::Pooling2d:
294 return IsPooling2dSupported(infos[0],
295 infos[1],
296 *(PolymorphicDowncast<const Pooling2dDescriptor*>(&descriptor)),
297 reasonIfUnsupported);
298 case LayerType::Prelu:
299 return IsPreluSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
300 case LayerType::Quantize:
301 return IsQuantizeSupported(infos[0], infos[1], reasonIfUnsupported);
302 case LayerType::Reshape:
303 return IsReshapeSupported(infos[0],
304 infos[1],
305 *(PolymorphicDowncast<const ReshapeDescriptor*>(&descriptor)),
306 reasonIfUnsupported);
307 case LayerType::Resize:
308 return IsResizeSupported(infos[0],
309 infos[1],
310 *(PolymorphicDowncast<const ResizeDescriptor*>(&descriptor)),
311 reasonIfUnsupported);
312 case LayerType::Reduce:
313 return IsReduceSupported(infos[0],
314 infos[1],
315 *(PolymorphicDowncast<const ReduceDescriptor*>(&descriptor)),
316 reasonIfUnsupported);
317 case LayerType::Slice:
318 return IsSliceSupported(infos[0],
319 infos[1],
320 *(PolymorphicDowncast<const SliceDescriptor*>(&descriptor)),
321 reasonIfUnsupported);
322 case LayerType::Softmax:
323 return IsSoftmaxSupported(infos[0],
324 infos[1],
325 *(PolymorphicDowncast<const SoftmaxDescriptor*>(&descriptor)),
326 reasonIfUnsupported);
327 case LayerType::SpaceToBatchNd:
328 return IsSpaceToBatchNdSupported(infos[0],
329 infos[1],
330 *(PolymorphicDowncast<const SpaceToBatchNdDescriptor*>(&descriptor)),
331 reasonIfUnsupported);
332 case LayerType::SpaceToDepth:
333 return IsSpaceToDepthSupported(infos[0],
334 infos[1],
335 *(PolymorphicDowncast<const SpaceToDepthDescriptor*>(&descriptor)),
336 reasonIfUnsupported);
337 case LayerType::Splitter:
338 {
339 std::vector<TensorInfo> outputInfos;
340 for (uint32_t i = 1; i < infos.size(); i++)
341 {
342 outputInfos.push_back(infos[i]);
343 }
344 return IsSplitterSupported(infos[0],
345 {outputInfos.begin(), outputInfos.end()},
346 *(PolymorphicDowncast<const ViewsDescriptor*>(&descriptor)),
347 reasonIfUnsupported);
348 }
349 case LayerType::Stack:
350 {
351 std::vector<const TensorInfo*> inputInfos;
352 for (uint32_t i = 0; i < infos.size() - 1; i++)
353 {
354 inputInfos.push_back(&infos[i]);
355 }
356 return IsStackSupported(inputInfos,
357 infos[infos.size() - 1],
358 *(PolymorphicDowncast<const StackDescriptor*>(&descriptor)),
359 reasonIfUnsupported);
360 }
361 case LayerType::StridedSlice:
362 return IsStridedSliceSupported(infos[0],
363 infos[1],
364 *(PolymorphicDowncast<const StridedSliceDescriptor*>(&descriptor)),
365 reasonIfUnsupported);
366 case LayerType::Subtraction:
367 return IsSubtractionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
368 case LayerType::Transpose:
369 return IsTransposeSupported(infos[0],
370 infos[1],
371 *(PolymorphicDowncast<const TransposeDescriptor*>(&descriptor)),
372 reasonIfUnsupported);
373 case LayerType::TransposeConvolution2d:
374 {
375 if (infos.size() != 4)
376 {
377 throw InvalidArgumentException("Invalid number of TransposeConvolution2d TensorInfos. "
378 "TensorInfos should be of format: {input, output, weights, biases}.");
379 }
380
381 auto desc = *(PolymorphicDowncast<const TransposeConvolution2dDescriptor*>(&descriptor));
382 if (infos[3] == TensorInfo())
383 {
384 return IsTransposeConvolution2dSupported(infos[0],
385 infos[1],
386 desc,
387 infos[2],
388 EmptyOptional(),
389 reasonIfUnsupported);
390 }
391 else
392 {
393 return IsTransposeConvolution2dSupported(infos[0],
394 infos[1],
395 desc,
396 infos[2],
397 infos[3],
398 reasonIfUnsupported);
399 }
400 }
401 case LayerType::Cast:
402 return IsCastSupported(infos[0], infos[1], reasonIfUnsupported);
403 case LayerType::ChannelShuffle:
404 return IsChannelShuffleSupported(infos[0],
405 infos[1],
406 *(PolymorphicDowncast<const ChannelShuffleDescriptor*>(&descriptor)),
407 reasonIfUnsupported);
408 case LayerType::Convolution3d:
409 {
410 if (infos.size() != 4)
411 {
412 throw InvalidArgumentException("Invalid number of Convolution3d TensorInfos. "
413 "TensorInfos should be of format: {input, output, weights, biases}.");
414 }
415
416 auto desc = *(PolymorphicDowncast<const Convolution3dDescriptor*>(&descriptor));
417 if (infos[3] == TensorInfo())
418 {
419 return IsConvolution3dSupported(infos[0],
420 infos[1],
421 desc,
422 infos[2],
423 EmptyOptional(),
424 reasonIfUnsupported);
425 }
426 else
427 {
428 return IsConvolution3dSupported(infos[0],
429 infos[1],
430 desc,
431 infos[2],
432 infos[3],
433 reasonIfUnsupported);
434 }
435 }
436 case LayerType::Debug:
437 return IsDebugSupported(infos[0], infos[1], reasonIfUnsupported);
438 case LayerType::DetectionPostProcess:
439 return IsDetectionPostProcessSupported(infos[0],
440 infos[1],
441 infos[2],
442 infos[3],
443 infos[4],
444 infos[5],
445 infos[6],
446 *(PolymorphicDowncast<const DetectionPostProcessDescriptor*>
447 (&descriptor)),
448 reasonIfUnsupported);
449 case LayerType::FakeQuantization:
450 return IsFakeQuantizationSupported(infos[0],
451 *(PolymorphicDowncast<const FakeQuantizationDescriptor*>(&descriptor)),
452 reasonIfUnsupported);
453 case LayerType::MemCopy:
454 return IsMemCopySupported(infos[0], infos[1], reasonIfUnsupported);
455 case LayerType::Rank:
456 return IsRankSupported(infos[0], infos[1], reasonIfUnsupported);
457 case LayerType::Shape:
458 return IsShapeSupported(infos[0], infos[1], reasonIfUnsupported);
459 case LayerType::UnidirectionalSequenceLstm:
460 {
461 if (infos.size() != 6)
462 {
463 throw InvalidArgumentException("Invalid number of UnidirectionalSequenceLstm TensorInfos. TensorInfos "
464 "should be of format: {input, outputStateIn, cellStateIn, "
465 "hiddenStateOutputVal, cellStateOutputVal, output}");
466 }
467 auto desc = *(PolymorphicDowncast<const UnidirectionalSequenceLstmDescriptor*>(&descriptor));
Mike Kelly12994962022-04-21 11:57:09 +0100468 return IsUnidirectionalSequenceLstmSupported(infos[0],
469 infos[1],
470 infos[2],
471 infos[3],
472 infos[4],
473 infos[5],
474 desc,
475 lstmParamsInfo.value(),
476 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000477 }
478 case LayerType::Pooling3d:
479 return IsPooling3dSupported(infos[0],
480 infos[1],
481 *(PolymorphicDowncast<const Pooling3dDescriptor*>(&descriptor)),
482 reasonIfUnsupported);
483 case LayerType::Map:
484 return true;
485 case LayerType::Unmap:
486 return true;
487 case LayerType::MemImport:
488 return LayerSupportBase::IsMemImportSupported(infos[0], infos[1], reasonIfUnsupported);
489 case LayerType::Merge:
490 return LayerSupportBase::IsMergeSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
491 case LayerType::QuantizedLstm:
492 return LayerSupportBase::IsQuantizedLstmSupported(infos[0],
493 infos[1],
494 infos[2],
495 infos[3],
496 infos[4],
497 quantizedLstmInputParamsInfo.value(),
498 reasonIfUnsupported);
499 default:
500 // layers not supported in neon by default:
501 // precompiled, standin, switch
502 return false;
503 }
504}
505
arovir011c7c81b2018-10-08 11:34:28 +0100506bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
507 const TensorInfo& output,
508 const ActivationDescriptor& descriptor,
509 Optional<std::string&> reasonIfUnsupported) const
510{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000511 bool supported = true;
512
513 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +0000514 std::array<DataType,6> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000515 DataType::BFloat16,
Derek Lamberti50db4e82019-03-13 14:16:15 +0000516 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100517 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000518 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000519 DataType::QAsymmU8,
520 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000521 };
522
523 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
524 "Reference activation: input type not supported.");
525
526 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
527 "Reference activation: output type not supported.");
528
529 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
530 "Reference activation: input and output types mismatched.");
531
532 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
533 "Reference activation: input and output shapes are of different rank.");
534
535
536 struct ActivationFunctionSupported : public Rule
537 {
538 ActivationFunctionSupported(const ActivationDescriptor& desc)
539 {
540 switch(desc.m_Function)
541 {
542 case ActivationFunction::Abs:
543 case ActivationFunction::BoundedReLu:
David Monahan3b3c3812020-02-25 09:03:29 +0000544 case ActivationFunction::Elu:
Colm Donelan03fbeaf2020-02-26 15:39:23 +0000545 case ActivationFunction::HardSwish:
Derek Lamberti50db4e82019-03-13 14:16:15 +0000546 case ActivationFunction::LeakyReLu:
547 case ActivationFunction::Linear:
548 case ActivationFunction::ReLu:
549 case ActivationFunction::Sigmoid:
550 case ActivationFunction::SoftReLu:
551 case ActivationFunction::Sqrt:
552 case ActivationFunction::Square:
553 case ActivationFunction::TanH:
554 {
555 m_Res = true;
556 break;
557 }
558 default:
559 {
560 m_Res = false;
561 break;
562 }
563 }
564 }
565 };
566
567 // Function is supported
568 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
569 "Reference activation: function not supported.");
570
571 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100572}
573
574bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
575 const TensorInfo& input1,
576 const TensorInfo& output,
577 Optional<std::string&> reasonIfUnsupported) const
578{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000579 bool supported = true;
580
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100581 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000582 DataType::BFloat16,
Derek Lamberti50db4e82019-03-13 14:16:15 +0000583 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100584 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000585 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000586 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100587 DataType::QSymmS16,
588 DataType::Signed32
Derek Lamberti50db4e82019-03-13 14:16:15 +0000589 };
590
591 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
592 "Reference addition: input 0 is not a supported type.");
593
594 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
595 "Reference addition: input 1 is not a supported type.");
596
597 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
598 "Reference addition: output is not a supported type.");
599
600 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
601 "Reference addition: input 0 and Input 1 types are mismatched");
602
603 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
604 "Reference addition: input and output types are mismatched");
605
606 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
607 "Reference addition: shapes are not suitable for implicit broadcast.");
608
609 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100610}
611
Nikhil Raj68c2c902019-09-19 11:21:11 +0100612bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
613 const armnn::ArgMinMaxDescriptor &descriptor,
614 armnn::Optional<std::string &> reasonIfUnsupported) const
615{
Jan Eilers8eb25602020-03-09 12:13:48 +0000616 IgnoreUnused(descriptor);
Nikhil Raj68c2c902019-09-19 11:21:11 +0100617
Mike Kelly1f140f72021-04-06 12:25:55 +0100618 std::array<DataType, 8> supportedInputTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100619 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000620 DataType::BFloat16,
Teresa Charline300b362020-05-25 10:01:03 +0100621 DataType::Float16,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100622 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100623 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000624 DataType::QAsymmU8,
625 DataType::QSymmS16,
Mike Kelly1f140f72021-04-06 12:25:55 +0100626 DataType::Signed32,
627 DataType::Signed64
628 };
629
630 std::array<DataType,2> supportedOutputTypes = {
631 DataType::Signed32,
632 DataType::Signed64
Nikhil Raj68c2c902019-09-19 11:21:11 +0100633 };
634
635 bool supported = true;
636
Mike Kelly1f140f72021-04-06 12:25:55 +0100637 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100638 "Reference ArgMinMax: input is not a supported type.");
Mike Kelly1f140f72021-04-06 12:25:55 +0100639 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100640 "Reference ArgMinMax: output type not supported");
641
642 return supported;
643}
644
arovir011c7c81b2018-10-08 11:34:28 +0100645bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
646 const TensorInfo& output,
647 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100648 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100649 const TensorInfo& beta,
650 const TensorInfo& gamma,
651 const BatchNormalizationDescriptor& descriptor,
652 Optional<std::string&> reasonIfUnsupported) const
653{
Jan Eilers8eb25602020-03-09 12:13:48 +0000654 IgnoreUnused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100655
Sadik Armagan303980c2020-04-17 12:45:14 +0100656 std::array<DataType, 6> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100657 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000658 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100659 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100660 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100661 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000662 DataType::QAsymmU8,
663 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100664 };
665
666 bool supported = true;
667
668 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
669 "Reference batch normalization: input is not a supported type.");
670
671 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
672 "Reference batch normalization: output is not a supported type.");
673
674 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
675 "Reference batch normalization: input and output types are mismatched");
676
677 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
678 "Reference batch normalization: mean is not a supported type.");
679
680 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
681 "Reference batch normalization: variance is not a supported type.");
682
683 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
684 "Reference batch normalization: beta is not a supported type.");
685
686 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
687 "Reference batch normalization: gamma is not a supported type.");
688
689 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100690}
691
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000692bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
693 const TensorInfo& output,
694 const BatchToSpaceNdDescriptor& descriptor,
695 Optional<std::string&> reasonIfUnsupported) const
696{
Jan Eilers8eb25602020-03-09 12:13:48 +0000697 IgnoreUnused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100698
699 bool supported = true;
700
701 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
702 std::string inputTensorStr = "input";
703 std::string outputTensorStr = "output";
704
705 // Define supported types.
Sadik Armagan303980c2020-04-17 12:45:14 +0100706 std::array<DataType,6> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100707 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000708 DataType::BFloat16,
709 DataType::Float32,
710 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100711 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000712 DataType::QAsymmU8,
713 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100714 };
715
716 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
717 "Reference BatchToSpaceNd: input type not supported.");
718
719 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
720 "Reference BatchToSpaceNd: output type not supported.");
721
722 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
723 "Reference BatchToSpaceNd: input and output types mismatched.");
724
725 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
726 reasonIfUnsupported,
727 CreateIncorrectDimensionsErrorMsg(4,
728 output.GetNumDimensions(),
729 batchToSpaceNdLayerStr,
730 outputTensorStr).data());
731
732 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
733 reasonIfUnsupported,
734 CreateIncorrectDimensionsErrorMsg(4,
735 input.GetNumDimensions(),
736 batchToSpaceNdLayerStr,
737 inputTensorStr).data());
738
739 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000740}
741
mathad01b392e982021-04-07 12:07:30 +0100742bool RefLayerSupport::IsCastSupported(const TensorInfo& input,
743 const TensorInfo& output,
744 Optional<std::string&> reasonIfUnsupported) const
745{
746 std::array<DataType, 9> supportedInputTypes =
747 {
748 DataType::BFloat16,
749 DataType::Float32,
750 DataType::Float16,
751 DataType::QSymmS8,
752 DataType::QAsymmS8,
753 DataType::QAsymmU8,
754 DataType::QSymmS16,
755 DataType::Signed32
756 };
757
758 bool supported = true;
759 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
760 "Reference cast: input is not a supported type");
761
762
763 supported &= CheckSupportRule(TypeAnyOf(output, supportedInputTypes), reasonIfUnsupported,
764 "Reference cast: output is not a supported type");
765
766 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
767 "Reference cast: input and output shapes have different number of total elements");
768
769 return supported;
770}
771
Simon Obute51f67772021-09-03 15:50:13 +0100772bool RefLayerSupport::IsChannelShuffleSupported(const TensorInfo& input,
773 const TensorInfo& output,
774 const ChannelShuffleDescriptor& descriptor,
775 Optional<std::string&> reasonIfUnsupported) const
776{
777 IgnoreUnused(descriptor);
778 bool supported = true;
779
780 // Define supported output and inputs types.
781 std::array<DataType, 7> supportedTypes =
782 {
783 DataType::BFloat16,
784 DataType::Float32,
785 DataType::Float16,
786 DataType::QAsymmS8,
787 DataType::QAsymmU8,
788 DataType::QSymmS8,
789 DataType::QSymmS16
790 };
791
792 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
793 "Reference ChannelShuffle: input is not a supported type.");
794
795 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
796 "Reference ChannelShuffle: output is not a supported type.");
797
798 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
799 "Reference ChannelShuffle: input and output types are mismatched.");
800
801 return supported;
802}
803
804
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100805bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
806 const TensorInfo& input1,
807 const TensorInfo& output,
808 const ComparisonDescriptor& descriptor,
809 Optional<std::string&> reasonIfUnsupported) const
810{
Jan Eilers8eb25602020-03-09 12:13:48 +0000811 IgnoreUnused(descriptor);
Sadik Armagan303980c2020-04-17 12:45:14 +0100812 std::array<DataType, 8> supportedInputTypes =
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100813 {
Sadik Armaganb60dd242020-03-19 13:53:16 +0000814 DataType::Boolean,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000815 DataType::BFloat16,
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100816 DataType::Float32,
817 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100818 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000819 DataType::QAsymmU8,
Sadik Armaganb60dd242020-03-19 13:53:16 +0000820 DataType::QSymmS16,
821 DataType::Signed32
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100822 };
823
824 bool supported = true;
825 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
826 "Reference comparison: input 0 is not a supported type");
827
828 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
829 "Reference comparison: input 0 and Input 1 types are mismatched");
830
831 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
832 "Reference comparison: output is not of type Boolean");
833
834 return supported;
835}
836
Jim Flynn906f9462019-05-10 13:55:21 +0100837bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
838 const TensorInfo& output,
Cathal Corbett34b429c2021-12-24 12:24:40 +0000839 const OriginsDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100840 Optional<std::string&> reasonIfUnsupported) const
841{
Jan Eilers8eb25602020-03-09 12:13:48 +0000842 IgnoreUnused(descriptor);
Jim Flynne242f2d2019-05-22 14:24:13 +0100843
844 bool supported = true;
Teresa Charlin6abc7ee2022-02-22 17:32:27 +0000845 std::array<DataType,7> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100846 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000847 DataType::BFloat16,
848 DataType::Float32,
849 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000850 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100851 DataType::QAsymmU8,
Teresa Charlin6abc7ee2022-02-22 17:32:27 +0000852 DataType::QSymmS16,
853 DataType::Signed32
Jim Flynne242f2d2019-05-22 14:24:13 +0100854 };
855
856 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
857 "Reference concatenation: output type not supported");
858 for (const TensorInfo* input : inputs)
859 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100860 ARMNN_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100861 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
862 "Reference concatenation: input type not supported");
863
864 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
865 "Reference concatenation: input and output types mismatched.");
866 }
867
868 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100869}
870
arovir011c7c81b2018-10-08 11:34:28 +0100871bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
872 Optional<std::string&> reasonIfUnsupported) const
873{
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100874 std::array<DataType,8> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100875 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000876 DataType::BFloat16,
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100877 DataType::Float16,
Nina Drozd58ef2c62019-05-16 12:09:18 +0100878 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +0000879 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100880 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000881 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100882 DataType::QSymmS16,
883 DataType::Signed32
Nina Drozd58ef2c62019-05-16 12:09:18 +0100884 };
885
886 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
887 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100888}
889
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000890bool RefLayerSupport::IsConvertBf16ToFp32Supported(const TensorInfo& input,
891 const TensorInfo& output,
892 Optional<std::string&> reasonIfUnsupported) const
893{
894 bool supported = true;
895
896 supported &= CheckSupportRule(TypeIs(input, DataType::BFloat16), reasonIfUnsupported,
897 "Reference for ConvertBf16ToFp32 layer: input type not supported");
898
899 supported &= CheckSupportRule(TypeIs(output, DataType::Float32), reasonIfUnsupported,
900 "Reference for ConvertBf16ToFp32 layer: output type not supported");
901
902 return supported;
903}
904
arovir011c7c81b2018-10-08 11:34:28 +0100905bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
906 const TensorInfo& output,
907 Optional<std::string&> reasonIfUnsupported) const
908{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100909 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
910 input.GetDataType(),
911 &TrueFunc<>,
912 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000913 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000914 &FalseFuncI32<>,
915 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100916 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
917 output.GetDataType(),
918 &FalseOutputFuncF16<>,
919 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000920 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000921 &FalseFuncI32<>,
922 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100923}
924
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000925bool RefLayerSupport::IsConvertFp32ToBf16Supported(const TensorInfo& input,
926 const TensorInfo& output,
927 Optional<std::string&> reasonIfUnsupported) const
928{
929 bool supported = true;
930
931 supported &= CheckSupportRule(TypeIs(input, DataType::Float32), reasonIfUnsupported,
932 "Reference for ConvertFp32ToBf16 layer: input type not supported");
933
934 supported &= CheckSupportRule(TypeIs(output, DataType::BFloat16), reasonIfUnsupported,
935 "Reference for ConvertFp32ToBf16 layer: output type not supported");
936
937 return supported;
938}
939
arovir011c7c81b2018-10-08 11:34:28 +0100940bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
941 const TensorInfo& output,
942 Optional<std::string&> reasonIfUnsupported) const
943{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100944 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
945 input.GetDataType(),
946 &FalseInputFuncF16<>,
947 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000948 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000949 &FalseFuncI32<>,
950 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100951 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
952 output.GetDataType(),
953 &TrueFunc<>,
954 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000955 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000956 &FalseFuncI32<>,
957 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100958}
959
960bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
961 const TensorInfo& output,
962 const Convolution2dDescriptor& descriptor,
963 const TensorInfo& weights,
964 const Optional<TensorInfo>& biases,
965 Optional<std::string&> reasonIfUnsupported) const
966{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100967 bool supported = true;
968
969 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000970 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000971 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000972 DataType::BFloat16,
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000973 DataType::Float32,
974 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000975 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100976 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000977 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000978 DataType::QSymmS16
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100979 };
980
981 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000982 "Reference Convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100983
984 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000985 "Reference Convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100986
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +0000987 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
988 if (input.GetDataType() == DataType::BFloat16)
989 {
990 if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
991 {
992 reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
993 supported = false;
994 }
995 }
996 else
997 {
998 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000999 "Reference Convolution2d: input and output types mismatched.");
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001000 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001001
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001002 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +00001003 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001004 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001005 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001006 {
Sadik Armagan303980c2020-04-17 12:45:14 +01001007 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001008 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01001009 DataType::QSymmS8
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001010 };
1011
1012 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001013 "Reference Convolution2d: weights type not supported for quantized input.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001014 }
1015 else
1016 {
1017 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001018 "Reference Convolution2d: weights is not a supported type.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001019
1020 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001021 "Reference Convolution2d: input and weights types mismatched.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001022 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001023
1024 if (biases.has_value())
1025 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001026 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001027 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001028 DataType::BFloat16,
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001029 DataType::Float32,
1030 DataType::Float16,
1031 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001032 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001033
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001034 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001035 "Reference Convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001036 }
Jan Eilers8eb25602020-03-09 12:13:48 +00001037 IgnoreUnused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001038
1039 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001040}
1041
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001042bool RefLayerSupport::IsConvolution3dSupported(const TensorInfo& input,
1043 const TensorInfo& output,
1044 const Convolution3dDescriptor& descriptor,
1045 const TensorInfo& weights,
1046 const Optional<TensorInfo>& biases,
1047 Optional<std::string&> reasonIfUnsupported) const
1048{
1049 bool supported = true;
1050
1051 // Define supported types.
1052 std::array<DataType,7> supportedTypes =
1053 {
1054 DataType::BFloat16,
1055 DataType::Float32,
1056 DataType::Float16,
1057 DataType::QAsymmS8,
1058 DataType::QAsymmU8,
1059 DataType::QSymmS8,
1060 DataType::QSymmS16
1061 };
1062
1063 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1064 "Reference Convolution3d: input is not a supported type.");
1065
1066 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1067 "Reference Convolution3d: output is not a supported type.");
1068
1069 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1070 "Reference Convolution3d: input and output types mismatched.");
1071
1072 const DataType inputType = input.GetDataType();
1073 if (IsQuantized8BitType(inputType))
1074 {
1075 std::array<DataType, 3> supportedWeightTypes =
1076 {
1077 DataType::QAsymmS8,
1078 DataType::QAsymmU8,
1079 DataType::QSymmS8
1080 };
1081
1082 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1083 "Reference Convolution3d: weights type not supported for quantized input.");
1084 }
1085 else
1086 {
1087 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1088 "Reference Convolution3d: weights is not a supported type.");
1089
1090 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1091 "Reference Convolution3d: input and weights types mismatched.");
1092 }
1093
1094 if (biases.has_value())
1095 {
1096 std::array<DataType,4> biasesSupportedTypes =
1097 {
1098 DataType::BFloat16,
1099 DataType::Float32,
1100 DataType::Float16,
1101 DataType::Signed32
1102 };
1103
1104 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1105 "Reference Convolution3d: biases is not a supported type.");
1106 }
1107 IgnoreUnused(descriptor);
1108
1109 return supported;
1110}
1111
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001112bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
1113 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001114 Optional<std::string&> reasonIfUnsupported) const
1115{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001116 bool supported = true;
1117
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001118 std::array<DataType, 8> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001119 {
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001120 DataType::BFloat16,
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +00001121 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001122 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001123 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001124 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001125 DataType::QSymmS8,
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +00001126 DataType::QSymmS16,
1127 DataType::Signed32
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001128 };
1129
1130 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001131 "Reference for Debug layer: input type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001132
1133 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001134 "Reference for Debug layer: output type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001135
1136 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001137 "Reference for Debug layer: input and output types are mismatched");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001138
1139 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001140}
1141
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001142bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
1143 const TensorInfo& output,
1144 const DepthToSpaceDescriptor& descriptor,
1145 Optional<std::string&> reasonIfUnsupported) const
1146{
Jan Eilers8eb25602020-03-09 12:13:48 +00001147 IgnoreUnused(descriptor);
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001148 bool supported = true;
1149
Sadik Armagan303980c2020-04-17 12:45:14 +01001150 std::array<DataType,6> supportedTypes =
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001151 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001152 DataType::BFloat16,
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001153 DataType::Float32,
1154 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001155 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001156 DataType::QAsymmU8,
1157 DataType::QSymmS16
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001158 };
1159
1160 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1161 "Reference DepthToSpace: input type not supported");
1162
1163 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1164 "Reference DepthToSpace: output type not supported");
1165
1166 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1167 "Reference DepthToSpace: input and output types are mismatched");
1168
1169 return supported;
1170}
1171
arovir011c7c81b2018-10-08 11:34:28 +01001172bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
1173 const TensorInfo& output,
1174 const DepthwiseConvolution2dDescriptor& descriptor,
1175 const TensorInfo& weights,
1176 const Optional<TensorInfo>& biases,
1177 Optional<std::string&> reasonIfUnsupported) const
1178{
Sadik Armagan303980c2020-04-17 12:45:14 +01001179 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001180 bool supported = true;
1181
1182 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001183 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001184 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001185 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001186 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001187 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001188 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001189 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001190 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001191 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001192 };
1193
1194 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1195 "Reference DepthwiseConvolution2d: input is not a supported type.");
1196
1197 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1198 "Reference DepthwiseConvolution2d: output is not a supported type.");
1199
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001200 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1201 "Reference DepthwiseConvolution2d: input and output types mismatched.");
1202
Teresa Charlind8df0262019-11-11 12:28:15 +00001203 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +00001204 if (IsQuantized8BitType(inputType))
Teresa Charlind8df0262019-11-11 12:28:15 +00001205 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001206 std::array<DataType, 3> supportedWeightTypes =
Sadik Armagan303980c2020-04-17 12:45:14 +01001207 {
1208 DataType::QAsymmS8,
1209 DataType::QAsymmU8,
1210 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001211 };
Teresa Charlind8df0262019-11-11 12:28:15 +00001212
1213 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Sadik Armagan303980c2020-04-17 12:45:14 +01001214 "Reference DepthwiseConvolution2d: weights type not supported for "
1215 "quantized input.");
Teresa Charlind8df0262019-11-11 12:28:15 +00001216 }
1217 else
1218 {
1219 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1220 "Reference DepthwiseConvolution2d: weights is not a supported type.");
1221
1222 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1223 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
1224 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001225
1226 if (biases.has_value())
1227 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001228 std::array<DataType,4> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001229 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001230 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001231 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001232 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001233 DataType::Signed32
1234 };
1235 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1236 "Reference DepthwiseConvolution2d: biases is not a supported type.");
1237 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001238
1239 return supported;
1240
arovir011c7c81b2018-10-08 11:34:28 +01001241}
1242
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001243bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
1244 const TensorInfo& output,
1245 Optional<std::string&> reasonIfUnsupported) const
1246{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001247 bool supported = true;
1248
Ryan OShea9add1202020-02-07 10:06:33 +00001249 std::array<DataType,4> supportedInputTypes = {
1250 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001251 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00001252 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001253 DataType::QSymmS16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001254 };
1255
1256 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001257 "Reference for Dequantize layer: input type not supported.");
1258
Derek Lambertid466a542020-01-22 15:37:29 +00001259 supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
Teresa Charlin1b1950d2021-06-02 20:23:21 +01001260 "Reference for Dequantize layer: per-axis quantized input not supported.");
Derek Lambertid466a542020-01-22 15:37:29 +00001261
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001262 std::array<DataType,3> supportedOutputTypes = {
1263 DataType::BFloat16,
Jan Eilersf7107932019-11-01 11:09:36 +00001264 DataType::Float32,
1265 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001266 };
1267
1268 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001269 "Reference for Dequantize layer: output type not supported.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001270
1271 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001272 "Reference for Dequantize layer: input/output shapes have different num total "
1273 "elements.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001274
1275 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001276}
1277
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001278bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
1279 const TensorInfo& scores,
1280 const TensorInfo& anchors,
1281 const TensorInfo& detectionBoxes,
1282 const TensorInfo& detectionClasses,
1283 const TensorInfo& detectionScores,
1284 const TensorInfo& numDetections,
1285 const DetectionPostProcessDescriptor& descriptor,
1286 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001287{
Jan Eilers8eb25602020-03-09 12:13:48 +00001288 IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
Derek Lamberti901ea112019-12-10 22:07:09 +00001289
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001290 bool supported = true;
1291
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001292 std::array<DataType,6> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001293 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001294 DataType::BFloat16,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001295 DataType::Float32,
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001296 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001297 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001298 DataType::QAsymmU8,
1299 DataType::QSymmS16
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001300 };
1301
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001302 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001303 "Reference DetectionPostProcess: input 0 is not a supported type.");
1304
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001305 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001306 "Reference DetectionPostProcess: input 1 is not a supported type.");
1307
1308 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001309}
1310
Pablo Tellof0bd6832019-04-26 17:58:13 +01001311bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
1312 const TensorInfo& output,
1313 const DepthwiseConvolution2dDescriptor& descriptor,
1314 const TensorInfo& weights,
1315 const Optional<TensorInfo>& biases,
1316 Optional<std::string&> reasonIfUnsupported) const
1317{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001318 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +01001319}
1320
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001321bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +01001322 const TensorInfo& input1,
1323 const TensorInfo& output,
1324 Optional<std::string&> reasonIfUnsupported) const
1325{
Sadik Armagan2999a022019-04-09 14:20:12 +01001326 bool supported = true;
1327
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001328 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001329 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001330 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001331 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001332 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001333 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001334 DataType::QSymmS16,
1335 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001336 };
1337
1338 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1339 "Reference division: input 0 is not a supported type.");
1340
1341 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1342 "Reference division: input 1 is not a supported type.");
1343
1344 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1345 "Reference division: output is not a supported type.");
1346
1347 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1348 "Reference division: input 0 and Input 1 types are mismatched");
1349
1350 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1351 "Reference division: input and output types are mismatched");
1352
1353 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1354 "Reference division: shapes are not suitable for implicit broadcast.");
1355
1356 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001357}
1358
josh minor4a3c6102020-01-06 16:40:46 -06001359bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
1360 const TensorInfo& output,
1361 const ElementwiseUnaryDescriptor& descriptor,
1362 Optional<std::string&> reasonIfUnsupported) const
1363{
Jan Eilers8eb25602020-03-09 12:13:48 +00001364 IgnoreUnused(descriptor);
josh minor4a3c6102020-01-06 16:40:46 -06001365
Sadik Armagan303980c2020-04-17 12:45:14 +01001366 std::array<DataType, 7> supportedTypes =
josh minor4a3c6102020-01-06 16:40:46 -06001367 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001368 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -06001369 DataType::Float32,
1370 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001371 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06001372 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00001373 DataType::QSymmS16,
1374 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06001375 };
1376
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001377 std::array<DataType, 1> logicalSupportedTypes =
1378 {
1379 DataType::Boolean
1380 };
1381
josh minor4a3c6102020-01-06 16:40:46 -06001382 bool supported = true;
1383
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001384 if (descriptor.m_Operation == UnaryOperation::LogicalNot)
1385 {
1386 supported &= CheckSupportRule(TypeAnyOf(input, logicalSupportedTypes), reasonIfUnsupported,
1387 "Reference elementwise unary: input type not supported");
josh minor4a3c6102020-01-06 16:40:46 -06001388
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001389 supported &= CheckSupportRule(TypeAnyOf(output, logicalSupportedTypes), reasonIfUnsupported,
1390 "Reference elementwise unary: output type not supported");
1391 }
1392 else
1393 {
1394 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1395 "Reference elementwise unary: input type not supported");
1396
1397 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1398 "Reference elementwise unary: output type not supported");
1399 }
josh minor4a3c6102020-01-06 16:40:46 -06001400
1401 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1402 "Reference elementwise unary: input and output types not matching");
1403
1404 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1405 "Reference elementwise unary: input and output shapes"
1406 "have different number of total elements");
1407
1408 return supported;
1409}
1410
arovir011c7c81b2018-10-08 11:34:28 +01001411bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
1412 const FakeQuantizationDescriptor& descriptor,
1413 Optional<std::string&> reasonIfUnsupported) const
1414{
Jan Eilers8eb25602020-03-09 12:13:48 +00001415 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001416 bool supported = true;
1417
1418 std::array<DataType,1> supportedTypes =
1419 {
1420 DataType::Float32
1421 };
1422
1423 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1424 "Reference fake quantization: input type not supported.");
1425
1426 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001427}
1428
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001429bool RefLayerSupport::IsFillSupported(const TensorInfo& input,
1430 const TensorInfo& output,
1431 const FillDescriptor& descriptor,
1432 Optional<std::string&> reasonIfUnsupported) const
1433{
1434 IgnoreUnused(descriptor);
1435 IgnoreUnused(output);
1436
1437 bool supported = true;
1438
Sadik Armagana792a052020-06-23 16:22:23 +01001439 std::array<DataType,3> supportedTypes =
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001440 {
1441 DataType::Float32,
Sadik Armagana792a052020-06-23 16:22:23 +01001442 DataType::Float16,
1443 DataType::Signed32
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001444 };
1445
Teresa Charlin4b10fef2020-07-29 09:36:41 +01001446 supported &= CheckSupportRule(TypeIs(input, DataType::Signed32), reasonIfUnsupported,
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001447 "Reference Fill: input type not supported.");
1448
Teresa Charlin44088502020-07-27 11:27:19 +01001449 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1450 "Reference Fill: output type not supported.");
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001451 return supported;
1452}
1453
arovir011c7c81b2018-10-08 11:34:28 +01001454bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
1455 const TensorInfo& output,
1456 Optional<std::string&> reasonIfUnsupported) const
1457{
Jan Eilers8eb25602020-03-09 12:13:48 +00001458 IgnoreUnused(output);
James Conroy83735b12019-05-30 16:36:59 +01001459 bool supported = true;
1460
Teresa Charlin38b72e82022-05-04 17:54:19 +01001461 std::array<DataType,4> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +01001462 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001463 DataType::BFloat16,
James Conroyb40d7102019-06-04 12:32:09 +01001464 DataType::Float32,
Teresa Charlin38b72e82022-05-04 17:54:19 +01001465 DataType::Float16,
1466 DataType::Signed32
James Conroy83735b12019-05-30 16:36:59 +01001467 };
1468
1469 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1470 "Reference Floor: input type not supported.");
1471
1472 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1473 "Reference Floor: output type not supported.");
1474
1475 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001476}
1477
1478bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
1479 const TensorInfo& output,
1480 const TensorInfo& weights,
1481 const TensorInfo& biases,
1482 const FullyConnectedDescriptor& descriptor,
1483 Optional<std::string&> reasonIfUnsupported) const
1484{
Francis Murtagh46c09d02019-05-28 08:15:28 +01001485 bool supported = true;
1486
1487 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001488 std::array<DataType,6> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +01001489 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001490 DataType::BFloat16,
1491 DataType::Float32,
1492 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001493 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001494 DataType::QAsymmU8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001495 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001496 };
1497
1498 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1499 "Reference Fully Connected: input type not supported.");
1500
1501 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1502 "Reference Fully Connected: output type not supported.");
1503
Francis Murtagh46c09d02019-05-28 08:15:28 +01001504 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1505 "Reference Fully Connected: weights type not supported.");
1506
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001507 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1508 if (input.GetDataType() == DataType::BFloat16)
1509 {
1510 if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
1511 {
1512 reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
1513 supported = false;
1514 }
1515 }
1516 else
1517 {
1518 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1519 "Reference Fully Connected: input and output types mismatched.");
1520 }
1521
Jan Eilers1f45dc32020-06-15 11:43:03 +01001522 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1523 "Reference Fully Connected: weights is not a supported type.");
Francis Murtaghddb1d062020-03-10 13:51:45 +00001524
Jan Eilers1f45dc32020-06-15 11:43:03 +01001525 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1526 "Reference Fully Connected: input and weights types mismatched.");
Francis Murtagh46c09d02019-05-28 08:15:28 +01001527
1528 if (descriptor.m_BiasEnabled)
1529 {
1530 // Defined supported types for bias
Sadik Armagandb73c982020-04-01 17:35:30 +01001531 std::array<DataType, 5>
Francis Murtagh46c09d02019-05-28 08:15:28 +01001532 supportedBiasTypes =
1533 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001534 DataType::BFloat16,
Francis Murtagh46c09d02019-05-28 08:15:28 +01001535 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001536 DataType::Float16,
Sadik Armagandb73c982020-04-01 17:35:30 +01001537 DataType::Signed32,
1538 DataType::QAsymmS8
Francis Murtagh46c09d02019-05-28 08:15:28 +01001539 };
1540
1541 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
1542 "Reference Fully Connected: bias type not supported.");
1543
1544 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
1545 "Reference Fully Connected: bias and weight types mismatch.");
1546
1547 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
1548 "Reference Fully Connected: bias type inferred from weights is incompatible.");
1549
Narumol Prangnawarat366d7232020-04-29 12:58:17 +01001550 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(biases, 1U), reasonIfUnsupported,
1551 "Reference Fully Connected: bias must have 1 dimension.");
1552
Francis Murtagh46c09d02019-05-28 08:15:28 +01001553 }
1554
1555 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001556}
1557
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01001558bool RefLayerSupport::IsGatherNdSupported(const armnn::TensorInfo& input0,
1559 const armnn::TensorInfo& input1,
1560 const armnn::TensorInfo& output,
1561 armnn::Optional<std::string&> reasonIfUnsupported) const
1562{
1563 bool supported = true;
1564 std::array<DataType,7> supportedTypes =
1565 {
1566 DataType::BFloat16,
1567 DataType::Float32,
1568 DataType::Float16,
1569 DataType::QAsymmS8,
1570 DataType::QAsymmU8,
1571 DataType::QSymmS16,
1572 DataType::Signed32
1573 };
1574
1575 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1576 "Reference GatherNd: input type not supported");
1577
1578 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1579 "Reference GatherNd: output type not supported");
1580
1581 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1582 "Reference GatherNd: indices (input1) type not supported");
1583
1584 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1585 "Reference GatherNd: input and output types not matching");
1586
1587 return supported;
1588}
1589
narpra014951d842019-01-18 16:53:53 +00001590bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
1591 const armnn::TensorInfo& input1,
1592 const armnn::TensorInfo& output,
Teresa Charlin52664732020-06-29 16:27:03 +01001593 const GatherDescriptor& descriptor,
narpra014951d842019-01-18 16:53:53 +00001594 armnn::Optional<std::string&> reasonIfUnsupported) const
1595{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001596 bool supported = true;
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001597 std::array<DataType,7> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001598 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001599 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001600 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001601 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001602 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001603 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001604 DataType::QSymmS16,
1605 DataType::Signed32
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001606 };
1607
Teresa Charlin52664732020-06-29 16:27:03 +01001608 if (descriptor.m_Axis != 0)
1609 {
1610 reasonIfUnsupported.value() += std::string("Reference Gather: axis not supported\n");
1611 supported &= false;
1612 }
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001613 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1614 "Reference Gather: input type not supported");
1615
1616 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1617 "Reference Gather: output type not supported");
1618
1619 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1620 "Reference Gather: indices (input1) type not supported");
1621
1622 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1623 "Reference Gather: input and output types not matching");
1624
1625 return supported;
narpra014951d842019-01-18 16:53:53 +00001626}
1627
Derek Lamberti901ea112019-12-10 22:07:09 +00001628bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
1629 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +01001630{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001631 return true;
arovir011c7c81b2018-10-08 11:34:28 +01001632}
1633
Kevin May09ca49c2019-10-09 12:37:34 +01001634bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
1635 const TensorInfo& output,
1636 const InstanceNormalizationDescriptor& descriptor,
1637 Optional<std::string&> reasonIfUnsupported) const
1638{
Jan Eilers8eb25602020-03-09 12:13:48 +00001639 IgnoreUnused(descriptor);
Kevin May09ca49c2019-10-09 12:37:34 +01001640 // Define supported types
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001641 std::array<DataType, 3> supportedTypes =
Kevin May09ca49c2019-10-09 12:37:34 +01001642 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001643 DataType::BFloat16,
Kevin May09ca49c2019-10-09 12:37:34 +01001644 DataType::Float32,
1645 DataType::Float16
1646 };
1647
1648 bool supported = true;
1649
1650 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1651 "Reference Instance Normalization: input type not supported.");
1652
1653 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1654 "Reference Instance Normalization: output type not supported.");
1655
1656 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1657 "Reference Instance Normalization: input and output types mismatched.");
1658
1659 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1660 "Reference Instance Normalization: input and output shapes have different "
1661 "num total elements.");
1662
1663 return supported;
1664}
1665
arovir011c7c81b2018-10-08 11:34:28 +01001666bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1667 const TensorInfo& output,
1668 const L2NormalizationDescriptor& descriptor,
1669 Optional<std::string&> reasonIfUnsupported) const
1670{
Jan Eilers8eb25602020-03-09 12:13:48 +00001671 IgnoreUnused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001672 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01001673 std::array<DataType, 6> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001674 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001675 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001676 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001677 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001678 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001679 DataType::QAsymmU8,
1680 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001681 };
1682
1683 bool supported = true;
1684
1685 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1686 "Reference L2normalization: input type not supported.");
1687
1688 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1689 "Reference L2normalization: output type not supported.");
1690
1691 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1692 "Reference L2normalization: input and output types mismatched.");
1693
1694 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1695 "Reference L2normalization: input and output shapes have different "
1696 "num total elements.");
1697
1698 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001699}
1700
James Conroyaba90cd2020-11-06 16:28:18 +00001701bool RefLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0,
1702 const TensorInfo& input1,
1703 const TensorInfo& output,
1704 const LogicalBinaryDescriptor& descriptor,
1705 Optional<std::string&> reasonIfUnsupported) const
1706{
1707 IgnoreUnused(descriptor);
1708
1709 std::array<DataType, 1> supportedTypes =
1710 {
1711 DataType::Boolean
1712 };
1713
1714 bool supported = true;
1715 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1716 "Reference LogicalBinary: input 0 type not supported");
1717 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1718 "Reference LogicalBinary: input 1 type not supported");
1719
1720 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1721 "Reference LogicalBinary: input and output types do not match");
1722
1723 return supported;
1724}
1725
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001726bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1727 const TensorInfo& output,
1728 const LogSoftmaxDescriptor& descriptor,
1729 Optional<std::string&> reasonIfUnsupported) const
1730{
Jan Eilers8eb25602020-03-09 12:13:48 +00001731 IgnoreUnused(descriptor);
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001732
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001733 std::array<DataType, 3> supportedTypes =
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001734 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001735 DataType::BFloat16,
1736 DataType::Float32,
1737 DataType::Float16
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001738 };
1739
1740 bool supported = true;
1741 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1742 "Reference LogSoftmax: input type not supported");
1743
1744 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1745 "Reference LogSoftmax: output type not supported");
1746
1747 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1748 "Reference LogSoftmax: input and output types do not match");
1749
1750 return supported;
1751}
1752
arovir011c7c81b2018-10-08 11:34:28 +01001753bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1754 const TensorInfo& outputStateIn,
1755 const TensorInfo& cellStateIn,
1756 const TensorInfo& scratchBuffer,
1757 const TensorInfo& outputStateOut,
1758 const TensorInfo& cellStateOut,
1759 const TensorInfo& output,
1760 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001761 const LstmInputParamsInfo& paramsInfo,
1762 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +01001763{
Jan Eilers8eb25602020-03-09 12:13:48 +00001764 IgnoreUnused(descriptor);
1765 IgnoreUnused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001766
1767 bool supported = true;
1768
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001769 std::array<DataType,3> supportedTypes = {
1770 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001771 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001772 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001773 };
1774
Jan Eilersd01a83c2019-07-03 18:20:40 +01001775 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001776 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1777 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001778 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1779 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001780 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1781 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001782 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1783 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001784 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1785 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001786 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1787 "Reference Lstm: input and cellStateOut types are mismatched");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01001788
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001789 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1790 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001791 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001792 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001793 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001794 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001795 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001796 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001797 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001798 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001799 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001800 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001801 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001802 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001803 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001804 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001805 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001806 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001807 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001808 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001809 "Reference Lstm: input and OutputGateBias types are mismatched");
1810 if (!descriptor.m_CifgEnabled)
1811 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001812 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001813 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001814 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001815 reasonIfUnsupported,
1816 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001817 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001818 "Reference Lstm: input and InputGateBias types are mismatched");
1819 if (descriptor.m_PeepholeEnabled)
1820 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001821 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001822 reasonIfUnsupported,
1823 "Reference Lstm: input and CellToInputWeights types are mismatched");
1824 }
1825 }
1826 if (descriptor.m_PeepholeEnabled)
1827 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001828 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001829 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001830 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001831 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1832 }
1833 if (descriptor.m_ProjectionEnabled)
1834 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001835 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001836 "Reference Lstm: input and mProjectionWeights types are mismatched");
1837 if (paramsInfo.m_ProjectionBias != nullptr)
1838 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001839 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001840 "Reference Lstm: input and ProjectionBias types are mismatched");
1841 }
1842 }
1843 if (descriptor.m_LayerNormEnabled)
1844 {
1845 if (!descriptor.m_CifgEnabled)
1846 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001847 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001848 reasonIfUnsupported,
1849 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1850 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001851 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001852 reasonIfUnsupported,
1853 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001854 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001855 reasonIfUnsupported,
1856 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001857 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001858 reasonIfUnsupported,
1859 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1860 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001861
1862 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001863}
1864
saoste012df12b32018-11-28 16:57:20 +00001865bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1866 const TensorInfo& input1,
1867 const TensorInfo& output,
1868 Optional<std::string&> reasonIfUnsupported) const
1869{
Sadik Armagan2999a022019-04-09 14:20:12 +01001870 bool supported = true;
1871
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001872 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001873 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001874 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001875 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001876 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001877 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001878 DataType::QSymmS16,
1879 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001880 };
1881
1882 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1883 "Reference maximum: input 0 is not a supported type.");
1884
1885 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1886 "Reference maximum: input 1 is not a supported type.");
1887
1888 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1889 "Reference maximum: output is not a supported type.");
1890
1891 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1892 "Reference maximum: input 0 and Input 1 types are mismatched");
1893
1894 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1895 "Reference maximum: input and output types are mismatched");
1896
1897 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1898 "Reference maximum: shapes are not suitable for implicit broadcast.");
1899
1900 return supported;
saoste012df12b32018-11-28 16:57:20 +00001901}
1902
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001903bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1904 const TensorInfo& output,
1905 const MeanDescriptor& descriptor,
1906 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001907{
James Conroy4d1ff582019-06-10 17:06:39 +01001908 bool supported = true;
1909 std::string meanLayerStr = "Mean";
1910 std::string outputTensorStr = "output";
1911
Sadik Armagan303980c2020-04-17 12:45:14 +01001912 std::array<DataType,6> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001913 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001914 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01001915 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001916 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001917 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001918 DataType::QAsymmU8,
1919 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01001920 };
1921
1922 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1923 "Reference Mean: input type not supported.");
1924
1925 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1926 "Reference Mean: input and output types are mismatched");
1927
1928 if (descriptor.m_KeepDims)
1929 {
1930 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1931 reasonIfUnsupported,
1932 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1933 output.GetNumDimensions(),
1934 meanLayerStr, outputTensorStr).data());
1935 }
1936 else if (descriptor.m_Axis.empty())
1937 {
1938 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1939 reasonIfUnsupported,
1940 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1941 meanLayerStr, outputTensorStr).data());
1942 }
1943 else
1944 {
Matthew Sloyan171214c2020-09-09 09:07:37 +01001945 auto outputDim = input.GetNumDimensions() - armnn::numeric_cast<unsigned int>(descriptor.m_Axis.size());
James Conroy4d1ff582019-06-10 17:06:39 +01001946
1947 if (outputDim > 0)
1948 {
1949 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1950 reasonIfUnsupported,
1951 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1952 meanLayerStr, outputTensorStr).data());
1953 }
1954 else
1955 {
1956 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1957 reasonIfUnsupported,
1958 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1959 meanLayerStr, outputTensorStr).data());
1960 }
1961 }
1962
1963 return supported;
narpra0132b90462018-09-13 11:07:48 +01001964}
1965
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001966bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1967 const TensorInfo &output,
1968 Optional<std::string &> reasonIfUnsupported) const
1969{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001970 bool supported = true;
1971
Sadik Armagan303980c2020-04-17 12:45:14 +01001972 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001973 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001974 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001975 DataType::Float32,
1976 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001977 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001978 DataType::QAsymmU8,
1979 DataType::QSymmS16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001980 DataType::Boolean
1981 };
1982
1983 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1984 "Reference MemCopy: input type not supported");
1985
1986 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1987 "Reference MemCopy: output type not supported");
1988
1989 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1990 "Reference MemCopy: input and output types are mismatched");
1991
1992 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001993}
1994
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001995bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1996 const TensorInfo& input1,
1997 const TensorInfo& output,
1998 Optional<std::string&> reasonIfUnsupported) const
1999{
Sadik Armagan2999a022019-04-09 14:20:12 +01002000 bool supported = true;
2001
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002002 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002003 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01002004 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002005 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002006 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002007 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002008 DataType::QSymmS16,
2009 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002010 };
2011
2012 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2013 "Reference minimum: input 0 is not a supported type.");
2014
2015 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2016 "Reference minimum: input 1 is not a supported type.");
2017
2018 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2019 "Reference minimum: output is not a supported type.");
2020
2021 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2022 "Reference minimum: input 0 and Input 1 types are mismatched");
2023
2024 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2025 "Reference minimum: input and output types are mismatched");
2026
2027 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2028 "Reference minimum: shapes are not suitable for implicit broadcast.");
2029
2030 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00002031}
2032
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002033bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
2034 const TensorInfo& input1,
2035 const TensorInfo& output,
2036 Optional<std::string&> reasonIfUnsupported) const
2037{
Sadik Armagan2999a022019-04-09 14:20:12 +01002038 bool supported = true;
2039
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002040 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002041 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01002042 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002043 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002044 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002045 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002046 DataType::QSymmS16,
2047 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002048 };
2049
2050 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2051 "Reference multiplication: input 0 is not a supported type.");
2052
2053 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2054 "Reference multiplication: input 1 is not a supported type.");
2055
2056 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2057 "Reference multiplication: output is not a supported type.");
2058
2059 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2060 "Reference multiplication: input 0 and Input 1 types are mismatched");
2061
2062 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2063 "Reference multiplication: input and output types are mismatched");
2064
2065 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2066 "Reference multiplication: shapes are not suitable for implicit broadcast.");
2067
2068 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002069}
2070
2071bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
2072 const TensorInfo& output,
2073 const NormalizationDescriptor& descriptor,
2074 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01002075{
Jan Eilers8eb25602020-03-09 12:13:48 +00002076 IgnoreUnused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002077
2078 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01002079 std::array<DataType, 6> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002080 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002081 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002082 DataType::Float16,
2083 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002084 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002085 DataType::QAsymmU8,
2086 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002087 };
2088
2089 bool supported = true;
2090
2091 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2092 "Reference normalization: input type not supported.");
2093
2094 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2095 "Reference normalization: output type not supported.");
2096
2097 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2098 "Reference normalization: input and output shapes have different "
2099 "num total elements.");
2100
2101 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002102}
2103
Derek Lamberti901ea112019-12-10 22:07:09 +00002104bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
2105 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002106{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01002107 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002108}
2109
2110bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
2111 const TensorInfo& output,
2112 const PadDescriptor& descriptor,
2113 Optional<std::string&> reasonIfUnsupported) const
2114{
Jan Eilers8eb25602020-03-09 12:13:48 +00002115 IgnoreUnused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002116 bool supported = true;
2117
2118 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002119 std::array<DataType,6> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002120 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002121 DataType::BFloat16,
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002122 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002123 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002124 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002125 DataType::QAsymmU8,
2126 DataType::QSymmS16
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002127 };
2128
2129 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2130 "Reference pad: input is not a supported type.");
2131
2132 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2133 "Reference pad: output is not a supported type.");
2134
2135 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2136 "Reference pad: input and output types are mismatched.");
2137
2138 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01002139}
2140
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002141bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
2142 const TensorInfo& output,
2143 const PermuteDescriptor& descriptor,
2144 Optional<std::string&> reasonIfUnsupported) const
2145{
Jan Eilers8eb25602020-03-09 12:13:48 +00002146 IgnoreUnused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002147 bool supported = true;
2148
2149 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002150 std::array<DataType, 6> supportedTypes =
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002151 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002152 DataType::BFloat16,
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002153 DataType::Float32,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002154 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002155 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002156 DataType::QAsymmU8,
2157 DataType::QSymmS16
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002158 };
2159
2160 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2161 "Reference permute: input is not a supported type.");
2162
2163 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2164 "Reference permute: output is not a supported type.");
2165
2166 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2167 "Reference permute: input and output types are mismatched.");
2168
2169 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002170}
2171
2172bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
2173 const TensorInfo& output,
2174 const Pooling2dDescriptor& descriptor,
2175 Optional<std::string&> reasonIfUnsupported) const
2176{
Jan Eilers8eb25602020-03-09 12:13:48 +00002177 IgnoreUnused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01002178 bool supported = true;
2179
2180 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002181 std::array<DataType,6> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01002182 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002183 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01002184 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002185 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002186 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002187 DataType::QAsymmU8,
2188 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01002189 };
2190
2191 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2192 "Reference poolind2d: input is not a supported type.");
2193
2194 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2195 "Reference poolind2d: output is not a supported type.");
2196
2197 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2198 "Reference poolind2d: input and output types are mismatched.");
2199
2200 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002201}
2202
Tamás Nyíri7b885b32021-10-26 14:47:57 +01002203bool RefLayerSupport::IsPooling3dSupported(const TensorInfo& input,
2204 const TensorInfo& output,
2205 const Pooling3dDescriptor& descriptor,
2206 Optional<std::string&> reasonIfUnsupported) const
2207{
2208 IgnoreUnused(descriptor);
2209 bool supported = true;
2210
2211 // Define supported output and inputs types.
2212 std::array<DataType,6> supportedTypes =
2213 {
2214 DataType::BFloat16,
2215 DataType::Float32,
2216 DataType::Float16,
2217 DataType::QAsymmS8,
2218 DataType::QAsymmU8,
2219 DataType::QSymmS16
2220 };
2221
2222 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2223 "Reference poolind3d: input is not a supported type.");
2224
2225 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2226 "Reference poolind3d: output is not a supported type.");
2227
2228 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2229 "Reference poolind3d: input and output types are mismatched.");
2230
2231 return supported;
2232}
2233
2234
James Conroy4f1f8992020-04-29 20:01:10 +01002235bool RefLayerSupport::IsQLstmSupported(const TensorInfo& input,
2236 const TensorInfo& previousOutputIn,
2237 const TensorInfo& previousCellStateIn,
2238 const TensorInfo& outputStateOut,
2239 const TensorInfo& cellStateOut,
2240 const TensorInfo& output,
2241 const QLstmDescriptor& descriptor,
2242 const LstmInputParamsInfo& paramsInfo,
2243 Optional<std::string&> reasonIfUnsupported) const
2244{
2245 IgnoreUnused(input);
2246 IgnoreUnused(previousOutputIn);
2247 IgnoreUnused(previousCellStateIn);
2248 IgnoreUnused(outputStateOut);
2249 IgnoreUnused(cellStateOut);
2250 IgnoreUnused(output);
2251 IgnoreUnused(descriptor);
2252 IgnoreUnused(paramsInfo);
2253
2254 IgnoreUnused(reasonIfUnsupported);
2255
2256 return true;
2257}
2258
Derek Lamberti5f400d62019-03-25 15:41:58 +00002259bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
2260 const TensorInfo& output,
2261 Optional<std::string&> reasonIfUnsupported) const
2262{
2263 bool supported = true;
2264
Finn Williamsfd271062019-12-04 14:27:27 +00002265 // Define supported input types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002266 std::array<DataType,7> supportedInputTypes = {
2267 DataType::BFloat16,
Keith Davis5e51cd82020-01-29 16:52:59 +00002268 DataType::Float32,
Keith Davis3d8bc972020-02-04 09:31:47 +00002269 DataType::Float16,
Ryan OShea9add1202020-02-07 10:06:33 +00002270 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002271 DataType::QAsymmU8,
2272 DataType::QSymmS8,
2273 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002274 };
2275
2276 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
2277 "Reference quantize: input type not supported.");
2278
2279 // Define supported output types.
Ryan OShea9add1202020-02-07 10:06:33 +00002280 std::array<DataType,4> supportedOutputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00002281 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002282 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00002283 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002284 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002285 };
2286 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2287 "Reference quantize: output type not supported.");
2288
2289 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2290 "Reference quantize: input and output shapes have different num total elements.");
2291
2292 return supported;
2293}
2294
Finn Williams2605b232020-06-10 15:53:46 +01002295bool RefLayerSupport::IsRankSupported(const TensorInfo& input,
2296 const TensorInfo& output,
2297 Optional<std::string&> reasonIfUnsupported) const
2298{
2299 IgnoreUnused(input);
2300 // Define supported output types.
2301 std::array<DataType,1> supportedOutputTypes =
2302 {
2303 DataType::Signed32,
2304 };
2305
2306 return CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2307 "Reference rank: input type not supported.");
2308}
2309
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00002310bool RefLayerSupport::IsReduceSupported(const TensorInfo& input,
2311 const TensorInfo& output,
2312 const ReduceDescriptor& descriptor,
2313 Optional<std::string&> reasonIfUnsupported) const
2314{
2315 IgnoreUnused(descriptor);
2316 bool supported = true;
2317 std::array<DataType,7> supportedTypes =
2318 {
2319 DataType::BFloat16,
2320 DataType::Float32,
2321 DataType::Float16,
2322 DataType::QAsymmS8,
2323 DataType::QAsymmU8,
2324 DataType::QSymmS16,
2325 DataType::Signed32
2326 };
2327
2328 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2329 "Reference Reduce: input type not supported");
2330
2331 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2332 "Reference Reduce: output type not supported");
2333
2334 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2335 "Reference Reduce: input and output types not matching");
2336
2337 return supported;
2338}
2339
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002340bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Kevin Maya023c402019-12-12 17:28:05 +00002341 const TensorInfo& output,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00002342 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002343 Optional<std::string&> reasonIfUnsupported) const
2344{
Jan Eilers8eb25602020-03-09 12:13:48 +00002345 IgnoreUnused(output);
2346 IgnoreUnused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01002347 // Define supported output types.
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002348 std::array<DataType,8> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01002349 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002350 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01002351 DataType::Float32,
2352 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01002353 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002354 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002355 DataType::QAsymmU8,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002356 DataType::QSymmS16,
2357 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01002358 };
Keith Davis0c2eeac2020-02-11 16:51:50 +00002359
Nina Drozd2f2778f2019-05-27 10:37:05 +01002360 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
2361 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002362}
2363
Teresa Charlin970f43b2019-07-01 13:51:07 +01002364bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
2365 const TensorInfo& output,
2366 const ResizeDescriptor& descriptor,
2367 Optional<std::string&> reasonIfUnsupported) const
2368{
Jan Eilers8eb25602020-03-09 12:13:48 +00002369 IgnoreUnused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01002370 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002371 std::array<DataType,6> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01002372 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002373 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01002374 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002375 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002376 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002377 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002378 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01002379 };
2380
2381 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2382 "Reference Resize: input type not supported");
2383
2384 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2385 "Reference Resize: output type not supported");
2386
2387 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2388 "Reference Resize: input and output types not matching");
2389
2390 return supported;
2391}
2392
Keith Davis3ae3f972021-05-21 16:33:48 +01002393bool RefLayerSupport::IsShapeSupported(const TensorInfo& input,
2394 const TensorInfo& output,
2395 Optional<std::string&> reasonIfUnsupported) const
2396{
2397 IgnoreUnused(input);
2398 bool supported = true;
2399
2400 std::array<DataType, 1> supportedTypes =
2401 {
2402 DataType::Signed32
2403 };
2404
2405 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2406 "Reference Shape: output type not supported");
2407
2408 return supported;
2409}
2410
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002411bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
2412 const TensorInfo& output,
2413 const SliceDescriptor& descriptor,
2414 Optional<std::string&> reasonIfUnsupported) const
2415{
Jan Eilers8eb25602020-03-09 12:13:48 +00002416 IgnoreUnused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002417 bool supported = true;
2418
Sadik Armagan303980c2020-04-17 12:45:14 +01002419 std::array<DataType, 5> supportedTypes =
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002420 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002421 DataType::BFloat16,
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002422 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002423 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002424 DataType::QAsymmU8,
2425 DataType::QSymmS16
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002426 };
2427
2428 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2429 "Reference Slice: input type not supported");
2430
2431 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2432 "Reference Slice: output type not supported");
2433
2434 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2435 "Reference Slice: input and output types are mismatched");
2436
2437 return supported;
2438}
2439
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002440bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
2441 const TensorInfo& output,
2442 const SoftmaxDescriptor& descriptor,
2443 Optional<std::string&> reasonIfUnsupported) const
2444{
Jan Eilers8eb25602020-03-09 12:13:48 +00002445 IgnoreUnused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01002446 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002447 std::array<DataType,7> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01002448 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002449 DataType::BFloat16,
2450 DataType::Float32,
2451 DataType::Float16,
2452 DataType::QSymmS8,
2453 DataType::QAsymmS8,
2454 DataType::QAsymmU8,
2455 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +01002456 };
2457
2458 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002459 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002460
2461 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002462 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002463
2464 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002465 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002466
2467 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002468}
2469
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002470bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
2471 const TensorInfo& output,
2472 const SpaceToBatchNdDescriptor& descriptor,
2473 Optional<std::string&> reasonIfUnsupported) const
2474{
Jan Eilers8eb25602020-03-09 12:13:48 +00002475 IgnoreUnused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01002476 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002477 std::array<DataType,6> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01002478 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002479 DataType::BFloat16,
2480 DataType::Float32,
2481 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002482 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002483 DataType::QAsymmU8,
2484 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01002485 };
2486
2487 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2488 "Reference SpaceToBatchNd: input type not supported");
2489
2490 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2491 "Reference SpaceToBatchNd: output type not supported");
2492
2493 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2494 "Reference SpaceToBatchNd: input and output types are mismatched");
2495
2496 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002497}
2498
Keith Davisa57eccb2019-06-14 17:33:22 +01002499bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01002500 const TensorInfo& output,
2501 const SpaceToDepthDescriptor& descriptor,
2502 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01002503{
2504
Jan Eilers8eb25602020-03-09 12:13:48 +00002505 IgnoreUnused(descriptor);
Keith Davisa57eccb2019-06-14 17:33:22 +01002506 bool supported = true;
2507
Sadik Armagan303980c2020-04-17 12:45:14 +01002508 std::array<DataType,6> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01002509 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002510 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01002511 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002512 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002513 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002514 DataType::QAsymmU8,
2515 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01002516 };
2517
2518 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2519 "Reference SpaceToDepth: input type not supported");
2520
2521 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2522 "Reference SpaceToDepth: output type not supported");
2523
2524 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2525 "Reference SpaceToDepth: input and output types are mismatched");
2526
2527 return supported;
2528}
2529
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002530bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002531 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
2532 const ViewsDescriptor& descriptor,
2533 Optional<std::string&> reasonIfUnsupported) const
2534{
Jan Eilers8eb25602020-03-09 12:13:48 +00002535 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002536 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002537 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002538 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002539 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002540 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002541 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002542 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002543 DataType::QAsymmU8,
2544 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002545 };
2546
2547 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2548 "Reference splitter: output type not supported");
Derek Lambertieac4adb2020-08-25 13:05:59 +01002549 for (const TensorInfo& output : outputs)
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002550 {
2551 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2552 "Reference splitter: input type not supported");
2553
2554 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2555 "Reference splitter: input and output types mismatched.");
2556 }
2557
2558 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002559}
2560
Matthew Jackson81e601c2019-07-11 12:07:09 +01002561bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
2562 const TensorInfo& output,
2563 const StackDescriptor& descriptor,
2564 Optional<std::string&> reasonIfUnsupported) const
2565{
Jan Eilers8eb25602020-03-09 12:13:48 +00002566 IgnoreUnused(descriptor);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002567
2568 bool supported = true;
Sadik Armagan529195f2022-01-14 12:56:35 +00002569 std::array<DataType,7> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01002570 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002571 DataType::BFloat16,
Matthew Jackson81e601c2019-07-11 12:07:09 +01002572 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01002573 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002574 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002575 DataType::QAsymmU8,
Sadik Armagan529195f2022-01-14 12:56:35 +00002576 DataType::QSymmS16,
2577 DataType::Signed32
Matthew Jackson81e601c2019-07-11 12:07:09 +01002578 };
2579
2580 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2581 "Reference stack: output type not supported");
2582 for (const TensorInfo* input : inputs)
2583 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01002584 ARMNN_ASSERT(input != nullptr);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002585 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
2586 "Reference stack: input type not supported");
2587
2588 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
2589 "Reference stack: input and output types mismatched.");
2590 }
2591
2592 return supported;
2593}
2594
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002595bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
2596 const TensorInfo& output,
2597 const StridedSliceDescriptor& descriptor,
2598 Optional<std::string&> reasonIfUnsupported) const
2599{
Jan Eilers8eb25602020-03-09 12:13:48 +00002600 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002601 bool supported = true;
2602
Sadik Armagan303980c2020-04-17 12:45:14 +01002603 std::array<DataType,5> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002604 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002605 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002606 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002607 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002608 DataType::QAsymmU8,
2609 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002610 };
2611
2612 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2613 "Reference StridedSlice: input type not supported");
2614
2615 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2616 "Reference StridedSlice: output type not supported");
2617
2618 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2619 "Reference StridedSlice: input and output types are mismatched");
2620
2621 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002622}
2623
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002624bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
2625 const TensorInfo& input1,
2626 const TensorInfo& output,
2627 Optional<std::string&> reasonIfUnsupported) const
2628{
Sadik Armagan2999a022019-04-09 14:20:12 +01002629 bool supported = true;
2630
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002631 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002632 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01002633 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002634 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002635 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002636 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002637 DataType::QSymmS16,
2638 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002639 };
2640
2641 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2642 "Reference subtraction: input 0 is not a supported type.");
2643
2644 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2645 "Reference subtraction: input 1 is not a supported type.");
2646
2647 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2648 "Reference subtraction: output is not a supported type.");
2649
2650 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2651 "Reference subtraction: input 0 and Input 1 types are mismatched");
2652
2653 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2654 "Reference subtraction: input and output types are mismatched");
2655
2656 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2657 "Reference subtraction: shapes are not suitable for implicit broadcast.");
2658
2659 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002660}
2661
Matteo Martincighab9e5252019-06-13 17:27:46 +01002662bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
2663 const TensorInfo& alpha,
2664 const TensorInfo& output,
2665 Optional<std::string&> reasonIfUnsupported) const
2666{
2667 bool supported = true;
2668
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002669 std::array<DataType, 6> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01002670 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002671 DataType::BFloat16,
Matteo Martincighab9e5252019-06-13 17:27:46 +01002672 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002673 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002674 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002675 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002676 DataType::QSymmS16
Matteo Martincighab9e5252019-06-13 17:27:46 +01002677 };
2678
2679 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2680 "PReLU: input is not a supported type.");
2681
2682 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
2683 "PReLU: alpha is not a supported type.");
2684
2685 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2686 "PReLU: output is not a supported type.");
2687
2688 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
2689 "PReLU: input, alpha and output types are mismatched");
2690
2691 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
2692 "PReLU: shapes are not suitable for implicit broadcast");
2693
2694 return supported;
2695}
2696
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002697bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
2698 const TensorInfo& output,
2699 const TransposeConvolution2dDescriptor& descriptor,
2700 const TensorInfo& weights,
2701 const Optional<TensorInfo>& biases,
2702 Optional<std::string&> reasonIfUnsupported) const
2703{
Jan Eilers8eb25602020-03-09 12:13:48 +00002704 IgnoreUnused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002705 bool supported = true;
2706
Sadik Armagan303980c2020-04-17 12:45:14 +01002707 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002708 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002709 DataType::BFloat16,
2710 DataType::Float32,
2711 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002712 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002713 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002714 DataType::QSymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002715 DataType::QSymmS16
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002716 };
2717
2718 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2719 "Reference TransposeConvolution2d: input is not a supported type.");
2720
2721 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2722 "Reference TransposeConvolution2d: output is not a supported type.");
2723
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002724 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2725 "Reference TransposeConvolution2d: input and output types mismatched.");
2726
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002727
2728 const DataType inputType = input.GetDataType();
Sadik Armagan303980c2020-04-17 12:45:14 +01002729 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002730 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01002731 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002732 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002733 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002734 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01002735 DataType::QSymmS8
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002736 };
2737
2738 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
2739 "Reference TransposeConvolution2d: weights type not supported for "
2740 "quantized input.");
2741 }
2742 else
2743 {
2744 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
2745 "Reference TransposeConvolution2d: weights is not a supported type.");
2746
2747 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
2748 "Reference TransposeConvolution2d: input and weights types mismatched.");
2749 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002750
2751 if (biases.has_value())
2752 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002753 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01002754 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002755 DataType::BFloat16,
2756 DataType::Float32,
2757 DataType::Float16,
2758 DataType::Signed32
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002759 };
2760 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
2761 "Reference TransposeConvolution2d: biases is not a supported type.");
2762 }
2763
2764 return supported;
2765}
2766
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002767bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
2768 const TensorInfo& output,
2769 const TransposeDescriptor& descriptor,
2770 Optional<std::string&> reasonIfUnsupported) const
2771{
Jan Eilers8eb25602020-03-09 12:13:48 +00002772 IgnoreUnused(descriptor);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002773 bool supported = true;
2774
2775 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002776 std::array<DataType, 6> supportedTypes =
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002777 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002778 DataType::BFloat16,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002779 DataType::Float32,
2780 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002781 DataType::QAsymmS8,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002782 DataType::QAsymmU8,
2783 DataType::QSymmS16
2784 };
2785
2786 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2787 "Reference transpose: input is not a supported type.");
2788
2789 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2790 "Reference transpose: output is not a supported type.");
2791
2792 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2793 "Reference transpose: input and output types are mismatched.");
2794
2795 return supported;
2796}
2797
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002798bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
2799 const TensorInfo& input,
2800 const TensorInfo& outputStateIn,
2801 const TensorInfo& cellStateIn,
Mike Kelly12994962022-04-21 11:57:09 +01002802 const TensorInfo& outputStateOut,
2803 const TensorInfo& cellStateOut,
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002804 const TensorInfo& output,
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002805 const UnidirectionalSequenceLstmDescriptor& descriptor,
2806 const LstmInputParamsInfo& paramsInfo,
2807 Optional<std::string&> reasonIfUnsupported) const
2808{
2809 IgnoreUnused(descriptor);
2810 IgnoreUnused(paramsInfo);
2811 IgnoreUnused(outputStateIn);
2812 IgnoreUnused(cellStateIn);
Mike Kelly12994962022-04-21 11:57:09 +01002813 IgnoreUnused(outputStateOut);
2814 IgnoreUnused(cellStateOut);
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002815 bool supported = true;
2816
Mike Kelly12994962022-04-21 11:57:09 +01002817 std::array<DataType, 2> supportedTypes =
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002818 {
Mike Kelly12994962022-04-21 11:57:09 +01002819 DataType::Float32,
2820 DataType::QAsymmS8
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002821 };
2822
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002823 std::array<DataType, 2> supportedWeightTypes =
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002824 {
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002825 DataType::Float32,
2826 DataType::QAsymmS8
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002827 };
2828
Mike Kelly12994962022-04-21 11:57:09 +01002829 std::array<DataType, 3> supportedBiasTypes =
2830 {
2831 DataType::Float32,
2832 DataType::QAsymmS8,
2833 DataType::Signed32
2834 };
2835
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002836 // check inputs and outputs
2837 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2838 "Reference UnidirectionalSequenceLstm: input is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002839 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2840 "Reference UnidirectionalSequenceLstm: output is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002841
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002842 // check layer parameters
2843 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes),
2844 reasonIfUnsupported,
2845 "Reference UnidirectionalSequenceLstm: InputToForgetWeights "
2846 "is not a supported type.");
2847 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToCellWeights(), supportedWeightTypes),
2848 reasonIfUnsupported,
2849 "Reference UnidirectionalSequenceLstm: InputToCellWeights is not a supported type.");
2850 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToOutputWeights(), supportedWeightTypes),
2851 reasonIfUnsupported,
2852 "Reference UnidirectionalSequenceLstm: InputToOutputWeights "
2853 "is not a supported type.");
2854 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToForgetWeights(), supportedWeightTypes),
2855 reasonIfUnsupported,
2856 "Reference UnidirectionalSequenceLstm: RecurrentToForgetWeights "
2857 "is not a supported type.");
2858 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToCellWeights(), supportedWeightTypes),
2859 reasonIfUnsupported,
2860 "Reference UnidirectionalSequenceLstm: RecurrentToCellWeights "
2861 "is not a supported type.");
2862 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToOutputWeights(), supportedWeightTypes),
2863 reasonIfUnsupported,
2864 "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights "
2865 "is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002866
2867 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetGateBias(), supportedBiasTypes), reasonIfUnsupported,
2868 "Reference UnidirectionalSequenceLstm: ForgetGateBias is not a supported type.");
2869 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellBias(), supportedBiasTypes), reasonIfUnsupported,
2870 "Reference UnidirectionalSequenceLstm: CellBias is not a supported type.");
2871 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputGateBias(), supportedBiasTypes), reasonIfUnsupported,
2872 "Reference UnidirectionalSequenceLstm: OutputGateBias is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002873 if (!descriptor.m_CifgEnabled)
2874 {
2875 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes),
2876 reasonIfUnsupported,
2877 "Reference UnidirectionalSequenceLstm: InputToInputWeights "
2878 "is not a supported type.");
2879 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToInputWeights(), supportedWeightTypes),
2880 reasonIfUnsupported,
2881 "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights "
2882 "is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002883 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputGateBias(), supportedBiasTypes), reasonIfUnsupported,
2884 "Reference UnidirectionalSequenceLstm: InputGateBias is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002885 if (descriptor.m_PeepholeEnabled)
2886 {
2887 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes),
2888 reasonIfUnsupported,
2889 "Reference UnidirectionalSequenceLstm: CellToInputWeights "
2890 "is not a supported type.");
2891 }
2892 }
2893 if (descriptor.m_PeepholeEnabled)
2894 {
2895 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToForgetWeights(), supportedWeightTypes),
2896 reasonIfUnsupported,
2897 "Reference UnidirectionalSequenceLstm: CellToForgetWeights "
2898 "is not a supported type.");
2899 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToOutputWeights(), supportedWeightTypes),
2900 reasonIfUnsupported,
2901 "Reference UnidirectionalSequenceLstm: CellToOutputWeights "
2902 "is not a supported type.");
2903 }
2904 if (descriptor.m_ProjectionEnabled)
2905 {
2906 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetProjectionWeights(), supportedWeightTypes),
2907 reasonIfUnsupported,
2908 "Reference UnidirectionalSequenceLstm: ProjectionWeights "
2909 "is not a supported type.");
2910 if (paramsInfo.m_ProjectionBias != nullptr)
2911 {
2912 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
2913 "Reference UnidirectionalSequenceLstm: input and ProjectionBias types "
2914 "are mismatched");
2915 }
2916 }
2917 if (descriptor.m_LayerNormEnabled)
2918 {
2919 if (!descriptor.m_CifgEnabled)
2920 {
2921 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputLayerNormWeights(), supportedWeightTypes),
2922 reasonIfUnsupported,
2923 "Reference UnidirectionalSequenceLstm: InputLayerNormWeights "
2924 "is not a supported type.");
2925 }
2926 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetLayerNormWeights(), supportedWeightTypes),
2927 reasonIfUnsupported,
2928 "Reference UnidirectionalSequenceLstm: ForgetLayerNormWeights "
2929 "is not a supported type.");
2930 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellLayerNormWeights(), supportedWeightTypes),
2931 reasonIfUnsupported,
2932 "Reference UnidirectionalSequenceLstm: CellLayerNormWeights "
2933 "is not a supported type.");
2934 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputLayerNormWeights(), supportedWeightTypes),
2935 reasonIfUnsupported,
2936 "Reference UnidirectionalSequenceLstm: OutputLayerNormWeights "
2937 "is not a supported type.");
2938 }
2939
2940 return supported;
2941}
2942
arovir011c7c81b2018-10-08 11:34:28 +01002943} // namespace armnn