blob: 40d243e10a9af3baddb4aa265f1ea9766edf11dc [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
Mike Kelly3ec30772023-03-08 13:47:17 +00002// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01007
Keith Davis0c2eeac2020-02-11 16:51:50 +00008#include <armnn/TypesUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000010#include <armnn/utility/IgnoreUnused.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010011#include <armnn/utility/NumericCast.hpp>
Cathal Corbett34b429c2021-12-24 12:24:40 +000012#include <armnn/utility/PolymorphicDowncast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
Matteo Martincighe011d202019-11-28 11:35:47 +000014#include <LayerSupportCommon.hpp>
Derek Lambertif674aa02019-08-01 15:56:25 +010015#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000016
Derek Lamberti50db4e82019-03-13 14:16:15 +000017#include <vector>
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <array>
19
telsoa014fcda012018-03-09 14:13:49 +000020namespace armnn
21{
22
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010023namespace
24{
25
26template<typename Float32Func, typename Uint8Func, typename ... Params>
27bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
28 DataType dataType,
29 Float32Func floatFuncPtr,
30 Uint8Func uint8FuncPtr,
31 Params&&... params)
32{
33 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
34 dataType,
35 &FalseFunc<Params...>,
36 floatFuncPtr,
37 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000038 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000039 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010040 std::forward<Params>(params)...);
41}
42
43} // anonymous namespace
44
James Conroy4d1ff582019-06-10 17:06:39 +010045namespace
46{
47
48std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
49 unsigned int actual,
50 std::string& layerStr,
51 std::string& tensorName)
52{
53 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
54 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
55
56 return errorMsg;
57}
58
59} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000060
Cathal Corbett34b429c2021-12-24 12:24:40 +000061bool RefLayerSupport::IsLayerSupported(const LayerType& type,
62 const std::vector<TensorInfo>& infos,
63 const BaseDescriptor& descriptor,
64 const Optional<LstmInputParamsInfo>& lstmParamsInfo,
65 const Optional<QuantizedLstmInputParamsInfo>& quantizedLstmInputParamsInfo,
66 Optional<std::string&> reasonIfUnsupported) const
67{
68 switch (type)
69 {
70 case LayerType::Activation:
71 return IsActivationSupported(infos[0],
72 infos[1],
73 *(PolymorphicDowncast<const ActivationDescriptor*>(&descriptor)),
74 reasonIfUnsupported);
75 case LayerType::Addition:
76 return IsAdditionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
77 case LayerType::ArgMinMax:
78 return IsArgMinMaxSupported(infos[0],
79 infos[1],
80 *(PolymorphicDowncast<const ArgMinMaxDescriptor*>(&descriptor)),
81 reasonIfUnsupported);
Samuel Yap6b478092022-07-06 15:36:03 +010082 case LayerType::BatchMatMul:
83 return IsBatchMatMulSupported(infos[0],
84 infos[1],
85 infos[2],
86 *(PolymorphicDowncast<const BatchMatMulDescriptor*>(&descriptor)),
87 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +000088 case LayerType::BatchNormalization:
89 return IsBatchNormalizationSupported(infos[0],
90 infos[1],
91 infos[2],
92 infos[3],
93 infos[4],
94 infos[5],
95 *(PolymorphicDowncast<const BatchNormalizationDescriptor*>
96 (&descriptor)),
97 reasonIfUnsupported);
98 case LayerType::BatchToSpaceNd:
99 return IsBatchToSpaceNdSupported(infos[0],
100 infos[1],
101 *(PolymorphicDowncast<const BatchToSpaceNdDescriptor*>(&descriptor)),
102 reasonIfUnsupported);
Idriss Chaouch98e383e2023-08-28 14:28:31 +0100103 case LayerType::BroadcastTo:
104 return IsBroadcastToSupported(infos[0],
105 infos[1],
106 *(PolymorphicDowncast<const BroadcastToDescriptor*>(&descriptor)),
107 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000108 case LayerType::Comparison:
109 return IsComparisonSupported(infos[0],
110 infos[1],
111 infos[2],
112 *(PolymorphicDowncast<const ComparisonDescriptor*>(&descriptor)),
113 reasonIfUnsupported);
114 case LayerType::Concat:
115 {
116 std::vector<const TensorInfo*> inputInfos;
117 for (uint32_t i = 0; i < (infos.size() - 1); i++)
118 {
119 inputInfos.push_back(&infos[i]);
120 }
121 return IsConcatSupported(inputInfos,
122 infos[infos.size() - 1],
123 *(PolymorphicDowncast<const OriginsDescriptor*>(&descriptor)),
124 reasonIfUnsupported);
125 }
126 case LayerType::Constant:
127 return IsConstantSupported(infos[0], reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000128 case LayerType::ConvertFp16ToFp32:
129 return IsConvertFp16ToFp32Supported(infos[0], infos[1], reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000130 case LayerType::ConvertFp32ToFp16:
131 return IsConvertFp32ToFp16Supported(infos[0], infos[1], reasonIfUnsupported);
132 case LayerType::Convolution2d:
133 {
134 if (infos.size() != 4)
135 {
136 throw InvalidArgumentException("Invalid number of Convolution2d TensorInfos. "
137 "TensorInfos should be of format: {input, output, weights, biases}.");
138 }
139
140 auto desc = *(PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor));
141 if (infos[3] == TensorInfo())
142 {
143 return IsConvolution2dSupported(infos[0],
144 infos[1],
145 desc,
146 infos[2],
147 EmptyOptional(),
148 reasonIfUnsupported);
149 }
150 else
151 {
152 return IsConvolution2dSupported(infos[0],
153 infos[1],
154 desc,
155 infos[2],
156 infos[3],
157 reasonIfUnsupported);
158 }
159 }
160 case LayerType::DepthToSpace:
161 return IsDepthToSpaceSupported(infos[0],
162 infos[1],
163 *(PolymorphicDowncast<const DepthToSpaceDescriptor*>(&descriptor)),
164 reasonIfUnsupported);
165 case LayerType::DepthwiseConvolution2d:
166 {
167 if (infos.size() != 4)
168 {
169 throw InvalidArgumentException("Invalid number of DepthwiseConvolution2d TensorInfos. "
170 "TensorInfos should be of format: {input, output, weights, biases}.");
171 }
172
173 auto desc = *(PolymorphicDowncast<const DepthwiseConvolution2dDescriptor*>(&descriptor));
174 if (infos[3] == TensorInfo())
175 {
176 return IsDepthwiseConvolutionSupported(infos[0],
177 infos[1],
178 desc,
179 infos[2],
180 EmptyOptional(),
181 reasonIfUnsupported);
182 }
183 else
184 {
185 return IsDepthwiseConvolutionSupported(infos[0],
186 infos[1],
187 desc,
188 infos[2],
189 infos[3],
190 reasonIfUnsupported);
191 }
192 }
193 case LayerType::Dequantize:
194 return IsDequantizeSupported(infos[0], infos[1], reasonIfUnsupported);
195 case LayerType::Division:
196 return IsDivisionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
Mike Kelly3ec30772023-03-08 13:47:17 +0000197 case LayerType::ElementwiseBinary:
198 {
199 std::array<DataType, 7> supportedTypes =
200 {
201 DataType::Float32,
202 DataType::Float16,
203 DataType::QAsymmS8,
204 DataType::QAsymmU8,
205 DataType::QSymmS16,
206 DataType::Signed32
207 };
208
209 bool supported = true;
210 supported &= CheckSupportRule(TypeAnyOf(infos[0], supportedTypes), reasonIfUnsupported,
211 "Reference elementwise unary: input type not supported");
212
213 supported &= CheckSupportRule(TypeAnyOf(infos[1], supportedTypes), reasonIfUnsupported,
214 "Reference elementwise unary: input type not supported");
215
216 supported &= CheckSupportRule(TypeAnyOf(infos[2], supportedTypes), reasonIfUnsupported,
217 "Reference elementwise unary: output type not supported");
218
219 supported &= CheckSupportRule(TypesAreEqual(infos[0], infos[1]), reasonIfUnsupported,
220 "Reference elementwise unary: input types not matching");
221
222 supported &= CheckSupportRule(TypesAreEqual(infos[0], infos[2]), reasonIfUnsupported,
223 "Reference elementwise unary: input and output types not matching");
224
225 return supported;
226 }
Cathal Corbett34b429c2021-12-24 12:24:40 +0000227 case LayerType::ElementwiseUnary:
228 return IsElementwiseUnarySupported(infos[0],
229 infos[1],
230 *(PolymorphicDowncast<const ElementwiseUnaryDescriptor*>(&descriptor)),
231 reasonIfUnsupported);
232 case LayerType::Fill:
233 return IsFillSupported(infos[0],
234 infos[1],
235 *(PolymorphicDowncast<const FillDescriptor*>(&descriptor)),
236 reasonIfUnsupported);
237 case LayerType::Floor:
238 return IsFloorSupported(infos[0], infos[1], reasonIfUnsupported);
239 case LayerType::FullyConnected:
240 return IsFullyConnectedSupported(infos[0],
241 infos[1],
242 infos[2],
243 infos[3],
244 *(PolymorphicDowncast<const FullyConnectedDescriptor*>(&descriptor)),
245 reasonIfUnsupported);
246 case LayerType::Gather:
247 return IsGatherSupported(infos[0],
248 infos[1],
249 infos[2],
250 *(PolymorphicDowncast<const GatherDescriptor*>(&descriptor)),
251 reasonIfUnsupported);
Teresa Charlinb2d3ec52022-04-12 22:07:09 +0100252 case LayerType::GatherNd:
253 return IsGatherNdSupported(infos[0],
254 infos[1],
255 infos[2],
256 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000257 case LayerType::Input:
258 return IsInputSupported(infos[0], reasonIfUnsupported);
259 case LayerType::InstanceNormalization:
260 return IsInstanceNormalizationSupported(infos[0],
261 infos[1],
262 *(PolymorphicDowncast<const InstanceNormalizationDescriptor*>
263 (&descriptor)),
264 reasonIfUnsupported);
265 case LayerType::L2Normalization:
266 return IsL2NormalizationSupported(infos[0],
267 infos[1],
268 *(PolymorphicDowncast<const L2NormalizationDescriptor*>(&descriptor)),
269 reasonIfUnsupported);
270 case LayerType::LogicalBinary:
271 return IsLogicalBinarySupported(infos[0],
272 infos[1],
273 infos[2],
274 *(PolymorphicDowncast<const LogicalBinaryDescriptor*>(&descriptor)),
275 reasonIfUnsupported);
276 case LayerType::LogSoftmax:
277 return IsLogSoftmaxSupported(infos[0],
278 infos[1],
279 *(PolymorphicDowncast<const LogSoftmaxDescriptor*>(&descriptor)),
280 reasonIfUnsupported);
281 case LayerType::Lstm:
282 return IsLstmSupported(infos[0],
283 infos[1],
284 infos[2],
285 infos[3],
286 infos[4],
287 infos[5],
288 infos[6],
289 *(PolymorphicDowncast<const LstmDescriptor*>(&descriptor)),
290 lstmParamsInfo.value(),
291 reasonIfUnsupported);
292 case LayerType::QLstm:
293 return IsQLstmSupported(infos[0],
294 infos[1],
295 infos[2],
296 infos[3],
297 infos[4],
298 infos[5],
299 *(PolymorphicDowncast<const QLstmDescriptor*>(&descriptor)),
300 lstmParamsInfo.value(),
301 reasonIfUnsupported);
302 case LayerType::Maximum:
303 return IsMaximumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
304 case LayerType::Mean:
305 return IsMeanSupported(infos[0],
306 infos[1],
307 *(PolymorphicDowncast<const MeanDescriptor*>(&descriptor)),
308 reasonIfUnsupported);
309 case LayerType::Minimum:
310 return IsMinimumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
311 case LayerType::Multiplication:
312 return IsMultiplicationSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
313 case LayerType::Normalization:
314 return IsNormalizationSupported(infos[0],
315 infos[1],
316 *(PolymorphicDowncast<const NormalizationDescriptor*>(&descriptor)),
317 reasonIfUnsupported);
318 case LayerType::Output:
319 return IsOutputSupported(infos[0], reasonIfUnsupported);
320 case LayerType::Pad:
321 return IsPadSupported(infos[0],
322 infos[1],
323 *(PolymorphicDowncast<const PadDescriptor*>(&descriptor)),
324 reasonIfUnsupported);
325 case LayerType::Permute:
326 return IsPermuteSupported(infos[0],
327 infos[1],
328 *(PolymorphicDowncast<const PermuteDescriptor*>(&descriptor)),
329 reasonIfUnsupported);
330 case LayerType::Pooling2d:
331 return IsPooling2dSupported(infos[0],
332 infos[1],
333 *(PolymorphicDowncast<const Pooling2dDescriptor*>(&descriptor)),
334 reasonIfUnsupported);
335 case LayerType::Prelu:
336 return IsPreluSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
337 case LayerType::Quantize:
338 return IsQuantizeSupported(infos[0], infos[1], reasonIfUnsupported);
339 case LayerType::Reshape:
340 return IsReshapeSupported(infos[0],
341 infos[1],
342 *(PolymorphicDowncast<const ReshapeDescriptor*>(&descriptor)),
343 reasonIfUnsupported);
344 case LayerType::Resize:
345 return IsResizeSupported(infos[0],
346 infos[1],
347 *(PolymorphicDowncast<const ResizeDescriptor*>(&descriptor)),
348 reasonIfUnsupported);
Tianle Cheng988354d2023-06-28 13:20:47 +0100349 case LayerType::ReverseV2:
350 return IsReverseV2Supported(infos[0],
351 infos[1],
Tracy Narinebb8d7592023-07-13 16:50:54 +0100352 infos[2],
Tianle Cheng988354d2023-06-28 13:20:47 +0100353 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000354 case LayerType::Reduce:
355 return IsReduceSupported(infos[0],
356 infos[1],
357 *(PolymorphicDowncast<const ReduceDescriptor*>(&descriptor)),
358 reasonIfUnsupported);
359 case LayerType::Slice:
360 return IsSliceSupported(infos[0],
361 infos[1],
362 *(PolymorphicDowncast<const SliceDescriptor*>(&descriptor)),
363 reasonIfUnsupported);
364 case LayerType::Softmax:
365 return IsSoftmaxSupported(infos[0],
366 infos[1],
367 *(PolymorphicDowncast<const SoftmaxDescriptor*>(&descriptor)),
368 reasonIfUnsupported);
369 case LayerType::SpaceToBatchNd:
370 return IsSpaceToBatchNdSupported(infos[0],
371 infos[1],
372 *(PolymorphicDowncast<const SpaceToBatchNdDescriptor*>(&descriptor)),
373 reasonIfUnsupported);
374 case LayerType::SpaceToDepth:
375 return IsSpaceToDepthSupported(infos[0],
376 infos[1],
377 *(PolymorphicDowncast<const SpaceToDepthDescriptor*>(&descriptor)),
378 reasonIfUnsupported);
379 case LayerType::Splitter:
380 {
381 std::vector<TensorInfo> outputInfos;
382 for (uint32_t i = 1; i < infos.size(); i++)
383 {
384 outputInfos.push_back(infos[i]);
385 }
386 return IsSplitterSupported(infos[0],
387 {outputInfos.begin(), outputInfos.end()},
388 *(PolymorphicDowncast<const ViewsDescriptor*>(&descriptor)),
389 reasonIfUnsupported);
390 }
391 case LayerType::Stack:
392 {
393 std::vector<const TensorInfo*> inputInfos;
394 for (uint32_t i = 0; i < infos.size() - 1; i++)
395 {
396 inputInfos.push_back(&infos[i]);
397 }
398 return IsStackSupported(inputInfos,
399 infos[infos.size() - 1],
400 *(PolymorphicDowncast<const StackDescriptor*>(&descriptor)),
401 reasonIfUnsupported);
402 }
403 case LayerType::StridedSlice:
404 return IsStridedSliceSupported(infos[0],
405 infos[1],
406 *(PolymorphicDowncast<const StridedSliceDescriptor*>(&descriptor)),
407 reasonIfUnsupported);
408 case LayerType::Subtraction:
409 return IsSubtractionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
Teresa Charlin79a06a52023-07-13 17:16:45 +0100410 case LayerType::Tile:
411 return IsTileSupported(infos[0],
412 infos[1],
413 *(PolymorphicDowncast<const TileDescriptor*>(&descriptor)),
414 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000415 case LayerType::Transpose:
416 return IsTransposeSupported(infos[0],
417 infos[1],
418 *(PolymorphicDowncast<const TransposeDescriptor*>(&descriptor)),
419 reasonIfUnsupported);
420 case LayerType::TransposeConvolution2d:
421 {
422 if (infos.size() != 4)
423 {
424 throw InvalidArgumentException("Invalid number of TransposeConvolution2d TensorInfos. "
425 "TensorInfos should be of format: {input, output, weights, biases}.");
426 }
427
428 auto desc = *(PolymorphicDowncast<const TransposeConvolution2dDescriptor*>(&descriptor));
429 if (infos[3] == TensorInfo())
430 {
431 return IsTransposeConvolution2dSupported(infos[0],
432 infos[1],
433 desc,
434 infos[2],
435 EmptyOptional(),
436 reasonIfUnsupported);
437 }
438 else
439 {
440 return IsTransposeConvolution2dSupported(infos[0],
441 infos[1],
442 desc,
443 infos[2],
444 infos[3],
445 reasonIfUnsupported);
446 }
447 }
448 case LayerType::Cast:
449 return IsCastSupported(infos[0], infos[1], reasonIfUnsupported);
450 case LayerType::ChannelShuffle:
451 return IsChannelShuffleSupported(infos[0],
452 infos[1],
453 *(PolymorphicDowncast<const ChannelShuffleDescriptor*>(&descriptor)),
454 reasonIfUnsupported);
455 case LayerType::Convolution3d:
456 {
457 if (infos.size() != 4)
458 {
459 throw InvalidArgumentException("Invalid number of Convolution3d TensorInfos. "
460 "TensorInfos should be of format: {input, output, weights, biases}.");
461 }
462
463 auto desc = *(PolymorphicDowncast<const Convolution3dDescriptor*>(&descriptor));
464 if (infos[3] == TensorInfo())
465 {
466 return IsConvolution3dSupported(infos[0],
467 infos[1],
468 desc,
469 infos[2],
470 EmptyOptional(),
471 reasonIfUnsupported);
472 }
473 else
474 {
475 return IsConvolution3dSupported(infos[0],
476 infos[1],
477 desc,
478 infos[2],
479 infos[3],
480 reasonIfUnsupported);
481 }
482 }
483 case LayerType::Debug:
484 return IsDebugSupported(infos[0], infos[1], reasonIfUnsupported);
485 case LayerType::DetectionPostProcess:
486 return IsDetectionPostProcessSupported(infos[0],
487 infos[1],
488 infos[2],
489 infos[3],
490 infos[4],
491 infos[5],
492 infos[6],
493 *(PolymorphicDowncast<const DetectionPostProcessDescriptor*>
494 (&descriptor)),
495 reasonIfUnsupported);
496 case LayerType::FakeQuantization:
497 return IsFakeQuantizationSupported(infos[0],
498 *(PolymorphicDowncast<const FakeQuantizationDescriptor*>(&descriptor)),
499 reasonIfUnsupported);
500 case LayerType::MemCopy:
501 return IsMemCopySupported(infos[0], infos[1], reasonIfUnsupported);
502 case LayerType::Rank:
503 return IsRankSupported(infos[0], infos[1], reasonIfUnsupported);
504 case LayerType::Shape:
505 return IsShapeSupported(infos[0], infos[1], reasonIfUnsupported);
506 case LayerType::UnidirectionalSequenceLstm:
507 {
508 if (infos.size() != 6)
509 {
510 throw InvalidArgumentException("Invalid number of UnidirectionalSequenceLstm TensorInfos. TensorInfos "
511 "should be of format: {input, outputStateIn, cellStateIn, "
512 "hiddenStateOutputVal, cellStateOutputVal, output}");
513 }
514 auto desc = *(PolymorphicDowncast<const UnidirectionalSequenceLstmDescriptor*>(&descriptor));
Mike Kelly12994962022-04-21 11:57:09 +0100515 return IsUnidirectionalSequenceLstmSupported(infos[0],
516 infos[1],
517 infos[2],
518 infos[3],
519 infos[4],
520 infos[5],
521 desc,
522 lstmParamsInfo.value(),
523 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000524 }
525 case LayerType::Pooling3d:
526 return IsPooling3dSupported(infos[0],
527 infos[1],
528 *(PolymorphicDowncast<const Pooling3dDescriptor*>(&descriptor)),
529 reasonIfUnsupported);
530 case LayerType::Map:
531 return true;
532 case LayerType::Unmap:
533 return true;
534 case LayerType::MemImport:
535 return LayerSupportBase::IsMemImportSupported(infos[0], infos[1], reasonIfUnsupported);
536 case LayerType::Merge:
537 return LayerSupportBase::IsMergeSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
538 case LayerType::QuantizedLstm:
539 return LayerSupportBase::IsQuantizedLstmSupported(infos[0],
540 infos[1],
541 infos[2],
542 infos[3],
543 infos[4],
544 quantizedLstmInputParamsInfo.value(),
545 reasonIfUnsupported);
546 default:
Teresa Charlin9145e382023-08-17 18:44:58 +0100547 // layers not supported in reference by default:
548 // precompiled, standin, switch, fused
Cathal Corbett34b429c2021-12-24 12:24:40 +0000549 return false;
550 }
551}
552
arovir011c7c81b2018-10-08 11:34:28 +0100553bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
554 const TensorInfo& output,
555 const ActivationDescriptor& descriptor,
556 Optional<std::string&> reasonIfUnsupported) const
557{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000558 bool supported = true;
559
560 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +0000561 std::array<DataType,6> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000562 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100563 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000564 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000565 DataType::QAsymmU8,
566 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000567 };
568
569 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
570 "Reference activation: input type not supported.");
571
572 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
573 "Reference activation: output type not supported.");
574
575 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
576 "Reference activation: input and output types mismatched.");
577
578 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
579 "Reference activation: input and output shapes are of different rank.");
580
581
582 struct ActivationFunctionSupported : public Rule
583 {
584 ActivationFunctionSupported(const ActivationDescriptor& desc)
585 {
586 switch(desc.m_Function)
587 {
588 case ActivationFunction::Abs:
589 case ActivationFunction::BoundedReLu:
David Monahan3b3c3812020-02-25 09:03:29 +0000590 case ActivationFunction::Elu:
Teresa Charlin077cddb2023-09-15 15:19:21 +0100591 case ActivationFunction::Gelu:
Colm Donelan03fbeaf2020-02-26 15:39:23 +0000592 case ActivationFunction::HardSwish:
Derek Lamberti50db4e82019-03-13 14:16:15 +0000593 case ActivationFunction::LeakyReLu:
594 case ActivationFunction::Linear:
595 case ActivationFunction::ReLu:
596 case ActivationFunction::Sigmoid:
597 case ActivationFunction::SoftReLu:
598 case ActivationFunction::Sqrt:
599 case ActivationFunction::Square:
600 case ActivationFunction::TanH:
601 {
602 m_Res = true;
603 break;
604 }
605 default:
606 {
607 m_Res = false;
608 break;
609 }
610 }
611 }
612 };
613
614 // Function is supported
615 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
616 "Reference activation: function not supported.");
617
618 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100619}
620
621bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
622 const TensorInfo& input1,
623 const TensorInfo& output,
624 Optional<std::string&> reasonIfUnsupported) const
625{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000626 bool supported = true;
627
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100628 std::array<DataType,7> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000629 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100630 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000631 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000632 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100633 DataType::QSymmS16,
634 DataType::Signed32
Derek Lamberti50db4e82019-03-13 14:16:15 +0000635 };
636
637 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
638 "Reference addition: input 0 is not a supported type.");
639
640 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
641 "Reference addition: input 1 is not a supported type.");
642
643 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
644 "Reference addition: output is not a supported type.");
645
646 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
647 "Reference addition: input 0 and Input 1 types are mismatched");
648
649 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
650 "Reference addition: input and output types are mismatched");
651
652 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
653 "Reference addition: shapes are not suitable for implicit broadcast.");
654
655 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100656}
657
Nikhil Raj68c2c902019-09-19 11:21:11 +0100658bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
659 const armnn::ArgMinMaxDescriptor &descriptor,
660 armnn::Optional<std::string &> reasonIfUnsupported) const
661{
Jan Eilers8eb25602020-03-09 12:13:48 +0000662 IgnoreUnused(descriptor);
Nikhil Raj68c2c902019-09-19 11:21:11 +0100663
Mike Kelly1f140f72021-04-06 12:25:55 +0100664 std::array<DataType, 8> supportedInputTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100665 {
Teresa Charline300b362020-05-25 10:01:03 +0100666 DataType::Float16,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100667 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100668 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000669 DataType::QAsymmU8,
670 DataType::QSymmS16,
Mike Kelly1f140f72021-04-06 12:25:55 +0100671 DataType::Signed32,
672 DataType::Signed64
673 };
674
675 std::array<DataType,2> supportedOutputTypes = {
676 DataType::Signed32,
677 DataType::Signed64
Nikhil Raj68c2c902019-09-19 11:21:11 +0100678 };
679
680 bool supported = true;
681
Mike Kelly1f140f72021-04-06 12:25:55 +0100682 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100683 "Reference ArgMinMax: input is not a supported type.");
Mike Kelly1f140f72021-04-06 12:25:55 +0100684 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100685 "Reference ArgMinMax: output type not supported");
686
687 return supported;
688}
689
Samuel Yap6b478092022-07-06 15:36:03 +0100690bool RefLayerSupport::IsBatchMatMulSupported(const TensorInfo& inputX,
691 const TensorInfo& inputY,
692 const TensorInfo& output,
693 const BatchMatMulDescriptor& descriptor,
694 Optional<std::string &> reasonIfUnsupported) const
695{
696 IgnoreUnused(descriptor);
697
698 std::array<DataType, 6> supportedTypes =
699 {
Samuel Yap6b478092022-07-06 15:36:03 +0100700 DataType::Float16,
701 DataType::Float32,
702 DataType::QAsymmS8,
703 DataType::QAsymmU8,
704 DataType::QSymmS16
705 };
706
707 bool supported = true;
708
709 supported &= CheckSupportRule(TypeAnyOf(inputX, supportedTypes), reasonIfUnsupported,
710 "Reference batch matrix multiplication: input X is not a supported type");
711
712 supported &= CheckSupportRule(TypeAnyOf(inputY, supportedTypes), reasonIfUnsupported,
713 "Reference batch matrix multiplication: input Y is not a supported type");
714
715 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
716 "Reference batch matrix multiplication: output is not a supported type");
717
718 supported &= CheckSupportRule(TypesAreEqual(inputX, inputY), reasonIfUnsupported,
719 "Reference batch matrix multiplication: input X and input Y types are mismatched");
720
721 supported &= CheckSupportRule(TypesAreEqual(inputX, output), reasonIfUnsupported,
722 "Reference batch matrix multiplication: inputs and output types are mismatched");
723
724 supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputX, 2),
725 reasonIfUnsupported,
726 "Reference batch matrix multiplication: input X is not of rank 2 or greater");
727
728 supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputY, 2),
729 reasonIfUnsupported,
730 "Reference batch matrix multiplication: input Y is not of rank 2 or greater");
731
732 return supported;
733}
734
arovir011c7c81b2018-10-08 11:34:28 +0100735bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
736 const TensorInfo& output,
737 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100738 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100739 const TensorInfo& beta,
740 const TensorInfo& gamma,
741 const BatchNormalizationDescriptor& descriptor,
742 Optional<std::string&> reasonIfUnsupported) const
743{
Jan Eilers8eb25602020-03-09 12:13:48 +0000744 IgnoreUnused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100745
Sadik Armagan303980c2020-04-17 12:45:14 +0100746 std::array<DataType, 6> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100747 {
748 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100749 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100750 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000751 DataType::QAsymmU8,
752 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100753 };
754
755 bool supported = true;
756
757 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
758 "Reference batch normalization: input is not a supported type.");
759
760 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
761 "Reference batch normalization: output is not a supported type.");
762
763 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
764 "Reference batch normalization: input and output types are mismatched");
765
766 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
767 "Reference batch normalization: mean is not a supported type.");
768
769 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
770 "Reference batch normalization: variance is not a supported type.");
771
772 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
773 "Reference batch normalization: beta is not a supported type.");
774
775 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
776 "Reference batch normalization: gamma is not a supported type.");
777
778 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100779}
780
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000781bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
782 const TensorInfo& output,
783 const BatchToSpaceNdDescriptor& descriptor,
784 Optional<std::string&> reasonIfUnsupported) const
785{
Jan Eilers8eb25602020-03-09 12:13:48 +0000786 IgnoreUnused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100787
788 bool supported = true;
789
790 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
791 std::string inputTensorStr = "input";
792 std::string outputTensorStr = "output";
793
794 // Define supported types.
Sadik Armagan303980c2020-04-17 12:45:14 +0100795 std::array<DataType,6> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100796 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000797 DataType::Float32,
798 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100799 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000800 DataType::QAsymmU8,
801 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100802 };
803
804 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
805 "Reference BatchToSpaceNd: input type not supported.");
806
807 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
808 "Reference BatchToSpaceNd: output type not supported.");
809
810 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
811 "Reference BatchToSpaceNd: input and output types mismatched.");
812
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100813 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000814}
815
Idriss Chaouch98e383e2023-08-28 14:28:31 +0100816bool RefLayerSupport::IsBroadcastToSupported(const TensorInfo& input,
817 const TensorInfo& output,
818 const BroadcastToDescriptor& descriptor,
819 Optional<std::string&> reasonIfUnsupported) const
820{
821 IgnoreUnused(descriptor);
822
823 bool supported = true;
824
825 std::array<DataType, 8> supportedTypes
826 {
827 DataType::Float32,
828 DataType::Float16,
829 DataType::QAsymmS8,
830 DataType::QAsymmU8,
831 DataType::QSymmS8,
832 DataType::QSymmS16,
833 DataType::Signed32,
834 DataType::Signed64
835 };
836
837 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
838 "BroadcastTo: input type not supported.");
839
840 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
841 "BroadcastTo: output type not supported");
842
843 return supported;
844}
845
mathad01b392e982021-04-07 12:07:30 +0100846bool RefLayerSupport::IsCastSupported(const TensorInfo& input,
847 const TensorInfo& output,
848 Optional<std::string&> reasonIfUnsupported) const
849{
Teresa Charlin5306dc82023-10-30 22:29:58 +0000850 std::array<DataType, 10> supportedInputTypes =
Idriss Chaouch98e383e2023-08-28 14:28:31 +0100851 {
852 DataType::Float32,
853 DataType::Float16,
854 DataType::QSymmS8,
855 DataType::QAsymmS8,
856 DataType::QAsymmU8,
857 DataType::QSymmS16,
Teresa Charlin5306dc82023-10-30 22:29:58 +0000858 DataType::Signed32,
859 DataType::Signed64
Idriss Chaouch98e383e2023-08-28 14:28:31 +0100860 };
mathad01b392e982021-04-07 12:07:30 +0100861
862 bool supported = true;
863 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
864 "Reference cast: input is not a supported type");
865
866
867 supported &= CheckSupportRule(TypeAnyOf(output, supportedInputTypes), reasonIfUnsupported,
868 "Reference cast: output is not a supported type");
869
870 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
871 "Reference cast: input and output shapes have different number of total elements");
872
873 return supported;
874}
875
Simon Obute51f67772021-09-03 15:50:13 +0100876bool RefLayerSupport::IsChannelShuffleSupported(const TensorInfo& input,
877 const TensorInfo& output,
878 const ChannelShuffleDescriptor& descriptor,
879 Optional<std::string&> reasonIfUnsupported) const
880{
881 IgnoreUnused(descriptor);
882 bool supported = true;
883
884 // Define supported output and inputs types.
885 std::array<DataType, 7> supportedTypes =
886 {
Simon Obute51f67772021-09-03 15:50:13 +0100887 DataType::Float32,
888 DataType::Float16,
889 DataType::QAsymmS8,
890 DataType::QAsymmU8,
891 DataType::QSymmS8,
892 DataType::QSymmS16
893 };
894
895 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
896 "Reference ChannelShuffle: input is not a supported type.");
897
898 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
899 "Reference ChannelShuffle: output is not a supported type.");
900
901 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
902 "Reference ChannelShuffle: input and output types are mismatched.");
903
904 return supported;
905}
906
907
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100908bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
909 const TensorInfo& input1,
910 const TensorInfo& output,
911 const ComparisonDescriptor& descriptor,
912 Optional<std::string&> reasonIfUnsupported) const
913{
Jan Eilers8eb25602020-03-09 12:13:48 +0000914 IgnoreUnused(descriptor);
Sadik Armagan303980c2020-04-17 12:45:14 +0100915 std::array<DataType, 8> supportedInputTypes =
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100916 {
Sadik Armaganb60dd242020-03-19 13:53:16 +0000917 DataType::Boolean,
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100918 DataType::Float32,
919 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100920 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000921 DataType::QAsymmU8,
Sadik Armaganb60dd242020-03-19 13:53:16 +0000922 DataType::QSymmS16,
923 DataType::Signed32
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100924 };
925
926 bool supported = true;
927 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
928 "Reference comparison: input 0 is not a supported type");
929
930 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
931 "Reference comparison: input 0 and Input 1 types are mismatched");
932
933 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
934 "Reference comparison: output is not of type Boolean");
935
936 return supported;
937}
938
Jim Flynn906f9462019-05-10 13:55:21 +0100939bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
940 const TensorInfo& output,
Cathal Corbett34b429c2021-12-24 12:24:40 +0000941 const OriginsDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100942 Optional<std::string&> reasonIfUnsupported) const
943{
Jan Eilers8eb25602020-03-09 12:13:48 +0000944 IgnoreUnused(descriptor);
Jim Flynne242f2d2019-05-22 14:24:13 +0100945
946 bool supported = true;
Teresa Charlin6abc7ee2022-02-22 17:32:27 +0000947 std::array<DataType,7> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100948 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000949 DataType::Float32,
950 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000951 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100952 DataType::QAsymmU8,
Teresa Charlin6abc7ee2022-02-22 17:32:27 +0000953 DataType::QSymmS16,
954 DataType::Signed32
Jim Flynne242f2d2019-05-22 14:24:13 +0100955 };
956
957 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
958 "Reference concatenation: output type not supported");
959 for (const TensorInfo* input : inputs)
960 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100961 ARMNN_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100962 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
963 "Reference concatenation: input type not supported");
964
965 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
966 "Reference concatenation: input and output types mismatched.");
967 }
968
969 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100970}
971
arovir011c7c81b2018-10-08 11:34:28 +0100972bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
973 Optional<std::string&> reasonIfUnsupported) const
974{
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100975 std::array<DataType,8> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100976 {
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100977 DataType::Float16,
Nina Drozd58ef2c62019-05-16 12:09:18 +0100978 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +0000979 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100980 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000981 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100982 DataType::QSymmS16,
983 DataType::Signed32
Nina Drozd58ef2c62019-05-16 12:09:18 +0100984 };
985
986 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
987 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100988}
989
990bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
991 const TensorInfo& output,
992 Optional<std::string&> reasonIfUnsupported) const
993{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100994 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
995 input.GetDataType(),
996 &TrueFunc<>,
997 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000998 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000999 &FalseFuncI32<>,
1000 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001001 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
1002 output.GetDataType(),
1003 &FalseOutputFuncF16<>,
1004 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +00001005 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +00001006 &FalseFuncI32<>,
1007 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +01001008}
1009
1010bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
1011 const TensorInfo& output,
1012 Optional<std::string&> reasonIfUnsupported) const
1013{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001014 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
1015 input.GetDataType(),
1016 &FalseInputFuncF16<>,
1017 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +00001018 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +00001019 &FalseFuncI32<>,
1020 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001021 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
1022 output.GetDataType(),
1023 &TrueFunc<>,
1024 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +00001025 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +00001026 &FalseFuncI32<>,
1027 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +01001028}
1029
1030bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
1031 const TensorInfo& output,
1032 const Convolution2dDescriptor& descriptor,
1033 const TensorInfo& weights,
1034 const Optional<TensorInfo>& biases,
1035 Optional<std::string&> reasonIfUnsupported) const
1036{
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001037 bool supported = true;
1038
1039 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001040 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001041 {
1042 DataType::Float32,
1043 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001044 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001045 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001046 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001047 DataType::QSymmS16
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001048 };
1049
1050 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001051 "Reference Convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001052
1053 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001054 "Reference Convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001055
Ryan OShea31441592022-11-07 16:20:48 +00001056 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1057 "Reference Convolution2d: input and output types mismatched.");
1058
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001059
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001060 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +00001061 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001062 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001063 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001064 {
Sadik Armagan303980c2020-04-17 12:45:14 +01001065 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001066 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01001067 DataType::QSymmS8
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001068 };
1069
1070 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001071 "Reference Convolution2d: weights type not supported for quantized input.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001072 }
1073 else
1074 {
1075 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001076 "Reference Convolution2d: weights is not a supported type.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001077
1078 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001079 "Reference Convolution2d: input and weights types mismatched.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001080 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001081
1082 if (biases.has_value())
1083 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001084 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001085 {
1086 DataType::Float32,
1087 DataType::Float16,
1088 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001089 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001090
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001091 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001092 "Reference Convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001093 }
Jan Eilers8eb25602020-03-09 12:13:48 +00001094 IgnoreUnused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001095
1096 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001097}
1098
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001099bool RefLayerSupport::IsConvolution3dSupported(const TensorInfo& input,
1100 const TensorInfo& output,
1101 const Convolution3dDescriptor& descriptor,
1102 const TensorInfo& weights,
1103 const Optional<TensorInfo>& biases,
1104 Optional<std::string&> reasonIfUnsupported) const
1105{
1106 bool supported = true;
1107
1108 // Define supported types.
1109 std::array<DataType,7> supportedTypes =
1110 {
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001111 DataType::Float32,
1112 DataType::Float16,
1113 DataType::QAsymmS8,
1114 DataType::QAsymmU8,
1115 DataType::QSymmS8,
1116 DataType::QSymmS16
1117 };
1118
1119 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1120 "Reference Convolution3d: input is not a supported type.");
1121
1122 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1123 "Reference Convolution3d: output is not a supported type.");
1124
1125 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1126 "Reference Convolution3d: input and output types mismatched.");
1127
1128 const DataType inputType = input.GetDataType();
1129 if (IsQuantized8BitType(inputType))
1130 {
1131 std::array<DataType, 3> supportedWeightTypes =
1132 {
1133 DataType::QAsymmS8,
1134 DataType::QAsymmU8,
1135 DataType::QSymmS8
1136 };
1137
1138 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1139 "Reference Convolution3d: weights type not supported for quantized input.");
1140 }
1141 else
1142 {
1143 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1144 "Reference Convolution3d: weights is not a supported type.");
1145
1146 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1147 "Reference Convolution3d: input and weights types mismatched.");
1148 }
1149
1150 if (biases.has_value())
1151 {
1152 std::array<DataType,4> biasesSupportedTypes =
1153 {
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001154 DataType::Float32,
1155 DataType::Float16,
1156 DataType::Signed32
1157 };
1158
1159 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1160 "Reference Convolution3d: biases is not a supported type.");
1161 }
1162 IgnoreUnused(descriptor);
1163
1164 return supported;
1165}
1166
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001167bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
1168 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001169 Optional<std::string&> reasonIfUnsupported) const
1170{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001171 bool supported = true;
1172
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001173 std::array<DataType, 8> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001174 {
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001175 DataType::BFloat16,
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +00001176 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001177 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001178 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001179 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001180 DataType::QSymmS8,
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +00001181 DataType::QSymmS16,
1182 DataType::Signed32
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001183 };
1184
1185 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001186 "Reference for Debug layer: input type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001187
1188 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001189 "Reference for Debug layer: output type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001190
1191 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001192 "Reference for Debug layer: input and output types are mismatched");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001193
1194 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001195}
1196
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001197bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
1198 const TensorInfo& output,
1199 const DepthToSpaceDescriptor& descriptor,
1200 Optional<std::string&> reasonIfUnsupported) const
1201{
Jan Eilers8eb25602020-03-09 12:13:48 +00001202 IgnoreUnused(descriptor);
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001203 bool supported = true;
1204
Sadik Armagan303980c2020-04-17 12:45:14 +01001205 std::array<DataType,6> supportedTypes =
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001206 {
1207 DataType::Float32,
1208 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001209 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001210 DataType::QAsymmU8,
1211 DataType::QSymmS16
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001212 };
1213
1214 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1215 "Reference DepthToSpace: input type not supported");
1216
1217 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1218 "Reference DepthToSpace: output type not supported");
1219
1220 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1221 "Reference DepthToSpace: input and output types are mismatched");
1222
1223 return supported;
1224}
1225
arovir011c7c81b2018-10-08 11:34:28 +01001226bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
1227 const TensorInfo& output,
1228 const DepthwiseConvolution2dDescriptor& descriptor,
1229 const TensorInfo& weights,
1230 const Optional<TensorInfo>& biases,
1231 Optional<std::string&> reasonIfUnsupported) const
1232{
Sadik Armagan303980c2020-04-17 12:45:14 +01001233 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001234 bool supported = true;
1235
1236 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001237 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001238 {
1239 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001240 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001241 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001242 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001243 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001244 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001245 };
1246
1247 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1248 "Reference DepthwiseConvolution2d: input is not a supported type.");
1249
1250 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1251 "Reference DepthwiseConvolution2d: output is not a supported type.");
1252
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001253 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1254 "Reference DepthwiseConvolution2d: input and output types mismatched.");
1255
Teresa Charlind8df0262019-11-11 12:28:15 +00001256 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +00001257 if (IsQuantized8BitType(inputType))
Teresa Charlind8df0262019-11-11 12:28:15 +00001258 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001259 std::array<DataType, 3> supportedWeightTypes =
Sadik Armagan303980c2020-04-17 12:45:14 +01001260 {
1261 DataType::QAsymmS8,
1262 DataType::QAsymmU8,
1263 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001264 };
Teresa Charlind8df0262019-11-11 12:28:15 +00001265
1266 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Sadik Armagan303980c2020-04-17 12:45:14 +01001267 "Reference DepthwiseConvolution2d: weights type not supported for "
1268 "quantized input.");
Teresa Charlind8df0262019-11-11 12:28:15 +00001269 }
1270 else
1271 {
1272 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1273 "Reference DepthwiseConvolution2d: weights is not a supported type.");
1274
1275 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1276 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
1277 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001278
1279 if (biases.has_value())
1280 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001281 std::array<DataType,4> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001282 {
1283 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001284 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001285 DataType::Signed32
1286 };
1287 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1288 "Reference DepthwiseConvolution2d: biases is not a supported type.");
1289 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001290
1291 return supported;
1292
arovir011c7c81b2018-10-08 11:34:28 +01001293}
1294
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001295bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
1296 const TensorInfo& output,
1297 Optional<std::string&> reasonIfUnsupported) const
1298{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001299 bool supported = true;
1300
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001301 std::array<DataType,5> supportedInputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00001302 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001303 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00001304 DataType::QSymmS8,
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001305 DataType::QSymmS16,
1306 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001307 };
1308
1309 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001310 "Reference for Dequantize layer: input type not supported.");
1311
Derek Lambertid466a542020-01-22 15:37:29 +00001312 supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
Teresa Charlin1b1950d2021-06-02 20:23:21 +01001313 "Reference for Dequantize layer: per-axis quantized input not supported.");
Derek Lambertid466a542020-01-22 15:37:29 +00001314
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001315 std::array<DataType,3> supportedOutputTypes = {
Jan Eilersf7107932019-11-01 11:09:36 +00001316 DataType::Float32,
1317 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001318 };
1319
1320 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001321 "Reference for Dequantize layer: output type not supported.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001322
1323 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001324 "Reference for Dequantize layer: input/output shapes have different num total "
1325 "elements.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001326
1327 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001328}
1329
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001330bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
1331 const TensorInfo& scores,
1332 const TensorInfo& anchors,
1333 const TensorInfo& detectionBoxes,
1334 const TensorInfo& detectionClasses,
1335 const TensorInfo& detectionScores,
1336 const TensorInfo& numDetections,
1337 const DetectionPostProcessDescriptor& descriptor,
1338 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001339{
Jan Eilers8eb25602020-03-09 12:13:48 +00001340 IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
Derek Lamberti901ea112019-12-10 22:07:09 +00001341
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001342 bool supported = true;
1343
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001344 std::array<DataType,6> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001345 {
1346 DataType::Float32,
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001347 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001348 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001349 DataType::QAsymmU8,
1350 DataType::QSymmS16
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001351 };
1352
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001353 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001354 "Reference DetectionPostProcess: input 0 is not a supported type.");
1355
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001356 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001357 "Reference DetectionPostProcess: input 1 is not a supported type.");
1358
1359 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001360}
1361
Pablo Tellof0bd6832019-04-26 17:58:13 +01001362bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
1363 const TensorInfo& output,
1364 const DepthwiseConvolution2dDescriptor& descriptor,
1365 const TensorInfo& weights,
1366 const Optional<TensorInfo>& biases,
1367 Optional<std::string&> reasonIfUnsupported) const
1368{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001369 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +01001370}
1371
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001372bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +01001373 const TensorInfo& input1,
1374 const TensorInfo& output,
1375 Optional<std::string&> reasonIfUnsupported) const
1376{
Sadik Armagan2999a022019-04-09 14:20:12 +01001377 bool supported = true;
1378
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001379 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001380 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001381 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001382 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001383 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001384 DataType::QSymmS16,
1385 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001386 };
1387
1388 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1389 "Reference division: input 0 is not a supported type.");
1390
1391 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1392 "Reference division: input 1 is not a supported type.");
1393
1394 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1395 "Reference division: output is not a supported type.");
1396
1397 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1398 "Reference division: input 0 and Input 1 types are mismatched");
1399
1400 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1401 "Reference division: input and output types are mismatched");
1402
1403 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1404 "Reference division: shapes are not suitable for implicit broadcast.");
1405
1406 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001407}
1408
josh minor4a3c6102020-01-06 16:40:46 -06001409bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
1410 const TensorInfo& output,
1411 const ElementwiseUnaryDescriptor& descriptor,
1412 Optional<std::string&> reasonIfUnsupported) const
1413{
Jan Eilers8eb25602020-03-09 12:13:48 +00001414 IgnoreUnused(descriptor);
josh minor4a3c6102020-01-06 16:40:46 -06001415
Sadik Armagan303980c2020-04-17 12:45:14 +01001416 std::array<DataType, 7> supportedTypes =
josh minor4a3c6102020-01-06 16:40:46 -06001417 {
1418 DataType::Float32,
1419 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001420 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06001421 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00001422 DataType::QSymmS16,
1423 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06001424 };
1425
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001426 std::array<DataType, 1> logicalSupportedTypes =
1427 {
1428 DataType::Boolean
1429 };
1430
josh minor4a3c6102020-01-06 16:40:46 -06001431 bool supported = true;
1432
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001433 if (descriptor.m_Operation == UnaryOperation::LogicalNot)
1434 {
1435 supported &= CheckSupportRule(TypeAnyOf(input, logicalSupportedTypes), reasonIfUnsupported,
1436 "Reference elementwise unary: input type not supported");
josh minor4a3c6102020-01-06 16:40:46 -06001437
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001438 supported &= CheckSupportRule(TypeAnyOf(output, logicalSupportedTypes), reasonIfUnsupported,
1439 "Reference elementwise unary: output type not supported");
1440 }
1441 else
1442 {
1443 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1444 "Reference elementwise unary: input type not supported");
1445
1446 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1447 "Reference elementwise unary: output type not supported");
1448 }
josh minor4a3c6102020-01-06 16:40:46 -06001449
1450 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1451 "Reference elementwise unary: input and output types not matching");
1452
1453 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1454 "Reference elementwise unary: input and output shapes"
1455 "have different number of total elements");
1456
1457 return supported;
1458}
1459
arovir011c7c81b2018-10-08 11:34:28 +01001460bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
1461 const FakeQuantizationDescriptor& descriptor,
1462 Optional<std::string&> reasonIfUnsupported) const
1463{
Jan Eilers8eb25602020-03-09 12:13:48 +00001464 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001465 bool supported = true;
1466
1467 std::array<DataType,1> supportedTypes =
1468 {
1469 DataType::Float32
1470 };
1471
1472 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1473 "Reference fake quantization: input type not supported.");
1474
1475 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001476}
1477
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001478bool RefLayerSupport::IsFillSupported(const TensorInfo& input,
1479 const TensorInfo& output,
1480 const FillDescriptor& descriptor,
1481 Optional<std::string&> reasonIfUnsupported) const
1482{
1483 IgnoreUnused(descriptor);
1484 IgnoreUnused(output);
1485
1486 bool supported = true;
1487
Sadik Armagana792a052020-06-23 16:22:23 +01001488 std::array<DataType,3> supportedTypes =
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001489 {
1490 DataType::Float32,
Sadik Armagana792a052020-06-23 16:22:23 +01001491 DataType::Float16,
1492 DataType::Signed32
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001493 };
1494
Teresa Charlin4b10fef2020-07-29 09:36:41 +01001495 supported &= CheckSupportRule(TypeIs(input, DataType::Signed32), reasonIfUnsupported,
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001496 "Reference Fill: input type not supported.");
1497
Teresa Charlin44088502020-07-27 11:27:19 +01001498 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1499 "Reference Fill: output type not supported.");
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001500 return supported;
1501}
1502
arovir011c7c81b2018-10-08 11:34:28 +01001503bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
1504 const TensorInfo& output,
1505 Optional<std::string&> reasonIfUnsupported) const
1506{
Jan Eilers8eb25602020-03-09 12:13:48 +00001507 IgnoreUnused(output);
James Conroy83735b12019-05-30 16:36:59 +01001508 bool supported = true;
1509
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01001510 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +01001511 {
James Conroyb40d7102019-06-04 12:32:09 +01001512 DataType::Float32,
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01001513 DataType::Float16
James Conroy83735b12019-05-30 16:36:59 +01001514 };
1515
1516 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1517 "Reference Floor: input type not supported.");
1518
1519 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1520 "Reference Floor: output type not supported.");
1521
1522 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001523}
1524
1525bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
1526 const TensorInfo& output,
1527 const TensorInfo& weights,
1528 const TensorInfo& biases,
1529 const FullyConnectedDescriptor& descriptor,
1530 Optional<std::string&> reasonIfUnsupported) const
1531{
Francis Murtagh46c09d02019-05-28 08:15:28 +01001532 bool supported = true;
1533
1534 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001535 std::array<DataType,6> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +01001536 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001537 DataType::Float32,
1538 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001539 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001540 DataType::QAsymmU8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001541 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001542 };
1543
1544 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1545 "Reference Fully Connected: input type not supported.");
1546
1547 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1548 "Reference Fully Connected: output type not supported.");
1549
Francis Murtagh46c09d02019-05-28 08:15:28 +01001550 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1551 "Reference Fully Connected: weights type not supported.");
1552
Ryan OShea31441592022-11-07 16:20:48 +00001553 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1554 "Reference Fully Connected: input and output types mismatched.");
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001555
Jan Eilers1f45dc32020-06-15 11:43:03 +01001556 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1557 "Reference Fully Connected: weights is not a supported type.");
Francis Murtaghddb1d062020-03-10 13:51:45 +00001558
Jan Eilers1f45dc32020-06-15 11:43:03 +01001559 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1560 "Reference Fully Connected: input and weights types mismatched.");
Francis Murtagh46c09d02019-05-28 08:15:28 +01001561
1562 if (descriptor.m_BiasEnabled)
1563 {
1564 // Defined supported types for bias
Sadik Armagandb73c982020-04-01 17:35:30 +01001565 std::array<DataType, 5>
Francis Murtagh46c09d02019-05-28 08:15:28 +01001566 supportedBiasTypes =
1567 {
1568 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001569 DataType::Float16,
Sadik Armagandb73c982020-04-01 17:35:30 +01001570 DataType::Signed32,
1571 DataType::QAsymmS8
Francis Murtagh46c09d02019-05-28 08:15:28 +01001572 };
1573
1574 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
1575 "Reference Fully Connected: bias type not supported.");
1576
1577 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
1578 "Reference Fully Connected: bias and weight types mismatch.");
1579
1580 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
1581 "Reference Fully Connected: bias type inferred from weights is incompatible.");
1582
Narumol Prangnawarat366d7232020-04-29 12:58:17 +01001583 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(biases, 1U), reasonIfUnsupported,
1584 "Reference Fully Connected: bias must have 1 dimension.");
1585
Francis Murtagh46c09d02019-05-28 08:15:28 +01001586 }
1587
1588 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001589}
1590
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01001591bool RefLayerSupport::IsGatherNdSupported(const armnn::TensorInfo& input0,
1592 const armnn::TensorInfo& input1,
1593 const armnn::TensorInfo& output,
1594 armnn::Optional<std::string&> reasonIfUnsupported) const
1595{
1596 bool supported = true;
1597 std::array<DataType,7> supportedTypes =
1598 {
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01001599 DataType::Float32,
1600 DataType::Float16,
1601 DataType::QAsymmS8,
1602 DataType::QAsymmU8,
1603 DataType::QSymmS16,
1604 DataType::Signed32
1605 };
1606
1607 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1608 "Reference GatherNd: input type not supported");
1609
1610 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1611 "Reference GatherNd: output type not supported");
1612
1613 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1614 "Reference GatherNd: indices (input1) type not supported");
1615
1616 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1617 "Reference GatherNd: input and output types not matching");
1618
1619 return supported;
1620}
1621
narpra014951d842019-01-18 16:53:53 +00001622bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
1623 const armnn::TensorInfo& input1,
1624 const armnn::TensorInfo& output,
Teresa Charlin52664732020-06-29 16:27:03 +01001625 const GatherDescriptor& descriptor,
narpra014951d842019-01-18 16:53:53 +00001626 armnn::Optional<std::string&> reasonIfUnsupported) const
1627{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001628 bool supported = true;
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001629 std::array<DataType,7> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001630 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001631 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001632 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001633 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001634 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001635 DataType::QSymmS16,
1636 DataType::Signed32
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001637 };
1638
Nikhil Raj369d8fc2022-11-24 13:12:36 +00001639 IgnoreUnused(descriptor);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001640 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1641 "Reference Gather: input type not supported");
1642
1643 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1644 "Reference Gather: output type not supported");
1645
1646 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1647 "Reference Gather: indices (input1) type not supported");
1648
1649 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1650 "Reference Gather: input and output types not matching");
1651
1652 return supported;
narpra014951d842019-01-18 16:53:53 +00001653}
1654
Derek Lamberti901ea112019-12-10 22:07:09 +00001655bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
1656 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +01001657{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001658 return true;
arovir011c7c81b2018-10-08 11:34:28 +01001659}
1660
Kevin May09ca49c2019-10-09 12:37:34 +01001661bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
1662 const TensorInfo& output,
1663 const InstanceNormalizationDescriptor& descriptor,
1664 Optional<std::string&> reasonIfUnsupported) const
1665{
Jan Eilers8eb25602020-03-09 12:13:48 +00001666 IgnoreUnused(descriptor);
Kevin May09ca49c2019-10-09 12:37:34 +01001667 // Define supported types
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001668 std::array<DataType, 3> supportedTypes =
Kevin May09ca49c2019-10-09 12:37:34 +01001669 {
1670 DataType::Float32,
1671 DataType::Float16
1672 };
1673
1674 bool supported = true;
1675
1676 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1677 "Reference Instance Normalization: input type not supported.");
1678
1679 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1680 "Reference Instance Normalization: output type not supported.");
1681
1682 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1683 "Reference Instance Normalization: input and output types mismatched.");
1684
1685 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1686 "Reference Instance Normalization: input and output shapes have different "
1687 "num total elements.");
1688
1689 return supported;
1690}
1691
arovir011c7c81b2018-10-08 11:34:28 +01001692bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1693 const TensorInfo& output,
1694 const L2NormalizationDescriptor& descriptor,
1695 Optional<std::string&> reasonIfUnsupported) const
1696{
Jan Eilers8eb25602020-03-09 12:13:48 +00001697 IgnoreUnused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001698 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01001699 std::array<DataType, 6> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001700 {
1701 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001702 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001703 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001704 DataType::QAsymmU8,
1705 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001706 };
1707
1708 bool supported = true;
1709
1710 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1711 "Reference L2normalization: input type not supported.");
1712
1713 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1714 "Reference L2normalization: output type not supported.");
1715
1716 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1717 "Reference L2normalization: input and output types mismatched.");
1718
1719 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1720 "Reference L2normalization: input and output shapes have different "
1721 "num total elements.");
1722
1723 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001724}
1725
James Conroyaba90cd2020-11-06 16:28:18 +00001726bool RefLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0,
1727 const TensorInfo& input1,
1728 const TensorInfo& output,
1729 const LogicalBinaryDescriptor& descriptor,
1730 Optional<std::string&> reasonIfUnsupported) const
1731{
1732 IgnoreUnused(descriptor);
1733
1734 std::array<DataType, 1> supportedTypes =
1735 {
1736 DataType::Boolean
1737 };
1738
1739 bool supported = true;
1740 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1741 "Reference LogicalBinary: input 0 type not supported");
1742 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1743 "Reference LogicalBinary: input 1 type not supported");
1744
1745 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1746 "Reference LogicalBinary: input and output types do not match");
1747
1748 return supported;
1749}
1750
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001751bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1752 const TensorInfo& output,
1753 const LogSoftmaxDescriptor& descriptor,
1754 Optional<std::string&> reasonIfUnsupported) const
1755{
Jan Eilers8eb25602020-03-09 12:13:48 +00001756 IgnoreUnused(descriptor);
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001757
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001758 std::array<DataType, 3> supportedTypes =
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001759 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001760 DataType::Float32,
1761 DataType::Float16
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001762 };
1763
1764 bool supported = true;
1765 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1766 "Reference LogSoftmax: input type not supported");
1767
1768 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1769 "Reference LogSoftmax: output type not supported");
1770
1771 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1772 "Reference LogSoftmax: input and output types do not match");
1773
1774 return supported;
1775}
1776
arovir011c7c81b2018-10-08 11:34:28 +01001777bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1778 const TensorInfo& outputStateIn,
1779 const TensorInfo& cellStateIn,
1780 const TensorInfo& scratchBuffer,
1781 const TensorInfo& outputStateOut,
1782 const TensorInfo& cellStateOut,
1783 const TensorInfo& output,
1784 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001785 const LstmInputParamsInfo& paramsInfo,
1786 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +01001787{
Jan Eilers8eb25602020-03-09 12:13:48 +00001788 IgnoreUnused(descriptor);
1789 IgnoreUnused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001790
1791 bool supported = true;
1792
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001793 std::array<DataType,3> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +01001794 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001795 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001796 };
1797
Jan Eilersd01a83c2019-07-03 18:20:40 +01001798 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001799 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1800 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001801 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1802 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001803 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1804 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001805 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1806 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001807 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1808 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001809 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1810 "Reference Lstm: input and cellStateOut types are mismatched");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01001811
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001812 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1813 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001814 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001815 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001816 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001817 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001818 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001819 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001820 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001821 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001822 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001823 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001824 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001825 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001826 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001827 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001828 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001829 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001830 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001831 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001832 "Reference Lstm: input and OutputGateBias types are mismatched");
1833 if (!descriptor.m_CifgEnabled)
1834 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001835 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001836 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001837 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001838 reasonIfUnsupported,
1839 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001840 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001841 "Reference Lstm: input and InputGateBias types are mismatched");
1842 if (descriptor.m_PeepholeEnabled)
1843 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001844 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001845 reasonIfUnsupported,
1846 "Reference Lstm: input and CellToInputWeights types are mismatched");
1847 }
1848 }
1849 if (descriptor.m_PeepholeEnabled)
1850 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001851 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001852 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001853 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001854 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1855 }
1856 if (descriptor.m_ProjectionEnabled)
1857 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001858 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001859 "Reference Lstm: input and mProjectionWeights types are mismatched");
1860 if (paramsInfo.m_ProjectionBias != nullptr)
1861 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001862 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001863 "Reference Lstm: input and ProjectionBias types are mismatched");
1864 }
1865 }
1866 if (descriptor.m_LayerNormEnabled)
1867 {
1868 if (!descriptor.m_CifgEnabled)
1869 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001870 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001871 reasonIfUnsupported,
1872 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1873 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001874 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001875 reasonIfUnsupported,
1876 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001877 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001878 reasonIfUnsupported,
1879 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001880 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001881 reasonIfUnsupported,
1882 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1883 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001884
1885 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001886}
1887
saoste012df12b32018-11-28 16:57:20 +00001888bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1889 const TensorInfo& input1,
1890 const TensorInfo& output,
1891 Optional<std::string&> reasonIfUnsupported) const
1892{
Sadik Armagan2999a022019-04-09 14:20:12 +01001893 bool supported = true;
1894
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001895 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001896 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001897 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001898 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001899 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001900 DataType::QSymmS16,
1901 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001902 };
1903
1904 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1905 "Reference maximum: input 0 is not a supported type.");
1906
1907 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1908 "Reference maximum: input 1 is not a supported type.");
1909
1910 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1911 "Reference maximum: output is not a supported type.");
1912
1913 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1914 "Reference maximum: input 0 and Input 1 types are mismatched");
1915
1916 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1917 "Reference maximum: input and output types are mismatched");
1918
1919 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1920 "Reference maximum: shapes are not suitable for implicit broadcast.");
1921
1922 return supported;
saoste012df12b32018-11-28 16:57:20 +00001923}
1924
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001925bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1926 const TensorInfo& output,
1927 const MeanDescriptor& descriptor,
1928 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001929{
James Conroy4d1ff582019-06-10 17:06:39 +01001930 bool supported = true;
1931 std::string meanLayerStr = "Mean";
1932 std::string outputTensorStr = "output";
1933
Sadik Armagan303980c2020-04-17 12:45:14 +01001934 std::array<DataType,6> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001935 {
1936 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001937 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001938 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001939 DataType::QAsymmU8,
1940 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01001941 };
1942
1943 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1944 "Reference Mean: input type not supported.");
1945
1946 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1947 "Reference Mean: input and output types are mismatched");
1948
1949 if (descriptor.m_KeepDims)
1950 {
1951 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1952 reasonIfUnsupported,
1953 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1954 output.GetNumDimensions(),
1955 meanLayerStr, outputTensorStr).data());
1956 }
1957 else if (descriptor.m_Axis.empty())
1958 {
1959 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1960 reasonIfUnsupported,
1961 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1962 meanLayerStr, outputTensorStr).data());
1963 }
1964 else
1965 {
Matthew Sloyan171214c2020-09-09 09:07:37 +01001966 auto outputDim = input.GetNumDimensions() - armnn::numeric_cast<unsigned int>(descriptor.m_Axis.size());
James Conroy4d1ff582019-06-10 17:06:39 +01001967
1968 if (outputDim > 0)
1969 {
1970 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1971 reasonIfUnsupported,
1972 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1973 meanLayerStr, outputTensorStr).data());
1974 }
1975 else
1976 {
1977 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1978 reasonIfUnsupported,
1979 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1980 meanLayerStr, outputTensorStr).data());
1981 }
1982 }
1983
1984 return supported;
narpra0132b90462018-09-13 11:07:48 +01001985}
1986
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001987bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1988 const TensorInfo &output,
1989 Optional<std::string &> reasonIfUnsupported) const
1990{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001991 bool supported = true;
1992
Sadik Armagan303980c2020-04-17 12:45:14 +01001993 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001994 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001995 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001996 DataType::Float32,
1997 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001998 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001999 DataType::QAsymmU8,
2000 DataType::QSymmS16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002001 DataType::Boolean
2002 };
2003
2004 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2005 "Reference MemCopy: input type not supported");
2006
2007 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2008 "Reference MemCopy: output type not supported");
2009
2010 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2011 "Reference MemCopy: input and output types are mismatched");
2012
2013 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00002014}
2015
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00002016bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
2017 const TensorInfo& input1,
2018 const TensorInfo& output,
2019 Optional<std::string&> reasonIfUnsupported) const
2020{
Sadik Armagan2999a022019-04-09 14:20:12 +01002021 bool supported = true;
2022
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002023 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01002024 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002025 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002026 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002027 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002028 DataType::QSymmS16,
2029 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002030 };
2031
2032 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2033 "Reference minimum: input 0 is not a supported type.");
2034
2035 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2036 "Reference minimum: input 1 is not a supported type.");
2037
2038 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2039 "Reference minimum: output is not a supported type.");
2040
2041 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2042 "Reference minimum: input 0 and Input 1 types are mismatched");
2043
2044 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2045 "Reference minimum: input and output types are mismatched");
2046
2047 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2048 "Reference minimum: shapes are not suitable for implicit broadcast.");
2049
2050 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00002051}
2052
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002053bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
2054 const TensorInfo& input1,
2055 const TensorInfo& output,
2056 Optional<std::string&> reasonIfUnsupported) const
2057{
Sadik Armagan2999a022019-04-09 14:20:12 +01002058 bool supported = true;
2059
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002060 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01002061 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002062 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002063 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002064 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002065 DataType::QSymmS16,
2066 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002067 };
2068
2069 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2070 "Reference multiplication: input 0 is not a supported type.");
2071
2072 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2073 "Reference multiplication: input 1 is not a supported type.");
2074
2075 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2076 "Reference multiplication: output is not a supported type.");
2077
2078 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2079 "Reference multiplication: input 0 and Input 1 types are mismatched");
2080
2081 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2082 "Reference multiplication: input and output types are mismatched");
2083
2084 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2085 "Reference multiplication: shapes are not suitable for implicit broadcast.");
2086
2087 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002088}
2089
2090bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
2091 const TensorInfo& output,
2092 const NormalizationDescriptor& descriptor,
2093 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01002094{
Jan Eilers8eb25602020-03-09 12:13:48 +00002095 IgnoreUnused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002096
2097 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01002098 std::array<DataType, 6> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002099 {
2100 DataType::Float16,
2101 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002102 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002103 DataType::QAsymmU8,
2104 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002105 };
2106
2107 bool supported = true;
2108
2109 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2110 "Reference normalization: input type not supported.");
2111
2112 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2113 "Reference normalization: output type not supported.");
2114
2115 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2116 "Reference normalization: input and output shapes have different "
2117 "num total elements.");
2118
2119 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002120}
2121
Derek Lamberti901ea112019-12-10 22:07:09 +00002122bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
2123 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002124{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01002125 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002126}
2127
2128bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
2129 const TensorInfo& output,
2130 const PadDescriptor& descriptor,
2131 Optional<std::string&> reasonIfUnsupported) const
2132{
Jan Eilers8eb25602020-03-09 12:13:48 +00002133 IgnoreUnused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002134 bool supported = true;
2135
2136 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002137 std::array<DataType,6> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002138 {
2139 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002140 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002141 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002142 DataType::QAsymmU8,
2143 DataType::QSymmS16
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002144 };
2145
2146 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2147 "Reference pad: input is not a supported type.");
2148
2149 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2150 "Reference pad: output is not a supported type.");
2151
2152 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2153 "Reference pad: input and output types are mismatched.");
2154
2155 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01002156}
2157
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002158bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
2159 const TensorInfo& output,
2160 const PermuteDescriptor& descriptor,
2161 Optional<std::string&> reasonIfUnsupported) const
2162{
Jan Eilers8eb25602020-03-09 12:13:48 +00002163 IgnoreUnused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002164 bool supported = true;
2165
2166 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002167 std::array<DataType, 6> supportedTypes =
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002168 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002169 DataType::BFloat16,
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002170 DataType::Float32,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002171 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002172 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002173 DataType::QAsymmU8,
2174 DataType::QSymmS16
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002175 };
2176
2177 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2178 "Reference permute: input is not a supported type.");
2179
2180 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2181 "Reference permute: output is not a supported type.");
2182
2183 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2184 "Reference permute: input and output types are mismatched.");
2185
2186 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002187}
2188
2189bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
2190 const TensorInfo& output,
2191 const Pooling2dDescriptor& descriptor,
2192 Optional<std::string&> reasonIfUnsupported) const
2193{
Jan Eilers8eb25602020-03-09 12:13:48 +00002194 IgnoreUnused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01002195 bool supported = true;
2196
2197 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002198 std::array<DataType,6> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01002199 {
2200 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002201 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002202 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002203 DataType::QAsymmU8,
2204 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01002205 };
2206
2207 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2208 "Reference poolind2d: input is not a supported type.");
2209
2210 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2211 "Reference poolind2d: output is not a supported type.");
2212
2213 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2214 "Reference poolind2d: input and output types are mismatched.");
2215
2216 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002217}
2218
Tamás Nyíri7b885b32021-10-26 14:47:57 +01002219bool RefLayerSupport::IsPooling3dSupported(const TensorInfo& input,
2220 const TensorInfo& output,
2221 const Pooling3dDescriptor& descriptor,
2222 Optional<std::string&> reasonIfUnsupported) const
2223{
2224 IgnoreUnused(descriptor);
2225 bool supported = true;
2226
2227 // Define supported output and inputs types.
2228 std::array<DataType,6> supportedTypes =
2229 {
Tamás Nyíri7b885b32021-10-26 14:47:57 +01002230 DataType::Float32,
2231 DataType::Float16,
2232 DataType::QAsymmS8,
2233 DataType::QAsymmU8,
2234 DataType::QSymmS16
2235 };
2236
2237 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2238 "Reference poolind3d: input is not a supported type.");
2239
2240 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2241 "Reference poolind3d: output is not a supported type.");
2242
2243 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2244 "Reference poolind3d: input and output types are mismatched.");
2245
2246 return supported;
2247}
2248
2249
James Conroy4f1f8992020-04-29 20:01:10 +01002250bool RefLayerSupport::IsQLstmSupported(const TensorInfo& input,
2251 const TensorInfo& previousOutputIn,
2252 const TensorInfo& previousCellStateIn,
2253 const TensorInfo& outputStateOut,
2254 const TensorInfo& cellStateOut,
2255 const TensorInfo& output,
2256 const QLstmDescriptor& descriptor,
2257 const LstmInputParamsInfo& paramsInfo,
2258 Optional<std::string&> reasonIfUnsupported) const
2259{
2260 IgnoreUnused(input);
2261 IgnoreUnused(previousOutputIn);
2262 IgnoreUnused(previousCellStateIn);
2263 IgnoreUnused(outputStateOut);
2264 IgnoreUnused(cellStateOut);
2265 IgnoreUnused(output);
2266 IgnoreUnused(descriptor);
2267 IgnoreUnused(paramsInfo);
2268
2269 IgnoreUnused(reasonIfUnsupported);
2270
2271 return true;
2272}
2273
Derek Lamberti5f400d62019-03-25 15:41:58 +00002274bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
2275 const TensorInfo& output,
2276 Optional<std::string&> reasonIfUnsupported) const
2277{
2278 bool supported = true;
2279
Finn Williamsfd271062019-12-04 14:27:27 +00002280 // Define supported input types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002281 std::array<DataType,7> supportedInputTypes = {
Keith Davis5e51cd82020-01-29 16:52:59 +00002282 DataType::Float32,
Keith Davis3d8bc972020-02-04 09:31:47 +00002283 DataType::Float16,
Ryan OShea9add1202020-02-07 10:06:33 +00002284 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002285 DataType::QAsymmU8,
2286 DataType::QSymmS8,
2287 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002288 };
2289
2290 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
2291 "Reference quantize: input type not supported.");
2292
2293 // Define supported output types.
Ryan OShea9add1202020-02-07 10:06:33 +00002294 std::array<DataType,4> supportedOutputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00002295 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002296 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00002297 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002298 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002299 };
2300 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2301 "Reference quantize: output type not supported.");
2302
2303 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2304 "Reference quantize: input and output shapes have different num total elements.");
2305
2306 return supported;
2307}
2308
Finn Williams2605b232020-06-10 15:53:46 +01002309bool RefLayerSupport::IsRankSupported(const TensorInfo& input,
2310 const TensorInfo& output,
2311 Optional<std::string&> reasonIfUnsupported) const
2312{
2313 IgnoreUnused(input);
2314 // Define supported output types.
2315 std::array<DataType,1> supportedOutputTypes =
2316 {
2317 DataType::Signed32,
2318 };
2319
2320 return CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2321 "Reference rank: input type not supported.");
2322}
2323
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00002324bool RefLayerSupport::IsReduceSupported(const TensorInfo& input,
2325 const TensorInfo& output,
2326 const ReduceDescriptor& descriptor,
2327 Optional<std::string&> reasonIfUnsupported) const
2328{
2329 IgnoreUnused(descriptor);
2330 bool supported = true;
2331 std::array<DataType,7> supportedTypes =
2332 {
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00002333 DataType::Float32,
2334 DataType::Float16,
2335 DataType::QAsymmS8,
2336 DataType::QAsymmU8,
2337 DataType::QSymmS16,
2338 DataType::Signed32
2339 };
2340
2341 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2342 "Reference Reduce: input type not supported");
2343
2344 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2345 "Reference Reduce: output type not supported");
2346
2347 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2348 "Reference Reduce: input and output types not matching");
2349
2350 return supported;
2351}
2352
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002353bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Kevin Maya023c402019-12-12 17:28:05 +00002354 const TensorInfo& output,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00002355 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002356 Optional<std::string&> reasonIfUnsupported) const
2357{
Jan Eilers8eb25602020-03-09 12:13:48 +00002358 IgnoreUnused(output);
2359 IgnoreUnused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01002360 // Define supported output types.
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002361 std::array<DataType,8> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01002362 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002363 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01002364 DataType::Float32,
2365 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01002366 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002367 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002368 DataType::QAsymmU8,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002369 DataType::QSymmS16,
2370 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01002371 };
Keith Davis0c2eeac2020-02-11 16:51:50 +00002372
Nina Drozd2f2778f2019-05-27 10:37:05 +01002373 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
2374 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002375}
2376
Teresa Charlin970f43b2019-07-01 13:51:07 +01002377bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
2378 const TensorInfo& output,
2379 const ResizeDescriptor& descriptor,
2380 Optional<std::string&> reasonIfUnsupported) const
2381{
Jan Eilers8eb25602020-03-09 12:13:48 +00002382 IgnoreUnused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01002383 bool supported = true;
Teresa Charlince655882023-11-21 15:44:13 +00002384 std::array<DataType,7> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01002385 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002386 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01002387 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002388 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002389 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002390 DataType::QAsymmU8,
Teresa Charlince655882023-11-21 15:44:13 +00002391 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002392 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01002393 };
2394
2395 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2396 "Reference Resize: input type not supported");
2397
2398 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2399 "Reference Resize: output type not supported");
2400
2401 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2402 "Reference Resize: input and output types not matching");
2403
2404 return supported;
2405}
2406
Tracy Narinebb8d7592023-07-13 16:50:54 +01002407bool RefLayerSupport::IsReverseV2Supported(const TensorInfo& input0,
2408 const TensorInfo& input1,
Tianle Cheng988354d2023-06-28 13:20:47 +01002409 const TensorInfo& output,
Tianle Cheng988354d2023-06-28 13:20:47 +01002410 Optional<std::string&> reasonIfUnsupported) const
2411{
Tianle Cheng988354d2023-06-28 13:20:47 +01002412 bool supported = true;
2413 // ReverseV2 is data type agnostic so it can support all the types in the Reference backend
Declan-ARM1bf56cd2023-07-20 17:32:57 +01002414 std::array<DataType,8> supportedTypes =
Tianle Cheng988354d2023-06-28 13:20:47 +01002415 {
2416 DataType::BFloat16,
2417 DataType::Float32,
2418 DataType::Float16,
2419 DataType::QAsymmS8,
2420 DataType::QAsymmU8,
Declan-ARM1bf56cd2023-07-20 17:32:57 +01002421 DataType::QSymmS8,
2422 DataType::QSymmS16,
2423 DataType::Signed32
Tianle Cheng988354d2023-06-28 13:20:47 +01002424 };
2425
Tracy Narinebb8d7592023-07-13 16:50:54 +01002426 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2427 "Reference ReverseV2: input0 type not supported");
Tianle Cheng988354d2023-06-28 13:20:47 +01002428
2429 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2430 "Reference ReverseV2: output type not supported");
2431
Tracy Narinebb8d7592023-07-13 16:50:54 +01002432 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2433 "Reference ReverseV2: input0 and output types not matching");
2434
2435 std::array<DataType,6> input2SupportedTypes =
2436 {
2437 DataType::Signed32
2438 };
2439
2440 supported &= CheckSupportRule(TypeAnyOf(input1, input2SupportedTypes), reasonIfUnsupported,
2441 "Reference ReverseV2: input1 type not supported");
Tianle Cheng988354d2023-06-28 13:20:47 +01002442
2443 return supported;
2444}
2445
Keith Davis3ae3f972021-05-21 16:33:48 +01002446bool RefLayerSupport::IsShapeSupported(const TensorInfo& input,
2447 const TensorInfo& output,
2448 Optional<std::string&> reasonIfUnsupported) const
2449{
2450 IgnoreUnused(input);
2451 bool supported = true;
2452
2453 std::array<DataType, 1> supportedTypes =
2454 {
2455 DataType::Signed32
2456 };
2457
2458 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2459 "Reference Shape: output type not supported");
2460
2461 return supported;
2462}
2463
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002464bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
2465 const TensorInfo& output,
2466 const SliceDescriptor& descriptor,
2467 Optional<std::string&> reasonIfUnsupported) const
2468{
Jan Eilers8eb25602020-03-09 12:13:48 +00002469 IgnoreUnused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002470 bool supported = true;
2471
Sadik Armagan303980c2020-04-17 12:45:14 +01002472 std::array<DataType, 5> supportedTypes =
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002473 {
2474 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002475 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002476 DataType::QAsymmU8,
Ryan OShea980446b2023-06-08 16:23:28 +01002477 DataType::QSymmS16,
2478 DataType::Signed32
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002479 };
2480
2481 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2482 "Reference Slice: input type not supported");
2483
2484 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2485 "Reference Slice: output type not supported");
2486
2487 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2488 "Reference Slice: input and output types are mismatched");
2489
2490 return supported;
2491}
2492
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002493bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
2494 const TensorInfo& output,
2495 const SoftmaxDescriptor& descriptor,
2496 Optional<std::string&> reasonIfUnsupported) const
2497{
Jan Eilers8eb25602020-03-09 12:13:48 +00002498 IgnoreUnused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01002499 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002500 std::array<DataType,7> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01002501 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002502 DataType::Float32,
2503 DataType::Float16,
2504 DataType::QSymmS8,
2505 DataType::QAsymmS8,
2506 DataType::QAsymmU8,
2507 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +01002508 };
2509
2510 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002511 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002512
2513 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002514 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002515
2516 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002517 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002518
2519 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002520}
2521
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002522bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
2523 const TensorInfo& output,
2524 const SpaceToBatchNdDescriptor& descriptor,
2525 Optional<std::string&> reasonIfUnsupported) const
2526{
Jan Eilers8eb25602020-03-09 12:13:48 +00002527 IgnoreUnused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01002528 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002529 std::array<DataType,6> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01002530 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002531 DataType::Float32,
2532 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002533 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002534 DataType::QAsymmU8,
2535 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01002536 };
2537
2538 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2539 "Reference SpaceToBatchNd: input type not supported");
2540
2541 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2542 "Reference SpaceToBatchNd: output type not supported");
2543
2544 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2545 "Reference SpaceToBatchNd: input and output types are mismatched");
2546
2547 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002548}
2549
Keith Davisa57eccb2019-06-14 17:33:22 +01002550bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01002551 const TensorInfo& output,
2552 const SpaceToDepthDescriptor& descriptor,
2553 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01002554{
2555
Jan Eilers8eb25602020-03-09 12:13:48 +00002556 IgnoreUnused(descriptor);
Keith Davisa57eccb2019-06-14 17:33:22 +01002557 bool supported = true;
2558
Sadik Armagan303980c2020-04-17 12:45:14 +01002559 std::array<DataType,6> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01002560 {
2561 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002562 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002563 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002564 DataType::QAsymmU8,
2565 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01002566 };
2567
2568 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2569 "Reference SpaceToDepth: input type not supported");
2570
2571 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2572 "Reference SpaceToDepth: output type not supported");
2573
2574 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2575 "Reference SpaceToDepth: input and output types are mismatched");
2576
2577 return supported;
2578}
2579
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002580bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002581 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
2582 const ViewsDescriptor& descriptor,
2583 Optional<std::string&> reasonIfUnsupported) const
2584{
Jan Eilers8eb25602020-03-09 12:13:48 +00002585 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002586 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002587 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002588 {
2589 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002590 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002591 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002592 DataType::QAsymmU8,
2593 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002594 };
2595
2596 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2597 "Reference splitter: output type not supported");
Derek Lambertieac4adb2020-08-25 13:05:59 +01002598 for (const TensorInfo& output : outputs)
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002599 {
2600 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2601 "Reference splitter: input type not supported");
2602
2603 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2604 "Reference splitter: input and output types mismatched.");
2605 }
2606
2607 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002608}
2609
Matthew Jackson81e601c2019-07-11 12:07:09 +01002610bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
2611 const TensorInfo& output,
2612 const StackDescriptor& descriptor,
2613 Optional<std::string&> reasonIfUnsupported) const
2614{
Jan Eilers8eb25602020-03-09 12:13:48 +00002615 IgnoreUnused(descriptor);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002616
2617 bool supported = true;
Sadik Armagan529195f2022-01-14 12:56:35 +00002618 std::array<DataType,7> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01002619 {
2620 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01002621 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002622 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002623 DataType::QAsymmU8,
Sadik Armagan529195f2022-01-14 12:56:35 +00002624 DataType::QSymmS16,
2625 DataType::Signed32
Matthew Jackson81e601c2019-07-11 12:07:09 +01002626 };
2627
2628 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2629 "Reference stack: output type not supported");
2630 for (const TensorInfo* input : inputs)
2631 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01002632 ARMNN_ASSERT(input != nullptr);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002633 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
2634 "Reference stack: input type not supported");
2635
2636 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
2637 "Reference stack: input and output types mismatched.");
2638 }
2639
2640 return supported;
2641}
2642
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002643bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
2644 const TensorInfo& output,
2645 const StridedSliceDescriptor& descriptor,
2646 Optional<std::string&> reasonIfUnsupported) const
2647{
Jan Eilers8eb25602020-03-09 12:13:48 +00002648 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002649 bool supported = true;
2650
Sadik Armagan303980c2020-04-17 12:45:14 +01002651 std::array<DataType,5> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002652 {
2653 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002654 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002655 DataType::QAsymmU8,
2656 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002657 };
2658
2659 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2660 "Reference StridedSlice: input type not supported");
2661
2662 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2663 "Reference StridedSlice: output type not supported");
2664
2665 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2666 "Reference StridedSlice: input and output types are mismatched");
2667
2668 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002669}
2670
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002671bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
2672 const TensorInfo& input1,
2673 const TensorInfo& output,
2674 Optional<std::string&> reasonIfUnsupported) const
2675{
Sadik Armagan2999a022019-04-09 14:20:12 +01002676 bool supported = true;
2677
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002678 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01002679 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002680 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002681 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002682 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002683 DataType::QSymmS16,
2684 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002685 };
2686
2687 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2688 "Reference subtraction: input 0 is not a supported type.");
2689
2690 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2691 "Reference subtraction: input 1 is not a supported type.");
2692
2693 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2694 "Reference subtraction: output is not a supported type.");
2695
2696 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2697 "Reference subtraction: input 0 and Input 1 types are mismatched");
2698
2699 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2700 "Reference subtraction: input and output types are mismatched");
2701
2702 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2703 "Reference subtraction: shapes are not suitable for implicit broadcast.");
2704
2705 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002706}
2707
Matteo Martincighab9e5252019-06-13 17:27:46 +01002708bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
2709 const TensorInfo& alpha,
2710 const TensorInfo& output,
2711 Optional<std::string&> reasonIfUnsupported) const
2712{
2713 bool supported = true;
2714
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002715 std::array<DataType, 6> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01002716 {
2717 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002718 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002719 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002720 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002721 DataType::QSymmS16
Matteo Martincighab9e5252019-06-13 17:27:46 +01002722 };
2723
2724 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2725 "PReLU: input is not a supported type.");
2726
2727 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
2728 "PReLU: alpha is not a supported type.");
2729
2730 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2731 "PReLU: output is not a supported type.");
2732
2733 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
2734 "PReLU: input, alpha and output types are mismatched");
2735
2736 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
2737 "PReLU: shapes are not suitable for implicit broadcast");
2738
2739 return supported;
2740}
2741
Teresa Charlin79a06a52023-07-13 17:16:45 +01002742bool RefLayerSupport::IsTileSupported(const TensorInfo& input,
2743 const TensorInfo& output,
2744 const TileDescriptor& descriptor,
2745 Optional<std::string&> reasonIfUnsupported) const
2746{
2747 IgnoreUnused(descriptor);
2748
2749 bool supported = true;
2750
2751 std::array<DataType, 7> supportedTypes
2752 {
2753 DataType::Float32,
2754 DataType::Float16,
2755 DataType::QAsymmS8,
2756 DataType::QAsymmU8,
2757 DataType::QSymmS8,
2758 DataType::QSymmS16,
2759 DataType::Signed32
2760 };
2761
2762 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2763 "Tile: input type not supported.");
2764
2765 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2766 "Tile: output type not supported");
2767
2768 return supported;
2769}
2770
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002771bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
2772 const TensorInfo& output,
2773 const TransposeConvolution2dDescriptor& descriptor,
2774 const TensorInfo& weights,
2775 const Optional<TensorInfo>& biases,
2776 Optional<std::string&> reasonIfUnsupported) const
2777{
Jan Eilers8eb25602020-03-09 12:13:48 +00002778 IgnoreUnused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002779 bool supported = true;
2780
Sadik Armagan303980c2020-04-17 12:45:14 +01002781 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002782 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002783 DataType::Float32,
2784 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002785 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002786 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002787 DataType::QSymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002788 DataType::QSymmS16
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002789 };
2790
2791 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2792 "Reference TransposeConvolution2d: input is not a supported type.");
2793
2794 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2795 "Reference TransposeConvolution2d: output is not a supported type.");
2796
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002797 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2798 "Reference TransposeConvolution2d: input and output types mismatched.");
2799
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002800
2801 const DataType inputType = input.GetDataType();
Sadik Armagan303980c2020-04-17 12:45:14 +01002802 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002803 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01002804 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002805 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002806 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002807 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01002808 DataType::QSymmS8
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002809 };
2810
2811 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
2812 "Reference TransposeConvolution2d: weights type not supported for "
2813 "quantized input.");
2814 }
2815 else
2816 {
2817 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
2818 "Reference TransposeConvolution2d: weights is not a supported type.");
2819
2820 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
2821 "Reference TransposeConvolution2d: input and weights types mismatched.");
2822 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002823
2824 if (biases.has_value())
2825 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002826 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01002827 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002828 DataType::Float32,
2829 DataType::Float16,
2830 DataType::Signed32
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002831 };
2832 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
2833 "Reference TransposeConvolution2d: biases is not a supported type.");
2834 }
2835
2836 return supported;
2837}
2838
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002839bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
2840 const TensorInfo& output,
2841 const TransposeDescriptor& descriptor,
2842 Optional<std::string&> reasonIfUnsupported) const
2843{
Jan Eilers8eb25602020-03-09 12:13:48 +00002844 IgnoreUnused(descriptor);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002845 bool supported = true;
2846
2847 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002848 std::array<DataType, 6> supportedTypes =
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002849 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002850 DataType::BFloat16,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002851 DataType::Float32,
2852 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002853 DataType::QAsymmS8,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002854 DataType::QAsymmU8,
2855 DataType::QSymmS16
2856 };
2857
2858 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2859 "Reference transpose: input is not a supported type.");
2860
2861 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2862 "Reference transpose: output is not a supported type.");
2863
2864 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2865 "Reference transpose: input and output types are mismatched.");
2866
2867 return supported;
2868}
2869
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002870bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
2871 const TensorInfo& input,
2872 const TensorInfo& outputStateIn,
2873 const TensorInfo& cellStateIn,
Mike Kelly12994962022-04-21 11:57:09 +01002874 const TensorInfo& outputStateOut,
2875 const TensorInfo& cellStateOut,
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002876 const TensorInfo& output,
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002877 const UnidirectionalSequenceLstmDescriptor& descriptor,
2878 const LstmInputParamsInfo& paramsInfo,
2879 Optional<std::string&> reasonIfUnsupported) const
2880{
2881 IgnoreUnused(descriptor);
2882 IgnoreUnused(paramsInfo);
2883 IgnoreUnused(outputStateIn);
2884 IgnoreUnused(cellStateIn);
Mike Kelly12994962022-04-21 11:57:09 +01002885 IgnoreUnused(outputStateOut);
2886 IgnoreUnused(cellStateOut);
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002887 bool supported = true;
2888
Mike Kelly12994962022-04-21 11:57:09 +01002889 std::array<DataType, 2> supportedTypes =
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002890 {
Mike Kelly12994962022-04-21 11:57:09 +01002891 DataType::Float32,
2892 DataType::QAsymmS8
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002893 };
2894
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002895 std::array<DataType, 2> supportedWeightTypes =
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002896 {
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002897 DataType::Float32,
2898 DataType::QAsymmS8
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002899 };
2900
Mike Kelly12994962022-04-21 11:57:09 +01002901 std::array<DataType, 3> supportedBiasTypes =
2902 {
2903 DataType::Float32,
2904 DataType::QAsymmS8,
2905 DataType::Signed32
2906 };
2907
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002908 // check inputs and outputs
2909 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2910 "Reference UnidirectionalSequenceLstm: input is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002911 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2912 "Reference UnidirectionalSequenceLstm: output is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002913
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002914 // check layer parameters
2915 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes),
2916 reasonIfUnsupported,
2917 "Reference UnidirectionalSequenceLstm: InputToForgetWeights "
2918 "is not a supported type.");
2919 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToCellWeights(), supportedWeightTypes),
2920 reasonIfUnsupported,
2921 "Reference UnidirectionalSequenceLstm: InputToCellWeights is not a supported type.");
2922 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToOutputWeights(), supportedWeightTypes),
2923 reasonIfUnsupported,
2924 "Reference UnidirectionalSequenceLstm: InputToOutputWeights "
2925 "is not a supported type.");
2926 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToForgetWeights(), supportedWeightTypes),
2927 reasonIfUnsupported,
2928 "Reference UnidirectionalSequenceLstm: RecurrentToForgetWeights "
2929 "is not a supported type.");
2930 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToCellWeights(), supportedWeightTypes),
2931 reasonIfUnsupported,
2932 "Reference UnidirectionalSequenceLstm: RecurrentToCellWeights "
2933 "is not a supported type.");
2934 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToOutputWeights(), supportedWeightTypes),
2935 reasonIfUnsupported,
2936 "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights "
2937 "is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002938
2939 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetGateBias(), supportedBiasTypes), reasonIfUnsupported,
2940 "Reference UnidirectionalSequenceLstm: ForgetGateBias is not a supported type.");
2941 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellBias(), supportedBiasTypes), reasonIfUnsupported,
2942 "Reference UnidirectionalSequenceLstm: CellBias is not a supported type.");
2943 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputGateBias(), supportedBiasTypes), reasonIfUnsupported,
2944 "Reference UnidirectionalSequenceLstm: OutputGateBias is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002945 if (!descriptor.m_CifgEnabled)
2946 {
2947 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes),
2948 reasonIfUnsupported,
2949 "Reference UnidirectionalSequenceLstm: InputToInputWeights "
2950 "is not a supported type.");
2951 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToInputWeights(), supportedWeightTypes),
2952 reasonIfUnsupported,
2953 "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights "
2954 "is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002955 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputGateBias(), supportedBiasTypes), reasonIfUnsupported,
2956 "Reference UnidirectionalSequenceLstm: InputGateBias is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002957 if (descriptor.m_PeepholeEnabled)
2958 {
2959 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes),
2960 reasonIfUnsupported,
2961 "Reference UnidirectionalSequenceLstm: CellToInputWeights "
2962 "is not a supported type.");
2963 }
2964 }
2965 if (descriptor.m_PeepholeEnabled)
2966 {
2967 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToForgetWeights(), supportedWeightTypes),
2968 reasonIfUnsupported,
2969 "Reference UnidirectionalSequenceLstm: CellToForgetWeights "
2970 "is not a supported type.");
2971 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToOutputWeights(), supportedWeightTypes),
2972 reasonIfUnsupported,
2973 "Reference UnidirectionalSequenceLstm: CellToOutputWeights "
2974 "is not a supported type.");
2975 }
2976 if (descriptor.m_ProjectionEnabled)
2977 {
2978 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetProjectionWeights(), supportedWeightTypes),
2979 reasonIfUnsupported,
2980 "Reference UnidirectionalSequenceLstm: ProjectionWeights "
2981 "is not a supported type.");
2982 if (paramsInfo.m_ProjectionBias != nullptr)
2983 {
2984 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
2985 "Reference UnidirectionalSequenceLstm: input and ProjectionBias types "
2986 "are mismatched");
2987 }
2988 }
2989 if (descriptor.m_LayerNormEnabled)
2990 {
2991 if (!descriptor.m_CifgEnabled)
2992 {
2993 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputLayerNormWeights(), supportedWeightTypes),
2994 reasonIfUnsupported,
2995 "Reference UnidirectionalSequenceLstm: InputLayerNormWeights "
2996 "is not a supported type.");
2997 }
2998 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetLayerNormWeights(), supportedWeightTypes),
2999 reasonIfUnsupported,
3000 "Reference UnidirectionalSequenceLstm: ForgetLayerNormWeights "
3001 "is not a supported type.");
3002 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellLayerNormWeights(), supportedWeightTypes),
3003 reasonIfUnsupported,
3004 "Reference UnidirectionalSequenceLstm: CellLayerNormWeights "
3005 "is not a supported type.");
3006 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputLayerNormWeights(), supportedWeightTypes),
3007 reasonIfUnsupported,
3008 "Reference UnidirectionalSequenceLstm: OutputLayerNormWeights "
3009 "is not a supported type.");
3010 }
3011
3012 return supported;
3013}
3014
arovir011c7c81b2018-10-08 11:34:28 +01003015} // namespace armnn