blob: aaf9aa0e7c1822933ac1b3759c54442bb09e8ef7 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
Teresa Charlin52664732020-06-29 16:27:03 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01007
Keith Davis0c2eeac2020-02-11 16:51:50 +00008#include <armnn/TypesUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000010#include <armnn/Descriptors.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000011#include <armnn/utility/IgnoreUnused.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010012#include <armnn/utility/NumericCast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
Matteo Martincighe011d202019-11-28 11:35:47 +000014#include <LayerSupportCommon.hpp>
Derek Lambertif674aa02019-08-01 15:56:25 +010015#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000016
Derek Lamberti50db4e82019-03-13 14:16:15 +000017#include <vector>
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <array>
19
telsoa014fcda012018-03-09 14:13:49 +000020namespace armnn
21{
22
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010023namespace
24{
25
26template<typename Float32Func, typename Uint8Func, typename ... Params>
27bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
28 DataType dataType,
29 Float32Func floatFuncPtr,
30 Uint8Func uint8FuncPtr,
31 Params&&... params)
32{
33 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
34 dataType,
35 &FalseFunc<Params...>,
36 floatFuncPtr,
37 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000038 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000039 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010040 std::forward<Params>(params)...);
41}
42
43} // anonymous namespace
44
James Conroy4d1ff582019-06-10 17:06:39 +010045namespace
46{
47
48std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
49 unsigned int actual,
50 std::string& layerStr,
51 std::string& tensorName)
52{
53 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
54 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
55
56 return errorMsg;
57}
58
59} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000060
Sadik Armagan9199e582019-09-05 17:35:31 +010061bool RefLayerSupport::IsAbsSupported(const TensorInfo& input, const TensorInfo& output,
62 Optional<std::string&> reasonIfUnsupported) const
63{
josh minor4a3c6102020-01-06 16:40:46 -060064 return IsElementwiseUnarySupported(input,
65 output,
66 ElementwiseUnaryDescriptor(UnaryOperation::Abs),
67 reasonIfUnsupported);
Sadik Armagan9199e582019-09-05 17:35:31 +010068}
69
arovir011c7c81b2018-10-08 11:34:28 +010070bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
71 const TensorInfo& output,
72 const ActivationDescriptor& descriptor,
73 Optional<std::string&> reasonIfUnsupported) const
74{
Derek Lamberti50db4e82019-03-13 14:16:15 +000075 bool supported = true;
76
77 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +000078 std::array<DataType,6> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +000079 DataType::BFloat16,
Derek Lamberti50db4e82019-03-13 14:16:15 +000080 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +010081 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +000082 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +000083 DataType::QAsymmU8,
84 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +000085 };
86
87 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
88 "Reference activation: input type not supported.");
89
90 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
91 "Reference activation: output type not supported.");
92
93 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
94 "Reference activation: input and output types mismatched.");
95
96 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
97 "Reference activation: input and output shapes are of different rank.");
98
99
100 struct ActivationFunctionSupported : public Rule
101 {
102 ActivationFunctionSupported(const ActivationDescriptor& desc)
103 {
104 switch(desc.m_Function)
105 {
106 case ActivationFunction::Abs:
107 case ActivationFunction::BoundedReLu:
David Monahan3b3c3812020-02-25 09:03:29 +0000108 case ActivationFunction::Elu:
Colm Donelan03fbeaf2020-02-26 15:39:23 +0000109 case ActivationFunction::HardSwish:
Derek Lamberti50db4e82019-03-13 14:16:15 +0000110 case ActivationFunction::LeakyReLu:
111 case ActivationFunction::Linear:
112 case ActivationFunction::ReLu:
113 case ActivationFunction::Sigmoid:
114 case ActivationFunction::SoftReLu:
115 case ActivationFunction::Sqrt:
116 case ActivationFunction::Square:
117 case ActivationFunction::TanH:
118 {
119 m_Res = true;
120 break;
121 }
122 default:
123 {
124 m_Res = false;
125 break;
126 }
127 }
128 }
129 };
130
131 // Function is supported
132 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
133 "Reference activation: function not supported.");
134
135 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100136}
137
138bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
139 const TensorInfo& input1,
140 const TensorInfo& output,
141 Optional<std::string&> reasonIfUnsupported) const
142{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000143 bool supported = true;
144
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100145 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000146 DataType::BFloat16,
Derek Lamberti50db4e82019-03-13 14:16:15 +0000147 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100148 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000149 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000150 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100151 DataType::QSymmS16,
152 DataType::Signed32
Derek Lamberti50db4e82019-03-13 14:16:15 +0000153 };
154
155 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
156 "Reference addition: input 0 is not a supported type.");
157
158 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
159 "Reference addition: input 1 is not a supported type.");
160
161 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
162 "Reference addition: output is not a supported type.");
163
164 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
165 "Reference addition: input 0 and Input 1 types are mismatched");
166
167 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
168 "Reference addition: input and output types are mismatched");
169
170 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
171 "Reference addition: shapes are not suitable for implicit broadcast.");
172
173 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100174}
175
Nikhil Raj68c2c902019-09-19 11:21:11 +0100176bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
177 const armnn::ArgMinMaxDescriptor &descriptor,
178 armnn::Optional<std::string &> reasonIfUnsupported) const
179{
Jan Eilers8eb25602020-03-09 12:13:48 +0000180 IgnoreUnused(descriptor);
Nikhil Raj68c2c902019-09-19 11:21:11 +0100181
Mike Kelly1f140f72021-04-06 12:25:55 +0100182 std::array<DataType, 8> supportedInputTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100183 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000184 DataType::BFloat16,
Teresa Charline300b362020-05-25 10:01:03 +0100185 DataType::Float16,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100186 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100187 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000188 DataType::QAsymmU8,
189 DataType::QSymmS16,
Mike Kelly1f140f72021-04-06 12:25:55 +0100190 DataType::Signed32,
191 DataType::Signed64
192 };
193
194 std::array<DataType,2> supportedOutputTypes = {
195 DataType::Signed32,
196 DataType::Signed64
Nikhil Raj68c2c902019-09-19 11:21:11 +0100197 };
198
199 bool supported = true;
200
Mike Kelly1f140f72021-04-06 12:25:55 +0100201 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100202 "Reference ArgMinMax: input is not a supported type.");
Mike Kelly1f140f72021-04-06 12:25:55 +0100203 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100204 "Reference ArgMinMax: output type not supported");
205
206 return supported;
207}
208
arovir011c7c81b2018-10-08 11:34:28 +0100209bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
210 const TensorInfo& output,
211 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100212 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100213 const TensorInfo& beta,
214 const TensorInfo& gamma,
215 const BatchNormalizationDescriptor& descriptor,
216 Optional<std::string&> reasonIfUnsupported) const
217{
Jan Eilers8eb25602020-03-09 12:13:48 +0000218 IgnoreUnused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100219
Sadik Armagan303980c2020-04-17 12:45:14 +0100220 std::array<DataType, 6> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100221 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000222 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100223 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100224 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100225 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000226 DataType::QAsymmU8,
227 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100228 };
229
230 bool supported = true;
231
232 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
233 "Reference batch normalization: input is not a supported type.");
234
235 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
236 "Reference batch normalization: output is not a supported type.");
237
238 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
239 "Reference batch normalization: input and output types are mismatched");
240
241 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
242 "Reference batch normalization: mean is not a supported type.");
243
244 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
245 "Reference batch normalization: variance is not a supported type.");
246
247 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
248 "Reference batch normalization: beta is not a supported type.");
249
250 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
251 "Reference batch normalization: gamma is not a supported type.");
252
253 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100254}
255
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000256bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
257 const TensorInfo& output,
258 const BatchToSpaceNdDescriptor& descriptor,
259 Optional<std::string&> reasonIfUnsupported) const
260{
Jan Eilers8eb25602020-03-09 12:13:48 +0000261 IgnoreUnused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100262
263 bool supported = true;
264
265 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
266 std::string inputTensorStr = "input";
267 std::string outputTensorStr = "output";
268
269 // Define supported types.
Sadik Armagan303980c2020-04-17 12:45:14 +0100270 std::array<DataType,6> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100271 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000272 DataType::BFloat16,
273 DataType::Float32,
274 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100275 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000276 DataType::QAsymmU8,
277 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100278 };
279
280 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
281 "Reference BatchToSpaceNd: input type not supported.");
282
283 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
284 "Reference BatchToSpaceNd: output type not supported.");
285
286 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
287 "Reference BatchToSpaceNd: input and output types mismatched.");
288
289 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
290 reasonIfUnsupported,
291 CreateIncorrectDimensionsErrorMsg(4,
292 output.GetNumDimensions(),
293 batchToSpaceNdLayerStr,
294 outputTensorStr).data());
295
296 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
297 reasonIfUnsupported,
298 CreateIncorrectDimensionsErrorMsg(4,
299 input.GetNumDimensions(),
300 batchToSpaceNdLayerStr,
301 inputTensorStr).data());
302
303 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000304}
305
mathad01b392e982021-04-07 12:07:30 +0100306bool RefLayerSupport::IsCastSupported(const TensorInfo& input,
307 const TensorInfo& output,
308 Optional<std::string&> reasonIfUnsupported) const
309{
310 std::array<DataType, 9> supportedInputTypes =
311 {
312 DataType::BFloat16,
313 DataType::Float32,
314 DataType::Float16,
315 DataType::QSymmS8,
316 DataType::QAsymmS8,
317 DataType::QAsymmU8,
318 DataType::QSymmS16,
319 DataType::Signed32
320 };
321
322 bool supported = true;
323 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
324 "Reference cast: input is not a supported type");
325
326
327 supported &= CheckSupportRule(TypeAnyOf(output, supportedInputTypes), reasonIfUnsupported,
328 "Reference cast: output is not a supported type");
329
330 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
331 "Reference cast: input and output shapes have different number of total elements");
332
333 return supported;
334}
335
Simon Obute51f67772021-09-03 15:50:13 +0100336bool RefLayerSupport::IsChannelShuffleSupported(const TensorInfo& input,
337 const TensorInfo& output,
338 const ChannelShuffleDescriptor& descriptor,
339 Optional<std::string&> reasonIfUnsupported) const
340{
341 IgnoreUnused(descriptor);
342 bool supported = true;
343
344 // Define supported output and inputs types.
345 std::array<DataType, 7> supportedTypes =
346 {
347 DataType::BFloat16,
348 DataType::Float32,
349 DataType::Float16,
350 DataType::QAsymmS8,
351 DataType::QAsymmU8,
352 DataType::QSymmS8,
353 DataType::QSymmS16
354 };
355
356 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
357 "Reference ChannelShuffle: input is not a supported type.");
358
359 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
360 "Reference ChannelShuffle: output is not a supported type.");
361
362 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
363 "Reference ChannelShuffle: input and output types are mismatched.");
364
365 return supported;
366}
367
368
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100369bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
370 const TensorInfo& input1,
371 const TensorInfo& output,
372 const ComparisonDescriptor& descriptor,
373 Optional<std::string&> reasonIfUnsupported) const
374{
Jan Eilers8eb25602020-03-09 12:13:48 +0000375 IgnoreUnused(descriptor);
Sadik Armagan303980c2020-04-17 12:45:14 +0100376 std::array<DataType, 8> supportedInputTypes =
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100377 {
Sadik Armaganb60dd242020-03-19 13:53:16 +0000378 DataType::Boolean,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000379 DataType::BFloat16,
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100380 DataType::Float32,
381 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100382 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000383 DataType::QAsymmU8,
Sadik Armaganb60dd242020-03-19 13:53:16 +0000384 DataType::QSymmS16,
385 DataType::Signed32
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100386 };
387
388 bool supported = true;
389 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
390 "Reference comparison: input 0 is not a supported type");
391
392 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
393 "Reference comparison: input 0 and Input 1 types are mismatched");
394
395 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
396 "Reference comparison: output is not of type Boolean");
397
398 return supported;
399}
400
Jim Flynn906f9462019-05-10 13:55:21 +0100401bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
402 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100403 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100404 Optional<std::string&> reasonIfUnsupported) const
405{
Jan Eilers8eb25602020-03-09 12:13:48 +0000406 IgnoreUnused(descriptor);
Jim Flynne242f2d2019-05-22 14:24:13 +0100407
408 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000409 std::array<DataType,6> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100410 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000411 DataType::BFloat16,
412 DataType::Float32,
413 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000414 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100415 DataType::QAsymmU8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000416 DataType::QSymmS16
Jim Flynne242f2d2019-05-22 14:24:13 +0100417 };
418
419 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
420 "Reference concatenation: output type not supported");
421 for (const TensorInfo* input : inputs)
422 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100423 ARMNN_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100424 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
425 "Reference concatenation: input type not supported");
426
427 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
428 "Reference concatenation: input and output types mismatched.");
429 }
430
431 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100432}
433
arovir011c7c81b2018-10-08 11:34:28 +0100434bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
435 Optional<std::string&> reasonIfUnsupported) const
436{
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100437 std::array<DataType,8> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100438 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000439 DataType::BFloat16,
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100440 DataType::Float16,
Nina Drozd58ef2c62019-05-16 12:09:18 +0100441 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +0000442 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100443 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000444 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100445 DataType::QSymmS16,
446 DataType::Signed32
Nina Drozd58ef2c62019-05-16 12:09:18 +0100447 };
448
449 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
450 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100451}
452
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000453bool RefLayerSupport::IsConvertBf16ToFp32Supported(const TensorInfo& input,
454 const TensorInfo& output,
455 Optional<std::string&> reasonIfUnsupported) const
456{
457 bool supported = true;
458
459 supported &= CheckSupportRule(TypeIs(input, DataType::BFloat16), reasonIfUnsupported,
460 "Reference for ConvertBf16ToFp32 layer: input type not supported");
461
462 supported &= CheckSupportRule(TypeIs(output, DataType::Float32), reasonIfUnsupported,
463 "Reference for ConvertBf16ToFp32 layer: output type not supported");
464
465 return supported;
466}
467
arovir011c7c81b2018-10-08 11:34:28 +0100468bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
469 const TensorInfo& output,
470 Optional<std::string&> reasonIfUnsupported) const
471{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100472 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
473 input.GetDataType(),
474 &TrueFunc<>,
475 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000476 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000477 &FalseFuncI32<>,
478 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100479 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
480 output.GetDataType(),
481 &FalseOutputFuncF16<>,
482 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000483 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000484 &FalseFuncI32<>,
485 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100486}
487
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000488bool RefLayerSupport::IsConvertFp32ToBf16Supported(const TensorInfo& input,
489 const TensorInfo& output,
490 Optional<std::string&> reasonIfUnsupported) const
491{
492 bool supported = true;
493
494 supported &= CheckSupportRule(TypeIs(input, DataType::Float32), reasonIfUnsupported,
495 "Reference for ConvertFp32ToBf16 layer: input type not supported");
496
497 supported &= CheckSupportRule(TypeIs(output, DataType::BFloat16), reasonIfUnsupported,
498 "Reference for ConvertFp32ToBf16 layer: output type not supported");
499
500 return supported;
501}
502
arovir011c7c81b2018-10-08 11:34:28 +0100503bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
504 const TensorInfo& output,
505 Optional<std::string&> reasonIfUnsupported) const
506{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100507 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
508 input.GetDataType(),
509 &FalseInputFuncF16<>,
510 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000511 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000512 &FalseFuncI32<>,
513 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100514 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
515 output.GetDataType(),
516 &TrueFunc<>,
517 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000518 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000519 &FalseFuncI32<>,
520 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100521}
522
523bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
524 const TensorInfo& output,
525 const Convolution2dDescriptor& descriptor,
526 const TensorInfo& weights,
527 const Optional<TensorInfo>& biases,
528 Optional<std::string&> reasonIfUnsupported) const
529{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100530 bool supported = true;
531
532 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000533 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000534 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000535 DataType::BFloat16,
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000536 DataType::Float32,
537 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000538 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100539 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000540 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000541 DataType::QSymmS16
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100542 };
543
544 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000545 "Reference Convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100546
547 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000548 "Reference Convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100549
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +0000550 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
551 if (input.GetDataType() == DataType::BFloat16)
552 {
553 if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
554 {
555 reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
556 supported = false;
557 }
558 }
559 else
560 {
561 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000562 "Reference Convolution2d: input and output types mismatched.");
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +0000563 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100564
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000565 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000566 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000567 {
Derek Lambertid466a542020-01-22 15:37:29 +0000568 ARMNN_NO_DEPRECATE_WARN_BEGIN
Keith Davis0c2eeac2020-02-11 16:51:50 +0000569 std::array<DataType, 4> supportedWeightTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000570 {
Sadik Armagan303980c2020-04-17 12:45:14 +0100571 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000572 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +0000573 DataType::QSymmS8,
574 DataType::QuantizedSymm8PerAxis // deprecated
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000575 };
Derek Lambertid466a542020-01-22 15:37:29 +0000576 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000577
578 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000579 "Reference Convolution2d: weights type not supported for quantized input.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000580 }
581 else
582 {
583 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000584 "Reference Convolution2d: weights is not a supported type.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000585
586 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000587 "Reference Convolution2d: input and weights types mismatched.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000588 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100589
590 if (biases.has_value())
591 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000592 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000593 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000594 DataType::BFloat16,
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000595 DataType::Float32,
596 DataType::Float16,
597 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100598 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000599
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100600 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000601 "Reference Convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100602 }
Jan Eilers8eb25602020-03-09 12:13:48 +0000603 IgnoreUnused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100604
605 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100606}
607
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000608bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
609 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000610 Optional<std::string&> reasonIfUnsupported) const
611{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100612 bool supported = true;
613
Narumol Prangnawarat403a1852020-03-12 14:24:13 +0000614 std::array<DataType, 8> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100615 {
Narumol Prangnawarat403a1852020-03-12 14:24:13 +0000616 DataType::BFloat16,
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +0000617 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100618 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000619 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100620 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000621 DataType::QSymmS8,
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +0000622 DataType::QSymmS16,
623 DataType::Signed32
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100624 };
625
626 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000627 "Reference for Debug layer: input type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100628
629 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000630 "Reference for Debug layer: output type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100631
632 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000633 "Reference for Debug layer: input and output types are mismatched");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100634
635 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000636}
637
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100638bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
639 const TensorInfo& output,
640 const DepthToSpaceDescriptor& descriptor,
641 Optional<std::string&> reasonIfUnsupported) const
642{
Jan Eilers8eb25602020-03-09 12:13:48 +0000643 IgnoreUnused(descriptor);
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100644 bool supported = true;
645
Sadik Armagan303980c2020-04-17 12:45:14 +0100646 std::array<DataType,6> supportedTypes =
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100647 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000648 DataType::BFloat16,
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100649 DataType::Float32,
650 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100651 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000652 DataType::QAsymmU8,
653 DataType::QSymmS16
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100654 };
655
656 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
657 "Reference DepthToSpace: input type not supported");
658
659 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
660 "Reference DepthToSpace: output type not supported");
661
662 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
663 "Reference DepthToSpace: input and output types are mismatched");
664
665 return supported;
666}
667
arovir011c7c81b2018-10-08 11:34:28 +0100668bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
669 const TensorInfo& output,
670 const DepthwiseConvolution2dDescriptor& descriptor,
671 const TensorInfo& weights,
672 const Optional<TensorInfo>& biases,
673 Optional<std::string&> reasonIfUnsupported) const
674{
Sadik Armagan303980c2020-04-17 12:45:14 +0100675 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100676 bool supported = true;
677
678 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000679 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100680 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000681 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100682 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100683 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000684 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000685 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100686 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000687 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100688 };
689
690 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
691 "Reference DepthwiseConvolution2d: input is not a supported type.");
692
693 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
694 "Reference DepthwiseConvolution2d: output is not a supported type.");
695
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100696 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
697 "Reference DepthwiseConvolution2d: input and output types mismatched.");
698
Teresa Charlind8df0262019-11-11 12:28:15 +0000699 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000700 if (IsQuantized8BitType(inputType))
Teresa Charlind8df0262019-11-11 12:28:15 +0000701 {
Sadik Armagan303980c2020-04-17 12:45:14 +0100702 ARMNN_NO_DEPRECATE_WARN_BEGIN
703 std::array<DataType, 4> supportedWeightTypes =
704 {
705 DataType::QAsymmS8,
706 DataType::QAsymmU8,
707 DataType::QSymmS8,
708 DataType::QuantizedSymm8PerAxis // deprecated
709 };
710 ARMNN_NO_DEPRECATE_WARN_END
Teresa Charlind8df0262019-11-11 12:28:15 +0000711
712 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Sadik Armagan303980c2020-04-17 12:45:14 +0100713 "Reference DepthwiseConvolution2d: weights type not supported for "
714 "quantized input.");
Teresa Charlind8df0262019-11-11 12:28:15 +0000715 }
716 else
717 {
718 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
719 "Reference DepthwiseConvolution2d: weights is not a supported type.");
720
721 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
722 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
723 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100724
725 if (biases.has_value())
726 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000727 std::array<DataType,4> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100728 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000729 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100730 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100731 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100732 DataType::Signed32
733 };
734 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
735 "Reference DepthwiseConvolution2d: biases is not a supported type.");
736 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100737
738 return supported;
739
arovir011c7c81b2018-10-08 11:34:28 +0100740}
741
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000742bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
743 const TensorInfo& output,
744 Optional<std::string&> reasonIfUnsupported) const
745{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100746 bool supported = true;
747
Ryan OShea9add1202020-02-07 10:06:33 +0000748 std::array<DataType,4> supportedInputTypes = {
749 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000750 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +0000751 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000752 DataType::QSymmS16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100753 };
754
755 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000756 "Reference for Dequantize layer: input type not supported.");
757
Derek Lambertid466a542020-01-22 15:37:29 +0000758 supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
Teresa Charlin1b1950d2021-06-02 20:23:21 +0100759 "Reference for Dequantize layer: per-axis quantized input not supported.");
Derek Lambertid466a542020-01-22 15:37:29 +0000760
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000761 std::array<DataType,3> supportedOutputTypes = {
762 DataType::BFloat16,
Jan Eilersf7107932019-11-01 11:09:36 +0000763 DataType::Float32,
764 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100765 };
766
767 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000768 "Reference for Dequantize layer: output type not supported.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100769
770 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000771 "Reference for Dequantize layer: input/output shapes have different num total "
772 "elements.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100773
774 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000775}
776
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000777bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
778 const TensorInfo& scores,
779 const TensorInfo& anchors,
780 const TensorInfo& detectionBoxes,
781 const TensorInfo& detectionClasses,
782 const TensorInfo& detectionScores,
783 const TensorInfo& numDetections,
784 const DetectionPostProcessDescriptor& descriptor,
785 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000786{
Jan Eilers8eb25602020-03-09 12:13:48 +0000787 IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
Derek Lamberti901ea112019-12-10 22:07:09 +0000788
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100789 bool supported = true;
790
Sadik Armaganaa41d5d2020-11-16 14:27:52 +0000791 std::array<DataType,6> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100792 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000793 DataType::BFloat16,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100794 DataType::Float32,
Sadik Armaganaa41d5d2020-11-16 14:27:52 +0000795 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100796 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000797 DataType::QAsymmU8,
798 DataType::QSymmS16
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100799 };
800
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000801 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100802 "Reference DetectionPostProcess: input 0 is not a supported type.");
803
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000804 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100805 "Reference DetectionPostProcess: input 1 is not a supported type.");
806
807 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000808}
809
Pablo Tellof0bd6832019-04-26 17:58:13 +0100810bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
811 const TensorInfo& output,
812 const DepthwiseConvolution2dDescriptor& descriptor,
813 const TensorInfo& weights,
814 const Optional<TensorInfo>& biases,
815 Optional<std::string&> reasonIfUnsupported) const
816{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100817 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100818}
819
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100820bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100821 const TensorInfo& input1,
822 const TensorInfo& output,
823 Optional<std::string&> reasonIfUnsupported) const
824{
Sadik Armagan2999a022019-04-09 14:20:12 +0100825 bool supported = true;
826
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100827 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000828 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +0100829 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100830 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100831 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000832 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100833 DataType::QSymmS16,
834 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +0100835 };
836
837 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
838 "Reference division: input 0 is not a supported type.");
839
840 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
841 "Reference division: input 1 is not a supported type.");
842
843 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
844 "Reference division: output is not a supported type.");
845
846 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
847 "Reference division: input 0 and Input 1 types are mismatched");
848
849 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
850 "Reference division: input and output types are mismatched");
851
852 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
853 "Reference division: shapes are not suitable for implicit broadcast.");
854
855 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100856}
857
josh minor4a3c6102020-01-06 16:40:46 -0600858bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
859 const TensorInfo& output,
860 const ElementwiseUnaryDescriptor& descriptor,
861 Optional<std::string&> reasonIfUnsupported) const
862{
Jan Eilers8eb25602020-03-09 12:13:48 +0000863 IgnoreUnused(descriptor);
josh minor4a3c6102020-01-06 16:40:46 -0600864
Sadik Armagan303980c2020-04-17 12:45:14 +0100865 std::array<DataType, 7> supportedTypes =
josh minor4a3c6102020-01-06 16:40:46 -0600866 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000867 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -0600868 DataType::Float32,
869 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100870 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -0600871 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +0000872 DataType::QSymmS16,
873 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -0600874 };
875
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +0000876 std::array<DataType, 1> logicalSupportedTypes =
877 {
878 DataType::Boolean
879 };
880
josh minor4a3c6102020-01-06 16:40:46 -0600881 bool supported = true;
882
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +0000883 if (descriptor.m_Operation == UnaryOperation::LogicalNot)
884 {
885 supported &= CheckSupportRule(TypeAnyOf(input, logicalSupportedTypes), reasonIfUnsupported,
886 "Reference elementwise unary: input type not supported");
josh minor4a3c6102020-01-06 16:40:46 -0600887
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +0000888 supported &= CheckSupportRule(TypeAnyOf(output, logicalSupportedTypes), reasonIfUnsupported,
889 "Reference elementwise unary: output type not supported");
890 }
891 else
892 {
893 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
894 "Reference elementwise unary: input type not supported");
895
896 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
897 "Reference elementwise unary: output type not supported");
898 }
josh minor4a3c6102020-01-06 16:40:46 -0600899
900 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
901 "Reference elementwise unary: input and output types not matching");
902
903 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
904 "Reference elementwise unary: input and output shapes"
905 "have different number of total elements");
906
907 return supported;
908}
909
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000910bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
911 const TensorInfo& input1,
912 const TensorInfo& output,
913 Optional<std::string&> reasonIfUnsupported) const
914{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100915 return IsComparisonSupported(input0,
916 input1,
917 output,
918 ComparisonDescriptor(ComparisonOperation::Equal),
919 reasonIfUnsupported);
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000920}
921
arovir011c7c81b2018-10-08 11:34:28 +0100922bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
923 const FakeQuantizationDescriptor& descriptor,
924 Optional<std::string&> reasonIfUnsupported) const
925{
Jan Eilers8eb25602020-03-09 12:13:48 +0000926 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100927 bool supported = true;
928
929 std::array<DataType,1> supportedTypes =
930 {
931 DataType::Float32
932 };
933
934 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
935 "Reference fake quantization: input type not supported.");
936
937 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100938}
939
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +0100940bool RefLayerSupport::IsFillSupported(const TensorInfo& input,
941 const TensorInfo& output,
942 const FillDescriptor& descriptor,
943 Optional<std::string&> reasonIfUnsupported) const
944{
945 IgnoreUnused(descriptor);
946 IgnoreUnused(output);
947
948 bool supported = true;
949
Sadik Armagana792a052020-06-23 16:22:23 +0100950 std::array<DataType,3> supportedTypes =
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +0100951 {
952 DataType::Float32,
Sadik Armagana792a052020-06-23 16:22:23 +0100953 DataType::Float16,
954 DataType::Signed32
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +0100955 };
956
Teresa Charlin4b10fef2020-07-29 09:36:41 +0100957 supported &= CheckSupportRule(TypeIs(input, DataType::Signed32), reasonIfUnsupported,
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +0100958 "Reference Fill: input type not supported.");
959
Teresa Charlin44088502020-07-27 11:27:19 +0100960 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
961 "Reference Fill: output type not supported.");
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +0100962 return supported;
963}
964
arovir011c7c81b2018-10-08 11:34:28 +0100965bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
966 const TensorInfo& output,
967 Optional<std::string&> reasonIfUnsupported) const
968{
Jan Eilers8eb25602020-03-09 12:13:48 +0000969 IgnoreUnused(output);
James Conroy83735b12019-05-30 16:36:59 +0100970 bool supported = true;
971
Francis Murtaghe8ac1332020-07-30 18:03:40 +0100972 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100973 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000974 DataType::BFloat16,
James Conroyb40d7102019-06-04 12:32:09 +0100975 DataType::Float32,
Francis Murtaghe8ac1332020-07-30 18:03:40 +0100976 DataType::Float16
James Conroy83735b12019-05-30 16:36:59 +0100977 };
978
979 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
980 "Reference Floor: input type not supported.");
981
982 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
983 "Reference Floor: output type not supported.");
984
985 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100986}
987
988bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
989 const TensorInfo& output,
990 const TensorInfo& weights,
991 const TensorInfo& biases,
992 const FullyConnectedDescriptor& descriptor,
993 Optional<std::string&> reasonIfUnsupported) const
994{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100995 bool supported = true;
996
997 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000998 std::array<DataType,6> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +0100999 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001000 DataType::BFloat16,
1001 DataType::Float32,
1002 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001003 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001004 DataType::QAsymmU8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001005 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +01001006 };
1007
1008 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1009 "Reference Fully Connected: input type not supported.");
1010
1011 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1012 "Reference Fully Connected: output type not supported.");
1013
Francis Murtagh46c09d02019-05-28 08:15:28 +01001014 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1015 "Reference Fully Connected: weights type not supported.");
1016
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +00001017 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1018 if (input.GetDataType() == DataType::BFloat16)
1019 {
1020 if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
1021 {
1022 reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
1023 supported = false;
1024 }
1025 }
1026 else
1027 {
1028 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1029 "Reference Fully Connected: input and output types mismatched.");
1030 }
1031
Jan Eilers1f45dc32020-06-15 11:43:03 +01001032 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1033 "Reference Fully Connected: weights is not a supported type.");
Francis Murtaghddb1d062020-03-10 13:51:45 +00001034
Jan Eilers1f45dc32020-06-15 11:43:03 +01001035 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1036 "Reference Fully Connected: input and weights types mismatched.");
Francis Murtagh46c09d02019-05-28 08:15:28 +01001037
1038 if (descriptor.m_BiasEnabled)
1039 {
1040 // Defined supported types for bias
Sadik Armagandb73c982020-04-01 17:35:30 +01001041 std::array<DataType, 5>
Francis Murtagh46c09d02019-05-28 08:15:28 +01001042 supportedBiasTypes =
1043 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001044 DataType::BFloat16,
Francis Murtagh46c09d02019-05-28 08:15:28 +01001045 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001046 DataType::Float16,
Sadik Armagandb73c982020-04-01 17:35:30 +01001047 DataType::Signed32,
1048 DataType::QAsymmS8
Francis Murtagh46c09d02019-05-28 08:15:28 +01001049 };
1050
1051 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
1052 "Reference Fully Connected: bias type not supported.");
1053
1054 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
1055 "Reference Fully Connected: bias and weight types mismatch.");
1056
1057 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
1058 "Reference Fully Connected: bias type inferred from weights is incompatible.");
1059
Narumol Prangnawarat366d7232020-04-29 12:58:17 +01001060 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(biases, 1U), reasonIfUnsupported,
1061 "Reference Fully Connected: bias must have 1 dimension.");
1062
Francis Murtagh46c09d02019-05-28 08:15:28 +01001063 }
1064
1065 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001066}
1067
narpra014951d842019-01-18 16:53:53 +00001068bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
1069 const armnn::TensorInfo& input1,
1070 const armnn::TensorInfo& output,
Teresa Charlin52664732020-06-29 16:27:03 +01001071 const GatherDescriptor& descriptor,
narpra014951d842019-01-18 16:53:53 +00001072 armnn::Optional<std::string&> reasonIfUnsupported) const
1073{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001074 bool supported = true;
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001075 std::array<DataType,7> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001076 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001077 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001078 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001079 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001080 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001081 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001082 DataType::QSymmS16,
1083 DataType::Signed32
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001084 };
1085
Teresa Charlin52664732020-06-29 16:27:03 +01001086 if (descriptor.m_Axis != 0)
1087 {
1088 reasonIfUnsupported.value() += std::string("Reference Gather: axis not supported\n");
1089 supported &= false;
1090 }
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001091 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1092 "Reference Gather: input type not supported");
1093
1094 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1095 "Reference Gather: output type not supported");
1096
1097 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1098 "Reference Gather: indices (input1) type not supported");
1099
1100 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1101 "Reference Gather: input and output types not matching");
1102
1103 return supported;
narpra014951d842019-01-18 16:53:53 +00001104}
1105
FrancisMurtagh878f0232018-12-19 10:56:15 +00001106bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
1107 const TensorInfo& input1,
1108 const TensorInfo& output,
1109 Optional<std::string&> reasonIfUnsupported) const
1110{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01001111 return IsComparisonSupported(input0,
1112 input1,
1113 output,
1114 ComparisonDescriptor(ComparisonOperation::Greater),
1115 reasonIfUnsupported);
FrancisMurtagh878f0232018-12-19 10:56:15 +00001116}
1117
Derek Lamberti901ea112019-12-10 22:07:09 +00001118bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
1119 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +01001120{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001121 return true;
arovir011c7c81b2018-10-08 11:34:28 +01001122}
1123
Kevin May09ca49c2019-10-09 12:37:34 +01001124bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
1125 const TensorInfo& output,
1126 const InstanceNormalizationDescriptor& descriptor,
1127 Optional<std::string&> reasonIfUnsupported) const
1128{
Jan Eilers8eb25602020-03-09 12:13:48 +00001129 IgnoreUnused(descriptor);
Kevin May09ca49c2019-10-09 12:37:34 +01001130 // Define supported types
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001131 std::array<DataType, 3> supportedTypes =
Kevin May09ca49c2019-10-09 12:37:34 +01001132 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001133 DataType::BFloat16,
Kevin May09ca49c2019-10-09 12:37:34 +01001134 DataType::Float32,
1135 DataType::Float16
1136 };
1137
1138 bool supported = true;
1139
1140 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1141 "Reference Instance Normalization: input type not supported.");
1142
1143 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1144 "Reference Instance Normalization: output type not supported.");
1145
1146 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1147 "Reference Instance Normalization: input and output types mismatched.");
1148
1149 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1150 "Reference Instance Normalization: input and output shapes have different "
1151 "num total elements.");
1152
1153 return supported;
1154}
1155
arovir011c7c81b2018-10-08 11:34:28 +01001156bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1157 const TensorInfo& output,
1158 const L2NormalizationDescriptor& descriptor,
1159 Optional<std::string&> reasonIfUnsupported) const
1160{
Jan Eilers8eb25602020-03-09 12:13:48 +00001161 IgnoreUnused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001162 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01001163 std::array<DataType, 6> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001164 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001165 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001166 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001167 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001168 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001169 DataType::QAsymmU8,
1170 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001171 };
1172
1173 bool supported = true;
1174
1175 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1176 "Reference L2normalization: input type not supported.");
1177
1178 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1179 "Reference L2normalization: output type not supported.");
1180
1181 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1182 "Reference L2normalization: input and output types mismatched.");
1183
1184 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1185 "Reference L2normalization: input and output shapes have different "
1186 "num total elements.");
1187
1188 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001189}
1190
James Conroyaba90cd2020-11-06 16:28:18 +00001191bool RefLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0,
1192 const TensorInfo& input1,
1193 const TensorInfo& output,
1194 const LogicalBinaryDescriptor& descriptor,
1195 Optional<std::string&> reasonIfUnsupported) const
1196{
1197 IgnoreUnused(descriptor);
1198
1199 std::array<DataType, 1> supportedTypes =
1200 {
1201 DataType::Boolean
1202 };
1203
1204 bool supported = true;
1205 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1206 "Reference LogicalBinary: input 0 type not supported");
1207 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1208 "Reference LogicalBinary: input 1 type not supported");
1209
1210 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1211 "Reference LogicalBinary: input and output types do not match");
1212
1213 return supported;
1214}
1215
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001216bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1217 const TensorInfo& output,
1218 const LogSoftmaxDescriptor& descriptor,
1219 Optional<std::string&> reasonIfUnsupported) const
1220{
Jan Eilers8eb25602020-03-09 12:13:48 +00001221 IgnoreUnused(descriptor);
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001222
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001223 std::array<DataType, 3> supportedTypes =
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001224 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001225 DataType::BFloat16,
1226 DataType::Float32,
1227 DataType::Float16
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001228 };
1229
1230 bool supported = true;
1231 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1232 "Reference LogSoftmax: input type not supported");
1233
1234 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1235 "Reference LogSoftmax: output type not supported");
1236
1237 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1238 "Reference LogSoftmax: input and output types do not match");
1239
1240 return supported;
1241}
1242
arovir011c7c81b2018-10-08 11:34:28 +01001243bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1244 const TensorInfo& outputStateIn,
1245 const TensorInfo& cellStateIn,
1246 const TensorInfo& scratchBuffer,
1247 const TensorInfo& outputStateOut,
1248 const TensorInfo& cellStateOut,
1249 const TensorInfo& output,
1250 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001251 const LstmInputParamsInfo& paramsInfo,
1252 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +01001253{
Jan Eilers8eb25602020-03-09 12:13:48 +00001254 IgnoreUnused(descriptor);
1255 IgnoreUnused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001256
1257 bool supported = true;
1258
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001259 std::array<DataType,3> supportedTypes = {
1260 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001261 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001262 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001263 };
1264
Jan Eilersd01a83c2019-07-03 18:20:40 +01001265 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001266 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1267 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001268 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1269 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001270 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1271 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001272 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1273 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001274 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1275 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001276 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1277 "Reference Lstm: input and cellStateOut types are mismatched");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01001278
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001279 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1280 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001281 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001282 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001283 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001284 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001285 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001286 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001287 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001288 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001289 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001290 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001291 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001292 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001293 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001294 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001295 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001296 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001297 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001298 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001299 "Reference Lstm: input and OutputGateBias types are mismatched");
1300 if (!descriptor.m_CifgEnabled)
1301 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001302 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001303 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001304 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001305 reasonIfUnsupported,
1306 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001307 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001308 "Reference Lstm: input and InputGateBias types are mismatched");
1309 if (descriptor.m_PeepholeEnabled)
1310 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001311 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001312 reasonIfUnsupported,
1313 "Reference Lstm: input and CellToInputWeights types are mismatched");
1314 }
1315 }
1316 if (descriptor.m_PeepholeEnabled)
1317 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001318 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001319 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001320 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001321 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1322 }
1323 if (descriptor.m_ProjectionEnabled)
1324 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001325 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001326 "Reference Lstm: input and mProjectionWeights types are mismatched");
1327 if (paramsInfo.m_ProjectionBias != nullptr)
1328 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001329 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001330 "Reference Lstm: input and ProjectionBias types are mismatched");
1331 }
1332 }
1333 if (descriptor.m_LayerNormEnabled)
1334 {
1335 if (!descriptor.m_CifgEnabled)
1336 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001337 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001338 reasonIfUnsupported,
1339 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1340 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001341 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001342 reasonIfUnsupported,
1343 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001344 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001345 reasonIfUnsupported,
1346 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001347 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001348 reasonIfUnsupported,
1349 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1350 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001351
1352 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001353}
1354
saoste012df12b32018-11-28 16:57:20 +00001355bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1356 const TensorInfo& input1,
1357 const TensorInfo& output,
1358 Optional<std::string&> reasonIfUnsupported) const
1359{
Sadik Armagan2999a022019-04-09 14:20:12 +01001360 bool supported = true;
1361
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001362 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001363 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001364 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001365 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001366 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001367 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001368 DataType::QSymmS16,
1369 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001370 };
1371
1372 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1373 "Reference maximum: input 0 is not a supported type.");
1374
1375 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1376 "Reference maximum: input 1 is not a supported type.");
1377
1378 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1379 "Reference maximum: output is not a supported type.");
1380
1381 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1382 "Reference maximum: input 0 and Input 1 types are mismatched");
1383
1384 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1385 "Reference maximum: input and output types are mismatched");
1386
1387 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1388 "Reference maximum: shapes are not suitable for implicit broadcast.");
1389
1390 return supported;
saoste012df12b32018-11-28 16:57:20 +00001391}
1392
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001393bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1394 const TensorInfo& output,
1395 const MeanDescriptor& descriptor,
1396 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001397{
James Conroy4d1ff582019-06-10 17:06:39 +01001398 bool supported = true;
1399 std::string meanLayerStr = "Mean";
1400 std::string outputTensorStr = "output";
1401
Sadik Armagan303980c2020-04-17 12:45:14 +01001402 std::array<DataType,6> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001403 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001404 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01001405 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001406 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001407 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001408 DataType::QAsymmU8,
1409 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01001410 };
1411
1412 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1413 "Reference Mean: input type not supported.");
1414
1415 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1416 "Reference Mean: input and output types are mismatched");
1417
1418 if (descriptor.m_KeepDims)
1419 {
1420 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1421 reasonIfUnsupported,
1422 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1423 output.GetNumDimensions(),
1424 meanLayerStr, outputTensorStr).data());
1425 }
1426 else if (descriptor.m_Axis.empty())
1427 {
1428 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1429 reasonIfUnsupported,
1430 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1431 meanLayerStr, outputTensorStr).data());
1432 }
1433 else
1434 {
Matthew Sloyan171214c2020-09-09 09:07:37 +01001435 auto outputDim = input.GetNumDimensions() - armnn::numeric_cast<unsigned int>(descriptor.m_Axis.size());
James Conroy4d1ff582019-06-10 17:06:39 +01001436
1437 if (outputDim > 0)
1438 {
1439 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1440 reasonIfUnsupported,
1441 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1442 meanLayerStr, outputTensorStr).data());
1443 }
1444 else
1445 {
1446 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1447 reasonIfUnsupported,
1448 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1449 meanLayerStr, outputTensorStr).data());
1450 }
1451 }
1452
1453 return supported;
narpra0132b90462018-09-13 11:07:48 +01001454}
1455
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001456bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +00001457 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +01001458 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001459 Optional<std::string&> reasonIfUnsupported) const
1460{
Jim Flynne242f2d2019-05-22 14:24:13 +01001461 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001462}
1463
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001464bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1465 const TensorInfo &output,
1466 Optional<std::string &> reasonIfUnsupported) const
1467{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001468 bool supported = true;
1469
Sadik Armagan303980c2020-04-17 12:45:14 +01001470 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001471 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001472 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001473 DataType::Float32,
1474 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001475 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001476 DataType::QAsymmU8,
1477 DataType::QSymmS16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001478 DataType::Boolean
1479 };
1480
1481 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1482 "Reference MemCopy: input type not supported");
1483
1484 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1485 "Reference MemCopy: output type not supported");
1486
1487 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1488 "Reference MemCopy: input and output types are mismatched");
1489
1490 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001491}
1492
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001493bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1494 const TensorInfo& input1,
1495 const TensorInfo& output,
1496 Optional<std::string&> reasonIfUnsupported) const
1497{
Sadik Armagan2999a022019-04-09 14:20:12 +01001498 bool supported = true;
1499
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001500 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001501 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001502 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001503 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001504 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001505 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001506 DataType::QSymmS16,
1507 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001508 };
1509
1510 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1511 "Reference minimum: input 0 is not a supported type.");
1512
1513 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1514 "Reference minimum: input 1 is not a supported type.");
1515
1516 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1517 "Reference minimum: output is not a supported type.");
1518
1519 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1520 "Reference minimum: input 0 and Input 1 types are mismatched");
1521
1522 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1523 "Reference minimum: input and output types are mismatched");
1524
1525 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1526 "Reference minimum: shapes are not suitable for implicit broadcast.");
1527
1528 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001529}
1530
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001531bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1532 const TensorInfo& input1,
1533 const TensorInfo& output,
1534 Optional<std::string&> reasonIfUnsupported) const
1535{
Sadik Armagan2999a022019-04-09 14:20:12 +01001536 bool supported = true;
1537
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001538 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001539 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001540 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001541 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001542 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001543 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001544 DataType::QSymmS16,
1545 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001546 };
1547
1548 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1549 "Reference multiplication: input 0 is not a supported type.");
1550
1551 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1552 "Reference multiplication: input 1 is not a supported type.");
1553
1554 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1555 "Reference multiplication: output is not a supported type.");
1556
1557 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1558 "Reference multiplication: input 0 and Input 1 types are mismatched");
1559
1560 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1561 "Reference multiplication: input and output types are mismatched");
1562
1563 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1564 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1565
1566 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001567}
1568
1569bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1570 const TensorInfo& output,
1571 const NormalizationDescriptor& descriptor,
1572 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001573{
Jan Eilers8eb25602020-03-09 12:13:48 +00001574 IgnoreUnused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001575
1576 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01001577 std::array<DataType, 6> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001578 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001579 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001580 DataType::Float16,
1581 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001582 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001583 DataType::QAsymmU8,
1584 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001585 };
1586
1587 bool supported = true;
1588
1589 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1590 "Reference normalization: input type not supported.");
1591
1592 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1593 "Reference normalization: output type not supported.");
1594
1595 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1596 "Reference normalization: input and output shapes have different "
1597 "num total elements.");
1598
1599 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001600}
1601
Derek Lamberti901ea112019-12-10 22:07:09 +00001602bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
1603 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001604{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001605 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001606}
1607
1608bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1609 const TensorInfo& output,
1610 const PadDescriptor& descriptor,
1611 Optional<std::string&> reasonIfUnsupported) const
1612{
Jan Eilers8eb25602020-03-09 12:13:48 +00001613 IgnoreUnused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001614 bool supported = true;
1615
1616 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01001617 std::array<DataType,6> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001618 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001619 DataType::BFloat16,
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001620 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001621 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001622 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001623 DataType::QAsymmU8,
1624 DataType::QSymmS16
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001625 };
1626
1627 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1628 "Reference pad: input is not a supported type.");
1629
1630 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1631 "Reference pad: output is not a supported type.");
1632
1633 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1634 "Reference pad: input and output types are mismatched.");
1635
1636 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01001637}
1638
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001639bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1640 const TensorInfo& output,
1641 const PermuteDescriptor& descriptor,
1642 Optional<std::string&> reasonIfUnsupported) const
1643{
Jan Eilers8eb25602020-03-09 12:13:48 +00001644 IgnoreUnused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001645 bool supported = true;
1646
1647 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01001648 std::array<DataType, 6> supportedTypes =
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001649 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001650 DataType::BFloat16,
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001651 DataType::Float32,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001652 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001653 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001654 DataType::QAsymmU8,
1655 DataType::QSymmS16
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001656 };
1657
1658 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1659 "Reference permute: input is not a supported type.");
1660
1661 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1662 "Reference permute: output is not a supported type.");
1663
1664 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1665 "Reference permute: input and output types are mismatched.");
1666
1667 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001668}
1669
1670bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1671 const TensorInfo& output,
1672 const Pooling2dDescriptor& descriptor,
1673 Optional<std::string&> reasonIfUnsupported) const
1674{
Jan Eilers8eb25602020-03-09 12:13:48 +00001675 IgnoreUnused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001676 bool supported = true;
1677
1678 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001679 std::array<DataType,6> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001680 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001681 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01001682 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001683 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001684 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001685 DataType::QAsymmU8,
1686 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001687 };
1688
1689 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1690 "Reference poolind2d: input is not a supported type.");
1691
1692 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1693 "Reference poolind2d: output is not a supported type.");
1694
1695 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1696 "Reference poolind2d: input and output types are mismatched.");
1697
1698 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001699}
1700
James Conroy4f1f8992020-04-29 20:01:10 +01001701bool RefLayerSupport::IsQLstmSupported(const TensorInfo& input,
1702 const TensorInfo& previousOutputIn,
1703 const TensorInfo& previousCellStateIn,
1704 const TensorInfo& outputStateOut,
1705 const TensorInfo& cellStateOut,
1706 const TensorInfo& output,
1707 const QLstmDescriptor& descriptor,
1708 const LstmInputParamsInfo& paramsInfo,
1709 Optional<std::string&> reasonIfUnsupported) const
1710{
1711 IgnoreUnused(input);
1712 IgnoreUnused(previousOutputIn);
1713 IgnoreUnused(previousCellStateIn);
1714 IgnoreUnused(outputStateOut);
1715 IgnoreUnused(cellStateOut);
1716 IgnoreUnused(output);
1717 IgnoreUnused(descriptor);
1718 IgnoreUnused(paramsInfo);
1719
1720 IgnoreUnused(reasonIfUnsupported);
1721
1722 return true;
1723}
1724
Derek Lamberti5f400d62019-03-25 15:41:58 +00001725bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1726 const TensorInfo& output,
1727 Optional<std::string&> reasonIfUnsupported) const
1728{
1729 bool supported = true;
1730
Finn Williamsfd271062019-12-04 14:27:27 +00001731 // Define supported input types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001732 std::array<DataType,7> supportedInputTypes = {
1733 DataType::BFloat16,
Keith Davis5e51cd82020-01-29 16:52:59 +00001734 DataType::Float32,
Keith Davis3d8bc972020-02-04 09:31:47 +00001735 DataType::Float16,
Ryan OShea9add1202020-02-07 10:06:33 +00001736 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00001737 DataType::QAsymmU8,
1738 DataType::QSymmS8,
1739 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00001740 };
1741
1742 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1743 "Reference quantize: input type not supported.");
1744
1745 // Define supported output types.
Ryan OShea9add1202020-02-07 10:06:33 +00001746 std::array<DataType,4> supportedOutputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00001747 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001748 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00001749 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001750 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00001751 };
1752 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1753 "Reference quantize: output type not supported.");
1754
1755 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1756 "Reference quantize: input and output shapes have different num total elements.");
1757
1758 return supported;
1759}
1760
Finn Williams2605b232020-06-10 15:53:46 +01001761bool RefLayerSupport::IsRankSupported(const TensorInfo& input,
1762 const TensorInfo& output,
1763 Optional<std::string&> reasonIfUnsupported) const
1764{
1765 IgnoreUnused(input);
1766 // Define supported output types.
1767 std::array<DataType,1> supportedOutputTypes =
1768 {
1769 DataType::Signed32,
1770 };
1771
1772 return CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1773 "Reference rank: input type not supported.");
1774}
1775
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001776bool RefLayerSupport::IsReduceSupported(const TensorInfo& input,
1777 const TensorInfo& output,
1778 const ReduceDescriptor& descriptor,
1779 Optional<std::string&> reasonIfUnsupported) const
1780{
1781 IgnoreUnused(descriptor);
1782 bool supported = true;
1783 std::array<DataType,7> supportedTypes =
1784 {
1785 DataType::BFloat16,
1786 DataType::Float32,
1787 DataType::Float16,
1788 DataType::QAsymmS8,
1789 DataType::QAsymmU8,
1790 DataType::QSymmS16,
1791 DataType::Signed32
1792 };
1793
1794 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1795 "Reference Reduce: input type not supported");
1796
1797 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1798 "Reference Reduce: output type not supported");
1799
1800 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1801 "Reference Reduce: input and output types not matching");
1802
1803 return supported;
1804}
1805
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001806bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Kevin Maya023c402019-12-12 17:28:05 +00001807 const TensorInfo& output,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001808 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001809 Optional<std::string&> reasonIfUnsupported) const
1810{
Jan Eilers8eb25602020-03-09 12:13:48 +00001811 IgnoreUnused(output);
1812 IgnoreUnused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001813 // Define supported output types.
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001814 std::array<DataType,8> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001815 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001816 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001817 DataType::Float32,
1818 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001819 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001820 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001821 DataType::QAsymmU8,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001822 DataType::QSymmS16,
1823 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01001824 };
Keith Davis0c2eeac2020-02-11 16:51:50 +00001825
Nina Drozd2f2778f2019-05-27 10:37:05 +01001826 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1827 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001828}
1829
1830bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001831 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001832 Optional<std::string&> reasonIfUnsupported) const
1833{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001834 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01001835 std::array<DataType,6> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001836 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001837 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001838 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001839 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001840 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001841 DataType::QAsymmU8,
1842 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001843 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001844
1845 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1846 "Reference ResizeBilinear: input type not supported");
1847
1848 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1849 "Reference ResizeBilinear: output type not supported");
1850
1851 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1852 "Reference ResizeBilinear: input and output types not matching");
1853
1854 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001855}
1856
Teresa Charlin970f43b2019-07-01 13:51:07 +01001857bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1858 const TensorInfo& output,
1859 const ResizeDescriptor& descriptor,
1860 Optional<std::string&> reasonIfUnsupported) const
1861{
Jan Eilers8eb25602020-03-09 12:13:48 +00001862 IgnoreUnused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01001863 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001864 std::array<DataType,6> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001865 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001866 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001867 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001868 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001869 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001870 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001871 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001872 };
1873
1874 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1875 "Reference Resize: input type not supported");
1876
1877 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1878 "Reference Resize: output type not supported");
1879
1880 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1881 "Reference Resize: input and output types not matching");
1882
1883 return supported;
1884}
1885
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001886bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1887 const TensorInfo& output,
1888 Optional<std::string&> reasonIfUnsupported) const
1889{
josh minor4a3c6102020-01-06 16:40:46 -06001890 return IsElementwiseUnarySupported(input,
1891 output,
1892 ElementwiseUnaryDescriptor(UnaryOperation::Rsqrt),
1893 reasonIfUnsupported);
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001894}
1895
Keith Davis3ae3f972021-05-21 16:33:48 +01001896bool RefLayerSupport::IsShapeSupported(const TensorInfo& input,
1897 const TensorInfo& output,
1898 Optional<std::string&> reasonIfUnsupported) const
1899{
1900 IgnoreUnused(input);
1901 bool supported = true;
1902
1903 std::array<DataType, 1> supportedTypes =
1904 {
1905 DataType::Signed32
1906 };
1907
1908 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1909 "Reference Shape: output type not supported");
1910
1911 return supported;
1912}
1913
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001914bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
1915 const TensorInfo& output,
1916 const SliceDescriptor& descriptor,
1917 Optional<std::string&> reasonIfUnsupported) const
1918{
Jan Eilers8eb25602020-03-09 12:13:48 +00001919 IgnoreUnused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001920 bool supported = true;
1921
Sadik Armagan303980c2020-04-17 12:45:14 +01001922 std::array<DataType, 5> supportedTypes =
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001923 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001924 DataType::BFloat16,
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001925 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001926 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001927 DataType::QAsymmU8,
1928 DataType::QSymmS16
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001929 };
1930
1931 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1932 "Reference Slice: input type not supported");
1933
1934 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1935 "Reference Slice: output type not supported");
1936
1937 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1938 "Reference Slice: input and output types are mismatched");
1939
1940 return supported;
1941}
1942
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001943bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1944 const TensorInfo& output,
1945 const SoftmaxDescriptor& descriptor,
1946 Optional<std::string&> reasonIfUnsupported) const
1947{
Jan Eilers8eb25602020-03-09 12:13:48 +00001948 IgnoreUnused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01001949 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001950 std::array<DataType,7> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01001951 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001952 DataType::BFloat16,
1953 DataType::Float32,
1954 DataType::Float16,
1955 DataType::QSymmS8,
1956 DataType::QAsymmS8,
1957 DataType::QAsymmU8,
1958 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +01001959 };
1960
1961 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001962 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001963
1964 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001965 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001966
1967 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001968 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001969
1970 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001971}
1972
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001973bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1974 const TensorInfo& output,
1975 const SpaceToBatchNdDescriptor& descriptor,
1976 Optional<std::string&> reasonIfUnsupported) const
1977{
Jan Eilers8eb25602020-03-09 12:13:48 +00001978 IgnoreUnused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01001979 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01001980 std::array<DataType,6> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01001981 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001982 DataType::BFloat16,
1983 DataType::Float32,
1984 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001985 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001986 DataType::QAsymmU8,
1987 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001988 };
1989
1990 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1991 "Reference SpaceToBatchNd: input type not supported");
1992
1993 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1994 "Reference SpaceToBatchNd: output type not supported");
1995
1996 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1997 "Reference SpaceToBatchNd: input and output types are mismatched");
1998
1999 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00002000}
2001
Keith Davisa57eccb2019-06-14 17:33:22 +01002002bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01002003 const TensorInfo& output,
2004 const SpaceToDepthDescriptor& descriptor,
2005 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01002006{
2007
Jan Eilers8eb25602020-03-09 12:13:48 +00002008 IgnoreUnused(descriptor);
Keith Davisa57eccb2019-06-14 17:33:22 +01002009 bool supported = true;
2010
Sadik Armagan303980c2020-04-17 12:45:14 +01002011 std::array<DataType,6> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01002012 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002013 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01002014 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002015 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002016 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002017 DataType::QAsymmU8,
2018 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01002019 };
2020
2021 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2022 "Reference SpaceToDepth: input type not supported");
2023
2024 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2025 "Reference SpaceToDepth: output type not supported");
2026
2027 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2028 "Reference SpaceToDepth: input and output types are mismatched");
2029
2030 return supported;
2031}
2032
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002033bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
2034 const ViewsDescriptor& descriptor,
2035 Optional<std::string&> reasonIfUnsupported) const
2036{
Jan Eilers8eb25602020-03-09 12:13:48 +00002037 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002038 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002039 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002040 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002041 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002042 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002043 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002044 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002045 DataType::QAsymmU8,
2046 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002047 };
2048
2049 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2050 "Reference splitter: input type not supported");
2051
2052 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002053}
2054
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002055bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
2056 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
2057 const ViewsDescriptor& descriptor,
2058 Optional<std::string&> reasonIfUnsupported) const
2059{
Jan Eilers8eb25602020-03-09 12:13:48 +00002060 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002061 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002062 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002063 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002064 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002065 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002066 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002067 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002068 DataType::QAsymmU8,
2069 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002070 };
2071
2072 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2073 "Reference splitter: output type not supported");
Derek Lambertieac4adb2020-08-25 13:05:59 +01002074 for (const TensorInfo& output : outputs)
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002075 {
2076 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2077 "Reference splitter: input type not supported");
2078
2079 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2080 "Reference splitter: input and output types mismatched.");
2081 }
2082
2083 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002084}
2085
Matthew Jackson81e601c2019-07-11 12:07:09 +01002086bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
2087 const TensorInfo& output,
2088 const StackDescriptor& descriptor,
2089 Optional<std::string&> reasonIfUnsupported) const
2090{
Jan Eilers8eb25602020-03-09 12:13:48 +00002091 IgnoreUnused(descriptor);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002092
2093 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002094 std::array<DataType,6> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01002095 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002096 DataType::BFloat16,
Matthew Jackson81e601c2019-07-11 12:07:09 +01002097 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01002098 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002099 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002100 DataType::QAsymmU8,
2101 DataType::QSymmS16
Matthew Jackson81e601c2019-07-11 12:07:09 +01002102 };
2103
2104 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2105 "Reference stack: output type not supported");
2106 for (const TensorInfo* input : inputs)
2107 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01002108 ARMNN_ASSERT(input != nullptr);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002109 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
2110 "Reference stack: input type not supported");
2111
2112 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
2113 "Reference stack: input and output types mismatched.");
2114 }
2115
2116 return supported;
2117}
2118
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002119bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
2120 const TensorInfo& output,
2121 const StridedSliceDescriptor& descriptor,
2122 Optional<std::string&> reasonIfUnsupported) const
2123{
Jan Eilers8eb25602020-03-09 12:13:48 +00002124 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002125 bool supported = true;
2126
Sadik Armagan303980c2020-04-17 12:45:14 +01002127 std::array<DataType,5> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002128 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002129 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002130 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002131 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002132 DataType::QAsymmU8,
2133 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002134 };
2135
2136 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2137 "Reference StridedSlice: input type not supported");
2138
2139 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2140 "Reference StridedSlice: output type not supported");
2141
2142 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2143 "Reference StridedSlice: input and output types are mismatched");
2144
2145 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002146}
2147
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002148bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
2149 const TensorInfo& input1,
2150 const TensorInfo& output,
2151 Optional<std::string&> reasonIfUnsupported) const
2152{
Sadik Armagan2999a022019-04-09 14:20:12 +01002153 bool supported = true;
2154
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002155 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002156 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01002157 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002158 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002159 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002160 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002161 DataType::QSymmS16,
2162 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002163 };
2164
2165 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2166 "Reference subtraction: input 0 is not a supported type.");
2167
2168 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2169 "Reference subtraction: input 1 is not a supported type.");
2170
2171 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2172 "Reference subtraction: output is not a supported type.");
2173
2174 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2175 "Reference subtraction: input 0 and Input 1 types are mismatched");
2176
2177 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2178 "Reference subtraction: input and output types are mismatched");
2179
2180 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2181 "Reference subtraction: shapes are not suitable for implicit broadcast.");
2182
2183 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002184}
2185
Matteo Martincighab9e5252019-06-13 17:27:46 +01002186bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
2187 const TensorInfo& alpha,
2188 const TensorInfo& output,
2189 Optional<std::string&> reasonIfUnsupported) const
2190{
2191 bool supported = true;
2192
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002193 std::array<DataType, 6> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01002194 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002195 DataType::BFloat16,
Matteo Martincighab9e5252019-06-13 17:27:46 +01002196 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002197 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002198 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002199 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002200 DataType::QSymmS16
Matteo Martincighab9e5252019-06-13 17:27:46 +01002201 };
2202
2203 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2204 "PReLU: input is not a supported type.");
2205
2206 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
2207 "PReLU: alpha is not a supported type.");
2208
2209 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2210 "PReLU: output is not a supported type.");
2211
2212 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
2213 "PReLU: input, alpha and output types are mismatched");
2214
2215 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
2216 "PReLU: shapes are not suitable for implicit broadcast");
2217
2218 return supported;
2219}
2220
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002221bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
2222 const TensorInfo& output,
2223 const TransposeConvolution2dDescriptor& descriptor,
2224 const TensorInfo& weights,
2225 const Optional<TensorInfo>& biases,
2226 Optional<std::string&> reasonIfUnsupported) const
2227{
Jan Eilers8eb25602020-03-09 12:13:48 +00002228 IgnoreUnused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002229 bool supported = true;
2230
Sadik Armagan303980c2020-04-17 12:45:14 +01002231 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002232 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002233 DataType::BFloat16,
2234 DataType::Float32,
2235 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002236 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002237 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002238 DataType::QSymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002239 DataType::QSymmS16
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002240 };
2241
2242 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2243 "Reference TransposeConvolution2d: input is not a supported type.");
2244
2245 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2246 "Reference TransposeConvolution2d: output is not a supported type.");
2247
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002248 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2249 "Reference TransposeConvolution2d: input and output types mismatched.");
2250
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002251
2252 const DataType inputType = input.GetDataType();
Sadik Armagan303980c2020-04-17 12:45:14 +01002253 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002254 {
Derek Lambertid466a542020-01-22 15:37:29 +00002255 ARMNN_NO_DEPRECATE_WARN_BEGIN
Sadik Armagan303980c2020-04-17 12:45:14 +01002256 std::array<DataType, 4> supportedWeightTypes =
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002257 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002258 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002259 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +00002260 DataType::QSymmS8,
2261 DataType::QuantizedSymm8PerAxis //Deprecated
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002262 };
Derek Lambertid466a542020-01-22 15:37:29 +00002263 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002264
2265 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
2266 "Reference TransposeConvolution2d: weights type not supported for "
2267 "quantized input.");
2268 }
2269 else
2270 {
2271 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
2272 "Reference TransposeConvolution2d: weights is not a supported type.");
2273
2274 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
2275 "Reference TransposeConvolution2d: input and weights types mismatched.");
2276 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002277
2278 if (biases.has_value())
2279 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002280 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01002281 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002282 DataType::BFloat16,
2283 DataType::Float32,
2284 DataType::Float16,
2285 DataType::Signed32
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002286 };
2287 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
2288 "Reference TransposeConvolution2d: biases is not a supported type.");
2289 }
2290
2291 return supported;
2292}
2293
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002294bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
2295 const TensorInfo& output,
2296 const TransposeDescriptor& descriptor,
2297 Optional<std::string&> reasonIfUnsupported) const
2298{
Jan Eilers8eb25602020-03-09 12:13:48 +00002299 IgnoreUnused(descriptor);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002300 bool supported = true;
2301
2302 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002303 std::array<DataType, 6> supportedTypes =
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002304 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002305 DataType::BFloat16,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002306 DataType::Float32,
2307 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002308 DataType::QAsymmS8,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002309 DataType::QAsymmU8,
2310 DataType::QSymmS16
2311 };
2312
2313 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2314 "Reference transpose: input is not a supported type.");
2315
2316 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2317 "Reference transpose: output is not a supported type.");
2318
2319 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2320 "Reference transpose: input and output types are mismatched.");
2321
2322 return supported;
2323}
2324
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002325bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
2326 const TensorInfo& input,
2327 const TensorInfo& outputStateIn,
2328 const TensorInfo& cellStateIn,
2329 const TensorInfo& output,
2330 const Optional<TensorInfo>& hiddenStateOutput,
2331 const Optional<TensorInfo>& cellStateOutput,
2332 const UnidirectionalSequenceLstmDescriptor& descriptor,
2333 const LstmInputParamsInfo& paramsInfo,
2334 Optional<std::string&> reasonIfUnsupported) const
2335{
2336 IgnoreUnused(descriptor);
2337 IgnoreUnused(paramsInfo);
2338 IgnoreUnused(outputStateIn);
2339 IgnoreUnused(cellStateIn);
2340 bool supported = true;
2341
2342 if (hiddenStateOutput.has_value() || cellStateOutput.has_value())
2343 {
2344 reasonIfUnsupported.value() += "Reference UnidirectionalSequenceLstm: hidden state output "
2345 "and cell state output are not supported at the moment.";
2346 }
2347
2348 std::array<DataType, 1> supportedTypes =
2349 {
2350 DataType::Float32
2351 };
2352
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002353 std::array<DataType, 2> supportedWeightTypes =
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002354 {
Narumol Prangnawaratbd575b22021-08-31 16:53:54 +01002355 DataType::Float32,
2356 DataType::QAsymmS8
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002357 };
2358
2359 // check inputs and outputs
2360 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2361 "Reference UnidirectionalSequenceLstm: input is not a supported type.");
2362 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
2363 "Reference UnidirectionalSequenceLstm: input and outputStateIn types are mismatched");
2364 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
2365 "Reference UnidirectionalSequenceLstm: input and cellStateIn types are mismatched");
2366
2367 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2368 "Reference UnidirectionalSequenceLstm: input and output types are mismatched");
2369 // check layer parameters
2370 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes),
2371 reasonIfUnsupported,
2372 "Reference UnidirectionalSequenceLstm: InputToForgetWeights "
2373 "is not a supported type.");
2374 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToCellWeights(), supportedWeightTypes),
2375 reasonIfUnsupported,
2376 "Reference UnidirectionalSequenceLstm: InputToCellWeights is not a supported type.");
2377 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToOutputWeights(), supportedWeightTypes),
2378 reasonIfUnsupported,
2379 "Reference UnidirectionalSequenceLstm: InputToOutputWeights "
2380 "is not a supported type.");
2381 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToForgetWeights(), supportedWeightTypes),
2382 reasonIfUnsupported,
2383 "Reference UnidirectionalSequenceLstm: RecurrentToForgetWeights "
2384 "is not a supported type.");
2385 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToCellWeights(), supportedWeightTypes),
2386 reasonIfUnsupported,
2387 "Reference UnidirectionalSequenceLstm: RecurrentToCellWeights "
2388 "is not a supported type.");
2389 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToOutputWeights(), supportedWeightTypes),
2390 reasonIfUnsupported,
2391 "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights "
2392 "is not a supported type.");
2393 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
2394 "Reference UnidirectionalSequenceLstm: input and ForgetGateBias types "
2395 "are mismatched");
2396 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
2397 "Reference UnidirectionalSequenceLstm: input and CellBias types are mismatched");
2398 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
2399 "Reference UnidirectionalSequenceLstm: input and OutputGateBias types "
2400 "are mismatched");
2401 if (!descriptor.m_CifgEnabled)
2402 {
2403 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes),
2404 reasonIfUnsupported,
2405 "Reference UnidirectionalSequenceLstm: InputToInputWeights "
2406 "is not a supported type.");
2407 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToInputWeights(), supportedWeightTypes),
2408 reasonIfUnsupported,
2409 "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights "
2410 "is not a supported type.");
2411 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
2412 "Reference UnidirectionalSequenceLstm: input and InputGateBias types "
2413 "are mismatched");
2414 if (descriptor.m_PeepholeEnabled)
2415 {
2416 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes),
2417 reasonIfUnsupported,
2418 "Reference UnidirectionalSequenceLstm: CellToInputWeights "
2419 "is not a supported type.");
2420 }
2421 }
2422 if (descriptor.m_PeepholeEnabled)
2423 {
2424 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToForgetWeights(), supportedWeightTypes),
2425 reasonIfUnsupported,
2426 "Reference UnidirectionalSequenceLstm: CellToForgetWeights "
2427 "is not a supported type.");
2428 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToOutputWeights(), supportedWeightTypes),
2429 reasonIfUnsupported,
2430 "Reference UnidirectionalSequenceLstm: CellToOutputWeights "
2431 "is not a supported type.");
2432 }
2433 if (descriptor.m_ProjectionEnabled)
2434 {
2435 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetProjectionWeights(), supportedWeightTypes),
2436 reasonIfUnsupported,
2437 "Reference UnidirectionalSequenceLstm: ProjectionWeights "
2438 "is not a supported type.");
2439 if (paramsInfo.m_ProjectionBias != nullptr)
2440 {
2441 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
2442 "Reference UnidirectionalSequenceLstm: input and ProjectionBias types "
2443 "are mismatched");
2444 }
2445 }
2446 if (descriptor.m_LayerNormEnabled)
2447 {
2448 if (!descriptor.m_CifgEnabled)
2449 {
2450 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputLayerNormWeights(), supportedWeightTypes),
2451 reasonIfUnsupported,
2452 "Reference UnidirectionalSequenceLstm: InputLayerNormWeights "
2453 "is not a supported type.");
2454 }
2455 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetLayerNormWeights(), supportedWeightTypes),
2456 reasonIfUnsupported,
2457 "Reference UnidirectionalSequenceLstm: ForgetLayerNormWeights "
2458 "is not a supported type.");
2459 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellLayerNormWeights(), supportedWeightTypes),
2460 reasonIfUnsupported,
2461 "Reference UnidirectionalSequenceLstm: CellLayerNormWeights "
2462 "is not a supported type.");
2463 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputLayerNormWeights(), supportedWeightTypes),
2464 reasonIfUnsupported,
2465 "Reference UnidirectionalSequenceLstm: OutputLayerNormWeights "
2466 "is not a supported type.");
2467 }
2468
2469 return supported;
2470}
2471
arovir011c7c81b2018-10-08 11:34:28 +01002472} // namespace armnn