blob: 0b1b9c7824649992326406887fcc8a0427e13f85 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
Mike Kelly3ec30772023-03-08 13:47:17 +00002// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01007
Keith Davis0c2eeac2020-02-11 16:51:50 +00008#include <armnn/TypesUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000010#include <armnn/utility/IgnoreUnused.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010011#include <armnn/utility/NumericCast.hpp>
Cathal Corbett34b429c2021-12-24 12:24:40 +000012#include <armnn/utility/PolymorphicDowncast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
Matteo Martincighe011d202019-11-28 11:35:47 +000014#include <LayerSupportCommon.hpp>
Derek Lambertif674aa02019-08-01 15:56:25 +010015#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000016
Derek Lamberti50db4e82019-03-13 14:16:15 +000017#include <vector>
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <array>
19
telsoa014fcda012018-03-09 14:13:49 +000020namespace armnn
21{
22
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010023namespace
24{
25
26template<typename Float32Func, typename Uint8Func, typename ... Params>
27bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
28 DataType dataType,
29 Float32Func floatFuncPtr,
30 Uint8Func uint8FuncPtr,
31 Params&&... params)
32{
33 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
34 dataType,
35 &FalseFunc<Params...>,
36 floatFuncPtr,
37 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000038 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000039 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010040 std::forward<Params>(params)...);
41}
42
43} // anonymous namespace
44
James Conroy4d1ff582019-06-10 17:06:39 +010045namespace
46{
47
48std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
49 unsigned int actual,
50 std::string& layerStr,
51 std::string& tensorName)
52{
53 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
54 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
55
56 return errorMsg;
57}
58
59} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000060
Cathal Corbett34b429c2021-12-24 12:24:40 +000061bool RefLayerSupport::IsLayerSupported(const LayerType& type,
62 const std::vector<TensorInfo>& infos,
63 const BaseDescriptor& descriptor,
64 const Optional<LstmInputParamsInfo>& lstmParamsInfo,
65 const Optional<QuantizedLstmInputParamsInfo>& quantizedLstmInputParamsInfo,
66 Optional<std::string&> reasonIfUnsupported) const
67{
68 switch (type)
69 {
70 case LayerType::Activation:
71 return IsActivationSupported(infos[0],
72 infos[1],
73 *(PolymorphicDowncast<const ActivationDescriptor*>(&descriptor)),
74 reasonIfUnsupported);
75 case LayerType::Addition:
76 return IsAdditionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
77 case LayerType::ArgMinMax:
78 return IsArgMinMaxSupported(infos[0],
79 infos[1],
80 *(PolymorphicDowncast<const ArgMinMaxDescriptor*>(&descriptor)),
81 reasonIfUnsupported);
Samuel Yap6b478092022-07-06 15:36:03 +010082 case LayerType::BatchMatMul:
83 return IsBatchMatMulSupported(infos[0],
84 infos[1],
85 infos[2],
86 *(PolymorphicDowncast<const BatchMatMulDescriptor*>(&descriptor)),
87 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +000088 case LayerType::BatchNormalization:
89 return IsBatchNormalizationSupported(infos[0],
90 infos[1],
91 infos[2],
92 infos[3],
93 infos[4],
94 infos[5],
95 *(PolymorphicDowncast<const BatchNormalizationDescriptor*>
96 (&descriptor)),
97 reasonIfUnsupported);
98 case LayerType::BatchToSpaceNd:
99 return IsBatchToSpaceNdSupported(infos[0],
100 infos[1],
101 *(PolymorphicDowncast<const BatchToSpaceNdDescriptor*>(&descriptor)),
102 reasonIfUnsupported);
103 case LayerType::Comparison:
104 return IsComparisonSupported(infos[0],
105 infos[1],
106 infos[2],
107 *(PolymorphicDowncast<const ComparisonDescriptor*>(&descriptor)),
108 reasonIfUnsupported);
109 case LayerType::Concat:
110 {
111 std::vector<const TensorInfo*> inputInfos;
112 for (uint32_t i = 0; i < (infos.size() - 1); i++)
113 {
114 inputInfos.push_back(&infos[i]);
115 }
116 return IsConcatSupported(inputInfos,
117 infos[infos.size() - 1],
118 *(PolymorphicDowncast<const OriginsDescriptor*>(&descriptor)),
119 reasonIfUnsupported);
120 }
121 case LayerType::Constant:
122 return IsConstantSupported(infos[0], reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000123 case LayerType::ConvertFp16ToFp32:
124 return IsConvertFp16ToFp32Supported(infos[0], infos[1], reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000125 case LayerType::ConvertFp32ToFp16:
126 return IsConvertFp32ToFp16Supported(infos[0], infos[1], reasonIfUnsupported);
127 case LayerType::Convolution2d:
128 {
129 if (infos.size() != 4)
130 {
131 throw InvalidArgumentException("Invalid number of Convolution2d TensorInfos. "
132 "TensorInfos should be of format: {input, output, weights, biases}.");
133 }
134
135 auto desc = *(PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor));
136 if (infos[3] == TensorInfo())
137 {
138 return IsConvolution2dSupported(infos[0],
139 infos[1],
140 desc,
141 infos[2],
142 EmptyOptional(),
143 reasonIfUnsupported);
144 }
145 else
146 {
147 return IsConvolution2dSupported(infos[0],
148 infos[1],
149 desc,
150 infos[2],
151 infos[3],
152 reasonIfUnsupported);
153 }
154 }
155 case LayerType::DepthToSpace:
156 return IsDepthToSpaceSupported(infos[0],
157 infos[1],
158 *(PolymorphicDowncast<const DepthToSpaceDescriptor*>(&descriptor)),
159 reasonIfUnsupported);
160 case LayerType::DepthwiseConvolution2d:
161 {
162 if (infos.size() != 4)
163 {
164 throw InvalidArgumentException("Invalid number of DepthwiseConvolution2d TensorInfos. "
165 "TensorInfos should be of format: {input, output, weights, biases}.");
166 }
167
168 auto desc = *(PolymorphicDowncast<const DepthwiseConvolution2dDescriptor*>(&descriptor));
169 if (infos[3] == TensorInfo())
170 {
171 return IsDepthwiseConvolutionSupported(infos[0],
172 infos[1],
173 desc,
174 infos[2],
175 EmptyOptional(),
176 reasonIfUnsupported);
177 }
178 else
179 {
180 return IsDepthwiseConvolutionSupported(infos[0],
181 infos[1],
182 desc,
183 infos[2],
184 infos[3],
185 reasonIfUnsupported);
186 }
187 }
188 case LayerType::Dequantize:
189 return IsDequantizeSupported(infos[0], infos[1], reasonIfUnsupported);
190 case LayerType::Division:
191 return IsDivisionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
Mike Kelly3ec30772023-03-08 13:47:17 +0000192 case LayerType::ElementwiseBinary:
193 {
194 std::array<DataType, 7> supportedTypes =
195 {
196 DataType::Float32,
197 DataType::Float16,
198 DataType::QAsymmS8,
199 DataType::QAsymmU8,
200 DataType::QSymmS16,
201 DataType::Signed32
202 };
203
204 bool supported = true;
205 supported &= CheckSupportRule(TypeAnyOf(infos[0], supportedTypes), reasonIfUnsupported,
206 "Reference elementwise unary: input type not supported");
207
208 supported &= CheckSupportRule(TypeAnyOf(infos[1], supportedTypes), reasonIfUnsupported,
209 "Reference elementwise unary: input type not supported");
210
211 supported &= CheckSupportRule(TypeAnyOf(infos[2], supportedTypes), reasonIfUnsupported,
212 "Reference elementwise unary: output type not supported");
213
214 supported &= CheckSupportRule(TypesAreEqual(infos[0], infos[1]), reasonIfUnsupported,
215 "Reference elementwise unary: input types not matching");
216
217 supported &= CheckSupportRule(TypesAreEqual(infos[0], infos[2]), reasonIfUnsupported,
218 "Reference elementwise unary: input and output types not matching");
219
220 return supported;
221 }
Cathal Corbett34b429c2021-12-24 12:24:40 +0000222 case LayerType::ElementwiseUnary:
223 return IsElementwiseUnarySupported(infos[0],
224 infos[1],
225 *(PolymorphicDowncast<const ElementwiseUnaryDescriptor*>(&descriptor)),
226 reasonIfUnsupported);
227 case LayerType::Fill:
228 return IsFillSupported(infos[0],
229 infos[1],
230 *(PolymorphicDowncast<const FillDescriptor*>(&descriptor)),
231 reasonIfUnsupported);
232 case LayerType::Floor:
233 return IsFloorSupported(infos[0], infos[1], reasonIfUnsupported);
234 case LayerType::FullyConnected:
235 return IsFullyConnectedSupported(infos[0],
236 infos[1],
237 infos[2],
238 infos[3],
239 *(PolymorphicDowncast<const FullyConnectedDescriptor*>(&descriptor)),
240 reasonIfUnsupported);
241 case LayerType::Gather:
242 return IsGatherSupported(infos[0],
243 infos[1],
244 infos[2],
245 *(PolymorphicDowncast<const GatherDescriptor*>(&descriptor)),
246 reasonIfUnsupported);
Teresa Charlinb2d3ec52022-04-12 22:07:09 +0100247 case LayerType::GatherNd:
248 return IsGatherNdSupported(infos[0],
249 infos[1],
250 infos[2],
251 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000252 case LayerType::Input:
253 return IsInputSupported(infos[0], reasonIfUnsupported);
254 case LayerType::InstanceNormalization:
255 return IsInstanceNormalizationSupported(infos[0],
256 infos[1],
257 *(PolymorphicDowncast<const InstanceNormalizationDescriptor*>
258 (&descriptor)),
259 reasonIfUnsupported);
260 case LayerType::L2Normalization:
261 return IsL2NormalizationSupported(infos[0],
262 infos[1],
263 *(PolymorphicDowncast<const L2NormalizationDescriptor*>(&descriptor)),
264 reasonIfUnsupported);
265 case LayerType::LogicalBinary:
266 return IsLogicalBinarySupported(infos[0],
267 infos[1],
268 infos[2],
269 *(PolymorphicDowncast<const LogicalBinaryDescriptor*>(&descriptor)),
270 reasonIfUnsupported);
271 case LayerType::LogSoftmax:
272 return IsLogSoftmaxSupported(infos[0],
273 infos[1],
274 *(PolymorphicDowncast<const LogSoftmaxDescriptor*>(&descriptor)),
275 reasonIfUnsupported);
276 case LayerType::Lstm:
277 return IsLstmSupported(infos[0],
278 infos[1],
279 infos[2],
280 infos[3],
281 infos[4],
282 infos[5],
283 infos[6],
284 *(PolymorphicDowncast<const LstmDescriptor*>(&descriptor)),
285 lstmParamsInfo.value(),
286 reasonIfUnsupported);
287 case LayerType::QLstm:
288 return IsQLstmSupported(infos[0],
289 infos[1],
290 infos[2],
291 infos[3],
292 infos[4],
293 infos[5],
294 *(PolymorphicDowncast<const QLstmDescriptor*>(&descriptor)),
295 lstmParamsInfo.value(),
296 reasonIfUnsupported);
297 case LayerType::Maximum:
298 return IsMaximumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
299 case LayerType::Mean:
300 return IsMeanSupported(infos[0],
301 infos[1],
302 *(PolymorphicDowncast<const MeanDescriptor*>(&descriptor)),
303 reasonIfUnsupported);
304 case LayerType::Minimum:
305 return IsMinimumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
306 case LayerType::Multiplication:
307 return IsMultiplicationSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
308 case LayerType::Normalization:
309 return IsNormalizationSupported(infos[0],
310 infos[1],
311 *(PolymorphicDowncast<const NormalizationDescriptor*>(&descriptor)),
312 reasonIfUnsupported);
313 case LayerType::Output:
314 return IsOutputSupported(infos[0], reasonIfUnsupported);
315 case LayerType::Pad:
316 return IsPadSupported(infos[0],
317 infos[1],
318 *(PolymorphicDowncast<const PadDescriptor*>(&descriptor)),
319 reasonIfUnsupported);
320 case LayerType::Permute:
321 return IsPermuteSupported(infos[0],
322 infos[1],
323 *(PolymorphicDowncast<const PermuteDescriptor*>(&descriptor)),
324 reasonIfUnsupported);
325 case LayerType::Pooling2d:
326 return IsPooling2dSupported(infos[0],
327 infos[1],
328 *(PolymorphicDowncast<const Pooling2dDescriptor*>(&descriptor)),
329 reasonIfUnsupported);
330 case LayerType::Prelu:
331 return IsPreluSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
332 case LayerType::Quantize:
333 return IsQuantizeSupported(infos[0], infos[1], reasonIfUnsupported);
334 case LayerType::Reshape:
335 return IsReshapeSupported(infos[0],
336 infos[1],
337 *(PolymorphicDowncast<const ReshapeDescriptor*>(&descriptor)),
338 reasonIfUnsupported);
339 case LayerType::Resize:
340 return IsResizeSupported(infos[0],
341 infos[1],
342 *(PolymorphicDowncast<const ResizeDescriptor*>(&descriptor)),
343 reasonIfUnsupported);
Tianle Cheng988354d2023-06-28 13:20:47 +0100344 case LayerType::ReverseV2:
345 return IsReverseV2Supported(infos[0],
346 infos[1],
Tracy Narinebb8d7592023-07-13 16:50:54 +0100347 infos[2],
Tianle Cheng988354d2023-06-28 13:20:47 +0100348 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000349 case LayerType::Reduce:
350 return IsReduceSupported(infos[0],
351 infos[1],
352 *(PolymorphicDowncast<const ReduceDescriptor*>(&descriptor)),
353 reasonIfUnsupported);
354 case LayerType::Slice:
355 return IsSliceSupported(infos[0],
356 infos[1],
357 *(PolymorphicDowncast<const SliceDescriptor*>(&descriptor)),
358 reasonIfUnsupported);
359 case LayerType::Softmax:
360 return IsSoftmaxSupported(infos[0],
361 infos[1],
362 *(PolymorphicDowncast<const SoftmaxDescriptor*>(&descriptor)),
363 reasonIfUnsupported);
364 case LayerType::SpaceToBatchNd:
365 return IsSpaceToBatchNdSupported(infos[0],
366 infos[1],
367 *(PolymorphicDowncast<const SpaceToBatchNdDescriptor*>(&descriptor)),
368 reasonIfUnsupported);
369 case LayerType::SpaceToDepth:
370 return IsSpaceToDepthSupported(infos[0],
371 infos[1],
372 *(PolymorphicDowncast<const SpaceToDepthDescriptor*>(&descriptor)),
373 reasonIfUnsupported);
374 case LayerType::Splitter:
375 {
376 std::vector<TensorInfo> outputInfos;
377 for (uint32_t i = 1; i < infos.size(); i++)
378 {
379 outputInfos.push_back(infos[i]);
380 }
381 return IsSplitterSupported(infos[0],
382 {outputInfos.begin(), outputInfos.end()},
383 *(PolymorphicDowncast<const ViewsDescriptor*>(&descriptor)),
384 reasonIfUnsupported);
385 }
386 case LayerType::Stack:
387 {
388 std::vector<const TensorInfo*> inputInfos;
389 for (uint32_t i = 0; i < infos.size() - 1; i++)
390 {
391 inputInfos.push_back(&infos[i]);
392 }
393 return IsStackSupported(inputInfos,
394 infos[infos.size() - 1],
395 *(PolymorphicDowncast<const StackDescriptor*>(&descriptor)),
396 reasonIfUnsupported);
397 }
398 case LayerType::StridedSlice:
399 return IsStridedSliceSupported(infos[0],
400 infos[1],
401 *(PolymorphicDowncast<const StridedSliceDescriptor*>(&descriptor)),
402 reasonIfUnsupported);
403 case LayerType::Subtraction:
404 return IsSubtractionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
Teresa Charlin79a06a52023-07-13 17:16:45 +0100405 case LayerType::Tile:
406 return IsTileSupported(infos[0],
407 infos[1],
408 *(PolymorphicDowncast<const TileDescriptor*>(&descriptor)),
409 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000410 case LayerType::Transpose:
411 return IsTransposeSupported(infos[0],
412 infos[1],
413 *(PolymorphicDowncast<const TransposeDescriptor*>(&descriptor)),
414 reasonIfUnsupported);
415 case LayerType::TransposeConvolution2d:
416 {
417 if (infos.size() != 4)
418 {
419 throw InvalidArgumentException("Invalid number of TransposeConvolution2d TensorInfos. "
420 "TensorInfos should be of format: {input, output, weights, biases}.");
421 }
422
423 auto desc = *(PolymorphicDowncast<const TransposeConvolution2dDescriptor*>(&descriptor));
424 if (infos[3] == TensorInfo())
425 {
426 return IsTransposeConvolution2dSupported(infos[0],
427 infos[1],
428 desc,
429 infos[2],
430 EmptyOptional(),
431 reasonIfUnsupported);
432 }
433 else
434 {
435 return IsTransposeConvolution2dSupported(infos[0],
436 infos[1],
437 desc,
438 infos[2],
439 infos[3],
440 reasonIfUnsupported);
441 }
442 }
443 case LayerType::Cast:
444 return IsCastSupported(infos[0], infos[1], reasonIfUnsupported);
445 case LayerType::ChannelShuffle:
446 return IsChannelShuffleSupported(infos[0],
447 infos[1],
448 *(PolymorphicDowncast<const ChannelShuffleDescriptor*>(&descriptor)),
449 reasonIfUnsupported);
450 case LayerType::Convolution3d:
451 {
452 if (infos.size() != 4)
453 {
454 throw InvalidArgumentException("Invalid number of Convolution3d TensorInfos. "
455 "TensorInfos should be of format: {input, output, weights, biases}.");
456 }
457
458 auto desc = *(PolymorphicDowncast<const Convolution3dDescriptor*>(&descriptor));
459 if (infos[3] == TensorInfo())
460 {
461 return IsConvolution3dSupported(infos[0],
462 infos[1],
463 desc,
464 infos[2],
465 EmptyOptional(),
466 reasonIfUnsupported);
467 }
468 else
469 {
470 return IsConvolution3dSupported(infos[0],
471 infos[1],
472 desc,
473 infos[2],
474 infos[3],
475 reasonIfUnsupported);
476 }
477 }
478 case LayerType::Debug:
479 return IsDebugSupported(infos[0], infos[1], reasonIfUnsupported);
480 case LayerType::DetectionPostProcess:
481 return IsDetectionPostProcessSupported(infos[0],
482 infos[1],
483 infos[2],
484 infos[3],
485 infos[4],
486 infos[5],
487 infos[6],
488 *(PolymorphicDowncast<const DetectionPostProcessDescriptor*>
489 (&descriptor)),
490 reasonIfUnsupported);
491 case LayerType::FakeQuantization:
492 return IsFakeQuantizationSupported(infos[0],
493 *(PolymorphicDowncast<const FakeQuantizationDescriptor*>(&descriptor)),
494 reasonIfUnsupported);
495 case LayerType::MemCopy:
496 return IsMemCopySupported(infos[0], infos[1], reasonIfUnsupported);
497 case LayerType::Rank:
498 return IsRankSupported(infos[0], infos[1], reasonIfUnsupported);
499 case LayerType::Shape:
500 return IsShapeSupported(infos[0], infos[1], reasonIfUnsupported);
501 case LayerType::UnidirectionalSequenceLstm:
502 {
503 if (infos.size() != 6)
504 {
505 throw InvalidArgumentException("Invalid number of UnidirectionalSequenceLstm TensorInfos. TensorInfos "
506 "should be of format: {input, outputStateIn, cellStateIn, "
507 "hiddenStateOutputVal, cellStateOutputVal, output}");
508 }
509 auto desc = *(PolymorphicDowncast<const UnidirectionalSequenceLstmDescriptor*>(&descriptor));
Mike Kelly12994962022-04-21 11:57:09 +0100510 return IsUnidirectionalSequenceLstmSupported(infos[0],
511 infos[1],
512 infos[2],
513 infos[3],
514 infos[4],
515 infos[5],
516 desc,
517 lstmParamsInfo.value(),
518 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000519 }
520 case LayerType::Pooling3d:
521 return IsPooling3dSupported(infos[0],
522 infos[1],
523 *(PolymorphicDowncast<const Pooling3dDescriptor*>(&descriptor)),
524 reasonIfUnsupported);
525 case LayerType::Map:
526 return true;
527 case LayerType::Unmap:
528 return true;
529 case LayerType::MemImport:
530 return LayerSupportBase::IsMemImportSupported(infos[0], infos[1], reasonIfUnsupported);
531 case LayerType::Merge:
532 return LayerSupportBase::IsMergeSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
533 case LayerType::QuantizedLstm:
534 return LayerSupportBase::IsQuantizedLstmSupported(infos[0],
535 infos[1],
536 infos[2],
537 infos[3],
538 infos[4],
539 quantizedLstmInputParamsInfo.value(),
540 reasonIfUnsupported);
541 default:
Teresa Charlin9145e382023-08-17 18:44:58 +0100542 // layers not supported in reference by default:
543 // precompiled, standin, switch, fused
Cathal Corbett34b429c2021-12-24 12:24:40 +0000544 return false;
545 }
546}
547
arovir011c7c81b2018-10-08 11:34:28 +0100548bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
549 const TensorInfo& output,
550 const ActivationDescriptor& descriptor,
551 Optional<std::string&> reasonIfUnsupported) const
552{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000553 bool supported = true;
554
555 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +0000556 std::array<DataType,6> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000557 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100558 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000559 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000560 DataType::QAsymmU8,
561 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000562 };
563
564 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
565 "Reference activation: input type not supported.");
566
567 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
568 "Reference activation: output type not supported.");
569
570 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
571 "Reference activation: input and output types mismatched.");
572
573 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
574 "Reference activation: input and output shapes are of different rank.");
575
576
577 struct ActivationFunctionSupported : public Rule
578 {
579 ActivationFunctionSupported(const ActivationDescriptor& desc)
580 {
581 switch(desc.m_Function)
582 {
583 case ActivationFunction::Abs:
584 case ActivationFunction::BoundedReLu:
David Monahan3b3c3812020-02-25 09:03:29 +0000585 case ActivationFunction::Elu:
Colm Donelan03fbeaf2020-02-26 15:39:23 +0000586 case ActivationFunction::HardSwish:
Derek Lamberti50db4e82019-03-13 14:16:15 +0000587 case ActivationFunction::LeakyReLu:
588 case ActivationFunction::Linear:
589 case ActivationFunction::ReLu:
590 case ActivationFunction::Sigmoid:
591 case ActivationFunction::SoftReLu:
592 case ActivationFunction::Sqrt:
593 case ActivationFunction::Square:
594 case ActivationFunction::TanH:
595 {
596 m_Res = true;
597 break;
598 }
599 default:
600 {
601 m_Res = false;
602 break;
603 }
604 }
605 }
606 };
607
608 // Function is supported
609 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
610 "Reference activation: function not supported.");
611
612 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100613}
614
615bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
616 const TensorInfo& input1,
617 const TensorInfo& output,
618 Optional<std::string&> reasonIfUnsupported) const
619{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000620 bool supported = true;
621
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100622 std::array<DataType,7> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000623 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100624 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000625 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000626 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100627 DataType::QSymmS16,
628 DataType::Signed32
Derek Lamberti50db4e82019-03-13 14:16:15 +0000629 };
630
631 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
632 "Reference addition: input 0 is not a supported type.");
633
634 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
635 "Reference addition: input 1 is not a supported type.");
636
637 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
638 "Reference addition: output is not a supported type.");
639
640 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
641 "Reference addition: input 0 and Input 1 types are mismatched");
642
643 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
644 "Reference addition: input and output types are mismatched");
645
646 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
647 "Reference addition: shapes are not suitable for implicit broadcast.");
648
649 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100650}
651
Nikhil Raj68c2c902019-09-19 11:21:11 +0100652bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
653 const armnn::ArgMinMaxDescriptor &descriptor,
654 armnn::Optional<std::string &> reasonIfUnsupported) const
655{
Jan Eilers8eb25602020-03-09 12:13:48 +0000656 IgnoreUnused(descriptor);
Nikhil Raj68c2c902019-09-19 11:21:11 +0100657
Mike Kelly1f140f72021-04-06 12:25:55 +0100658 std::array<DataType, 8> supportedInputTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100659 {
Teresa Charline300b362020-05-25 10:01:03 +0100660 DataType::Float16,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100661 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100662 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000663 DataType::QAsymmU8,
664 DataType::QSymmS16,
Mike Kelly1f140f72021-04-06 12:25:55 +0100665 DataType::Signed32,
666 DataType::Signed64
667 };
668
669 std::array<DataType,2> supportedOutputTypes = {
670 DataType::Signed32,
671 DataType::Signed64
Nikhil Raj68c2c902019-09-19 11:21:11 +0100672 };
673
674 bool supported = true;
675
Mike Kelly1f140f72021-04-06 12:25:55 +0100676 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100677 "Reference ArgMinMax: input is not a supported type.");
Mike Kelly1f140f72021-04-06 12:25:55 +0100678 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100679 "Reference ArgMinMax: output type not supported");
680
681 return supported;
682}
683
Samuel Yap6b478092022-07-06 15:36:03 +0100684bool RefLayerSupport::IsBatchMatMulSupported(const TensorInfo& inputX,
685 const TensorInfo& inputY,
686 const TensorInfo& output,
687 const BatchMatMulDescriptor& descriptor,
688 Optional<std::string &> reasonIfUnsupported) const
689{
690 IgnoreUnused(descriptor);
691
692 std::array<DataType, 6> supportedTypes =
693 {
Samuel Yap6b478092022-07-06 15:36:03 +0100694 DataType::Float16,
695 DataType::Float32,
696 DataType::QAsymmS8,
697 DataType::QAsymmU8,
698 DataType::QSymmS16
699 };
700
701 bool supported = true;
702
703 supported &= CheckSupportRule(TypeAnyOf(inputX, supportedTypes), reasonIfUnsupported,
704 "Reference batch matrix multiplication: input X is not a supported type");
705
706 supported &= CheckSupportRule(TypeAnyOf(inputY, supportedTypes), reasonIfUnsupported,
707 "Reference batch matrix multiplication: input Y is not a supported type");
708
709 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
710 "Reference batch matrix multiplication: output is not a supported type");
711
712 supported &= CheckSupportRule(TypesAreEqual(inputX, inputY), reasonIfUnsupported,
713 "Reference batch matrix multiplication: input X and input Y types are mismatched");
714
715 supported &= CheckSupportRule(TypesAreEqual(inputX, output), reasonIfUnsupported,
716 "Reference batch matrix multiplication: inputs and output types are mismatched");
717
718 supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputX, 2),
719 reasonIfUnsupported,
720 "Reference batch matrix multiplication: input X is not of rank 2 or greater");
721
722 supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputY, 2),
723 reasonIfUnsupported,
724 "Reference batch matrix multiplication: input Y is not of rank 2 or greater");
725
726 return supported;
727}
728
arovir011c7c81b2018-10-08 11:34:28 +0100729bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
730 const TensorInfo& output,
731 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100732 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100733 const TensorInfo& beta,
734 const TensorInfo& gamma,
735 const BatchNormalizationDescriptor& descriptor,
736 Optional<std::string&> reasonIfUnsupported) const
737{
Jan Eilers8eb25602020-03-09 12:13:48 +0000738 IgnoreUnused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100739
Sadik Armagan303980c2020-04-17 12:45:14 +0100740 std::array<DataType, 6> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100741 {
742 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100743 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100744 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000745 DataType::QAsymmU8,
746 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100747 };
748
749 bool supported = true;
750
751 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
752 "Reference batch normalization: input is not a supported type.");
753
754 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
755 "Reference batch normalization: output is not a supported type.");
756
757 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
758 "Reference batch normalization: input and output types are mismatched");
759
760 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
761 "Reference batch normalization: mean is not a supported type.");
762
763 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
764 "Reference batch normalization: variance is not a supported type.");
765
766 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
767 "Reference batch normalization: beta is not a supported type.");
768
769 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
770 "Reference batch normalization: gamma is not a supported type.");
771
772 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100773}
774
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000775bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
776 const TensorInfo& output,
777 const BatchToSpaceNdDescriptor& descriptor,
778 Optional<std::string&> reasonIfUnsupported) const
779{
Jan Eilers8eb25602020-03-09 12:13:48 +0000780 IgnoreUnused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100781
782 bool supported = true;
783
784 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
785 std::string inputTensorStr = "input";
786 std::string outputTensorStr = "output";
787
788 // Define supported types.
Sadik Armagan303980c2020-04-17 12:45:14 +0100789 std::array<DataType,6> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100790 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000791 DataType::Float32,
792 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100793 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000794 DataType::QAsymmU8,
795 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100796 };
797
798 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
799 "Reference BatchToSpaceNd: input type not supported.");
800
801 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
802 "Reference BatchToSpaceNd: output type not supported.");
803
804 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
805 "Reference BatchToSpaceNd: input and output types mismatched.");
806
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100807 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000808}
809
mathad01b392e982021-04-07 12:07:30 +0100810bool RefLayerSupport::IsCastSupported(const TensorInfo& input,
811 const TensorInfo& output,
812 Optional<std::string&> reasonIfUnsupported) const
813{
814 std::array<DataType, 9> supportedInputTypes =
815 {
mathad01b392e982021-04-07 12:07:30 +0100816 DataType::Float32,
817 DataType::Float16,
818 DataType::QSymmS8,
819 DataType::QAsymmS8,
820 DataType::QAsymmU8,
821 DataType::QSymmS16,
822 DataType::Signed32
823 };
824
825 bool supported = true;
826 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
827 "Reference cast: input is not a supported type");
828
829
830 supported &= CheckSupportRule(TypeAnyOf(output, supportedInputTypes), reasonIfUnsupported,
831 "Reference cast: output is not a supported type");
832
833 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
834 "Reference cast: input and output shapes have different number of total elements");
835
836 return supported;
837}
838
Simon Obute51f67772021-09-03 15:50:13 +0100839bool RefLayerSupport::IsChannelShuffleSupported(const TensorInfo& input,
840 const TensorInfo& output,
841 const ChannelShuffleDescriptor& descriptor,
842 Optional<std::string&> reasonIfUnsupported) const
843{
844 IgnoreUnused(descriptor);
845 bool supported = true;
846
847 // Define supported output and inputs types.
848 std::array<DataType, 7> supportedTypes =
849 {
Simon Obute51f67772021-09-03 15:50:13 +0100850 DataType::Float32,
851 DataType::Float16,
852 DataType::QAsymmS8,
853 DataType::QAsymmU8,
854 DataType::QSymmS8,
855 DataType::QSymmS16
856 };
857
858 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
859 "Reference ChannelShuffle: input is not a supported type.");
860
861 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
862 "Reference ChannelShuffle: output is not a supported type.");
863
864 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
865 "Reference ChannelShuffle: input and output types are mismatched.");
866
867 return supported;
868}
869
870
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100871bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
872 const TensorInfo& input1,
873 const TensorInfo& output,
874 const ComparisonDescriptor& descriptor,
875 Optional<std::string&> reasonIfUnsupported) const
876{
Jan Eilers8eb25602020-03-09 12:13:48 +0000877 IgnoreUnused(descriptor);
Sadik Armagan303980c2020-04-17 12:45:14 +0100878 std::array<DataType, 8> supportedInputTypes =
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100879 {
Sadik Armaganb60dd242020-03-19 13:53:16 +0000880 DataType::Boolean,
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100881 DataType::Float32,
882 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100883 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000884 DataType::QAsymmU8,
Sadik Armaganb60dd242020-03-19 13:53:16 +0000885 DataType::QSymmS16,
886 DataType::Signed32
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100887 };
888
889 bool supported = true;
890 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
891 "Reference comparison: input 0 is not a supported type");
892
893 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
894 "Reference comparison: input 0 and Input 1 types are mismatched");
895
896 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
897 "Reference comparison: output is not of type Boolean");
898
899 return supported;
900}
901
Jim Flynn906f9462019-05-10 13:55:21 +0100902bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
903 const TensorInfo& output,
Cathal Corbett34b429c2021-12-24 12:24:40 +0000904 const OriginsDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100905 Optional<std::string&> reasonIfUnsupported) const
906{
Jan Eilers8eb25602020-03-09 12:13:48 +0000907 IgnoreUnused(descriptor);
Jim Flynne242f2d2019-05-22 14:24:13 +0100908
909 bool supported = true;
Teresa Charlin6abc7ee2022-02-22 17:32:27 +0000910 std::array<DataType,7> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100911 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000912 DataType::Float32,
913 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000914 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100915 DataType::QAsymmU8,
Teresa Charlin6abc7ee2022-02-22 17:32:27 +0000916 DataType::QSymmS16,
917 DataType::Signed32
Jim Flynne242f2d2019-05-22 14:24:13 +0100918 };
919
920 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
921 "Reference concatenation: output type not supported");
922 for (const TensorInfo* input : inputs)
923 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100924 ARMNN_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100925 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
926 "Reference concatenation: input type not supported");
927
928 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
929 "Reference concatenation: input and output types mismatched.");
930 }
931
932 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100933}
934
arovir011c7c81b2018-10-08 11:34:28 +0100935bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
936 Optional<std::string&> reasonIfUnsupported) const
937{
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100938 std::array<DataType,8> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100939 {
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100940 DataType::Float16,
Nina Drozd58ef2c62019-05-16 12:09:18 +0100941 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +0000942 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100943 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000944 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100945 DataType::QSymmS16,
946 DataType::Signed32
Nina Drozd58ef2c62019-05-16 12:09:18 +0100947 };
948
949 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
950 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100951}
952
953bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
954 const TensorInfo& output,
955 Optional<std::string&> reasonIfUnsupported) const
956{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100957 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
958 input.GetDataType(),
959 &TrueFunc<>,
960 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000961 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000962 &FalseFuncI32<>,
963 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100964 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
965 output.GetDataType(),
966 &FalseOutputFuncF16<>,
967 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000968 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000969 &FalseFuncI32<>,
970 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100971}
972
973bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
974 const TensorInfo& output,
975 Optional<std::string&> reasonIfUnsupported) const
976{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100977 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
978 input.GetDataType(),
979 &FalseInputFuncF16<>,
980 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000981 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000982 &FalseFuncI32<>,
983 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100984 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
985 output.GetDataType(),
986 &TrueFunc<>,
987 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000988 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000989 &FalseFuncI32<>,
990 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100991}
992
993bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
994 const TensorInfo& output,
995 const Convolution2dDescriptor& descriptor,
996 const TensorInfo& weights,
997 const Optional<TensorInfo>& biases,
998 Optional<std::string&> reasonIfUnsupported) const
999{
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001000 bool supported = true;
1001
1002 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001003 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001004 {
1005 DataType::Float32,
1006 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001007 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001008 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001009 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001010 DataType::QSymmS16
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001011 };
1012
1013 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001014 "Reference Convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001015
1016 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001017 "Reference Convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001018
Ryan OShea31441592022-11-07 16:20:48 +00001019 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1020 "Reference Convolution2d: input and output types mismatched.");
1021
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001022
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001023 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +00001024 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001025 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001026 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001027 {
Sadik Armagan303980c2020-04-17 12:45:14 +01001028 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001029 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01001030 DataType::QSymmS8
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001031 };
1032
1033 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001034 "Reference Convolution2d: weights type not supported for quantized input.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001035 }
1036 else
1037 {
1038 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001039 "Reference Convolution2d: weights is not a supported type.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001040
1041 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001042 "Reference Convolution2d: input and weights types mismatched.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001043 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001044
1045 if (biases.has_value())
1046 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001047 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001048 {
1049 DataType::Float32,
1050 DataType::Float16,
1051 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001052 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001053
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001054 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001055 "Reference Convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001056 }
Jan Eilers8eb25602020-03-09 12:13:48 +00001057 IgnoreUnused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001058
1059 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001060}
1061
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001062bool RefLayerSupport::IsConvolution3dSupported(const TensorInfo& input,
1063 const TensorInfo& output,
1064 const Convolution3dDescriptor& descriptor,
1065 const TensorInfo& weights,
1066 const Optional<TensorInfo>& biases,
1067 Optional<std::string&> reasonIfUnsupported) const
1068{
1069 bool supported = true;
1070
1071 // Define supported types.
1072 std::array<DataType,7> supportedTypes =
1073 {
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001074 DataType::Float32,
1075 DataType::Float16,
1076 DataType::QAsymmS8,
1077 DataType::QAsymmU8,
1078 DataType::QSymmS8,
1079 DataType::QSymmS16
1080 };
1081
1082 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1083 "Reference Convolution3d: input is not a supported type.");
1084
1085 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1086 "Reference Convolution3d: output is not a supported type.");
1087
1088 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1089 "Reference Convolution3d: input and output types mismatched.");
1090
1091 const DataType inputType = input.GetDataType();
1092 if (IsQuantized8BitType(inputType))
1093 {
1094 std::array<DataType, 3> supportedWeightTypes =
1095 {
1096 DataType::QAsymmS8,
1097 DataType::QAsymmU8,
1098 DataType::QSymmS8
1099 };
1100
1101 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1102 "Reference Convolution3d: weights type not supported for quantized input.");
1103 }
1104 else
1105 {
1106 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1107 "Reference Convolution3d: weights is not a supported type.");
1108
1109 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1110 "Reference Convolution3d: input and weights types mismatched.");
1111 }
1112
1113 if (biases.has_value())
1114 {
1115 std::array<DataType,4> biasesSupportedTypes =
1116 {
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001117 DataType::Float32,
1118 DataType::Float16,
1119 DataType::Signed32
1120 };
1121
1122 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1123 "Reference Convolution3d: biases is not a supported type.");
1124 }
1125 IgnoreUnused(descriptor);
1126
1127 return supported;
1128}
1129
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001130bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
1131 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001132 Optional<std::string&> reasonIfUnsupported) const
1133{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001134 bool supported = true;
1135
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001136 std::array<DataType, 8> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001137 {
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001138 DataType::BFloat16,
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +00001139 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001140 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001141 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001142 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001143 DataType::QSymmS8,
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +00001144 DataType::QSymmS16,
1145 DataType::Signed32
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001146 };
1147
1148 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001149 "Reference for Debug layer: input type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001150
1151 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001152 "Reference for Debug layer: output type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001153
1154 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001155 "Reference for Debug layer: input and output types are mismatched");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001156
1157 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001158}
1159
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001160bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
1161 const TensorInfo& output,
1162 const DepthToSpaceDescriptor& descriptor,
1163 Optional<std::string&> reasonIfUnsupported) const
1164{
Jan Eilers8eb25602020-03-09 12:13:48 +00001165 IgnoreUnused(descriptor);
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001166 bool supported = true;
1167
Sadik Armagan303980c2020-04-17 12:45:14 +01001168 std::array<DataType,6> supportedTypes =
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001169 {
1170 DataType::Float32,
1171 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001172 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001173 DataType::QAsymmU8,
1174 DataType::QSymmS16
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001175 };
1176
1177 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1178 "Reference DepthToSpace: input type not supported");
1179
1180 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1181 "Reference DepthToSpace: output type not supported");
1182
1183 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1184 "Reference DepthToSpace: input and output types are mismatched");
1185
1186 return supported;
1187}
1188
arovir011c7c81b2018-10-08 11:34:28 +01001189bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
1190 const TensorInfo& output,
1191 const DepthwiseConvolution2dDescriptor& descriptor,
1192 const TensorInfo& weights,
1193 const Optional<TensorInfo>& biases,
1194 Optional<std::string&> reasonIfUnsupported) const
1195{
Sadik Armagan303980c2020-04-17 12:45:14 +01001196 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001197 bool supported = true;
1198
1199 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001200 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001201 {
1202 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001203 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001204 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001205 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001206 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001207 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001208 };
1209
1210 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1211 "Reference DepthwiseConvolution2d: input is not a supported type.");
1212
1213 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1214 "Reference DepthwiseConvolution2d: output is not a supported type.");
1215
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001216 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1217 "Reference DepthwiseConvolution2d: input and output types mismatched.");
1218
Teresa Charlind8df0262019-11-11 12:28:15 +00001219 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +00001220 if (IsQuantized8BitType(inputType))
Teresa Charlind8df0262019-11-11 12:28:15 +00001221 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001222 std::array<DataType, 3> supportedWeightTypes =
Sadik Armagan303980c2020-04-17 12:45:14 +01001223 {
1224 DataType::QAsymmS8,
1225 DataType::QAsymmU8,
1226 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001227 };
Teresa Charlind8df0262019-11-11 12:28:15 +00001228
1229 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Sadik Armagan303980c2020-04-17 12:45:14 +01001230 "Reference DepthwiseConvolution2d: weights type not supported for "
1231 "quantized input.");
Teresa Charlind8df0262019-11-11 12:28:15 +00001232 }
1233 else
1234 {
1235 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1236 "Reference DepthwiseConvolution2d: weights is not a supported type.");
1237
1238 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1239 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
1240 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001241
1242 if (biases.has_value())
1243 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001244 std::array<DataType,4> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001245 {
1246 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001247 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001248 DataType::Signed32
1249 };
1250 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1251 "Reference DepthwiseConvolution2d: biases is not a supported type.");
1252 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001253
1254 return supported;
1255
arovir011c7c81b2018-10-08 11:34:28 +01001256}
1257
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001258bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
1259 const TensorInfo& output,
1260 Optional<std::string&> reasonIfUnsupported) const
1261{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001262 bool supported = true;
1263
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001264 std::array<DataType,5> supportedInputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00001265 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001266 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00001267 DataType::QSymmS8,
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001268 DataType::QSymmS16,
1269 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001270 };
1271
1272 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001273 "Reference for Dequantize layer: input type not supported.");
1274
Derek Lambertid466a542020-01-22 15:37:29 +00001275 supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
Teresa Charlin1b1950d2021-06-02 20:23:21 +01001276 "Reference for Dequantize layer: per-axis quantized input not supported.");
Derek Lambertid466a542020-01-22 15:37:29 +00001277
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001278 std::array<DataType,3> supportedOutputTypes = {
Jan Eilersf7107932019-11-01 11:09:36 +00001279 DataType::Float32,
1280 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001281 };
1282
1283 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001284 "Reference for Dequantize layer: output type not supported.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001285
1286 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001287 "Reference for Dequantize layer: input/output shapes have different num total "
1288 "elements.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001289
1290 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001291}
1292
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001293bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
1294 const TensorInfo& scores,
1295 const TensorInfo& anchors,
1296 const TensorInfo& detectionBoxes,
1297 const TensorInfo& detectionClasses,
1298 const TensorInfo& detectionScores,
1299 const TensorInfo& numDetections,
1300 const DetectionPostProcessDescriptor& descriptor,
1301 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001302{
Jan Eilers8eb25602020-03-09 12:13:48 +00001303 IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
Derek Lamberti901ea112019-12-10 22:07:09 +00001304
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001305 bool supported = true;
1306
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001307 std::array<DataType,6> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001308 {
1309 DataType::Float32,
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001310 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001311 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001312 DataType::QAsymmU8,
1313 DataType::QSymmS16
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001314 };
1315
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001316 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001317 "Reference DetectionPostProcess: input 0 is not a supported type.");
1318
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001319 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001320 "Reference DetectionPostProcess: input 1 is not a supported type.");
1321
1322 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001323}
1324
Pablo Tellof0bd6832019-04-26 17:58:13 +01001325bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
1326 const TensorInfo& output,
1327 const DepthwiseConvolution2dDescriptor& descriptor,
1328 const TensorInfo& weights,
1329 const Optional<TensorInfo>& biases,
1330 Optional<std::string&> reasonIfUnsupported) const
1331{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001332 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +01001333}
1334
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001335bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +01001336 const TensorInfo& input1,
1337 const TensorInfo& output,
1338 Optional<std::string&> reasonIfUnsupported) const
1339{
Sadik Armagan2999a022019-04-09 14:20:12 +01001340 bool supported = true;
1341
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001342 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001343 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001344 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001345 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001346 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001347 DataType::QSymmS16,
1348 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001349 };
1350
1351 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1352 "Reference division: input 0 is not a supported type.");
1353
1354 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1355 "Reference division: input 1 is not a supported type.");
1356
1357 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1358 "Reference division: output is not a supported type.");
1359
1360 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1361 "Reference division: input 0 and Input 1 types are mismatched");
1362
1363 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1364 "Reference division: input and output types are mismatched");
1365
1366 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1367 "Reference division: shapes are not suitable for implicit broadcast.");
1368
1369 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001370}
1371
josh minor4a3c6102020-01-06 16:40:46 -06001372bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
1373 const TensorInfo& output,
1374 const ElementwiseUnaryDescriptor& descriptor,
1375 Optional<std::string&> reasonIfUnsupported) const
1376{
Jan Eilers8eb25602020-03-09 12:13:48 +00001377 IgnoreUnused(descriptor);
josh minor4a3c6102020-01-06 16:40:46 -06001378
Sadik Armagan303980c2020-04-17 12:45:14 +01001379 std::array<DataType, 7> supportedTypes =
josh minor4a3c6102020-01-06 16:40:46 -06001380 {
1381 DataType::Float32,
1382 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001383 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06001384 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00001385 DataType::QSymmS16,
1386 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06001387 };
1388
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001389 std::array<DataType, 1> logicalSupportedTypes =
1390 {
1391 DataType::Boolean
1392 };
1393
josh minor4a3c6102020-01-06 16:40:46 -06001394 bool supported = true;
1395
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001396 if (descriptor.m_Operation == UnaryOperation::LogicalNot)
1397 {
1398 supported &= CheckSupportRule(TypeAnyOf(input, logicalSupportedTypes), reasonIfUnsupported,
1399 "Reference elementwise unary: input type not supported");
josh minor4a3c6102020-01-06 16:40:46 -06001400
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001401 supported &= CheckSupportRule(TypeAnyOf(output, logicalSupportedTypes), reasonIfUnsupported,
1402 "Reference elementwise unary: output type not supported");
1403 }
1404 else
1405 {
1406 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1407 "Reference elementwise unary: input type not supported");
1408
1409 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1410 "Reference elementwise unary: output type not supported");
1411 }
josh minor4a3c6102020-01-06 16:40:46 -06001412
1413 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1414 "Reference elementwise unary: input and output types not matching");
1415
1416 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1417 "Reference elementwise unary: input and output shapes"
1418 "have different number of total elements");
1419
1420 return supported;
1421}
1422
arovir011c7c81b2018-10-08 11:34:28 +01001423bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
1424 const FakeQuantizationDescriptor& descriptor,
1425 Optional<std::string&> reasonIfUnsupported) const
1426{
Jan Eilers8eb25602020-03-09 12:13:48 +00001427 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001428 bool supported = true;
1429
1430 std::array<DataType,1> supportedTypes =
1431 {
1432 DataType::Float32
1433 };
1434
1435 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1436 "Reference fake quantization: input type not supported.");
1437
1438 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001439}
1440
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001441bool RefLayerSupport::IsFillSupported(const TensorInfo& input,
1442 const TensorInfo& output,
1443 const FillDescriptor& descriptor,
1444 Optional<std::string&> reasonIfUnsupported) const
1445{
1446 IgnoreUnused(descriptor);
1447 IgnoreUnused(output);
1448
1449 bool supported = true;
1450
Sadik Armagana792a052020-06-23 16:22:23 +01001451 std::array<DataType,3> supportedTypes =
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001452 {
1453 DataType::Float32,
Sadik Armagana792a052020-06-23 16:22:23 +01001454 DataType::Float16,
1455 DataType::Signed32
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001456 };
1457
Teresa Charlin4b10fef2020-07-29 09:36:41 +01001458 supported &= CheckSupportRule(TypeIs(input, DataType::Signed32), reasonIfUnsupported,
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001459 "Reference Fill: input type not supported.");
1460
Teresa Charlin44088502020-07-27 11:27:19 +01001461 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1462 "Reference Fill: output type not supported.");
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001463 return supported;
1464}
1465
arovir011c7c81b2018-10-08 11:34:28 +01001466bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
1467 const TensorInfo& output,
1468 Optional<std::string&> reasonIfUnsupported) const
1469{
Jan Eilers8eb25602020-03-09 12:13:48 +00001470 IgnoreUnused(output);
James Conroy83735b12019-05-30 16:36:59 +01001471 bool supported = true;
1472
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01001473 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +01001474 {
James Conroyb40d7102019-06-04 12:32:09 +01001475 DataType::Float32,
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01001476 DataType::Float16
James Conroy83735b12019-05-30 16:36:59 +01001477 };
1478
1479 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1480 "Reference Floor: input type not supported.");
1481
1482 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1483 "Reference Floor: output type not supported.");
1484
1485 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001486}
1487
1488bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
1489 const TensorInfo& output,
1490 const TensorInfo& weights,
1491 const TensorInfo& biases,
1492 const FullyConnectedDescriptor& descriptor,
1493 Optional<std::string&> reasonIfUnsupported) const
1494{
Francis Murtagh46c09d02019-05-28 08:15:28 +01001495 bool supported = true;
1496
1497 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001498 std::array<DataType,6> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +01001499 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001500 DataType::Float32,
1501 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001502 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001503 DataType::QAsymmU8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001504 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001505 };
1506
1507 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1508 "Reference Fully Connected: input type not supported.");
1509
1510 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1511 "Reference Fully Connected: output type not supported.");
1512
Francis Murtagh46c09d02019-05-28 08:15:28 +01001513 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1514 "Reference Fully Connected: weights type not supported.");
1515
Ryan OShea31441592022-11-07 16:20:48 +00001516 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1517 "Reference Fully Connected: input and output types mismatched.");
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001518
Jan Eilers1f45dc32020-06-15 11:43:03 +01001519 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1520 "Reference Fully Connected: weights is not a supported type.");
Francis Murtaghddb1d062020-03-10 13:51:45 +00001521
Jan Eilers1f45dc32020-06-15 11:43:03 +01001522 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1523 "Reference Fully Connected: input and weights types mismatched.");
Francis Murtagh46c09d02019-05-28 08:15:28 +01001524
1525 if (descriptor.m_BiasEnabled)
1526 {
1527 // Defined supported types for bias
Sadik Armagandb73c982020-04-01 17:35:30 +01001528 std::array<DataType, 5>
Francis Murtagh46c09d02019-05-28 08:15:28 +01001529 supportedBiasTypes =
1530 {
1531 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001532 DataType::Float16,
Sadik Armagandb73c982020-04-01 17:35:30 +01001533 DataType::Signed32,
1534 DataType::QAsymmS8
Francis Murtagh46c09d02019-05-28 08:15:28 +01001535 };
1536
1537 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
1538 "Reference Fully Connected: bias type not supported.");
1539
1540 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
1541 "Reference Fully Connected: bias and weight types mismatch.");
1542
1543 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
1544 "Reference Fully Connected: bias type inferred from weights is incompatible.");
1545
Narumol Prangnawarat366d7232020-04-29 12:58:17 +01001546 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(biases, 1U), reasonIfUnsupported,
1547 "Reference Fully Connected: bias must have 1 dimension.");
1548
Francis Murtagh46c09d02019-05-28 08:15:28 +01001549 }
1550
1551 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001552}
1553
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01001554bool RefLayerSupport::IsGatherNdSupported(const armnn::TensorInfo& input0,
1555 const armnn::TensorInfo& input1,
1556 const armnn::TensorInfo& output,
1557 armnn::Optional<std::string&> reasonIfUnsupported) const
1558{
1559 bool supported = true;
1560 std::array<DataType,7> supportedTypes =
1561 {
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01001562 DataType::Float32,
1563 DataType::Float16,
1564 DataType::QAsymmS8,
1565 DataType::QAsymmU8,
1566 DataType::QSymmS16,
1567 DataType::Signed32
1568 };
1569
1570 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1571 "Reference GatherNd: input type not supported");
1572
1573 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1574 "Reference GatherNd: output type not supported");
1575
1576 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1577 "Reference GatherNd: indices (input1) type not supported");
1578
1579 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1580 "Reference GatherNd: input and output types not matching");
1581
1582 return supported;
1583}
1584
narpra014951d842019-01-18 16:53:53 +00001585bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
1586 const armnn::TensorInfo& input1,
1587 const armnn::TensorInfo& output,
Teresa Charlin52664732020-06-29 16:27:03 +01001588 const GatherDescriptor& descriptor,
narpra014951d842019-01-18 16:53:53 +00001589 armnn::Optional<std::string&> reasonIfUnsupported) const
1590{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001591 bool supported = true;
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001592 std::array<DataType,7> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001593 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001594 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001595 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001596 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001597 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001598 DataType::QSymmS16,
1599 DataType::Signed32
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001600 };
1601
Nikhil Raj369d8fc2022-11-24 13:12:36 +00001602 IgnoreUnused(descriptor);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001603 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1604 "Reference Gather: input type not supported");
1605
1606 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1607 "Reference Gather: output type not supported");
1608
1609 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1610 "Reference Gather: indices (input1) type not supported");
1611
1612 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1613 "Reference Gather: input and output types not matching");
1614
1615 return supported;
narpra014951d842019-01-18 16:53:53 +00001616}
1617
Derek Lamberti901ea112019-12-10 22:07:09 +00001618bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
1619 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +01001620{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001621 return true;
arovir011c7c81b2018-10-08 11:34:28 +01001622}
1623
Kevin May09ca49c2019-10-09 12:37:34 +01001624bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
1625 const TensorInfo& output,
1626 const InstanceNormalizationDescriptor& descriptor,
1627 Optional<std::string&> reasonIfUnsupported) const
1628{
Jan Eilers8eb25602020-03-09 12:13:48 +00001629 IgnoreUnused(descriptor);
Kevin May09ca49c2019-10-09 12:37:34 +01001630 // Define supported types
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001631 std::array<DataType, 3> supportedTypes =
Kevin May09ca49c2019-10-09 12:37:34 +01001632 {
1633 DataType::Float32,
1634 DataType::Float16
1635 };
1636
1637 bool supported = true;
1638
1639 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1640 "Reference Instance Normalization: input type not supported.");
1641
1642 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1643 "Reference Instance Normalization: output type not supported.");
1644
1645 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1646 "Reference Instance Normalization: input and output types mismatched.");
1647
1648 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1649 "Reference Instance Normalization: input and output shapes have different "
1650 "num total elements.");
1651
1652 return supported;
1653}
1654
arovir011c7c81b2018-10-08 11:34:28 +01001655bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1656 const TensorInfo& output,
1657 const L2NormalizationDescriptor& descriptor,
1658 Optional<std::string&> reasonIfUnsupported) const
1659{
Jan Eilers8eb25602020-03-09 12:13:48 +00001660 IgnoreUnused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001661 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01001662 std::array<DataType, 6> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001663 {
1664 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001665 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001666 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001667 DataType::QAsymmU8,
1668 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001669 };
1670
1671 bool supported = true;
1672
1673 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1674 "Reference L2normalization: input type not supported.");
1675
1676 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1677 "Reference L2normalization: output type not supported.");
1678
1679 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1680 "Reference L2normalization: input and output types mismatched.");
1681
1682 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1683 "Reference L2normalization: input and output shapes have different "
1684 "num total elements.");
1685
1686 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001687}
1688
James Conroyaba90cd2020-11-06 16:28:18 +00001689bool RefLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0,
1690 const TensorInfo& input1,
1691 const TensorInfo& output,
1692 const LogicalBinaryDescriptor& descriptor,
1693 Optional<std::string&> reasonIfUnsupported) const
1694{
1695 IgnoreUnused(descriptor);
1696
1697 std::array<DataType, 1> supportedTypes =
1698 {
1699 DataType::Boolean
1700 };
1701
1702 bool supported = true;
1703 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1704 "Reference LogicalBinary: input 0 type not supported");
1705 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1706 "Reference LogicalBinary: input 1 type not supported");
1707
1708 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1709 "Reference LogicalBinary: input and output types do not match");
1710
1711 return supported;
1712}
1713
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001714bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1715 const TensorInfo& output,
1716 const LogSoftmaxDescriptor& descriptor,
1717 Optional<std::string&> reasonIfUnsupported) const
1718{
Jan Eilers8eb25602020-03-09 12:13:48 +00001719 IgnoreUnused(descriptor);
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001720
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001721 std::array<DataType, 3> supportedTypes =
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001722 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001723 DataType::Float32,
1724 DataType::Float16
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001725 };
1726
1727 bool supported = true;
1728 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1729 "Reference LogSoftmax: input type not supported");
1730
1731 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1732 "Reference LogSoftmax: output type not supported");
1733
1734 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1735 "Reference LogSoftmax: input and output types do not match");
1736
1737 return supported;
1738}
1739
arovir011c7c81b2018-10-08 11:34:28 +01001740bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1741 const TensorInfo& outputStateIn,
1742 const TensorInfo& cellStateIn,
1743 const TensorInfo& scratchBuffer,
1744 const TensorInfo& outputStateOut,
1745 const TensorInfo& cellStateOut,
1746 const TensorInfo& output,
1747 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001748 const LstmInputParamsInfo& paramsInfo,
1749 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +01001750{
Jan Eilers8eb25602020-03-09 12:13:48 +00001751 IgnoreUnused(descriptor);
1752 IgnoreUnused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001753
1754 bool supported = true;
1755
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001756 std::array<DataType,3> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +01001757 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001758 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001759 };
1760
Jan Eilersd01a83c2019-07-03 18:20:40 +01001761 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001762 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1763 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001764 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1765 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001766 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1767 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001768 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1769 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001770 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1771 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001772 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1773 "Reference Lstm: input and cellStateOut types are mismatched");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01001774
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001775 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1776 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001777 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001778 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001779 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001780 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001781 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001782 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001783 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001784 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001785 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001786 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001787 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001788 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001789 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001790 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001791 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001792 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001793 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001794 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001795 "Reference Lstm: input and OutputGateBias types are mismatched");
1796 if (!descriptor.m_CifgEnabled)
1797 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001798 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001799 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001800 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001801 reasonIfUnsupported,
1802 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001803 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001804 "Reference Lstm: input and InputGateBias types are mismatched");
1805 if (descriptor.m_PeepholeEnabled)
1806 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001807 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001808 reasonIfUnsupported,
1809 "Reference Lstm: input and CellToInputWeights types are mismatched");
1810 }
1811 }
1812 if (descriptor.m_PeepholeEnabled)
1813 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001814 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001815 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001816 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001817 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1818 }
1819 if (descriptor.m_ProjectionEnabled)
1820 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001821 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001822 "Reference Lstm: input and mProjectionWeights types are mismatched");
1823 if (paramsInfo.m_ProjectionBias != nullptr)
1824 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001825 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001826 "Reference Lstm: input and ProjectionBias types are mismatched");
1827 }
1828 }
1829 if (descriptor.m_LayerNormEnabled)
1830 {
1831 if (!descriptor.m_CifgEnabled)
1832 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001833 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001834 reasonIfUnsupported,
1835 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1836 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001837 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001838 reasonIfUnsupported,
1839 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001840 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001841 reasonIfUnsupported,
1842 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001843 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001844 reasonIfUnsupported,
1845 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1846 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001847
1848 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001849}
1850
saoste012df12b32018-11-28 16:57:20 +00001851bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1852 const TensorInfo& input1,
1853 const TensorInfo& output,
1854 Optional<std::string&> reasonIfUnsupported) const
1855{
Sadik Armagan2999a022019-04-09 14:20:12 +01001856 bool supported = true;
1857
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001858 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001859 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001860 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001861 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001862 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001863 DataType::QSymmS16,
1864 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001865 };
1866
1867 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1868 "Reference maximum: input 0 is not a supported type.");
1869
1870 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1871 "Reference maximum: input 1 is not a supported type.");
1872
1873 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1874 "Reference maximum: output is not a supported type.");
1875
1876 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1877 "Reference maximum: input 0 and Input 1 types are mismatched");
1878
1879 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1880 "Reference maximum: input and output types are mismatched");
1881
1882 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1883 "Reference maximum: shapes are not suitable for implicit broadcast.");
1884
1885 return supported;
saoste012df12b32018-11-28 16:57:20 +00001886}
1887
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001888bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1889 const TensorInfo& output,
1890 const MeanDescriptor& descriptor,
1891 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001892{
James Conroy4d1ff582019-06-10 17:06:39 +01001893 bool supported = true;
1894 std::string meanLayerStr = "Mean";
1895 std::string outputTensorStr = "output";
1896
Sadik Armagan303980c2020-04-17 12:45:14 +01001897 std::array<DataType,6> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001898 {
1899 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001900 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001901 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001902 DataType::QAsymmU8,
1903 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01001904 };
1905
1906 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1907 "Reference Mean: input type not supported.");
1908
1909 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1910 "Reference Mean: input and output types are mismatched");
1911
1912 if (descriptor.m_KeepDims)
1913 {
1914 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1915 reasonIfUnsupported,
1916 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1917 output.GetNumDimensions(),
1918 meanLayerStr, outputTensorStr).data());
1919 }
1920 else if (descriptor.m_Axis.empty())
1921 {
1922 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1923 reasonIfUnsupported,
1924 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1925 meanLayerStr, outputTensorStr).data());
1926 }
1927 else
1928 {
Matthew Sloyan171214c2020-09-09 09:07:37 +01001929 auto outputDim = input.GetNumDimensions() - armnn::numeric_cast<unsigned int>(descriptor.m_Axis.size());
James Conroy4d1ff582019-06-10 17:06:39 +01001930
1931 if (outputDim > 0)
1932 {
1933 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1934 reasonIfUnsupported,
1935 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1936 meanLayerStr, outputTensorStr).data());
1937 }
1938 else
1939 {
1940 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1941 reasonIfUnsupported,
1942 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1943 meanLayerStr, outputTensorStr).data());
1944 }
1945 }
1946
1947 return supported;
narpra0132b90462018-09-13 11:07:48 +01001948}
1949
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001950bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1951 const TensorInfo &output,
1952 Optional<std::string &> reasonIfUnsupported) const
1953{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001954 bool supported = true;
1955
Sadik Armagan303980c2020-04-17 12:45:14 +01001956 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001957 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001958 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001959 DataType::Float32,
1960 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001961 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001962 DataType::QAsymmU8,
1963 DataType::QSymmS16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001964 DataType::Boolean
1965 };
1966
1967 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1968 "Reference MemCopy: input type not supported");
1969
1970 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1971 "Reference MemCopy: output type not supported");
1972
1973 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1974 "Reference MemCopy: input and output types are mismatched");
1975
1976 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001977}
1978
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001979bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1980 const TensorInfo& input1,
1981 const TensorInfo& output,
1982 Optional<std::string&> reasonIfUnsupported) const
1983{
Sadik Armagan2999a022019-04-09 14:20:12 +01001984 bool supported = true;
1985
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001986 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001987 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001988 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001989 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001990 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001991 DataType::QSymmS16,
1992 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001993 };
1994
1995 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1996 "Reference minimum: input 0 is not a supported type.");
1997
1998 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1999 "Reference minimum: input 1 is not a supported type.");
2000
2001 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2002 "Reference minimum: output is not a supported type.");
2003
2004 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2005 "Reference minimum: input 0 and Input 1 types are mismatched");
2006
2007 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2008 "Reference minimum: input and output types are mismatched");
2009
2010 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2011 "Reference minimum: shapes are not suitable for implicit broadcast.");
2012
2013 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00002014}
2015
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002016bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
2017 const TensorInfo& input1,
2018 const TensorInfo& output,
2019 Optional<std::string&> reasonIfUnsupported) const
2020{
Sadik Armagan2999a022019-04-09 14:20:12 +01002021 bool supported = true;
2022
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002023 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01002024 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002025 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002026 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002027 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002028 DataType::QSymmS16,
2029 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002030 };
2031
2032 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2033 "Reference multiplication: input 0 is not a supported type.");
2034
2035 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2036 "Reference multiplication: input 1 is not a supported type.");
2037
2038 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2039 "Reference multiplication: output is not a supported type.");
2040
2041 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2042 "Reference multiplication: input 0 and Input 1 types are mismatched");
2043
2044 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2045 "Reference multiplication: input and output types are mismatched");
2046
2047 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2048 "Reference multiplication: shapes are not suitable for implicit broadcast.");
2049
2050 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002051}
2052
2053bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
2054 const TensorInfo& output,
2055 const NormalizationDescriptor& descriptor,
2056 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01002057{
Jan Eilers8eb25602020-03-09 12:13:48 +00002058 IgnoreUnused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002059
2060 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01002061 std::array<DataType, 6> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002062 {
2063 DataType::Float16,
2064 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002065 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002066 DataType::QAsymmU8,
2067 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002068 };
2069
2070 bool supported = true;
2071
2072 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2073 "Reference normalization: input type not supported.");
2074
2075 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2076 "Reference normalization: output type not supported.");
2077
2078 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2079 "Reference normalization: input and output shapes have different "
2080 "num total elements.");
2081
2082 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002083}
2084
Derek Lamberti901ea112019-12-10 22:07:09 +00002085bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
2086 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002087{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01002088 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002089}
2090
2091bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
2092 const TensorInfo& output,
2093 const PadDescriptor& descriptor,
2094 Optional<std::string&> reasonIfUnsupported) const
2095{
Jan Eilers8eb25602020-03-09 12:13:48 +00002096 IgnoreUnused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002097 bool supported = true;
2098
2099 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002100 std::array<DataType,6> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002101 {
2102 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002103 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002104 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002105 DataType::QAsymmU8,
2106 DataType::QSymmS16
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002107 };
2108
2109 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2110 "Reference pad: input is not a supported type.");
2111
2112 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2113 "Reference pad: output is not a supported type.");
2114
2115 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2116 "Reference pad: input and output types are mismatched.");
2117
2118 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01002119}
2120
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002121bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
2122 const TensorInfo& output,
2123 const PermuteDescriptor& descriptor,
2124 Optional<std::string&> reasonIfUnsupported) const
2125{
Jan Eilers8eb25602020-03-09 12:13:48 +00002126 IgnoreUnused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002127 bool supported = true;
2128
2129 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002130 std::array<DataType, 6> supportedTypes =
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002131 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002132 DataType::BFloat16,
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002133 DataType::Float32,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002134 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002135 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002136 DataType::QAsymmU8,
2137 DataType::QSymmS16
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002138 };
2139
2140 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2141 "Reference permute: input is not a supported type.");
2142
2143 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2144 "Reference permute: output is not a supported type.");
2145
2146 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2147 "Reference permute: input and output types are mismatched.");
2148
2149 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002150}
2151
2152bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
2153 const TensorInfo& output,
2154 const Pooling2dDescriptor& descriptor,
2155 Optional<std::string&> reasonIfUnsupported) const
2156{
Jan Eilers8eb25602020-03-09 12:13:48 +00002157 IgnoreUnused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01002158 bool supported = true;
2159
2160 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002161 std::array<DataType,6> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01002162 {
2163 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002164 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002165 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002166 DataType::QAsymmU8,
2167 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01002168 };
2169
2170 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2171 "Reference poolind2d: input is not a supported type.");
2172
2173 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2174 "Reference poolind2d: output is not a supported type.");
2175
2176 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2177 "Reference poolind2d: input and output types are mismatched.");
2178
2179 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002180}
2181
Tamás Nyíri7b885b32021-10-26 14:47:57 +01002182bool RefLayerSupport::IsPooling3dSupported(const TensorInfo& input,
2183 const TensorInfo& output,
2184 const Pooling3dDescriptor& descriptor,
2185 Optional<std::string&> reasonIfUnsupported) const
2186{
2187 IgnoreUnused(descriptor);
2188 bool supported = true;
2189
2190 // Define supported output and inputs types.
2191 std::array<DataType,6> supportedTypes =
2192 {
Tamás Nyíri7b885b32021-10-26 14:47:57 +01002193 DataType::Float32,
2194 DataType::Float16,
2195 DataType::QAsymmS8,
2196 DataType::QAsymmU8,
2197 DataType::QSymmS16
2198 };
2199
2200 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2201 "Reference poolind3d: input is not a supported type.");
2202
2203 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2204 "Reference poolind3d: output is not a supported type.");
2205
2206 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2207 "Reference poolind3d: input and output types are mismatched.");
2208
2209 return supported;
2210}
2211
2212
James Conroy4f1f8992020-04-29 20:01:10 +01002213bool RefLayerSupport::IsQLstmSupported(const TensorInfo& input,
2214 const TensorInfo& previousOutputIn,
2215 const TensorInfo& previousCellStateIn,
2216 const TensorInfo& outputStateOut,
2217 const TensorInfo& cellStateOut,
2218 const TensorInfo& output,
2219 const QLstmDescriptor& descriptor,
2220 const LstmInputParamsInfo& paramsInfo,
2221 Optional<std::string&> reasonIfUnsupported) const
2222{
2223 IgnoreUnused(input);
2224 IgnoreUnused(previousOutputIn);
2225 IgnoreUnused(previousCellStateIn);
2226 IgnoreUnused(outputStateOut);
2227 IgnoreUnused(cellStateOut);
2228 IgnoreUnused(output);
2229 IgnoreUnused(descriptor);
2230 IgnoreUnused(paramsInfo);
2231
2232 IgnoreUnused(reasonIfUnsupported);
2233
2234 return true;
2235}
2236
Derek Lamberti5f400d62019-03-25 15:41:58 +00002237bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
2238 const TensorInfo& output,
2239 Optional<std::string&> reasonIfUnsupported) const
2240{
2241 bool supported = true;
2242
Finn Williamsfd271062019-12-04 14:27:27 +00002243 // Define supported input types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002244 std::array<DataType,7> supportedInputTypes = {
Keith Davis5e51cd82020-01-29 16:52:59 +00002245 DataType::Float32,
Keith Davis3d8bc972020-02-04 09:31:47 +00002246 DataType::Float16,
Ryan OShea9add1202020-02-07 10:06:33 +00002247 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002248 DataType::QAsymmU8,
2249 DataType::QSymmS8,
2250 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002251 };
2252
2253 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
2254 "Reference quantize: input type not supported.");
2255
2256 // Define supported output types.
Ryan OShea9add1202020-02-07 10:06:33 +00002257 std::array<DataType,4> supportedOutputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00002258 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002259 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00002260 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002261 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002262 };
2263 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2264 "Reference quantize: output type not supported.");
2265
2266 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2267 "Reference quantize: input and output shapes have different num total elements.");
2268
2269 return supported;
2270}
2271
Finn Williams2605b232020-06-10 15:53:46 +01002272bool RefLayerSupport::IsRankSupported(const TensorInfo& input,
2273 const TensorInfo& output,
2274 Optional<std::string&> reasonIfUnsupported) const
2275{
2276 IgnoreUnused(input);
2277 // Define supported output types.
2278 std::array<DataType,1> supportedOutputTypes =
2279 {
2280 DataType::Signed32,
2281 };
2282
2283 return CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2284 "Reference rank: input type not supported.");
2285}
2286
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00002287bool RefLayerSupport::IsReduceSupported(const TensorInfo& input,
2288 const TensorInfo& output,
2289 const ReduceDescriptor& descriptor,
2290 Optional<std::string&> reasonIfUnsupported) const
2291{
2292 IgnoreUnused(descriptor);
2293 bool supported = true;
2294 std::array<DataType,7> supportedTypes =
2295 {
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00002296 DataType::Float32,
2297 DataType::Float16,
2298 DataType::QAsymmS8,
2299 DataType::QAsymmU8,
2300 DataType::QSymmS16,
2301 DataType::Signed32
2302 };
2303
2304 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2305 "Reference Reduce: input type not supported");
2306
2307 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2308 "Reference Reduce: output type not supported");
2309
2310 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2311 "Reference Reduce: input and output types not matching");
2312
2313 return supported;
2314}
2315
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002316bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Kevin Maya023c402019-12-12 17:28:05 +00002317 const TensorInfo& output,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00002318 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002319 Optional<std::string&> reasonIfUnsupported) const
2320{
Jan Eilers8eb25602020-03-09 12:13:48 +00002321 IgnoreUnused(output);
2322 IgnoreUnused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01002323 // Define supported output types.
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002324 std::array<DataType,8> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01002325 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002326 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01002327 DataType::Float32,
2328 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01002329 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002330 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002331 DataType::QAsymmU8,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002332 DataType::QSymmS16,
2333 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01002334 };
Keith Davis0c2eeac2020-02-11 16:51:50 +00002335
Nina Drozd2f2778f2019-05-27 10:37:05 +01002336 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
2337 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002338}
2339
Teresa Charlin970f43b2019-07-01 13:51:07 +01002340bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
2341 const TensorInfo& output,
2342 const ResizeDescriptor& descriptor,
2343 Optional<std::string&> reasonIfUnsupported) const
2344{
Jan Eilers8eb25602020-03-09 12:13:48 +00002345 IgnoreUnused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01002346 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002347 std::array<DataType,6> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01002348 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002349 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01002350 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002351 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002352 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002353 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002354 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01002355 };
2356
2357 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2358 "Reference Resize: input type not supported");
2359
2360 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2361 "Reference Resize: output type not supported");
2362
2363 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2364 "Reference Resize: input and output types not matching");
2365
2366 return supported;
2367}
2368
Tracy Narinebb8d7592023-07-13 16:50:54 +01002369bool RefLayerSupport::IsReverseV2Supported(const TensorInfo& input0,
2370 const TensorInfo& input1,
Tianle Cheng988354d2023-06-28 13:20:47 +01002371 const TensorInfo& output,
Tianle Cheng988354d2023-06-28 13:20:47 +01002372 Optional<std::string&> reasonIfUnsupported) const
2373{
Tianle Cheng988354d2023-06-28 13:20:47 +01002374 bool supported = true;
2375 // ReverseV2 is data type agnostic so it can support all the types in the Reference backend
Declan-ARM1bf56cd2023-07-20 17:32:57 +01002376 std::array<DataType,8> supportedTypes =
Tianle Cheng988354d2023-06-28 13:20:47 +01002377 {
2378 DataType::BFloat16,
2379 DataType::Float32,
2380 DataType::Float16,
2381 DataType::QAsymmS8,
2382 DataType::QAsymmU8,
Declan-ARM1bf56cd2023-07-20 17:32:57 +01002383 DataType::QSymmS8,
2384 DataType::QSymmS16,
2385 DataType::Signed32
Tianle Cheng988354d2023-06-28 13:20:47 +01002386 };
2387
Tracy Narinebb8d7592023-07-13 16:50:54 +01002388 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2389 "Reference ReverseV2: input0 type not supported");
Tianle Cheng988354d2023-06-28 13:20:47 +01002390
2391 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2392 "Reference ReverseV2: output type not supported");
2393
Tracy Narinebb8d7592023-07-13 16:50:54 +01002394 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2395 "Reference ReverseV2: input0 and output types not matching");
2396
2397 std::array<DataType,6> input2SupportedTypes =
2398 {
2399 DataType::Signed32
2400 };
2401
2402 supported &= CheckSupportRule(TypeAnyOf(input1, input2SupportedTypes), reasonIfUnsupported,
2403 "Reference ReverseV2: input1 type not supported");
Tianle Cheng988354d2023-06-28 13:20:47 +01002404
2405 return supported;
2406}
2407
Keith Davis3ae3f972021-05-21 16:33:48 +01002408bool RefLayerSupport::IsShapeSupported(const TensorInfo& input,
2409 const TensorInfo& output,
2410 Optional<std::string&> reasonIfUnsupported) const
2411{
2412 IgnoreUnused(input);
2413 bool supported = true;
2414
2415 std::array<DataType, 1> supportedTypes =
2416 {
2417 DataType::Signed32
2418 };
2419
2420 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2421 "Reference Shape: output type not supported");
2422
2423 return supported;
2424}
2425
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002426bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
2427 const TensorInfo& output,
2428 const SliceDescriptor& descriptor,
2429 Optional<std::string&> reasonIfUnsupported) const
2430{
Jan Eilers8eb25602020-03-09 12:13:48 +00002431 IgnoreUnused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002432 bool supported = true;
2433
Sadik Armagan303980c2020-04-17 12:45:14 +01002434 std::array<DataType, 5> supportedTypes =
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002435 {
2436 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002437 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002438 DataType::QAsymmU8,
Ryan OShea980446b2023-06-08 16:23:28 +01002439 DataType::QSymmS16,
2440 DataType::Signed32
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002441 };
2442
2443 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2444 "Reference Slice: input type not supported");
2445
2446 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2447 "Reference Slice: output type not supported");
2448
2449 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2450 "Reference Slice: input and output types are mismatched");
2451
2452 return supported;
2453}
2454
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002455bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
2456 const TensorInfo& output,
2457 const SoftmaxDescriptor& descriptor,
2458 Optional<std::string&> reasonIfUnsupported) const
2459{
Jan Eilers8eb25602020-03-09 12:13:48 +00002460 IgnoreUnused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01002461 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002462 std::array<DataType,7> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01002463 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002464 DataType::Float32,
2465 DataType::Float16,
2466 DataType::QSymmS8,
2467 DataType::QAsymmS8,
2468 DataType::QAsymmU8,
2469 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +01002470 };
2471
2472 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002473 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002474
2475 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002476 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002477
2478 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002479 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002480
2481 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002482}
2483
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002484bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
2485 const TensorInfo& output,
2486 const SpaceToBatchNdDescriptor& descriptor,
2487 Optional<std::string&> reasonIfUnsupported) const
2488{
Jan Eilers8eb25602020-03-09 12:13:48 +00002489 IgnoreUnused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01002490 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002491 std::array<DataType,6> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01002492 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002493 DataType::Float32,
2494 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002495 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002496 DataType::QAsymmU8,
2497 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01002498 };
2499
2500 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2501 "Reference SpaceToBatchNd: input type not supported");
2502
2503 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2504 "Reference SpaceToBatchNd: output type not supported");
2505
2506 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2507 "Reference SpaceToBatchNd: input and output types are mismatched");
2508
2509 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002510}
2511
Keith Davisa57eccb2019-06-14 17:33:22 +01002512bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01002513 const TensorInfo& output,
2514 const SpaceToDepthDescriptor& descriptor,
2515 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01002516{
2517
Jan Eilers8eb25602020-03-09 12:13:48 +00002518 IgnoreUnused(descriptor);
Keith Davisa57eccb2019-06-14 17:33:22 +01002519 bool supported = true;
2520
Sadik Armagan303980c2020-04-17 12:45:14 +01002521 std::array<DataType,6> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01002522 {
2523 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002524 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002525 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002526 DataType::QAsymmU8,
2527 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01002528 };
2529
2530 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2531 "Reference SpaceToDepth: input type not supported");
2532
2533 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2534 "Reference SpaceToDepth: output type not supported");
2535
2536 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2537 "Reference SpaceToDepth: input and output types are mismatched");
2538
2539 return supported;
2540}
2541
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002542bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002543 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
2544 const ViewsDescriptor& descriptor,
2545 Optional<std::string&> reasonIfUnsupported) const
2546{
Jan Eilers8eb25602020-03-09 12:13:48 +00002547 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002548 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002549 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002550 {
2551 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002552 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002553 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002554 DataType::QAsymmU8,
2555 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002556 };
2557
2558 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2559 "Reference splitter: output type not supported");
Derek Lambertieac4adb2020-08-25 13:05:59 +01002560 for (const TensorInfo& output : outputs)
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002561 {
2562 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2563 "Reference splitter: input type not supported");
2564
2565 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2566 "Reference splitter: input and output types mismatched.");
2567 }
2568
2569 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002570}
2571
Matthew Jackson81e601c2019-07-11 12:07:09 +01002572bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
2573 const TensorInfo& output,
2574 const StackDescriptor& descriptor,
2575 Optional<std::string&> reasonIfUnsupported) const
2576{
Jan Eilers8eb25602020-03-09 12:13:48 +00002577 IgnoreUnused(descriptor);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002578
2579 bool supported = true;
Sadik Armagan529195f2022-01-14 12:56:35 +00002580 std::array<DataType,7> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01002581 {
2582 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01002583 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002584 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002585 DataType::QAsymmU8,
Sadik Armagan529195f2022-01-14 12:56:35 +00002586 DataType::QSymmS16,
2587 DataType::Signed32
Matthew Jackson81e601c2019-07-11 12:07:09 +01002588 };
2589
2590 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2591 "Reference stack: output type not supported");
2592 for (const TensorInfo* input : inputs)
2593 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01002594 ARMNN_ASSERT(input != nullptr);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002595 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
2596 "Reference stack: input type not supported");
2597
2598 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
2599 "Reference stack: input and output types mismatched.");
2600 }
2601
2602 return supported;
2603}
2604
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002605bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
2606 const TensorInfo& output,
2607 const StridedSliceDescriptor& descriptor,
2608 Optional<std::string&> reasonIfUnsupported) const
2609{
Jan Eilers8eb25602020-03-09 12:13:48 +00002610 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002611 bool supported = true;
2612
Sadik Armagan303980c2020-04-17 12:45:14 +01002613 std::array<DataType,5> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002614 {
2615 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002616 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002617 DataType::QAsymmU8,
2618 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002619 };
2620
2621 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2622 "Reference StridedSlice: input type not supported");
2623
2624 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2625 "Reference StridedSlice: output type not supported");
2626
2627 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2628 "Reference StridedSlice: input and output types are mismatched");
2629
2630 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002631}
2632
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002633bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
2634 const TensorInfo& input1,
2635 const TensorInfo& output,
2636 Optional<std::string&> reasonIfUnsupported) const
2637{
Sadik Armagan2999a022019-04-09 14:20:12 +01002638 bool supported = true;
2639
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002640 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01002641 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002642 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002643 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002644 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002645 DataType::QSymmS16,
2646 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002647 };
2648
2649 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2650 "Reference subtraction: input 0 is not a supported type.");
2651
2652 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2653 "Reference subtraction: input 1 is not a supported type.");
2654
2655 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2656 "Reference subtraction: output is not a supported type.");
2657
2658 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2659 "Reference subtraction: input 0 and Input 1 types are mismatched");
2660
2661 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2662 "Reference subtraction: input and output types are mismatched");
2663
2664 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2665 "Reference subtraction: shapes are not suitable for implicit broadcast.");
2666
2667 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002668}
2669
Matteo Martincighab9e5252019-06-13 17:27:46 +01002670bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
2671 const TensorInfo& alpha,
2672 const TensorInfo& output,
2673 Optional<std::string&> reasonIfUnsupported) const
2674{
2675 bool supported = true;
2676
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002677 std::array<DataType, 6> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01002678 {
2679 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002680 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002681 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002682 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002683 DataType::QSymmS16
Matteo Martincighab9e5252019-06-13 17:27:46 +01002684 };
2685
2686 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2687 "PReLU: input is not a supported type.");
2688
2689 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
2690 "PReLU: alpha is not a supported type.");
2691
2692 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2693 "PReLU: output is not a supported type.");
2694
2695 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
2696 "PReLU: input, alpha and output types are mismatched");
2697
2698 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
2699 "PReLU: shapes are not suitable for implicit broadcast");
2700
2701 return supported;
2702}
2703
Teresa Charlin79a06a52023-07-13 17:16:45 +01002704bool RefLayerSupport::IsTileSupported(const TensorInfo& input,
2705 const TensorInfo& output,
2706 const TileDescriptor& descriptor,
2707 Optional<std::string&> reasonIfUnsupported) const
2708{
2709 IgnoreUnused(descriptor);
2710
2711 bool supported = true;
2712
2713 std::array<DataType, 7> supportedTypes
2714 {
2715 DataType::Float32,
2716 DataType::Float16,
2717 DataType::QAsymmS8,
2718 DataType::QAsymmU8,
2719 DataType::QSymmS8,
2720 DataType::QSymmS16,
2721 DataType::Signed32
2722 };
2723
2724 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2725 "Tile: input type not supported.");
2726
2727 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2728 "Tile: output type not supported");
2729
2730 return supported;
2731}
2732
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002733bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
2734 const TensorInfo& output,
2735 const TransposeConvolution2dDescriptor& descriptor,
2736 const TensorInfo& weights,
2737 const Optional<TensorInfo>& biases,
2738 Optional<std::string&> reasonIfUnsupported) const
2739{
Jan Eilers8eb25602020-03-09 12:13:48 +00002740 IgnoreUnused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002741 bool supported = true;
2742
Sadik Armagan303980c2020-04-17 12:45:14 +01002743 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002744 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002745 DataType::Float32,
2746 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002747 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002748 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002749 DataType::QSymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002750 DataType::QSymmS16
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002751 };
2752
2753 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2754 "Reference TransposeConvolution2d: input is not a supported type.");
2755
2756 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2757 "Reference TransposeConvolution2d: output is not a supported type.");
2758
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002759 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2760 "Reference TransposeConvolution2d: input and output types mismatched.");
2761
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002762
2763 const DataType inputType = input.GetDataType();
Sadik Armagan303980c2020-04-17 12:45:14 +01002764 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002765 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01002766 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002767 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002768 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002769 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01002770 DataType::QSymmS8
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002771 };
2772
2773 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
2774 "Reference TransposeConvolution2d: weights type not supported for "
2775 "quantized input.");
2776 }
2777 else
2778 {
2779 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
2780 "Reference TransposeConvolution2d: weights is not a supported type.");
2781
2782 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
2783 "Reference TransposeConvolution2d: input and weights types mismatched.");
2784 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002785
2786 if (biases.has_value())
2787 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002788 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01002789 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002790 DataType::Float32,
2791 DataType::Float16,
2792 DataType::Signed32
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002793 };
2794 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
2795 "Reference TransposeConvolution2d: biases is not a supported type.");
2796 }
2797
2798 return supported;
2799}
2800
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002801bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
2802 const TensorInfo& output,
2803 const TransposeDescriptor& descriptor,
2804 Optional<std::string&> reasonIfUnsupported) const
2805{
Jan Eilers8eb25602020-03-09 12:13:48 +00002806 IgnoreUnused(descriptor);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002807 bool supported = true;
2808
2809 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002810 std::array<DataType, 6> supportedTypes =
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002811 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002812 DataType::BFloat16,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002813 DataType::Float32,
2814 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002815 DataType::QAsymmS8,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002816 DataType::QAsymmU8,
2817 DataType::QSymmS16
2818 };
2819
2820 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2821 "Reference transpose: input is not a supported type.");
2822
2823 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2824 "Reference transpose: output is not a supported type.");
2825
2826 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2827 "Reference transpose: input and output types are mismatched.");
2828
2829 return supported;
2830}
2831
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002832bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
2833 const TensorInfo& input,
2834 const TensorInfo& outputStateIn,
2835 const TensorInfo& cellStateIn,
Mike Kelly12994962022-04-21 11:57:09 +01002836 const TensorInfo& outputStateOut,
2837 const TensorInfo& cellStateOut,
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002838 const TensorInfo& output,
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002839 const UnidirectionalSequenceLstmDescriptor& descriptor,
2840 const LstmInputParamsInfo& paramsInfo,
2841 Optional<std::string&> reasonIfUnsupported) const
2842{
2843 IgnoreUnused(descriptor);
2844 IgnoreUnused(paramsInfo);
2845 IgnoreUnused(outputStateIn);
2846 IgnoreUnused(cellStateIn);
Mike Kelly12994962022-04-21 11:57:09 +01002847 IgnoreUnused(outputStateOut);
2848 IgnoreUnused(cellStateOut);
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002849 bool supported = true;
2850
Mike Kelly12994962022-04-21 11:57:09 +01002851 std::array<DataType, 2> supportedTypes =
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002852 {
Mike Kelly12994962022-04-21 11:57:09 +01002853 DataType::Float32,
2854 DataType::QAsymmS8
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002855 };
2856
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002857 std::array<DataType, 2> supportedWeightTypes =
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002858 {
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002859 DataType::Float32,
2860 DataType::QAsymmS8
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002861 };
2862
Mike Kelly12994962022-04-21 11:57:09 +01002863 std::array<DataType, 3> supportedBiasTypes =
2864 {
2865 DataType::Float32,
2866 DataType::QAsymmS8,
2867 DataType::Signed32
2868 };
2869
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002870 // check inputs and outputs
2871 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2872 "Reference UnidirectionalSequenceLstm: input is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002873 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2874 "Reference UnidirectionalSequenceLstm: output is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002875
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002876 // check layer parameters
2877 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes),
2878 reasonIfUnsupported,
2879 "Reference UnidirectionalSequenceLstm: InputToForgetWeights "
2880 "is not a supported type.");
2881 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToCellWeights(), supportedWeightTypes),
2882 reasonIfUnsupported,
2883 "Reference UnidirectionalSequenceLstm: InputToCellWeights is not a supported type.");
2884 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToOutputWeights(), supportedWeightTypes),
2885 reasonIfUnsupported,
2886 "Reference UnidirectionalSequenceLstm: InputToOutputWeights "
2887 "is not a supported type.");
2888 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToForgetWeights(), supportedWeightTypes),
2889 reasonIfUnsupported,
2890 "Reference UnidirectionalSequenceLstm: RecurrentToForgetWeights "
2891 "is not a supported type.");
2892 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToCellWeights(), supportedWeightTypes),
2893 reasonIfUnsupported,
2894 "Reference UnidirectionalSequenceLstm: RecurrentToCellWeights "
2895 "is not a supported type.");
2896 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToOutputWeights(), supportedWeightTypes),
2897 reasonIfUnsupported,
2898 "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights "
2899 "is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002900
2901 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetGateBias(), supportedBiasTypes), reasonIfUnsupported,
2902 "Reference UnidirectionalSequenceLstm: ForgetGateBias is not a supported type.");
2903 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellBias(), supportedBiasTypes), reasonIfUnsupported,
2904 "Reference UnidirectionalSequenceLstm: CellBias is not a supported type.");
2905 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputGateBias(), supportedBiasTypes), reasonIfUnsupported,
2906 "Reference UnidirectionalSequenceLstm: OutputGateBias is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002907 if (!descriptor.m_CifgEnabled)
2908 {
2909 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes),
2910 reasonIfUnsupported,
2911 "Reference UnidirectionalSequenceLstm: InputToInputWeights "
2912 "is not a supported type.");
2913 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToInputWeights(), supportedWeightTypes),
2914 reasonIfUnsupported,
2915 "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights "
2916 "is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002917 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputGateBias(), supportedBiasTypes), reasonIfUnsupported,
2918 "Reference UnidirectionalSequenceLstm: InputGateBias is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002919 if (descriptor.m_PeepholeEnabled)
2920 {
2921 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes),
2922 reasonIfUnsupported,
2923 "Reference UnidirectionalSequenceLstm: CellToInputWeights "
2924 "is not a supported type.");
2925 }
2926 }
2927 if (descriptor.m_PeepholeEnabled)
2928 {
2929 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToForgetWeights(), supportedWeightTypes),
2930 reasonIfUnsupported,
2931 "Reference UnidirectionalSequenceLstm: CellToForgetWeights "
2932 "is not a supported type.");
2933 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToOutputWeights(), supportedWeightTypes),
2934 reasonIfUnsupported,
2935 "Reference UnidirectionalSequenceLstm: CellToOutputWeights "
2936 "is not a supported type.");
2937 }
2938 if (descriptor.m_ProjectionEnabled)
2939 {
2940 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetProjectionWeights(), supportedWeightTypes),
2941 reasonIfUnsupported,
2942 "Reference UnidirectionalSequenceLstm: ProjectionWeights "
2943 "is not a supported type.");
2944 if (paramsInfo.m_ProjectionBias != nullptr)
2945 {
2946 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
2947 "Reference UnidirectionalSequenceLstm: input and ProjectionBias types "
2948 "are mismatched");
2949 }
2950 }
2951 if (descriptor.m_LayerNormEnabled)
2952 {
2953 if (!descriptor.m_CifgEnabled)
2954 {
2955 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputLayerNormWeights(), supportedWeightTypes),
2956 reasonIfUnsupported,
2957 "Reference UnidirectionalSequenceLstm: InputLayerNormWeights "
2958 "is not a supported type.");
2959 }
2960 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetLayerNormWeights(), supportedWeightTypes),
2961 reasonIfUnsupported,
2962 "Reference UnidirectionalSequenceLstm: ForgetLayerNormWeights "
2963 "is not a supported type.");
2964 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellLayerNormWeights(), supportedWeightTypes),
2965 reasonIfUnsupported,
2966 "Reference UnidirectionalSequenceLstm: CellLayerNormWeights "
2967 "is not a supported type.");
2968 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputLayerNormWeights(), supportedWeightTypes),
2969 reasonIfUnsupported,
2970 "Reference UnidirectionalSequenceLstm: OutputLayerNormWeights "
2971 "is not a supported type.");
2972 }
2973
2974 return supported;
2975}
2976
arovir011c7c81b2018-10-08 11:34:28 +01002977} // namespace armnn