blob: defdf0d807bb8c2ab50cef68a9d94bf357a4d8ed [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
Mike Kelly3ec30772023-03-08 13:47:17 +00002// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01007
Keith Davis0c2eeac2020-02-11 16:51:50 +00008#include <armnn/TypesUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000010#include <armnn/utility/IgnoreUnused.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010011#include <armnn/utility/NumericCast.hpp>
Cathal Corbett34b429c2021-12-24 12:24:40 +000012#include <armnn/utility/PolymorphicDowncast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
Matteo Martincighe011d202019-11-28 11:35:47 +000014#include <LayerSupportCommon.hpp>
Derek Lambertif674aa02019-08-01 15:56:25 +010015#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000016
Derek Lamberti50db4e82019-03-13 14:16:15 +000017#include <vector>
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <array>
19
telsoa014fcda012018-03-09 14:13:49 +000020namespace armnn
21{
22
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010023namespace
24{
25
26template<typename Float32Func, typename Uint8Func, typename ... Params>
27bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
28 DataType dataType,
29 Float32Func floatFuncPtr,
30 Uint8Func uint8FuncPtr,
31 Params&&... params)
32{
33 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
34 dataType,
35 &FalseFunc<Params...>,
36 floatFuncPtr,
37 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000038 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000039 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010040 std::forward<Params>(params)...);
41}
42
43} // anonymous namespace
44
James Conroy4d1ff582019-06-10 17:06:39 +010045namespace
46{
47
48std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
49 unsigned int actual,
50 std::string& layerStr,
51 std::string& tensorName)
52{
53 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
54 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
55
56 return errorMsg;
57}
58
59} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000060
Cathal Corbett34b429c2021-12-24 12:24:40 +000061bool RefLayerSupport::IsLayerSupported(const LayerType& type,
62 const std::vector<TensorInfo>& infos,
63 const BaseDescriptor& descriptor,
64 const Optional<LstmInputParamsInfo>& lstmParamsInfo,
65 const Optional<QuantizedLstmInputParamsInfo>& quantizedLstmInputParamsInfo,
66 Optional<std::string&> reasonIfUnsupported) const
67{
68 switch (type)
69 {
70 case LayerType::Activation:
71 return IsActivationSupported(infos[0],
72 infos[1],
73 *(PolymorphicDowncast<const ActivationDescriptor*>(&descriptor)),
74 reasonIfUnsupported);
75 case LayerType::Addition:
76 return IsAdditionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
77 case LayerType::ArgMinMax:
78 return IsArgMinMaxSupported(infos[0],
79 infos[1],
80 *(PolymorphicDowncast<const ArgMinMaxDescriptor*>(&descriptor)),
81 reasonIfUnsupported);
Samuel Yap6b478092022-07-06 15:36:03 +010082 case LayerType::BatchMatMul:
83 return IsBatchMatMulSupported(infos[0],
84 infos[1],
85 infos[2],
86 *(PolymorphicDowncast<const BatchMatMulDescriptor*>(&descriptor)),
87 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +000088 case LayerType::BatchNormalization:
89 return IsBatchNormalizationSupported(infos[0],
90 infos[1],
91 infos[2],
92 infos[3],
93 infos[4],
94 infos[5],
95 *(PolymorphicDowncast<const BatchNormalizationDescriptor*>
96 (&descriptor)),
97 reasonIfUnsupported);
98 case LayerType::BatchToSpaceNd:
99 return IsBatchToSpaceNdSupported(infos[0],
100 infos[1],
101 *(PolymorphicDowncast<const BatchToSpaceNdDescriptor*>(&descriptor)),
102 reasonIfUnsupported);
Idriss Chaouch98e383e2023-08-28 14:28:31 +0100103 case LayerType::BroadcastTo:
104 return IsBroadcastToSupported(infos[0],
105 infos[1],
106 *(PolymorphicDowncast<const BroadcastToDescriptor*>(&descriptor)),
107 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000108 case LayerType::Comparison:
109 return IsComparisonSupported(infos[0],
110 infos[1],
111 infos[2],
112 *(PolymorphicDowncast<const ComparisonDescriptor*>(&descriptor)),
113 reasonIfUnsupported);
114 case LayerType::Concat:
115 {
116 std::vector<const TensorInfo*> inputInfos;
117 for (uint32_t i = 0; i < (infos.size() - 1); i++)
118 {
119 inputInfos.push_back(&infos[i]);
120 }
121 return IsConcatSupported(inputInfos,
122 infos[infos.size() - 1],
123 *(PolymorphicDowncast<const OriginsDescriptor*>(&descriptor)),
124 reasonIfUnsupported);
125 }
126 case LayerType::Constant:
127 return IsConstantSupported(infos[0], reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000128 case LayerType::ConvertFp16ToFp32:
129 return IsConvertFp16ToFp32Supported(infos[0], infos[1], reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000130 case LayerType::ConvertFp32ToFp16:
131 return IsConvertFp32ToFp16Supported(infos[0], infos[1], reasonIfUnsupported);
132 case LayerType::Convolution2d:
133 {
134 if (infos.size() != 4)
135 {
136 throw InvalidArgumentException("Invalid number of Convolution2d TensorInfos. "
137 "TensorInfos should be of format: {input, output, weights, biases}.");
138 }
139
140 auto desc = *(PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor));
141 if (infos[3] == TensorInfo())
142 {
143 return IsConvolution2dSupported(infos[0],
144 infos[1],
145 desc,
146 infos[2],
147 EmptyOptional(),
148 reasonIfUnsupported);
149 }
150 else
151 {
152 return IsConvolution2dSupported(infos[0],
153 infos[1],
154 desc,
155 infos[2],
156 infos[3],
157 reasonIfUnsupported);
158 }
159 }
160 case LayerType::DepthToSpace:
161 return IsDepthToSpaceSupported(infos[0],
162 infos[1],
163 *(PolymorphicDowncast<const DepthToSpaceDescriptor*>(&descriptor)),
164 reasonIfUnsupported);
165 case LayerType::DepthwiseConvolution2d:
166 {
167 if (infos.size() != 4)
168 {
169 throw InvalidArgumentException("Invalid number of DepthwiseConvolution2d TensorInfos. "
170 "TensorInfos should be of format: {input, output, weights, biases}.");
171 }
172
173 auto desc = *(PolymorphicDowncast<const DepthwiseConvolution2dDescriptor*>(&descriptor));
174 if (infos[3] == TensorInfo())
175 {
176 return IsDepthwiseConvolutionSupported(infos[0],
177 infos[1],
178 desc,
179 infos[2],
180 EmptyOptional(),
181 reasonIfUnsupported);
182 }
183 else
184 {
185 return IsDepthwiseConvolutionSupported(infos[0],
186 infos[1],
187 desc,
188 infos[2],
189 infos[3],
190 reasonIfUnsupported);
191 }
192 }
193 case LayerType::Dequantize:
194 return IsDequantizeSupported(infos[0], infos[1], reasonIfUnsupported);
195 case LayerType::Division:
196 return IsDivisionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
Mike Kelly3ec30772023-03-08 13:47:17 +0000197 case LayerType::ElementwiseBinary:
198 {
199 std::array<DataType, 7> supportedTypes =
200 {
201 DataType::Float32,
202 DataType::Float16,
203 DataType::QAsymmS8,
204 DataType::QAsymmU8,
205 DataType::QSymmS16,
206 DataType::Signed32
207 };
208
209 bool supported = true;
210 supported &= CheckSupportRule(TypeAnyOf(infos[0], supportedTypes), reasonIfUnsupported,
211 "Reference elementwise unary: input type not supported");
212
213 supported &= CheckSupportRule(TypeAnyOf(infos[1], supportedTypes), reasonIfUnsupported,
214 "Reference elementwise unary: input type not supported");
215
216 supported &= CheckSupportRule(TypeAnyOf(infos[2], supportedTypes), reasonIfUnsupported,
217 "Reference elementwise unary: output type not supported");
218
219 supported &= CheckSupportRule(TypesAreEqual(infos[0], infos[1]), reasonIfUnsupported,
220 "Reference elementwise unary: input types not matching");
221
222 supported &= CheckSupportRule(TypesAreEqual(infos[0], infos[2]), reasonIfUnsupported,
223 "Reference elementwise unary: input and output types not matching");
224
225 return supported;
226 }
Cathal Corbett34b429c2021-12-24 12:24:40 +0000227 case LayerType::ElementwiseUnary:
228 return IsElementwiseUnarySupported(infos[0],
229 infos[1],
230 *(PolymorphicDowncast<const ElementwiseUnaryDescriptor*>(&descriptor)),
231 reasonIfUnsupported);
232 case LayerType::Fill:
233 return IsFillSupported(infos[0],
234 infos[1],
235 *(PolymorphicDowncast<const FillDescriptor*>(&descriptor)),
236 reasonIfUnsupported);
237 case LayerType::Floor:
238 return IsFloorSupported(infos[0], infos[1], reasonIfUnsupported);
239 case LayerType::FullyConnected:
240 return IsFullyConnectedSupported(infos[0],
241 infos[1],
242 infos[2],
243 infos[3],
244 *(PolymorphicDowncast<const FullyConnectedDescriptor*>(&descriptor)),
245 reasonIfUnsupported);
246 case LayerType::Gather:
247 return IsGatherSupported(infos[0],
248 infos[1],
249 infos[2],
250 *(PolymorphicDowncast<const GatherDescriptor*>(&descriptor)),
251 reasonIfUnsupported);
Teresa Charlinb2d3ec52022-04-12 22:07:09 +0100252 case LayerType::GatherNd:
253 return IsGatherNdSupported(infos[0],
254 infos[1],
255 infos[2],
256 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000257 case LayerType::Input:
258 return IsInputSupported(infos[0], reasonIfUnsupported);
259 case LayerType::InstanceNormalization:
260 return IsInstanceNormalizationSupported(infos[0],
261 infos[1],
262 *(PolymorphicDowncast<const InstanceNormalizationDescriptor*>
263 (&descriptor)),
264 reasonIfUnsupported);
265 case LayerType::L2Normalization:
266 return IsL2NormalizationSupported(infos[0],
267 infos[1],
268 *(PolymorphicDowncast<const L2NormalizationDescriptor*>(&descriptor)),
269 reasonIfUnsupported);
270 case LayerType::LogicalBinary:
271 return IsLogicalBinarySupported(infos[0],
272 infos[1],
273 infos[2],
274 *(PolymorphicDowncast<const LogicalBinaryDescriptor*>(&descriptor)),
275 reasonIfUnsupported);
276 case LayerType::LogSoftmax:
277 return IsLogSoftmaxSupported(infos[0],
278 infos[1],
279 *(PolymorphicDowncast<const LogSoftmaxDescriptor*>(&descriptor)),
280 reasonIfUnsupported);
281 case LayerType::Lstm:
282 return IsLstmSupported(infos[0],
283 infos[1],
284 infos[2],
285 infos[3],
286 infos[4],
287 infos[5],
288 infos[6],
289 *(PolymorphicDowncast<const LstmDescriptor*>(&descriptor)),
290 lstmParamsInfo.value(),
291 reasonIfUnsupported);
292 case LayerType::QLstm:
293 return IsQLstmSupported(infos[0],
294 infos[1],
295 infos[2],
296 infos[3],
297 infos[4],
298 infos[5],
299 *(PolymorphicDowncast<const QLstmDescriptor*>(&descriptor)),
300 lstmParamsInfo.value(),
301 reasonIfUnsupported);
302 case LayerType::Maximum:
303 return IsMaximumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
304 case LayerType::Mean:
305 return IsMeanSupported(infos[0],
306 infos[1],
307 *(PolymorphicDowncast<const MeanDescriptor*>(&descriptor)),
308 reasonIfUnsupported);
309 case LayerType::Minimum:
310 return IsMinimumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
311 case LayerType::Multiplication:
312 return IsMultiplicationSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
313 case LayerType::Normalization:
314 return IsNormalizationSupported(infos[0],
315 infos[1],
316 *(PolymorphicDowncast<const NormalizationDescriptor*>(&descriptor)),
317 reasonIfUnsupported);
318 case LayerType::Output:
319 return IsOutputSupported(infos[0], reasonIfUnsupported);
320 case LayerType::Pad:
321 return IsPadSupported(infos[0],
322 infos[1],
323 *(PolymorphicDowncast<const PadDescriptor*>(&descriptor)),
324 reasonIfUnsupported);
325 case LayerType::Permute:
326 return IsPermuteSupported(infos[0],
327 infos[1],
328 *(PolymorphicDowncast<const PermuteDescriptor*>(&descriptor)),
329 reasonIfUnsupported);
330 case LayerType::Pooling2d:
331 return IsPooling2dSupported(infos[0],
332 infos[1],
333 *(PolymorphicDowncast<const Pooling2dDescriptor*>(&descriptor)),
334 reasonIfUnsupported);
335 case LayerType::Prelu:
336 return IsPreluSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
337 case LayerType::Quantize:
338 return IsQuantizeSupported(infos[0], infos[1], reasonIfUnsupported);
339 case LayerType::Reshape:
340 return IsReshapeSupported(infos[0],
341 infos[1],
342 *(PolymorphicDowncast<const ReshapeDescriptor*>(&descriptor)),
343 reasonIfUnsupported);
344 case LayerType::Resize:
345 return IsResizeSupported(infos[0],
346 infos[1],
347 *(PolymorphicDowncast<const ResizeDescriptor*>(&descriptor)),
348 reasonIfUnsupported);
Tianle Cheng988354d2023-06-28 13:20:47 +0100349 case LayerType::ReverseV2:
350 return IsReverseV2Supported(infos[0],
351 infos[1],
Tracy Narinebb8d7592023-07-13 16:50:54 +0100352 infos[2],
Tianle Cheng988354d2023-06-28 13:20:47 +0100353 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000354 case LayerType::Reduce:
355 return IsReduceSupported(infos[0],
356 infos[1],
357 *(PolymorphicDowncast<const ReduceDescriptor*>(&descriptor)),
358 reasonIfUnsupported);
359 case LayerType::Slice:
360 return IsSliceSupported(infos[0],
361 infos[1],
362 *(PolymorphicDowncast<const SliceDescriptor*>(&descriptor)),
363 reasonIfUnsupported);
364 case LayerType::Softmax:
365 return IsSoftmaxSupported(infos[0],
366 infos[1],
367 *(PolymorphicDowncast<const SoftmaxDescriptor*>(&descriptor)),
368 reasonIfUnsupported);
369 case LayerType::SpaceToBatchNd:
370 return IsSpaceToBatchNdSupported(infos[0],
371 infos[1],
372 *(PolymorphicDowncast<const SpaceToBatchNdDescriptor*>(&descriptor)),
373 reasonIfUnsupported);
374 case LayerType::SpaceToDepth:
375 return IsSpaceToDepthSupported(infos[0],
376 infos[1],
377 *(PolymorphicDowncast<const SpaceToDepthDescriptor*>(&descriptor)),
378 reasonIfUnsupported);
379 case LayerType::Splitter:
380 {
381 std::vector<TensorInfo> outputInfos;
382 for (uint32_t i = 1; i < infos.size(); i++)
383 {
384 outputInfos.push_back(infos[i]);
385 }
386 return IsSplitterSupported(infos[0],
387 {outputInfos.begin(), outputInfos.end()},
388 *(PolymorphicDowncast<const ViewsDescriptor*>(&descriptor)),
389 reasonIfUnsupported);
390 }
391 case LayerType::Stack:
392 {
393 std::vector<const TensorInfo*> inputInfos;
394 for (uint32_t i = 0; i < infos.size() - 1; i++)
395 {
396 inputInfos.push_back(&infos[i]);
397 }
398 return IsStackSupported(inputInfos,
399 infos[infos.size() - 1],
400 *(PolymorphicDowncast<const StackDescriptor*>(&descriptor)),
401 reasonIfUnsupported);
402 }
403 case LayerType::StridedSlice:
404 return IsStridedSliceSupported(infos[0],
405 infos[1],
406 *(PolymorphicDowncast<const StridedSliceDescriptor*>(&descriptor)),
407 reasonIfUnsupported);
408 case LayerType::Subtraction:
409 return IsSubtractionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
Teresa Charlin79a06a52023-07-13 17:16:45 +0100410 case LayerType::Tile:
411 return IsTileSupported(infos[0],
412 infos[1],
413 *(PolymorphicDowncast<const TileDescriptor*>(&descriptor)),
414 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000415 case LayerType::Transpose:
416 return IsTransposeSupported(infos[0],
417 infos[1],
418 *(PolymorphicDowncast<const TransposeDescriptor*>(&descriptor)),
419 reasonIfUnsupported);
420 case LayerType::TransposeConvolution2d:
421 {
422 if (infos.size() != 4)
423 {
424 throw InvalidArgumentException("Invalid number of TransposeConvolution2d TensorInfos. "
425 "TensorInfos should be of format: {input, output, weights, biases}.");
426 }
427
428 auto desc = *(PolymorphicDowncast<const TransposeConvolution2dDescriptor*>(&descriptor));
429 if (infos[3] == TensorInfo())
430 {
431 return IsTransposeConvolution2dSupported(infos[0],
432 infos[1],
433 desc,
434 infos[2],
435 EmptyOptional(),
436 reasonIfUnsupported);
437 }
438 else
439 {
440 return IsTransposeConvolution2dSupported(infos[0],
441 infos[1],
442 desc,
443 infos[2],
444 infos[3],
445 reasonIfUnsupported);
446 }
447 }
448 case LayerType::Cast:
449 return IsCastSupported(infos[0], infos[1], reasonIfUnsupported);
450 case LayerType::ChannelShuffle:
451 return IsChannelShuffleSupported(infos[0],
452 infos[1],
453 *(PolymorphicDowncast<const ChannelShuffleDescriptor*>(&descriptor)),
454 reasonIfUnsupported);
455 case LayerType::Convolution3d:
456 {
457 if (infos.size() != 4)
458 {
459 throw InvalidArgumentException("Invalid number of Convolution3d TensorInfos. "
460 "TensorInfos should be of format: {input, output, weights, biases}.");
461 }
462
463 auto desc = *(PolymorphicDowncast<const Convolution3dDescriptor*>(&descriptor));
464 if (infos[3] == TensorInfo())
465 {
466 return IsConvolution3dSupported(infos[0],
467 infos[1],
468 desc,
469 infos[2],
470 EmptyOptional(),
471 reasonIfUnsupported);
472 }
473 else
474 {
475 return IsConvolution3dSupported(infos[0],
476 infos[1],
477 desc,
478 infos[2],
479 infos[3],
480 reasonIfUnsupported);
481 }
482 }
483 case LayerType::Debug:
484 return IsDebugSupported(infos[0], infos[1], reasonIfUnsupported);
485 case LayerType::DetectionPostProcess:
486 return IsDetectionPostProcessSupported(infos[0],
487 infos[1],
488 infos[2],
489 infos[3],
490 infos[4],
491 infos[5],
492 infos[6],
493 *(PolymorphicDowncast<const DetectionPostProcessDescriptor*>
494 (&descriptor)),
495 reasonIfUnsupported);
496 case LayerType::FakeQuantization:
497 return IsFakeQuantizationSupported(infos[0],
498 *(PolymorphicDowncast<const FakeQuantizationDescriptor*>(&descriptor)),
499 reasonIfUnsupported);
500 case LayerType::MemCopy:
501 return IsMemCopySupported(infos[0], infos[1], reasonIfUnsupported);
502 case LayerType::Rank:
503 return IsRankSupported(infos[0], infos[1], reasonIfUnsupported);
504 case LayerType::Shape:
505 return IsShapeSupported(infos[0], infos[1], reasonIfUnsupported);
506 case LayerType::UnidirectionalSequenceLstm:
507 {
508 if (infos.size() != 6)
509 {
510 throw InvalidArgumentException("Invalid number of UnidirectionalSequenceLstm TensorInfos. TensorInfos "
511 "should be of format: {input, outputStateIn, cellStateIn, "
512 "hiddenStateOutputVal, cellStateOutputVal, output}");
513 }
514 auto desc = *(PolymorphicDowncast<const UnidirectionalSequenceLstmDescriptor*>(&descriptor));
Mike Kelly12994962022-04-21 11:57:09 +0100515 return IsUnidirectionalSequenceLstmSupported(infos[0],
516 infos[1],
517 infos[2],
518 infos[3],
519 infos[4],
520 infos[5],
521 desc,
522 lstmParamsInfo.value(),
523 reasonIfUnsupported);
Cathal Corbett34b429c2021-12-24 12:24:40 +0000524 }
525 case LayerType::Pooling3d:
526 return IsPooling3dSupported(infos[0],
527 infos[1],
528 *(PolymorphicDowncast<const Pooling3dDescriptor*>(&descriptor)),
529 reasonIfUnsupported);
530 case LayerType::Map:
531 return true;
532 case LayerType::Unmap:
533 return true;
534 case LayerType::MemImport:
535 return LayerSupportBase::IsMemImportSupported(infos[0], infos[1], reasonIfUnsupported);
536 case LayerType::Merge:
537 return LayerSupportBase::IsMergeSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
538 case LayerType::QuantizedLstm:
539 return LayerSupportBase::IsQuantizedLstmSupported(infos[0],
540 infos[1],
541 infos[2],
542 infos[3],
543 infos[4],
544 quantizedLstmInputParamsInfo.value(),
545 reasonIfUnsupported);
546 default:
Teresa Charlin9145e382023-08-17 18:44:58 +0100547 // layers not supported in reference by default:
548 // precompiled, standin, switch, fused
Cathal Corbett34b429c2021-12-24 12:24:40 +0000549 return false;
550 }
551}
552
arovir011c7c81b2018-10-08 11:34:28 +0100553bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
554 const TensorInfo& output,
555 const ActivationDescriptor& descriptor,
556 Optional<std::string&> reasonIfUnsupported) const
557{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000558 bool supported = true;
559
560 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +0000561 std::array<DataType,6> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000562 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100563 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000564 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000565 DataType::QAsymmU8,
566 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000567 };
568
569 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
570 "Reference activation: input type not supported.");
571
572 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
573 "Reference activation: output type not supported.");
574
575 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
576 "Reference activation: input and output types mismatched.");
577
578 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
579 "Reference activation: input and output shapes are of different rank.");
580
581
582 struct ActivationFunctionSupported : public Rule
583 {
584 ActivationFunctionSupported(const ActivationDescriptor& desc)
585 {
586 switch(desc.m_Function)
587 {
588 case ActivationFunction::Abs:
589 case ActivationFunction::BoundedReLu:
David Monahan3b3c3812020-02-25 09:03:29 +0000590 case ActivationFunction::Elu:
Colm Donelan03fbeaf2020-02-26 15:39:23 +0000591 case ActivationFunction::HardSwish:
Derek Lamberti50db4e82019-03-13 14:16:15 +0000592 case ActivationFunction::LeakyReLu:
593 case ActivationFunction::Linear:
594 case ActivationFunction::ReLu:
595 case ActivationFunction::Sigmoid:
596 case ActivationFunction::SoftReLu:
597 case ActivationFunction::Sqrt:
598 case ActivationFunction::Square:
599 case ActivationFunction::TanH:
600 {
601 m_Res = true;
602 break;
603 }
604 default:
605 {
606 m_Res = false;
607 break;
608 }
609 }
610 }
611 };
612
613 // Function is supported
614 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
615 "Reference activation: function not supported.");
616
617 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100618}
619
620bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
621 const TensorInfo& input1,
622 const TensorInfo& output,
623 Optional<std::string&> reasonIfUnsupported) const
624{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000625 bool supported = true;
626
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100627 std::array<DataType,7> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000628 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100629 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000630 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000631 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100632 DataType::QSymmS16,
633 DataType::Signed32
Derek Lamberti50db4e82019-03-13 14:16:15 +0000634 };
635
636 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
637 "Reference addition: input 0 is not a supported type.");
638
639 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
640 "Reference addition: input 1 is not a supported type.");
641
642 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
643 "Reference addition: output is not a supported type.");
644
645 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
646 "Reference addition: input 0 and Input 1 types are mismatched");
647
648 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
649 "Reference addition: input and output types are mismatched");
650
651 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
652 "Reference addition: shapes are not suitable for implicit broadcast.");
653
654 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100655}
656
Nikhil Raj68c2c902019-09-19 11:21:11 +0100657bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
658 const armnn::ArgMinMaxDescriptor &descriptor,
659 armnn::Optional<std::string &> reasonIfUnsupported) const
660{
Jan Eilers8eb25602020-03-09 12:13:48 +0000661 IgnoreUnused(descriptor);
Nikhil Raj68c2c902019-09-19 11:21:11 +0100662
Mike Kelly1f140f72021-04-06 12:25:55 +0100663 std::array<DataType, 8> supportedInputTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100664 {
Teresa Charline300b362020-05-25 10:01:03 +0100665 DataType::Float16,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100666 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100667 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000668 DataType::QAsymmU8,
669 DataType::QSymmS16,
Mike Kelly1f140f72021-04-06 12:25:55 +0100670 DataType::Signed32,
671 DataType::Signed64
672 };
673
674 std::array<DataType,2> supportedOutputTypes = {
675 DataType::Signed32,
676 DataType::Signed64
Nikhil Raj68c2c902019-09-19 11:21:11 +0100677 };
678
679 bool supported = true;
680
Mike Kelly1f140f72021-04-06 12:25:55 +0100681 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100682 "Reference ArgMinMax: input is not a supported type.");
Mike Kelly1f140f72021-04-06 12:25:55 +0100683 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100684 "Reference ArgMinMax: output type not supported");
685
686 return supported;
687}
688
Samuel Yap6b478092022-07-06 15:36:03 +0100689bool RefLayerSupport::IsBatchMatMulSupported(const TensorInfo& inputX,
690 const TensorInfo& inputY,
691 const TensorInfo& output,
692 const BatchMatMulDescriptor& descriptor,
693 Optional<std::string &> reasonIfUnsupported) const
694{
695 IgnoreUnused(descriptor);
696
697 std::array<DataType, 6> supportedTypes =
698 {
Samuel Yap6b478092022-07-06 15:36:03 +0100699 DataType::Float16,
700 DataType::Float32,
701 DataType::QAsymmS8,
702 DataType::QAsymmU8,
703 DataType::QSymmS16
704 };
705
706 bool supported = true;
707
708 supported &= CheckSupportRule(TypeAnyOf(inputX, supportedTypes), reasonIfUnsupported,
709 "Reference batch matrix multiplication: input X is not a supported type");
710
711 supported &= CheckSupportRule(TypeAnyOf(inputY, supportedTypes), reasonIfUnsupported,
712 "Reference batch matrix multiplication: input Y is not a supported type");
713
714 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
715 "Reference batch matrix multiplication: output is not a supported type");
716
717 supported &= CheckSupportRule(TypesAreEqual(inputX, inputY), reasonIfUnsupported,
718 "Reference batch matrix multiplication: input X and input Y types are mismatched");
719
720 supported &= CheckSupportRule(TypesAreEqual(inputX, output), reasonIfUnsupported,
721 "Reference batch matrix multiplication: inputs and output types are mismatched");
722
723 supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputX, 2),
724 reasonIfUnsupported,
725 "Reference batch matrix multiplication: input X is not of rank 2 or greater");
726
727 supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputY, 2),
728 reasonIfUnsupported,
729 "Reference batch matrix multiplication: input Y is not of rank 2 or greater");
730
731 return supported;
732}
733
arovir011c7c81b2018-10-08 11:34:28 +0100734bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
735 const TensorInfo& output,
736 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100737 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100738 const TensorInfo& beta,
739 const TensorInfo& gamma,
740 const BatchNormalizationDescriptor& descriptor,
741 Optional<std::string&> reasonIfUnsupported) const
742{
Jan Eilers8eb25602020-03-09 12:13:48 +0000743 IgnoreUnused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100744
Sadik Armagan303980c2020-04-17 12:45:14 +0100745 std::array<DataType, 6> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100746 {
747 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100748 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100749 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000750 DataType::QAsymmU8,
751 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100752 };
753
754 bool supported = true;
755
756 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
757 "Reference batch normalization: input is not a supported type.");
758
759 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
760 "Reference batch normalization: output is not a supported type.");
761
762 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
763 "Reference batch normalization: input and output types are mismatched");
764
765 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
766 "Reference batch normalization: mean is not a supported type.");
767
768 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
769 "Reference batch normalization: variance is not a supported type.");
770
771 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
772 "Reference batch normalization: beta is not a supported type.");
773
774 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
775 "Reference batch normalization: gamma is not a supported type.");
776
777 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100778}
779
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000780bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
781 const TensorInfo& output,
782 const BatchToSpaceNdDescriptor& descriptor,
783 Optional<std::string&> reasonIfUnsupported) const
784{
Jan Eilers8eb25602020-03-09 12:13:48 +0000785 IgnoreUnused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100786
787 bool supported = true;
788
789 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
790 std::string inputTensorStr = "input";
791 std::string outputTensorStr = "output";
792
793 // Define supported types.
Sadik Armagan303980c2020-04-17 12:45:14 +0100794 std::array<DataType,6> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100795 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000796 DataType::Float32,
797 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100798 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000799 DataType::QAsymmU8,
800 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100801 };
802
803 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
804 "Reference BatchToSpaceNd: input type not supported.");
805
806 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
807 "Reference BatchToSpaceNd: output type not supported.");
808
809 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
810 "Reference BatchToSpaceNd: input and output types mismatched.");
811
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100812 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000813}
814
Idriss Chaouch98e383e2023-08-28 14:28:31 +0100815bool RefLayerSupport::IsBroadcastToSupported(const TensorInfo& input,
816 const TensorInfo& output,
817 const BroadcastToDescriptor& descriptor,
818 Optional<std::string&> reasonIfUnsupported) const
819{
820 IgnoreUnused(descriptor);
821
822 bool supported = true;
823
824 std::array<DataType, 8> supportedTypes
825 {
826 DataType::Float32,
827 DataType::Float16,
828 DataType::QAsymmS8,
829 DataType::QAsymmU8,
830 DataType::QSymmS8,
831 DataType::QSymmS16,
832 DataType::Signed32,
833 DataType::Signed64
834 };
835
836 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
837 "BroadcastTo: input type not supported.");
838
839 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
840 "BroadcastTo: output type not supported");
841
842 return supported;
843}
844
mathad01b392e982021-04-07 12:07:30 +0100845bool RefLayerSupport::IsCastSupported(const TensorInfo& input,
846 const TensorInfo& output,
847 Optional<std::string&> reasonIfUnsupported) const
848{
849 std::array<DataType, 9> supportedInputTypes =
Idriss Chaouch98e383e2023-08-28 14:28:31 +0100850 {
851 DataType::Float32,
852 DataType::Float16,
853 DataType::QSymmS8,
854 DataType::QAsymmS8,
855 DataType::QAsymmU8,
856 DataType::QSymmS16,
857 DataType::Signed32
858 };
mathad01b392e982021-04-07 12:07:30 +0100859
860 bool supported = true;
861 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
862 "Reference cast: input is not a supported type");
863
864
865 supported &= CheckSupportRule(TypeAnyOf(output, supportedInputTypes), reasonIfUnsupported,
866 "Reference cast: output is not a supported type");
867
868 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
869 "Reference cast: input and output shapes have different number of total elements");
870
871 return supported;
872}
873
Simon Obute51f67772021-09-03 15:50:13 +0100874bool RefLayerSupport::IsChannelShuffleSupported(const TensorInfo& input,
875 const TensorInfo& output,
876 const ChannelShuffleDescriptor& descriptor,
877 Optional<std::string&> reasonIfUnsupported) const
878{
879 IgnoreUnused(descriptor);
880 bool supported = true;
881
882 // Define supported output and inputs types.
883 std::array<DataType, 7> supportedTypes =
884 {
Simon Obute51f67772021-09-03 15:50:13 +0100885 DataType::Float32,
886 DataType::Float16,
887 DataType::QAsymmS8,
888 DataType::QAsymmU8,
889 DataType::QSymmS8,
890 DataType::QSymmS16
891 };
892
893 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
894 "Reference ChannelShuffle: input is not a supported type.");
895
896 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
897 "Reference ChannelShuffle: output is not a supported type.");
898
899 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
900 "Reference ChannelShuffle: input and output types are mismatched.");
901
902 return supported;
903}
904
905
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100906bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
907 const TensorInfo& input1,
908 const TensorInfo& output,
909 const ComparisonDescriptor& descriptor,
910 Optional<std::string&> reasonIfUnsupported) const
911{
Jan Eilers8eb25602020-03-09 12:13:48 +0000912 IgnoreUnused(descriptor);
Sadik Armagan303980c2020-04-17 12:45:14 +0100913 std::array<DataType, 8> supportedInputTypes =
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100914 {
Sadik Armaganb60dd242020-03-19 13:53:16 +0000915 DataType::Boolean,
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100916 DataType::Float32,
917 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100918 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000919 DataType::QAsymmU8,
Sadik Armaganb60dd242020-03-19 13:53:16 +0000920 DataType::QSymmS16,
921 DataType::Signed32
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100922 };
923
924 bool supported = true;
925 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
926 "Reference comparison: input 0 is not a supported type");
927
928 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
929 "Reference comparison: input 0 and Input 1 types are mismatched");
930
931 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
932 "Reference comparison: output is not of type Boolean");
933
934 return supported;
935}
936
Jim Flynn906f9462019-05-10 13:55:21 +0100937bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
938 const TensorInfo& output,
Cathal Corbett34b429c2021-12-24 12:24:40 +0000939 const OriginsDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100940 Optional<std::string&> reasonIfUnsupported) const
941{
Jan Eilers8eb25602020-03-09 12:13:48 +0000942 IgnoreUnused(descriptor);
Jim Flynne242f2d2019-05-22 14:24:13 +0100943
944 bool supported = true;
Teresa Charlin6abc7ee2022-02-22 17:32:27 +0000945 std::array<DataType,7> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100946 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000947 DataType::Float32,
948 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000949 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100950 DataType::QAsymmU8,
Teresa Charlin6abc7ee2022-02-22 17:32:27 +0000951 DataType::QSymmS16,
952 DataType::Signed32
Jim Flynne242f2d2019-05-22 14:24:13 +0100953 };
954
955 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
956 "Reference concatenation: output type not supported");
957 for (const TensorInfo* input : inputs)
958 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100959 ARMNN_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100960 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
961 "Reference concatenation: input type not supported");
962
963 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
964 "Reference concatenation: input and output types mismatched.");
965 }
966
967 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100968}
969
arovir011c7c81b2018-10-08 11:34:28 +0100970bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
971 Optional<std::string&> reasonIfUnsupported) const
972{
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100973 std::array<DataType,8> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100974 {
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100975 DataType::Float16,
Nina Drozd58ef2c62019-05-16 12:09:18 +0100976 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +0000977 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100978 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000979 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100980 DataType::QSymmS16,
981 DataType::Signed32
Nina Drozd58ef2c62019-05-16 12:09:18 +0100982 };
983
984 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
985 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100986}
987
988bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
989 const TensorInfo& output,
990 Optional<std::string&> reasonIfUnsupported) const
991{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100992 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
993 input.GetDataType(),
994 &TrueFunc<>,
995 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000996 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000997 &FalseFuncI32<>,
998 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100999 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
1000 output.GetDataType(),
1001 &FalseOutputFuncF16<>,
1002 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +00001003 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +00001004 &FalseFuncI32<>,
1005 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +01001006}
1007
1008bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
1009 const TensorInfo& output,
1010 Optional<std::string&> reasonIfUnsupported) const
1011{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001012 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
1013 input.GetDataType(),
1014 &FalseInputFuncF16<>,
1015 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +00001016 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +00001017 &FalseFuncI32<>,
1018 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001019 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
1020 output.GetDataType(),
1021 &TrueFunc<>,
1022 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +00001023 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +00001024 &FalseFuncI32<>,
1025 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +01001026}
1027
1028bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
1029 const TensorInfo& output,
1030 const Convolution2dDescriptor& descriptor,
1031 const TensorInfo& weights,
1032 const Optional<TensorInfo>& biases,
1033 Optional<std::string&> reasonIfUnsupported) const
1034{
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001035 bool supported = true;
1036
1037 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001038 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001039 {
1040 DataType::Float32,
1041 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001042 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001043 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001044 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001045 DataType::QSymmS16
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001046 };
1047
1048 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001049 "Reference Convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001050
1051 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001052 "Reference Convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001053
Ryan OShea31441592022-11-07 16:20:48 +00001054 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1055 "Reference Convolution2d: input and output types mismatched.");
1056
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001057
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001058 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +00001059 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001060 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001061 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001062 {
Sadik Armagan303980c2020-04-17 12:45:14 +01001063 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001064 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01001065 DataType::QSymmS8
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001066 };
1067
1068 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001069 "Reference Convolution2d: weights type not supported for quantized input.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001070 }
1071 else
1072 {
1073 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001074 "Reference Convolution2d: weights is not a supported type.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001075
1076 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001077 "Reference Convolution2d: input and weights types mismatched.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001078 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001079
1080 if (biases.has_value())
1081 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001082 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001083 {
1084 DataType::Float32,
1085 DataType::Float16,
1086 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001087 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +00001088
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001089 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001090 "Reference Convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001091 }
Jan Eilers8eb25602020-03-09 12:13:48 +00001092 IgnoreUnused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +01001093
1094 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001095}
1096
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001097bool RefLayerSupport::IsConvolution3dSupported(const TensorInfo& input,
1098 const TensorInfo& output,
1099 const Convolution3dDescriptor& descriptor,
1100 const TensorInfo& weights,
1101 const Optional<TensorInfo>& biases,
1102 Optional<std::string&> reasonIfUnsupported) const
1103{
1104 bool supported = true;
1105
1106 // Define supported types.
1107 std::array<DataType,7> supportedTypes =
1108 {
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001109 DataType::Float32,
1110 DataType::Float16,
1111 DataType::QAsymmS8,
1112 DataType::QAsymmU8,
1113 DataType::QSymmS8,
1114 DataType::QSymmS16
1115 };
1116
1117 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1118 "Reference Convolution3d: input is not a supported type.");
1119
1120 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1121 "Reference Convolution3d: output is not a supported type.");
1122
1123 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1124 "Reference Convolution3d: input and output types mismatched.");
1125
1126 const DataType inputType = input.GetDataType();
1127 if (IsQuantized8BitType(inputType))
1128 {
1129 std::array<DataType, 3> supportedWeightTypes =
1130 {
1131 DataType::QAsymmS8,
1132 DataType::QAsymmU8,
1133 DataType::QSymmS8
1134 };
1135
1136 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1137 "Reference Convolution3d: weights type not supported for quantized input.");
1138 }
1139 else
1140 {
1141 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1142 "Reference Convolution3d: weights is not a supported type.");
1143
1144 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1145 "Reference Convolution3d: input and weights types mismatched.");
1146 }
1147
1148 if (biases.has_value())
1149 {
1150 std::array<DataType,4> biasesSupportedTypes =
1151 {
Matthew Sloyanb63a3112021-09-08 13:05:51 +01001152 DataType::Float32,
1153 DataType::Float16,
1154 DataType::Signed32
1155 };
1156
1157 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1158 "Reference Convolution3d: biases is not a supported type.");
1159 }
1160 IgnoreUnused(descriptor);
1161
1162 return supported;
1163}
1164
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001165bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
1166 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001167 Optional<std::string&> reasonIfUnsupported) const
1168{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001169 bool supported = true;
1170
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001171 std::array<DataType, 8> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001172 {
Narumol Prangnawarat403a1852020-03-12 14:24:13 +00001173 DataType::BFloat16,
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +00001174 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001175 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001176 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001177 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +00001178 DataType::QSymmS8,
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +00001179 DataType::QSymmS16,
1180 DataType::Signed32
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001181 };
1182
1183 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001184 "Reference for Debug layer: input type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001185
1186 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001187 "Reference for Debug layer: output type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001188
1189 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001190 "Reference for Debug layer: input and output types are mismatched");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +01001191
1192 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +00001193}
1194
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001195bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
1196 const TensorInfo& output,
1197 const DepthToSpaceDescriptor& descriptor,
1198 Optional<std::string&> reasonIfUnsupported) const
1199{
Jan Eilers8eb25602020-03-09 12:13:48 +00001200 IgnoreUnused(descriptor);
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001201 bool supported = true;
1202
Sadik Armagan303980c2020-04-17 12:45:14 +01001203 std::array<DataType,6> supportedTypes =
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001204 {
1205 DataType::Float32,
1206 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001207 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001208 DataType::QAsymmU8,
1209 DataType::QSymmS16
Aron Virginas-Tar73f66422019-09-23 19:11:59 +01001210 };
1211
1212 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1213 "Reference DepthToSpace: input type not supported");
1214
1215 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1216 "Reference DepthToSpace: output type not supported");
1217
1218 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1219 "Reference DepthToSpace: input and output types are mismatched");
1220
1221 return supported;
1222}
1223
arovir011c7c81b2018-10-08 11:34:28 +01001224bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
1225 const TensorInfo& output,
1226 const DepthwiseConvolution2dDescriptor& descriptor,
1227 const TensorInfo& weights,
1228 const Optional<TensorInfo>& biases,
1229 Optional<std::string&> reasonIfUnsupported) const
1230{
Sadik Armagan303980c2020-04-17 12:45:14 +01001231 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001232 bool supported = true;
1233
1234 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001235 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001236 {
1237 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001238 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001239 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001240 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001241 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001242 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001243 };
1244
1245 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1246 "Reference DepthwiseConvolution2d: input is not a supported type.");
1247
1248 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1249 "Reference DepthwiseConvolution2d: output is not a supported type.");
1250
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001251 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1252 "Reference DepthwiseConvolution2d: input and output types mismatched.");
1253
Teresa Charlind8df0262019-11-11 12:28:15 +00001254 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +00001255 if (IsQuantized8BitType(inputType))
Teresa Charlind8df0262019-11-11 12:28:15 +00001256 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01001257 std::array<DataType, 3> supportedWeightTypes =
Sadik Armagan303980c2020-04-17 12:45:14 +01001258 {
1259 DataType::QAsymmS8,
1260 DataType::QAsymmU8,
1261 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001262 };
Teresa Charlind8df0262019-11-11 12:28:15 +00001263
1264 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Sadik Armagan303980c2020-04-17 12:45:14 +01001265 "Reference DepthwiseConvolution2d: weights type not supported for "
1266 "quantized input.");
Teresa Charlind8df0262019-11-11 12:28:15 +00001267 }
1268 else
1269 {
1270 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1271 "Reference DepthwiseConvolution2d: weights is not a supported type.");
1272
1273 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1274 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
1275 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001276
1277 if (biases.has_value())
1278 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001279 std::array<DataType,4> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001280 {
1281 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001282 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001283 DataType::Signed32
1284 };
1285 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1286 "Reference DepthwiseConvolution2d: biases is not a supported type.");
1287 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001288
1289 return supported;
1290
arovir011c7c81b2018-10-08 11:34:28 +01001291}
1292
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001293bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
1294 const TensorInfo& output,
1295 Optional<std::string&> reasonIfUnsupported) const
1296{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001297 bool supported = true;
1298
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001299 std::array<DataType,5> supportedInputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00001300 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001301 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00001302 DataType::QSymmS8,
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001303 DataType::QSymmS16,
1304 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001305 };
1306
1307 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001308 "Reference for Dequantize layer: input type not supported.");
1309
Derek Lambertid466a542020-01-22 15:37:29 +00001310 supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
Teresa Charlin1b1950d2021-06-02 20:23:21 +01001311 "Reference for Dequantize layer: per-axis quantized input not supported.");
Derek Lambertid466a542020-01-22 15:37:29 +00001312
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001313 std::array<DataType,3> supportedOutputTypes = {
Jan Eilersf7107932019-11-01 11:09:36 +00001314 DataType::Float32,
1315 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001316 };
1317
1318 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001319 "Reference for Dequantize layer: output type not supported.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001320
1321 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +00001322 "Reference for Dequantize layer: input/output shapes have different num total "
1323 "elements.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +01001324
1325 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +00001326}
1327
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001328bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
1329 const TensorInfo& scores,
1330 const TensorInfo& anchors,
1331 const TensorInfo& detectionBoxes,
1332 const TensorInfo& detectionClasses,
1333 const TensorInfo& detectionScores,
1334 const TensorInfo& numDetections,
1335 const DetectionPostProcessDescriptor& descriptor,
1336 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001337{
Jan Eilers8eb25602020-03-09 12:13:48 +00001338 IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
Derek Lamberti901ea112019-12-10 22:07:09 +00001339
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001340 bool supported = true;
1341
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001342 std::array<DataType,6> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001343 {
1344 DataType::Float32,
Sadik Armaganaa41d5d2020-11-16 14:27:52 +00001345 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001346 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001347 DataType::QAsymmU8,
1348 DataType::QSymmS16
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001349 };
1350
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001351 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001352 "Reference DetectionPostProcess: input 0 is not a supported type.");
1353
Derek Lamberti6a5e5e82019-12-05 14:41:20 +00001354 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +01001355 "Reference DetectionPostProcess: input 1 is not a supported type.");
1356
1357 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001358}
1359
Pablo Tellof0bd6832019-04-26 17:58:13 +01001360bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
1361 const TensorInfo& output,
1362 const DepthwiseConvolution2dDescriptor& descriptor,
1363 const TensorInfo& weights,
1364 const Optional<TensorInfo>& biases,
1365 Optional<std::string&> reasonIfUnsupported) const
1366{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001367 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +01001368}
1369
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +01001370bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +01001371 const TensorInfo& input1,
1372 const TensorInfo& output,
1373 Optional<std::string&> reasonIfUnsupported) const
1374{
Sadik Armagan2999a022019-04-09 14:20:12 +01001375 bool supported = true;
1376
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001377 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001378 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001379 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001380 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001381 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001382 DataType::QSymmS16,
1383 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001384 };
1385
1386 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1387 "Reference division: input 0 is not a supported type.");
1388
1389 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1390 "Reference division: input 1 is not a supported type.");
1391
1392 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1393 "Reference division: output is not a supported type.");
1394
1395 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1396 "Reference division: input 0 and Input 1 types are mismatched");
1397
1398 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1399 "Reference division: input and output types are mismatched");
1400
1401 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1402 "Reference division: shapes are not suitable for implicit broadcast.");
1403
1404 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001405}
1406
josh minor4a3c6102020-01-06 16:40:46 -06001407bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
1408 const TensorInfo& output,
1409 const ElementwiseUnaryDescriptor& descriptor,
1410 Optional<std::string&> reasonIfUnsupported) const
1411{
Jan Eilers8eb25602020-03-09 12:13:48 +00001412 IgnoreUnused(descriptor);
josh minor4a3c6102020-01-06 16:40:46 -06001413
Sadik Armagan303980c2020-04-17 12:45:14 +01001414 std::array<DataType, 7> supportedTypes =
josh minor4a3c6102020-01-06 16:40:46 -06001415 {
1416 DataType::Float32,
1417 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001418 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -06001419 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +00001420 DataType::QSymmS16,
1421 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -06001422 };
1423
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001424 std::array<DataType, 1> logicalSupportedTypes =
1425 {
1426 DataType::Boolean
1427 };
1428
josh minor4a3c6102020-01-06 16:40:46 -06001429 bool supported = true;
1430
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001431 if (descriptor.m_Operation == UnaryOperation::LogicalNot)
1432 {
1433 supported &= CheckSupportRule(TypeAnyOf(input, logicalSupportedTypes), reasonIfUnsupported,
1434 "Reference elementwise unary: input type not supported");
josh minor4a3c6102020-01-06 16:40:46 -06001435
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001436 supported &= CheckSupportRule(TypeAnyOf(output, logicalSupportedTypes), reasonIfUnsupported,
1437 "Reference elementwise unary: output type not supported");
1438 }
1439 else
1440 {
1441 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1442 "Reference elementwise unary: input type not supported");
1443
1444 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1445 "Reference elementwise unary: output type not supported");
1446 }
josh minor4a3c6102020-01-06 16:40:46 -06001447
1448 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1449 "Reference elementwise unary: input and output types not matching");
1450
1451 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1452 "Reference elementwise unary: input and output shapes"
1453 "have different number of total elements");
1454
1455 return supported;
1456}
1457
arovir011c7c81b2018-10-08 11:34:28 +01001458bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
1459 const FakeQuantizationDescriptor& descriptor,
1460 Optional<std::string&> reasonIfUnsupported) const
1461{
Jan Eilers8eb25602020-03-09 12:13:48 +00001462 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001463 bool supported = true;
1464
1465 std::array<DataType,1> supportedTypes =
1466 {
1467 DataType::Float32
1468 };
1469
1470 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1471 "Reference fake quantization: input type not supported.");
1472
1473 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001474}
1475
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001476bool RefLayerSupport::IsFillSupported(const TensorInfo& input,
1477 const TensorInfo& output,
1478 const FillDescriptor& descriptor,
1479 Optional<std::string&> reasonIfUnsupported) const
1480{
1481 IgnoreUnused(descriptor);
1482 IgnoreUnused(output);
1483
1484 bool supported = true;
1485
Sadik Armagana792a052020-06-23 16:22:23 +01001486 std::array<DataType,3> supportedTypes =
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001487 {
1488 DataType::Float32,
Sadik Armagana792a052020-06-23 16:22:23 +01001489 DataType::Float16,
1490 DataType::Signed32
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001491 };
1492
Teresa Charlin4b10fef2020-07-29 09:36:41 +01001493 supported &= CheckSupportRule(TypeIs(input, DataType::Signed32), reasonIfUnsupported,
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001494 "Reference Fill: input type not supported.");
1495
Teresa Charlin44088502020-07-27 11:27:19 +01001496 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1497 "Reference Fill: output type not supported.");
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +01001498 return supported;
1499}
1500
arovir011c7c81b2018-10-08 11:34:28 +01001501bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
1502 const TensorInfo& output,
1503 Optional<std::string&> reasonIfUnsupported) const
1504{
Jan Eilers8eb25602020-03-09 12:13:48 +00001505 IgnoreUnused(output);
James Conroy83735b12019-05-30 16:36:59 +01001506 bool supported = true;
1507
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01001508 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +01001509 {
James Conroyb40d7102019-06-04 12:32:09 +01001510 DataType::Float32,
Teresa Charlin3a3a6bf2022-05-05 15:26:27 +01001511 DataType::Float16
James Conroy83735b12019-05-30 16:36:59 +01001512 };
1513
1514 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1515 "Reference Floor: input type not supported.");
1516
1517 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1518 "Reference Floor: output type not supported.");
1519
1520 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001521}
1522
1523bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
1524 const TensorInfo& output,
1525 const TensorInfo& weights,
1526 const TensorInfo& biases,
1527 const FullyConnectedDescriptor& descriptor,
1528 Optional<std::string&> reasonIfUnsupported) const
1529{
Francis Murtagh46c09d02019-05-28 08:15:28 +01001530 bool supported = true;
1531
1532 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001533 std::array<DataType,6> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +01001534 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001535 DataType::Float32,
1536 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001537 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001538 DataType::QAsymmU8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001539 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001540 };
1541
1542 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1543 "Reference Fully Connected: input type not supported.");
1544
1545 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1546 "Reference Fully Connected: output type not supported.");
1547
Francis Murtagh46c09d02019-05-28 08:15:28 +01001548 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1549 "Reference Fully Connected: weights type not supported.");
1550
Ryan OShea31441592022-11-07 16:20:48 +00001551 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1552 "Reference Fully Connected: input and output types mismatched.");
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001553
Jan Eilers1f45dc32020-06-15 11:43:03 +01001554 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1555 "Reference Fully Connected: weights is not a supported type.");
Francis Murtaghddb1d062020-03-10 13:51:45 +00001556
Jan Eilers1f45dc32020-06-15 11:43:03 +01001557 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1558 "Reference Fully Connected: input and weights types mismatched.");
Francis Murtagh46c09d02019-05-28 08:15:28 +01001559
1560 if (descriptor.m_BiasEnabled)
1561 {
1562 // Defined supported types for bias
Sadik Armagandb73c982020-04-01 17:35:30 +01001563 std::array<DataType, 5>
Francis Murtagh46c09d02019-05-28 08:15:28 +01001564 supportedBiasTypes =
1565 {
1566 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001567 DataType::Float16,
Sadik Armagandb73c982020-04-01 17:35:30 +01001568 DataType::Signed32,
1569 DataType::QAsymmS8
Francis Murtagh46c09d02019-05-28 08:15:28 +01001570 };
1571
1572 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
1573 "Reference Fully Connected: bias type not supported.");
1574
1575 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
1576 "Reference Fully Connected: bias and weight types mismatch.");
1577
1578 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
1579 "Reference Fully Connected: bias type inferred from weights is incompatible.");
1580
Narumol Prangnawarat366d7232020-04-29 12:58:17 +01001581 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(biases, 1U), reasonIfUnsupported,
1582 "Reference Fully Connected: bias must have 1 dimension.");
1583
Francis Murtagh46c09d02019-05-28 08:15:28 +01001584 }
1585
1586 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001587}
1588
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01001589bool RefLayerSupport::IsGatherNdSupported(const armnn::TensorInfo& input0,
1590 const armnn::TensorInfo& input1,
1591 const armnn::TensorInfo& output,
1592 armnn::Optional<std::string&> reasonIfUnsupported) const
1593{
1594 bool supported = true;
1595 std::array<DataType,7> supportedTypes =
1596 {
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01001597 DataType::Float32,
1598 DataType::Float16,
1599 DataType::QAsymmS8,
1600 DataType::QAsymmU8,
1601 DataType::QSymmS16,
1602 DataType::Signed32
1603 };
1604
1605 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1606 "Reference GatherNd: input type not supported");
1607
1608 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1609 "Reference GatherNd: output type not supported");
1610
1611 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1612 "Reference GatherNd: indices (input1) type not supported");
1613
1614 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1615 "Reference GatherNd: input and output types not matching");
1616
1617 return supported;
1618}
1619
narpra014951d842019-01-18 16:53:53 +00001620bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
1621 const armnn::TensorInfo& input1,
1622 const armnn::TensorInfo& output,
Teresa Charlin52664732020-06-29 16:27:03 +01001623 const GatherDescriptor& descriptor,
narpra014951d842019-01-18 16:53:53 +00001624 armnn::Optional<std::string&> reasonIfUnsupported) const
1625{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001626 bool supported = true;
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001627 std::array<DataType,7> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001628 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001629 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001630 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001631 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001632 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001633 DataType::QSymmS16,
1634 DataType::Signed32
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001635 };
1636
Nikhil Raj369d8fc2022-11-24 13:12:36 +00001637 IgnoreUnused(descriptor);
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001638 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1639 "Reference Gather: input type not supported");
1640
1641 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1642 "Reference Gather: output type not supported");
1643
1644 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1645 "Reference Gather: indices (input1) type not supported");
1646
1647 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1648 "Reference Gather: input and output types not matching");
1649
1650 return supported;
narpra014951d842019-01-18 16:53:53 +00001651}
1652
Derek Lamberti901ea112019-12-10 22:07:09 +00001653bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
1654 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +01001655{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001656 return true;
arovir011c7c81b2018-10-08 11:34:28 +01001657}
1658
Kevin May09ca49c2019-10-09 12:37:34 +01001659bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
1660 const TensorInfo& output,
1661 const InstanceNormalizationDescriptor& descriptor,
1662 Optional<std::string&> reasonIfUnsupported) const
1663{
Jan Eilers8eb25602020-03-09 12:13:48 +00001664 IgnoreUnused(descriptor);
Kevin May09ca49c2019-10-09 12:37:34 +01001665 // Define supported types
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001666 std::array<DataType, 3> supportedTypes =
Kevin May09ca49c2019-10-09 12:37:34 +01001667 {
1668 DataType::Float32,
1669 DataType::Float16
1670 };
1671
1672 bool supported = true;
1673
1674 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1675 "Reference Instance Normalization: input type not supported.");
1676
1677 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1678 "Reference Instance Normalization: output type not supported.");
1679
1680 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1681 "Reference Instance Normalization: input and output types mismatched.");
1682
1683 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1684 "Reference Instance Normalization: input and output shapes have different "
1685 "num total elements.");
1686
1687 return supported;
1688}
1689
arovir011c7c81b2018-10-08 11:34:28 +01001690bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1691 const TensorInfo& output,
1692 const L2NormalizationDescriptor& descriptor,
1693 Optional<std::string&> reasonIfUnsupported) const
1694{
Jan Eilers8eb25602020-03-09 12:13:48 +00001695 IgnoreUnused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001696 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01001697 std::array<DataType, 6> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001698 {
1699 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001700 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001701 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001702 DataType::QAsymmU8,
1703 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001704 };
1705
1706 bool supported = true;
1707
1708 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1709 "Reference L2normalization: input type not supported.");
1710
1711 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1712 "Reference L2normalization: output type not supported.");
1713
1714 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1715 "Reference L2normalization: input and output types mismatched.");
1716
1717 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1718 "Reference L2normalization: input and output shapes have different "
1719 "num total elements.");
1720
1721 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001722}
1723
James Conroyaba90cd2020-11-06 16:28:18 +00001724bool RefLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0,
1725 const TensorInfo& input1,
1726 const TensorInfo& output,
1727 const LogicalBinaryDescriptor& descriptor,
1728 Optional<std::string&> reasonIfUnsupported) const
1729{
1730 IgnoreUnused(descriptor);
1731
1732 std::array<DataType, 1> supportedTypes =
1733 {
1734 DataType::Boolean
1735 };
1736
1737 bool supported = true;
1738 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1739 "Reference LogicalBinary: input 0 type not supported");
1740 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1741 "Reference LogicalBinary: input 1 type not supported");
1742
1743 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1744 "Reference LogicalBinary: input and output types do not match");
1745
1746 return supported;
1747}
1748
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001749bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1750 const TensorInfo& output,
1751 const LogSoftmaxDescriptor& descriptor,
1752 Optional<std::string&> reasonIfUnsupported) const
1753{
Jan Eilers8eb25602020-03-09 12:13:48 +00001754 IgnoreUnused(descriptor);
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001755
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001756 std::array<DataType, 3> supportedTypes =
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001757 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001758 DataType::Float32,
1759 DataType::Float16
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001760 };
1761
1762 bool supported = true;
1763 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1764 "Reference LogSoftmax: input type not supported");
1765
1766 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1767 "Reference LogSoftmax: output type not supported");
1768
1769 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1770 "Reference LogSoftmax: input and output types do not match");
1771
1772 return supported;
1773}
1774
arovir011c7c81b2018-10-08 11:34:28 +01001775bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1776 const TensorInfo& outputStateIn,
1777 const TensorInfo& cellStateIn,
1778 const TensorInfo& scratchBuffer,
1779 const TensorInfo& outputStateOut,
1780 const TensorInfo& cellStateOut,
1781 const TensorInfo& output,
1782 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001783 const LstmInputParamsInfo& paramsInfo,
1784 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +01001785{
Jan Eilers8eb25602020-03-09 12:13:48 +00001786 IgnoreUnused(descriptor);
1787 IgnoreUnused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001788
1789 bool supported = true;
1790
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001791 std::array<DataType,3> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +01001792 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001793 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001794 };
1795
Jan Eilersd01a83c2019-07-03 18:20:40 +01001796 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001797 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1798 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001799 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1800 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001801 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1802 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001803 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1804 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001805 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1806 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001807 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1808 "Reference Lstm: input and cellStateOut types are mismatched");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01001809
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001810 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1811 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001812 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001813 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001814 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001815 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001816 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001817 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001818 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001819 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001820 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001821 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001822 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001823 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001824 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001825 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001826 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001827 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001828 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001829 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001830 "Reference Lstm: input and OutputGateBias types are mismatched");
1831 if (!descriptor.m_CifgEnabled)
1832 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001833 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001834 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001835 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001836 reasonIfUnsupported,
1837 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001838 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001839 "Reference Lstm: input and InputGateBias types are mismatched");
1840 if (descriptor.m_PeepholeEnabled)
1841 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001842 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001843 reasonIfUnsupported,
1844 "Reference Lstm: input and CellToInputWeights types are mismatched");
1845 }
1846 }
1847 if (descriptor.m_PeepholeEnabled)
1848 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001849 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001850 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001851 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001852 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1853 }
1854 if (descriptor.m_ProjectionEnabled)
1855 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001856 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001857 "Reference Lstm: input and mProjectionWeights types are mismatched");
1858 if (paramsInfo.m_ProjectionBias != nullptr)
1859 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001860 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001861 "Reference Lstm: input and ProjectionBias types are mismatched");
1862 }
1863 }
1864 if (descriptor.m_LayerNormEnabled)
1865 {
1866 if (!descriptor.m_CifgEnabled)
1867 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001868 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001869 reasonIfUnsupported,
1870 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1871 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001872 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001873 reasonIfUnsupported,
1874 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001875 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001876 reasonIfUnsupported,
1877 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001878 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001879 reasonIfUnsupported,
1880 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1881 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001882
1883 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001884}
1885
saoste012df12b32018-11-28 16:57:20 +00001886bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1887 const TensorInfo& input1,
1888 const TensorInfo& output,
1889 Optional<std::string&> reasonIfUnsupported) const
1890{
Sadik Armagan2999a022019-04-09 14:20:12 +01001891 bool supported = true;
1892
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001893 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001894 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001895 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001896 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001897 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001898 DataType::QSymmS16,
1899 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001900 };
1901
1902 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1903 "Reference maximum: input 0 is not a supported type.");
1904
1905 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1906 "Reference maximum: input 1 is not a supported type.");
1907
1908 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1909 "Reference maximum: output is not a supported type.");
1910
1911 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1912 "Reference maximum: input 0 and Input 1 types are mismatched");
1913
1914 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1915 "Reference maximum: input and output types are mismatched");
1916
1917 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1918 "Reference maximum: shapes are not suitable for implicit broadcast.");
1919
1920 return supported;
saoste012df12b32018-11-28 16:57:20 +00001921}
1922
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001923bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1924 const TensorInfo& output,
1925 const MeanDescriptor& descriptor,
1926 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001927{
James Conroy4d1ff582019-06-10 17:06:39 +01001928 bool supported = true;
1929 std::string meanLayerStr = "Mean";
1930 std::string outputTensorStr = "output";
1931
Sadik Armagan303980c2020-04-17 12:45:14 +01001932 std::array<DataType,6> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001933 {
1934 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001935 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001936 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001937 DataType::QAsymmU8,
1938 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01001939 };
1940
1941 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1942 "Reference Mean: input type not supported.");
1943
1944 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1945 "Reference Mean: input and output types are mismatched");
1946
1947 if (descriptor.m_KeepDims)
1948 {
1949 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1950 reasonIfUnsupported,
1951 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1952 output.GetNumDimensions(),
1953 meanLayerStr, outputTensorStr).data());
1954 }
1955 else if (descriptor.m_Axis.empty())
1956 {
1957 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1958 reasonIfUnsupported,
1959 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1960 meanLayerStr, outputTensorStr).data());
1961 }
1962 else
1963 {
Matthew Sloyan171214c2020-09-09 09:07:37 +01001964 auto outputDim = input.GetNumDimensions() - armnn::numeric_cast<unsigned int>(descriptor.m_Axis.size());
James Conroy4d1ff582019-06-10 17:06:39 +01001965
1966 if (outputDim > 0)
1967 {
1968 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1969 reasonIfUnsupported,
1970 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1971 meanLayerStr, outputTensorStr).data());
1972 }
1973 else
1974 {
1975 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1976 reasonIfUnsupported,
1977 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1978 meanLayerStr, outputTensorStr).data());
1979 }
1980 }
1981
1982 return supported;
narpra0132b90462018-09-13 11:07:48 +01001983}
1984
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001985bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1986 const TensorInfo &output,
1987 Optional<std::string &> reasonIfUnsupported) const
1988{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001989 bool supported = true;
1990
Sadik Armagan303980c2020-04-17 12:45:14 +01001991 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001992 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001993 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001994 DataType::Float32,
1995 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001996 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001997 DataType::QAsymmU8,
1998 DataType::QSymmS16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001999 DataType::Boolean
2000 };
2001
2002 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2003 "Reference MemCopy: input type not supported");
2004
2005 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2006 "Reference MemCopy: output type not supported");
2007
2008 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2009 "Reference MemCopy: input and output types are mismatched");
2010
2011 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00002012}
2013
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00002014bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
2015 const TensorInfo& input1,
2016 const TensorInfo& output,
2017 Optional<std::string&> reasonIfUnsupported) const
2018{
Sadik Armagan2999a022019-04-09 14:20:12 +01002019 bool supported = true;
2020
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002021 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01002022 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002023 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002024 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002025 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002026 DataType::QSymmS16,
2027 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002028 };
2029
2030 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2031 "Reference minimum: input 0 is not a supported type.");
2032
2033 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2034 "Reference minimum: input 1 is not a supported type.");
2035
2036 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2037 "Reference minimum: output is not a supported type.");
2038
2039 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2040 "Reference minimum: input 0 and Input 1 types are mismatched");
2041
2042 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2043 "Reference minimum: input and output types are mismatched");
2044
2045 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2046 "Reference minimum: shapes are not suitable for implicit broadcast.");
2047
2048 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00002049}
2050
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002051bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
2052 const TensorInfo& input1,
2053 const TensorInfo& output,
2054 Optional<std::string&> reasonIfUnsupported) const
2055{
Sadik Armagan2999a022019-04-09 14:20:12 +01002056 bool supported = true;
2057
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002058 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01002059 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002060 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002061 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002062 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002063 DataType::QSymmS16,
2064 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002065 };
2066
2067 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2068 "Reference multiplication: input 0 is not a supported type.");
2069
2070 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2071 "Reference multiplication: input 1 is not a supported type.");
2072
2073 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2074 "Reference multiplication: output is not a supported type.");
2075
2076 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2077 "Reference multiplication: input 0 and Input 1 types are mismatched");
2078
2079 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2080 "Reference multiplication: input and output types are mismatched");
2081
2082 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2083 "Reference multiplication: shapes are not suitable for implicit broadcast.");
2084
2085 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002086}
2087
2088bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
2089 const TensorInfo& output,
2090 const NormalizationDescriptor& descriptor,
2091 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01002092{
Jan Eilers8eb25602020-03-09 12:13:48 +00002093 IgnoreUnused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002094
2095 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01002096 std::array<DataType, 6> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002097 {
2098 DataType::Float16,
2099 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002100 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002101 DataType::QAsymmU8,
2102 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01002103 };
2104
2105 bool supported = true;
2106
2107 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2108 "Reference normalization: input type not supported.");
2109
2110 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2111 "Reference normalization: output type not supported.");
2112
2113 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2114 "Reference normalization: input and output shapes have different "
2115 "num total elements.");
2116
2117 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002118}
2119
Derek Lamberti901ea112019-12-10 22:07:09 +00002120bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
2121 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002122{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01002123 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002124}
2125
2126bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
2127 const TensorInfo& output,
2128 const PadDescriptor& descriptor,
2129 Optional<std::string&> reasonIfUnsupported) const
2130{
Jan Eilers8eb25602020-03-09 12:13:48 +00002131 IgnoreUnused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002132 bool supported = true;
2133
2134 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002135 std::array<DataType,6> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002136 {
2137 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002138 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002139 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002140 DataType::QAsymmU8,
2141 DataType::QSymmS16
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01002142 };
2143
2144 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2145 "Reference pad: input is not a supported type.");
2146
2147 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2148 "Reference pad: output is not a supported type.");
2149
2150 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2151 "Reference pad: input and output types are mismatched.");
2152
2153 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01002154}
2155
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002156bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
2157 const TensorInfo& output,
2158 const PermuteDescriptor& descriptor,
2159 Optional<std::string&> reasonIfUnsupported) const
2160{
Jan Eilers8eb25602020-03-09 12:13:48 +00002161 IgnoreUnused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002162 bool supported = true;
2163
2164 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002165 std::array<DataType, 6> supportedTypes =
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002166 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002167 DataType::BFloat16,
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002168 DataType::Float32,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002169 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002170 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002171 DataType::QAsymmU8,
2172 DataType::QSymmS16
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01002173 };
2174
2175 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2176 "Reference permute: input is not a supported type.");
2177
2178 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2179 "Reference permute: output is not a supported type.");
2180
2181 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2182 "Reference permute: input and output types are mismatched.");
2183
2184 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002185}
2186
2187bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
2188 const TensorInfo& output,
2189 const Pooling2dDescriptor& descriptor,
2190 Optional<std::string&> reasonIfUnsupported) const
2191{
Jan Eilers8eb25602020-03-09 12:13:48 +00002192 IgnoreUnused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01002193 bool supported = true;
2194
2195 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002196 std::array<DataType,6> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01002197 {
2198 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01002199 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002200 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002201 DataType::QAsymmU8,
2202 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01002203 };
2204
2205 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2206 "Reference poolind2d: input is not a supported type.");
2207
2208 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2209 "Reference poolind2d: output is not a supported type.");
2210
2211 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2212 "Reference poolind2d: input and output types are mismatched.");
2213
2214 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002215}
2216
Tamás Nyíri7b885b32021-10-26 14:47:57 +01002217bool RefLayerSupport::IsPooling3dSupported(const TensorInfo& input,
2218 const TensorInfo& output,
2219 const Pooling3dDescriptor& descriptor,
2220 Optional<std::string&> reasonIfUnsupported) const
2221{
2222 IgnoreUnused(descriptor);
2223 bool supported = true;
2224
2225 // Define supported output and inputs types.
2226 std::array<DataType,6> supportedTypes =
2227 {
Tamás Nyíri7b885b32021-10-26 14:47:57 +01002228 DataType::Float32,
2229 DataType::Float16,
2230 DataType::QAsymmS8,
2231 DataType::QAsymmU8,
2232 DataType::QSymmS16
2233 };
2234
2235 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2236 "Reference poolind3d: input is not a supported type.");
2237
2238 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2239 "Reference poolind3d: output is not a supported type.");
2240
2241 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2242 "Reference poolind3d: input and output types are mismatched.");
2243
2244 return supported;
2245}
2246
2247
James Conroy4f1f8992020-04-29 20:01:10 +01002248bool RefLayerSupport::IsQLstmSupported(const TensorInfo& input,
2249 const TensorInfo& previousOutputIn,
2250 const TensorInfo& previousCellStateIn,
2251 const TensorInfo& outputStateOut,
2252 const TensorInfo& cellStateOut,
2253 const TensorInfo& output,
2254 const QLstmDescriptor& descriptor,
2255 const LstmInputParamsInfo& paramsInfo,
2256 Optional<std::string&> reasonIfUnsupported) const
2257{
2258 IgnoreUnused(input);
2259 IgnoreUnused(previousOutputIn);
2260 IgnoreUnused(previousCellStateIn);
2261 IgnoreUnused(outputStateOut);
2262 IgnoreUnused(cellStateOut);
2263 IgnoreUnused(output);
2264 IgnoreUnused(descriptor);
2265 IgnoreUnused(paramsInfo);
2266
2267 IgnoreUnused(reasonIfUnsupported);
2268
2269 return true;
2270}
2271
Derek Lamberti5f400d62019-03-25 15:41:58 +00002272bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
2273 const TensorInfo& output,
2274 Optional<std::string&> reasonIfUnsupported) const
2275{
2276 bool supported = true;
2277
Finn Williamsfd271062019-12-04 14:27:27 +00002278 // Define supported input types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002279 std::array<DataType,7> supportedInputTypes = {
Keith Davis5e51cd82020-01-29 16:52:59 +00002280 DataType::Float32,
Keith Davis3d8bc972020-02-04 09:31:47 +00002281 DataType::Float16,
Ryan OShea9add1202020-02-07 10:06:33 +00002282 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00002283 DataType::QAsymmU8,
2284 DataType::QSymmS8,
2285 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002286 };
2287
2288 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
2289 "Reference quantize: input type not supported.");
2290
2291 // Define supported output types.
Ryan OShea9add1202020-02-07 10:06:33 +00002292 std::array<DataType,4> supportedOutputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00002293 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002294 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00002295 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002296 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00002297 };
2298 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2299 "Reference quantize: output type not supported.");
2300
2301 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
2302 "Reference quantize: input and output shapes have different num total elements.");
2303
2304 return supported;
2305}
2306
Finn Williams2605b232020-06-10 15:53:46 +01002307bool RefLayerSupport::IsRankSupported(const TensorInfo& input,
2308 const TensorInfo& output,
2309 Optional<std::string&> reasonIfUnsupported) const
2310{
2311 IgnoreUnused(input);
2312 // Define supported output types.
2313 std::array<DataType,1> supportedOutputTypes =
2314 {
2315 DataType::Signed32,
2316 };
2317
2318 return CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
2319 "Reference rank: input type not supported.");
2320}
2321
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00002322bool RefLayerSupport::IsReduceSupported(const TensorInfo& input,
2323 const TensorInfo& output,
2324 const ReduceDescriptor& descriptor,
2325 Optional<std::string&> reasonIfUnsupported) const
2326{
2327 IgnoreUnused(descriptor);
2328 bool supported = true;
2329 std::array<DataType,7> supportedTypes =
2330 {
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00002331 DataType::Float32,
2332 DataType::Float16,
2333 DataType::QAsymmS8,
2334 DataType::QAsymmU8,
2335 DataType::QSymmS16,
2336 DataType::Signed32
2337 };
2338
2339 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2340 "Reference Reduce: input type not supported");
2341
2342 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2343 "Reference Reduce: output type not supported");
2344
2345 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2346 "Reference Reduce: input and output types not matching");
2347
2348 return supported;
2349}
2350
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002351bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Kevin Maya023c402019-12-12 17:28:05 +00002352 const TensorInfo& output,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00002353 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002354 Optional<std::string&> reasonIfUnsupported) const
2355{
Jan Eilers8eb25602020-03-09 12:13:48 +00002356 IgnoreUnused(output);
2357 IgnoreUnused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01002358 // Define supported output types.
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002359 std::array<DataType,8> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01002360 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002361 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01002362 DataType::Float32,
2363 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01002364 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00002365 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002366 DataType::QAsymmU8,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00002367 DataType::QSymmS16,
2368 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01002369 };
Keith Davis0c2eeac2020-02-11 16:51:50 +00002370
Nina Drozd2f2778f2019-05-27 10:37:05 +01002371 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
2372 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002373}
2374
Teresa Charlin970f43b2019-07-01 13:51:07 +01002375bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
2376 const TensorInfo& output,
2377 const ResizeDescriptor& descriptor,
2378 Optional<std::string&> reasonIfUnsupported) const
2379{
Jan Eilers8eb25602020-03-09 12:13:48 +00002380 IgnoreUnused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01002381 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002382 std::array<DataType,6> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01002383 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002384 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01002385 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002386 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00002387 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002388 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002389 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01002390 };
2391
2392 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2393 "Reference Resize: input type not supported");
2394
2395 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2396 "Reference Resize: output type not supported");
2397
2398 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2399 "Reference Resize: input and output types not matching");
2400
2401 return supported;
2402}
2403
Tracy Narinebb8d7592023-07-13 16:50:54 +01002404bool RefLayerSupport::IsReverseV2Supported(const TensorInfo& input0,
2405 const TensorInfo& input1,
Tianle Cheng988354d2023-06-28 13:20:47 +01002406 const TensorInfo& output,
Tianle Cheng988354d2023-06-28 13:20:47 +01002407 Optional<std::string&> reasonIfUnsupported) const
2408{
Tianle Cheng988354d2023-06-28 13:20:47 +01002409 bool supported = true;
2410 // ReverseV2 is data type agnostic so it can support all the types in the Reference backend
Declan-ARM1bf56cd2023-07-20 17:32:57 +01002411 std::array<DataType,8> supportedTypes =
Tianle Cheng988354d2023-06-28 13:20:47 +01002412 {
2413 DataType::BFloat16,
2414 DataType::Float32,
2415 DataType::Float16,
2416 DataType::QAsymmS8,
2417 DataType::QAsymmU8,
Declan-ARM1bf56cd2023-07-20 17:32:57 +01002418 DataType::QSymmS8,
2419 DataType::QSymmS16,
2420 DataType::Signed32
Tianle Cheng988354d2023-06-28 13:20:47 +01002421 };
2422
Tracy Narinebb8d7592023-07-13 16:50:54 +01002423 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2424 "Reference ReverseV2: input0 type not supported");
Tianle Cheng988354d2023-06-28 13:20:47 +01002425
2426 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2427 "Reference ReverseV2: output type not supported");
2428
Tracy Narinebb8d7592023-07-13 16:50:54 +01002429 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2430 "Reference ReverseV2: input0 and output types not matching");
2431
2432 std::array<DataType,6> input2SupportedTypes =
2433 {
2434 DataType::Signed32
2435 };
2436
2437 supported &= CheckSupportRule(TypeAnyOf(input1, input2SupportedTypes), reasonIfUnsupported,
2438 "Reference ReverseV2: input1 type not supported");
Tianle Cheng988354d2023-06-28 13:20:47 +01002439
2440 return supported;
2441}
2442
Keith Davis3ae3f972021-05-21 16:33:48 +01002443bool RefLayerSupport::IsShapeSupported(const TensorInfo& input,
2444 const TensorInfo& output,
2445 Optional<std::string&> reasonIfUnsupported) const
2446{
2447 IgnoreUnused(input);
2448 bool supported = true;
2449
2450 std::array<DataType, 1> supportedTypes =
2451 {
2452 DataType::Signed32
2453 };
2454
2455 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2456 "Reference Shape: output type not supported");
2457
2458 return supported;
2459}
2460
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002461bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
2462 const TensorInfo& output,
2463 const SliceDescriptor& descriptor,
2464 Optional<std::string&> reasonIfUnsupported) const
2465{
Jan Eilers8eb25602020-03-09 12:13:48 +00002466 IgnoreUnused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002467 bool supported = true;
2468
Sadik Armagan303980c2020-04-17 12:45:14 +01002469 std::array<DataType, 5> supportedTypes =
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002470 {
2471 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002472 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002473 DataType::QAsymmU8,
Ryan OShea980446b2023-06-08 16:23:28 +01002474 DataType::QSymmS16,
2475 DataType::Signed32
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01002476 };
2477
2478 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2479 "Reference Slice: input type not supported");
2480
2481 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2482 "Reference Slice: output type not supported");
2483
2484 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2485 "Reference Slice: input and output types are mismatched");
2486
2487 return supported;
2488}
2489
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002490bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
2491 const TensorInfo& output,
2492 const SoftmaxDescriptor& descriptor,
2493 Optional<std::string&> reasonIfUnsupported) const
2494{
Jan Eilers8eb25602020-03-09 12:13:48 +00002495 IgnoreUnused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01002496 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002497 std::array<DataType,7> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01002498 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002499 DataType::Float32,
2500 DataType::Float16,
2501 DataType::QSymmS8,
2502 DataType::QAsymmS8,
2503 DataType::QAsymmU8,
2504 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +01002505 };
2506
2507 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002508 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002509
2510 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002511 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002512
2513 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01002514 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01002515
2516 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002517}
2518
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002519bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
2520 const TensorInfo& output,
2521 const SpaceToBatchNdDescriptor& descriptor,
2522 Optional<std::string&> reasonIfUnsupported) const
2523{
Jan Eilers8eb25602020-03-09 12:13:48 +00002524 IgnoreUnused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01002525 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002526 std::array<DataType,6> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01002527 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002528 DataType::Float32,
2529 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002530 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002531 DataType::QAsymmU8,
2532 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01002533 };
2534
2535 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2536 "Reference SpaceToBatchNd: input type not supported");
2537
2538 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2539 "Reference SpaceToBatchNd: output type not supported");
2540
2541 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2542 "Reference SpaceToBatchNd: input and output types are mismatched");
2543
2544 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002545}
2546
Keith Davisa57eccb2019-06-14 17:33:22 +01002547bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01002548 const TensorInfo& output,
2549 const SpaceToDepthDescriptor& descriptor,
2550 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01002551{
2552
Jan Eilers8eb25602020-03-09 12:13:48 +00002553 IgnoreUnused(descriptor);
Keith Davisa57eccb2019-06-14 17:33:22 +01002554 bool supported = true;
2555
Sadik Armagan303980c2020-04-17 12:45:14 +01002556 std::array<DataType,6> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01002557 {
2558 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002559 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002560 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002561 DataType::QAsymmU8,
2562 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01002563 };
2564
2565 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2566 "Reference SpaceToDepth: input type not supported");
2567
2568 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2569 "Reference SpaceToDepth: output type not supported");
2570
2571 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2572 "Reference SpaceToDepth: input and output types are mismatched");
2573
2574 return supported;
2575}
2576
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002577bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002578 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
2579 const ViewsDescriptor& descriptor,
2580 Optional<std::string&> reasonIfUnsupported) const
2581{
Jan Eilers8eb25602020-03-09 12:13:48 +00002582 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002583 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002584 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002585 {
2586 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002587 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002588 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002589 DataType::QAsymmU8,
2590 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002591 };
2592
2593 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2594 "Reference splitter: output type not supported");
Derek Lambertieac4adb2020-08-25 13:05:59 +01002595 for (const TensorInfo& output : outputs)
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002596 {
2597 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2598 "Reference splitter: input type not supported");
2599
2600 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2601 "Reference splitter: input and output types mismatched.");
2602 }
2603
2604 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002605}
2606
Matthew Jackson81e601c2019-07-11 12:07:09 +01002607bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
2608 const TensorInfo& output,
2609 const StackDescriptor& descriptor,
2610 Optional<std::string&> reasonIfUnsupported) const
2611{
Jan Eilers8eb25602020-03-09 12:13:48 +00002612 IgnoreUnused(descriptor);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002613
2614 bool supported = true;
Sadik Armagan529195f2022-01-14 12:56:35 +00002615 std::array<DataType,7> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01002616 {
2617 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01002618 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002619 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002620 DataType::QAsymmU8,
Sadik Armagan529195f2022-01-14 12:56:35 +00002621 DataType::QSymmS16,
2622 DataType::Signed32
Matthew Jackson81e601c2019-07-11 12:07:09 +01002623 };
2624
2625 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2626 "Reference stack: output type not supported");
2627 for (const TensorInfo* input : inputs)
2628 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01002629 ARMNN_ASSERT(input != nullptr);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002630 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
2631 "Reference stack: input type not supported");
2632
2633 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
2634 "Reference stack: input and output types mismatched.");
2635 }
2636
2637 return supported;
2638}
2639
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002640bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
2641 const TensorInfo& output,
2642 const StridedSliceDescriptor& descriptor,
2643 Optional<std::string&> reasonIfUnsupported) const
2644{
Jan Eilers8eb25602020-03-09 12:13:48 +00002645 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002646 bool supported = true;
2647
Sadik Armagan303980c2020-04-17 12:45:14 +01002648 std::array<DataType,5> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002649 {
2650 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002651 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002652 DataType::QAsymmU8,
2653 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002654 };
2655
2656 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2657 "Reference StridedSlice: input type not supported");
2658
2659 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2660 "Reference StridedSlice: output type not supported");
2661
2662 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2663 "Reference StridedSlice: input and output types are mismatched");
2664
2665 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002666}
2667
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002668bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
2669 const TensorInfo& input1,
2670 const TensorInfo& output,
2671 Optional<std::string&> reasonIfUnsupported) const
2672{
Sadik Armagan2999a022019-04-09 14:20:12 +01002673 bool supported = true;
2674
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002675 std::array<DataType,7> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01002676 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002677 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002678 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002679 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002680 DataType::QSymmS16,
2681 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002682 };
2683
2684 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2685 "Reference subtraction: input 0 is not a supported type.");
2686
2687 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2688 "Reference subtraction: input 1 is not a supported type.");
2689
2690 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2691 "Reference subtraction: output is not a supported type.");
2692
2693 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2694 "Reference subtraction: input 0 and Input 1 types are mismatched");
2695
2696 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2697 "Reference subtraction: input and output types are mismatched");
2698
2699 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2700 "Reference subtraction: shapes are not suitable for implicit broadcast.");
2701
2702 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002703}
2704
Matteo Martincighab9e5252019-06-13 17:27:46 +01002705bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
2706 const TensorInfo& alpha,
2707 const TensorInfo& output,
2708 Optional<std::string&> reasonIfUnsupported) const
2709{
2710 bool supported = true;
2711
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002712 std::array<DataType, 6> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01002713 {
2714 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002715 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002716 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002717 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002718 DataType::QSymmS16
Matteo Martincighab9e5252019-06-13 17:27:46 +01002719 };
2720
2721 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2722 "PReLU: input is not a supported type.");
2723
2724 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
2725 "PReLU: alpha is not a supported type.");
2726
2727 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2728 "PReLU: output is not a supported type.");
2729
2730 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
2731 "PReLU: input, alpha and output types are mismatched");
2732
2733 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
2734 "PReLU: shapes are not suitable for implicit broadcast");
2735
2736 return supported;
2737}
2738
Teresa Charlin79a06a52023-07-13 17:16:45 +01002739bool RefLayerSupport::IsTileSupported(const TensorInfo& input,
2740 const TensorInfo& output,
2741 const TileDescriptor& descriptor,
2742 Optional<std::string&> reasonIfUnsupported) const
2743{
2744 IgnoreUnused(descriptor);
2745
2746 bool supported = true;
2747
2748 std::array<DataType, 7> supportedTypes
2749 {
2750 DataType::Float32,
2751 DataType::Float16,
2752 DataType::QAsymmS8,
2753 DataType::QAsymmU8,
2754 DataType::QSymmS8,
2755 DataType::QSymmS16,
2756 DataType::Signed32
2757 };
2758
2759 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2760 "Tile: input type not supported.");
2761
2762 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2763 "Tile: output type not supported");
2764
2765 return supported;
2766}
2767
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002768bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
2769 const TensorInfo& output,
2770 const TransposeConvolution2dDescriptor& descriptor,
2771 const TensorInfo& weights,
2772 const Optional<TensorInfo>& biases,
2773 Optional<std::string&> reasonIfUnsupported) const
2774{
Jan Eilers8eb25602020-03-09 12:13:48 +00002775 IgnoreUnused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002776 bool supported = true;
2777
Sadik Armagan303980c2020-04-17 12:45:14 +01002778 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002779 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002780 DataType::Float32,
2781 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002782 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002783 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002784 DataType::QSymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002785 DataType::QSymmS16
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002786 };
2787
2788 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2789 "Reference TransposeConvolution2d: input is not a supported type.");
2790
2791 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2792 "Reference TransposeConvolution2d: output is not a supported type.");
2793
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002794 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2795 "Reference TransposeConvolution2d: input and output types mismatched.");
2796
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002797
2798 const DataType inputType = input.GetDataType();
Sadik Armagan303980c2020-04-17 12:45:14 +01002799 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002800 {
Jan Eilers1b2654f2021-09-24 15:45:46 +01002801 std::array<DataType, 3> supportedWeightTypes =
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002802 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002803 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002804 DataType::QAsymmU8,
Jan Eilers1b2654f2021-09-24 15:45:46 +01002805 DataType::QSymmS8
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002806 };
2807
2808 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
2809 "Reference TransposeConvolution2d: weights type not supported for "
2810 "quantized input.");
2811 }
2812 else
2813 {
2814 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
2815 "Reference TransposeConvolution2d: weights is not a supported type.");
2816
2817 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
2818 "Reference TransposeConvolution2d: input and weights types mismatched.");
2819 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002820
2821 if (biases.has_value())
2822 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002823 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01002824 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002825 DataType::Float32,
2826 DataType::Float16,
2827 DataType::Signed32
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002828 };
2829 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
2830 "Reference TransposeConvolution2d: biases is not a supported type.");
2831 }
2832
2833 return supported;
2834}
2835
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002836bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
2837 const TensorInfo& output,
2838 const TransposeDescriptor& descriptor,
2839 Optional<std::string&> reasonIfUnsupported) const
2840{
Jan Eilers8eb25602020-03-09 12:13:48 +00002841 IgnoreUnused(descriptor);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002842 bool supported = true;
2843
2844 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002845 std::array<DataType, 6> supportedTypes =
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002846 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002847 DataType::BFloat16,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002848 DataType::Float32,
2849 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002850 DataType::QAsymmS8,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002851 DataType::QAsymmU8,
2852 DataType::QSymmS16
2853 };
2854
2855 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2856 "Reference transpose: input is not a supported type.");
2857
2858 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2859 "Reference transpose: output is not a supported type.");
2860
2861 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2862 "Reference transpose: input and output types are mismatched.");
2863
2864 return supported;
2865}
2866
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002867bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
2868 const TensorInfo& input,
2869 const TensorInfo& outputStateIn,
2870 const TensorInfo& cellStateIn,
Mike Kelly12994962022-04-21 11:57:09 +01002871 const TensorInfo& outputStateOut,
2872 const TensorInfo& cellStateOut,
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002873 const TensorInfo& output,
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002874 const UnidirectionalSequenceLstmDescriptor& descriptor,
2875 const LstmInputParamsInfo& paramsInfo,
2876 Optional<std::string&> reasonIfUnsupported) const
2877{
2878 IgnoreUnused(descriptor);
2879 IgnoreUnused(paramsInfo);
2880 IgnoreUnused(outputStateIn);
2881 IgnoreUnused(cellStateIn);
Mike Kelly12994962022-04-21 11:57:09 +01002882 IgnoreUnused(outputStateOut);
2883 IgnoreUnused(cellStateOut);
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002884 bool supported = true;
2885
Mike Kelly12994962022-04-21 11:57:09 +01002886 std::array<DataType, 2> supportedTypes =
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002887 {
Mike Kelly12994962022-04-21 11:57:09 +01002888 DataType::Float32,
2889 DataType::QAsymmS8
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002890 };
2891
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002892 std::array<DataType, 2> supportedWeightTypes =
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002893 {
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002894 DataType::Float32,
2895 DataType::QAsymmS8
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002896 };
2897
Mike Kelly12994962022-04-21 11:57:09 +01002898 std::array<DataType, 3> supportedBiasTypes =
2899 {
2900 DataType::Float32,
2901 DataType::QAsymmS8,
2902 DataType::Signed32
2903 };
2904
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002905 // check inputs and outputs
2906 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2907 "Reference UnidirectionalSequenceLstm: input is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002908 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2909 "Reference UnidirectionalSequenceLstm: output is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002910
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002911 // check layer parameters
2912 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes),
2913 reasonIfUnsupported,
2914 "Reference UnidirectionalSequenceLstm: InputToForgetWeights "
2915 "is not a supported type.");
2916 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToCellWeights(), supportedWeightTypes),
2917 reasonIfUnsupported,
2918 "Reference UnidirectionalSequenceLstm: InputToCellWeights is not a supported type.");
2919 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToOutputWeights(), supportedWeightTypes),
2920 reasonIfUnsupported,
2921 "Reference UnidirectionalSequenceLstm: InputToOutputWeights "
2922 "is not a supported type.");
2923 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToForgetWeights(), supportedWeightTypes),
2924 reasonIfUnsupported,
2925 "Reference UnidirectionalSequenceLstm: RecurrentToForgetWeights "
2926 "is not a supported type.");
2927 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToCellWeights(), supportedWeightTypes),
2928 reasonIfUnsupported,
2929 "Reference UnidirectionalSequenceLstm: RecurrentToCellWeights "
2930 "is not a supported type.");
2931 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToOutputWeights(), supportedWeightTypes),
2932 reasonIfUnsupported,
2933 "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights "
2934 "is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002935
2936 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetGateBias(), supportedBiasTypes), reasonIfUnsupported,
2937 "Reference UnidirectionalSequenceLstm: ForgetGateBias is not a supported type.");
2938 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellBias(), supportedBiasTypes), reasonIfUnsupported,
2939 "Reference UnidirectionalSequenceLstm: CellBias is not a supported type.");
2940 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputGateBias(), supportedBiasTypes), reasonIfUnsupported,
2941 "Reference UnidirectionalSequenceLstm: OutputGateBias is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002942 if (!descriptor.m_CifgEnabled)
2943 {
2944 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes),
2945 reasonIfUnsupported,
2946 "Reference UnidirectionalSequenceLstm: InputToInputWeights "
2947 "is not a supported type.");
2948 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToInputWeights(), supportedWeightTypes),
2949 reasonIfUnsupported,
2950 "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights "
2951 "is not a supported type.");
Mike Kelly12994962022-04-21 11:57:09 +01002952 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputGateBias(), supportedBiasTypes), reasonIfUnsupported,
2953 "Reference UnidirectionalSequenceLstm: InputGateBias is not a supported type.");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002954 if (descriptor.m_PeepholeEnabled)
2955 {
2956 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes),
2957 reasonIfUnsupported,
2958 "Reference UnidirectionalSequenceLstm: CellToInputWeights "
2959 "is not a supported type.");
2960 }
2961 }
2962 if (descriptor.m_PeepholeEnabled)
2963 {
2964 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToForgetWeights(), supportedWeightTypes),
2965 reasonIfUnsupported,
2966 "Reference UnidirectionalSequenceLstm: CellToForgetWeights "
2967 "is not a supported type.");
2968 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToOutputWeights(), supportedWeightTypes),
2969 reasonIfUnsupported,
2970 "Reference UnidirectionalSequenceLstm: CellToOutputWeights "
2971 "is not a supported type.");
2972 }
2973 if (descriptor.m_ProjectionEnabled)
2974 {
2975 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetProjectionWeights(), supportedWeightTypes),
2976 reasonIfUnsupported,
2977 "Reference UnidirectionalSequenceLstm: ProjectionWeights "
2978 "is not a supported type.");
2979 if (paramsInfo.m_ProjectionBias != nullptr)
2980 {
2981 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
2982 "Reference UnidirectionalSequenceLstm: input and ProjectionBias types "
2983 "are mismatched");
2984 }
2985 }
2986 if (descriptor.m_LayerNormEnabled)
2987 {
2988 if (!descriptor.m_CifgEnabled)
2989 {
2990 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputLayerNormWeights(), supportedWeightTypes),
2991 reasonIfUnsupported,
2992 "Reference UnidirectionalSequenceLstm: InputLayerNormWeights "
2993 "is not a supported type.");
2994 }
2995 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetLayerNormWeights(), supportedWeightTypes),
2996 reasonIfUnsupported,
2997 "Reference UnidirectionalSequenceLstm: ForgetLayerNormWeights "
2998 "is not a supported type.");
2999 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellLayerNormWeights(), supportedWeightTypes),
3000 reasonIfUnsupported,
3001 "Reference UnidirectionalSequenceLstm: CellLayerNormWeights "
3002 "is not a supported type.");
3003 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputLayerNormWeights(), supportedWeightTypes),
3004 reasonIfUnsupported,
3005 "Reference UnidirectionalSequenceLstm: OutputLayerNormWeights "
3006 "is not a supported type.");
3007 }
3008
3009 return supported;
3010}
3011
arovir011c7c81b2018-10-08 11:34:28 +01003012} // namespace armnn