blob: 26033719272392000078a986dcd264bd152311e7 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
Teresa Charlin52664732020-06-29 16:27:03 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01007
Keith Davis0c2eeac2020-02-11 16:51:50 +00008#include <armnn/TypesUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000010#include <armnn/Descriptors.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000011#include <armnn/utility/IgnoreUnused.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010012#include <armnn/utility/NumericCast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
Matteo Martincighe011d202019-11-28 11:35:47 +000014#include <LayerSupportCommon.hpp>
Derek Lambertif674aa02019-08-01 15:56:25 +010015#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000016
Derek Lamberti50db4e82019-03-13 14:16:15 +000017#include <vector>
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <array>
19
telsoa014fcda012018-03-09 14:13:49 +000020namespace armnn
21{
22
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010023namespace
24{
25
26template<typename Float32Func, typename Uint8Func, typename ... Params>
27bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
28 DataType dataType,
29 Float32Func floatFuncPtr,
30 Uint8Func uint8FuncPtr,
31 Params&&... params)
32{
33 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
34 dataType,
35 &FalseFunc<Params...>,
36 floatFuncPtr,
37 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000038 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000039 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010040 std::forward<Params>(params)...);
41}
42
43} // anonymous namespace
44
James Conroy4d1ff582019-06-10 17:06:39 +010045namespace
46{
47
48std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
49 unsigned int actual,
50 std::string& layerStr,
51 std::string& tensorName)
52{
53 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
54 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
55
56 return errorMsg;
57}
58
59} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000060
Sadik Armagan9199e582019-09-05 17:35:31 +010061bool RefLayerSupport::IsAbsSupported(const TensorInfo& input, const TensorInfo& output,
62 Optional<std::string&> reasonIfUnsupported) const
63{
josh minor4a3c6102020-01-06 16:40:46 -060064 return IsElementwiseUnarySupported(input,
65 output,
66 ElementwiseUnaryDescriptor(UnaryOperation::Abs),
67 reasonIfUnsupported);
Sadik Armagan9199e582019-09-05 17:35:31 +010068}
69
arovir011c7c81b2018-10-08 11:34:28 +010070bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
71 const TensorInfo& output,
72 const ActivationDescriptor& descriptor,
73 Optional<std::string&> reasonIfUnsupported) const
74{
Derek Lamberti50db4e82019-03-13 14:16:15 +000075 bool supported = true;
76
77 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +000078 std::array<DataType,6> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +000079 DataType::BFloat16,
Derek Lamberti50db4e82019-03-13 14:16:15 +000080 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +010081 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +000082 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +000083 DataType::QAsymmU8,
84 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +000085 };
86
87 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
88 "Reference activation: input type not supported.");
89
90 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
91 "Reference activation: output type not supported.");
92
93 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
94 "Reference activation: input and output types mismatched.");
95
96 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
97 "Reference activation: input and output shapes are of different rank.");
98
99
100 struct ActivationFunctionSupported : public Rule
101 {
102 ActivationFunctionSupported(const ActivationDescriptor& desc)
103 {
104 switch(desc.m_Function)
105 {
106 case ActivationFunction::Abs:
107 case ActivationFunction::BoundedReLu:
David Monahan3b3c3812020-02-25 09:03:29 +0000108 case ActivationFunction::Elu:
Colm Donelan03fbeaf2020-02-26 15:39:23 +0000109 case ActivationFunction::HardSwish:
Derek Lamberti50db4e82019-03-13 14:16:15 +0000110 case ActivationFunction::LeakyReLu:
111 case ActivationFunction::Linear:
112 case ActivationFunction::ReLu:
113 case ActivationFunction::Sigmoid:
114 case ActivationFunction::SoftReLu:
115 case ActivationFunction::Sqrt:
116 case ActivationFunction::Square:
117 case ActivationFunction::TanH:
118 {
119 m_Res = true;
120 break;
121 }
122 default:
123 {
124 m_Res = false;
125 break;
126 }
127 }
128 }
129 };
130
131 // Function is supported
132 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
133 "Reference activation: function not supported.");
134
135 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100136}
137
138bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
139 const TensorInfo& input1,
140 const TensorInfo& output,
141 Optional<std::string&> reasonIfUnsupported) const
142{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000143 bool supported = true;
144
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100145 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000146 DataType::BFloat16,
Derek Lamberti50db4e82019-03-13 14:16:15 +0000147 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100148 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000149 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000150 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100151 DataType::QSymmS16,
152 DataType::Signed32
Derek Lamberti50db4e82019-03-13 14:16:15 +0000153 };
154
155 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
156 "Reference addition: input 0 is not a supported type.");
157
158 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
159 "Reference addition: input 1 is not a supported type.");
160
161 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
162 "Reference addition: output is not a supported type.");
163
164 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
165 "Reference addition: input 0 and Input 1 types are mismatched");
166
167 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
168 "Reference addition: input and output types are mismatched");
169
170 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
171 "Reference addition: shapes are not suitable for implicit broadcast.");
172
173 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100174}
175
Nikhil Raj68c2c902019-09-19 11:21:11 +0100176bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
177 const armnn::ArgMinMaxDescriptor &descriptor,
178 armnn::Optional<std::string &> reasonIfUnsupported) const
179{
Jan Eilers8eb25602020-03-09 12:13:48 +0000180 IgnoreUnused(descriptor);
Nikhil Raj68c2c902019-09-19 11:21:11 +0100181
Mike Kelly1f140f72021-04-06 12:25:55 +0100182 std::array<DataType, 8> supportedInputTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100183 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000184 DataType::BFloat16,
Teresa Charline300b362020-05-25 10:01:03 +0100185 DataType::Float16,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100186 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100187 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000188 DataType::QAsymmU8,
189 DataType::QSymmS16,
Mike Kelly1f140f72021-04-06 12:25:55 +0100190 DataType::Signed32,
191 DataType::Signed64
192 };
193
194 std::array<DataType,2> supportedOutputTypes = {
195 DataType::Signed32,
196 DataType::Signed64
Nikhil Raj68c2c902019-09-19 11:21:11 +0100197 };
198
199 bool supported = true;
200
Mike Kelly1f140f72021-04-06 12:25:55 +0100201 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100202 "Reference ArgMinMax: input is not a supported type.");
Mike Kelly1f140f72021-04-06 12:25:55 +0100203 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100204 "Reference ArgMinMax: output type not supported");
205
206 return supported;
207}
208
arovir011c7c81b2018-10-08 11:34:28 +0100209bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
210 const TensorInfo& output,
211 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100212 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100213 const TensorInfo& beta,
214 const TensorInfo& gamma,
215 const BatchNormalizationDescriptor& descriptor,
216 Optional<std::string&> reasonIfUnsupported) const
217{
Jan Eilers8eb25602020-03-09 12:13:48 +0000218 IgnoreUnused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100219
Sadik Armagan303980c2020-04-17 12:45:14 +0100220 std::array<DataType, 6> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100221 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000222 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100223 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100224 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100225 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000226 DataType::QAsymmU8,
227 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100228 };
229
230 bool supported = true;
231
232 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
233 "Reference batch normalization: input is not a supported type.");
234
235 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
236 "Reference batch normalization: output is not a supported type.");
237
238 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
239 "Reference batch normalization: input and output types are mismatched");
240
241 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
242 "Reference batch normalization: mean is not a supported type.");
243
244 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
245 "Reference batch normalization: variance is not a supported type.");
246
247 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
248 "Reference batch normalization: beta is not a supported type.");
249
250 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
251 "Reference batch normalization: gamma is not a supported type.");
252
253 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100254}
255
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000256bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
257 const TensorInfo& output,
258 const BatchToSpaceNdDescriptor& descriptor,
259 Optional<std::string&> reasonIfUnsupported) const
260{
Jan Eilers8eb25602020-03-09 12:13:48 +0000261 IgnoreUnused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100262
263 bool supported = true;
264
265 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
266 std::string inputTensorStr = "input";
267 std::string outputTensorStr = "output";
268
269 // Define supported types.
Sadik Armagan303980c2020-04-17 12:45:14 +0100270 std::array<DataType,6> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100271 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000272 DataType::BFloat16,
273 DataType::Float32,
274 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100275 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000276 DataType::QAsymmU8,
277 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100278 };
279
280 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
281 "Reference BatchToSpaceNd: input type not supported.");
282
283 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
284 "Reference BatchToSpaceNd: output type not supported.");
285
286 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
287 "Reference BatchToSpaceNd: input and output types mismatched.");
288
289 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
290 reasonIfUnsupported,
291 CreateIncorrectDimensionsErrorMsg(4,
292 output.GetNumDimensions(),
293 batchToSpaceNdLayerStr,
294 outputTensorStr).data());
295
296 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
297 reasonIfUnsupported,
298 CreateIncorrectDimensionsErrorMsg(4,
299 input.GetNumDimensions(),
300 batchToSpaceNdLayerStr,
301 inputTensorStr).data());
302
303 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000304}
305
mathad01b392e982021-04-07 12:07:30 +0100306bool RefLayerSupport::IsCastSupported(const TensorInfo& input,
307 const TensorInfo& output,
308 Optional<std::string&> reasonIfUnsupported) const
309{
310 std::array<DataType, 9> supportedInputTypes =
311 {
312 DataType::BFloat16,
313 DataType::Float32,
314 DataType::Float16,
315 DataType::QSymmS8,
316 DataType::QAsymmS8,
317 DataType::QAsymmU8,
318 DataType::QSymmS16,
319 DataType::Signed32
320 };
321
322 bool supported = true;
323 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
324 "Reference cast: input is not a supported type");
325
326
327 supported &= CheckSupportRule(TypeAnyOf(output, supportedInputTypes), reasonIfUnsupported,
328 "Reference cast: output is not a supported type");
329
330 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
331 "Reference cast: input and output shapes have different number of total elements");
332
333 return supported;
334}
335
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100336bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
337 const TensorInfo& input1,
338 const TensorInfo& output,
339 const ComparisonDescriptor& descriptor,
340 Optional<std::string&> reasonIfUnsupported) const
341{
Jan Eilers8eb25602020-03-09 12:13:48 +0000342 IgnoreUnused(descriptor);
Sadik Armagan303980c2020-04-17 12:45:14 +0100343 std::array<DataType, 8> supportedInputTypes =
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100344 {
Sadik Armaganb60dd242020-03-19 13:53:16 +0000345 DataType::Boolean,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000346 DataType::BFloat16,
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100347 DataType::Float32,
348 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100349 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000350 DataType::QAsymmU8,
Sadik Armaganb60dd242020-03-19 13:53:16 +0000351 DataType::QSymmS16,
352 DataType::Signed32
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100353 };
354
355 bool supported = true;
356 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
357 "Reference comparison: input 0 is not a supported type");
358
359 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
360 "Reference comparison: input 0 and Input 1 types are mismatched");
361
362 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
363 "Reference comparison: output is not of type Boolean");
364
365 return supported;
366}
367
Jim Flynn906f9462019-05-10 13:55:21 +0100368bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
369 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100370 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100371 Optional<std::string&> reasonIfUnsupported) const
372{
Jan Eilers8eb25602020-03-09 12:13:48 +0000373 IgnoreUnused(descriptor);
Jim Flynne242f2d2019-05-22 14:24:13 +0100374
375 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000376 std::array<DataType,6> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100377 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000378 DataType::BFloat16,
379 DataType::Float32,
380 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000381 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100382 DataType::QAsymmU8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000383 DataType::QSymmS16
Jim Flynne242f2d2019-05-22 14:24:13 +0100384 };
385
386 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
387 "Reference concatenation: output type not supported");
388 for (const TensorInfo* input : inputs)
389 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100390 ARMNN_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100391 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
392 "Reference concatenation: input type not supported");
393
394 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
395 "Reference concatenation: input and output types mismatched.");
396 }
397
398 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100399}
400
arovir011c7c81b2018-10-08 11:34:28 +0100401bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
402 Optional<std::string&> reasonIfUnsupported) const
403{
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100404 std::array<DataType,8> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100405 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000406 DataType::BFloat16,
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100407 DataType::Float16,
Nina Drozd58ef2c62019-05-16 12:09:18 +0100408 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +0000409 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100410 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000411 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100412 DataType::QSymmS16,
413 DataType::Signed32
Nina Drozd58ef2c62019-05-16 12:09:18 +0100414 };
415
416 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
417 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100418}
419
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000420bool RefLayerSupport::IsConvertBf16ToFp32Supported(const TensorInfo& input,
421 const TensorInfo& output,
422 Optional<std::string&> reasonIfUnsupported) const
423{
424 bool supported = true;
425
426 supported &= CheckSupportRule(TypeIs(input, DataType::BFloat16), reasonIfUnsupported,
427 "Reference for ConvertBf16ToFp32 layer: input type not supported");
428
429 supported &= CheckSupportRule(TypeIs(output, DataType::Float32), reasonIfUnsupported,
430 "Reference for ConvertBf16ToFp32 layer: output type not supported");
431
432 return supported;
433}
434
arovir011c7c81b2018-10-08 11:34:28 +0100435bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
436 const TensorInfo& output,
437 Optional<std::string&> reasonIfUnsupported) const
438{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100439 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
440 input.GetDataType(),
441 &TrueFunc<>,
442 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000443 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000444 &FalseFuncI32<>,
445 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100446 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
447 output.GetDataType(),
448 &FalseOutputFuncF16<>,
449 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000450 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000451 &FalseFuncI32<>,
452 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100453}
454
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000455bool RefLayerSupport::IsConvertFp32ToBf16Supported(const TensorInfo& input,
456 const TensorInfo& output,
457 Optional<std::string&> reasonIfUnsupported) const
458{
459 bool supported = true;
460
461 supported &= CheckSupportRule(TypeIs(input, DataType::Float32), reasonIfUnsupported,
462 "Reference for ConvertFp32ToBf16 layer: input type not supported");
463
464 supported &= CheckSupportRule(TypeIs(output, DataType::BFloat16), reasonIfUnsupported,
465 "Reference for ConvertFp32ToBf16 layer: output type not supported");
466
467 return supported;
468}
469
arovir011c7c81b2018-10-08 11:34:28 +0100470bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
471 const TensorInfo& output,
472 Optional<std::string&> reasonIfUnsupported) const
473{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100474 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
475 input.GetDataType(),
476 &FalseInputFuncF16<>,
477 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000478 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000479 &FalseFuncI32<>,
480 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100481 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
482 output.GetDataType(),
483 &TrueFunc<>,
484 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000485 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000486 &FalseFuncI32<>,
487 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100488}
489
490bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
491 const TensorInfo& output,
492 const Convolution2dDescriptor& descriptor,
493 const TensorInfo& weights,
494 const Optional<TensorInfo>& biases,
495 Optional<std::string&> reasonIfUnsupported) const
496{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100497 bool supported = true;
498
499 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000500 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000501 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000502 DataType::BFloat16,
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000503 DataType::Float32,
504 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000505 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100506 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000507 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000508 DataType::QSymmS16
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100509 };
510
511 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000512 "Reference Convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100513
514 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000515 "Reference Convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100516
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +0000517 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
518 if (input.GetDataType() == DataType::BFloat16)
519 {
520 if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
521 {
522 reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
523 supported = false;
524 }
525 }
526 else
527 {
528 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000529 "Reference Convolution2d: input and output types mismatched.");
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +0000530 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100531
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000532 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000533 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000534 {
Derek Lambertid466a542020-01-22 15:37:29 +0000535 ARMNN_NO_DEPRECATE_WARN_BEGIN
Keith Davis0c2eeac2020-02-11 16:51:50 +0000536 std::array<DataType, 4> supportedWeightTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000537 {
Sadik Armagan303980c2020-04-17 12:45:14 +0100538 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000539 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +0000540 DataType::QSymmS8,
541 DataType::QuantizedSymm8PerAxis // deprecated
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000542 };
Derek Lambertid466a542020-01-22 15:37:29 +0000543 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000544
545 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000546 "Reference Convolution2d: weights type not supported for quantized input.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000547 }
548 else
549 {
550 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000551 "Reference Convolution2d: weights is not a supported type.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000552
553 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000554 "Reference Convolution2d: input and weights types mismatched.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000555 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100556
557 if (biases.has_value())
558 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000559 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000560 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000561 DataType::BFloat16,
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000562 DataType::Float32,
563 DataType::Float16,
564 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100565 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000566
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100567 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000568 "Reference Convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100569 }
Jan Eilers8eb25602020-03-09 12:13:48 +0000570 IgnoreUnused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100571
572 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100573}
574
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000575bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
576 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000577 Optional<std::string&> reasonIfUnsupported) const
578{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100579 bool supported = true;
580
Narumol Prangnawarat403a1852020-03-12 14:24:13 +0000581 std::array<DataType, 8> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100582 {
Narumol Prangnawarat403a1852020-03-12 14:24:13 +0000583 DataType::BFloat16,
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +0000584 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100585 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000586 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100587 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000588 DataType::QSymmS8,
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +0000589 DataType::QSymmS16,
590 DataType::Signed32
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100591 };
592
593 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000594 "Reference for Debug layer: input type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100595
596 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000597 "Reference for Debug layer: output type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100598
599 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000600 "Reference for Debug layer: input and output types are mismatched");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100601
602 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000603}
604
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100605bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
606 const TensorInfo& output,
607 const DepthToSpaceDescriptor& descriptor,
608 Optional<std::string&> reasonIfUnsupported) const
609{
Jan Eilers8eb25602020-03-09 12:13:48 +0000610 IgnoreUnused(descriptor);
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100611 bool supported = true;
612
Sadik Armagan303980c2020-04-17 12:45:14 +0100613 std::array<DataType,6> supportedTypes =
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100614 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000615 DataType::BFloat16,
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100616 DataType::Float32,
617 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100618 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000619 DataType::QAsymmU8,
620 DataType::QSymmS16
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100621 };
622
623 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
624 "Reference DepthToSpace: input type not supported");
625
626 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
627 "Reference DepthToSpace: output type not supported");
628
629 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
630 "Reference DepthToSpace: input and output types are mismatched");
631
632 return supported;
633}
634
arovir011c7c81b2018-10-08 11:34:28 +0100635bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
636 const TensorInfo& output,
637 const DepthwiseConvolution2dDescriptor& descriptor,
638 const TensorInfo& weights,
639 const Optional<TensorInfo>& biases,
640 Optional<std::string&> reasonIfUnsupported) const
641{
Sadik Armagan303980c2020-04-17 12:45:14 +0100642 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100643 bool supported = true;
644
645 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000646 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100647 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000648 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100649 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100650 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000651 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000652 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100653 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000654 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100655 };
656
657 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
658 "Reference DepthwiseConvolution2d: input is not a supported type.");
659
660 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
661 "Reference DepthwiseConvolution2d: output is not a supported type.");
662
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100663 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
664 "Reference DepthwiseConvolution2d: input and output types mismatched.");
665
Teresa Charlind8df0262019-11-11 12:28:15 +0000666 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000667 if (IsQuantized8BitType(inputType))
Teresa Charlind8df0262019-11-11 12:28:15 +0000668 {
Sadik Armagan303980c2020-04-17 12:45:14 +0100669 ARMNN_NO_DEPRECATE_WARN_BEGIN
670 std::array<DataType, 4> supportedWeightTypes =
671 {
672 DataType::QAsymmS8,
673 DataType::QAsymmU8,
674 DataType::QSymmS8,
675 DataType::QuantizedSymm8PerAxis // deprecated
676 };
677 ARMNN_NO_DEPRECATE_WARN_END
Teresa Charlind8df0262019-11-11 12:28:15 +0000678
679 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Sadik Armagan303980c2020-04-17 12:45:14 +0100680 "Reference DepthwiseConvolution2d: weights type not supported for "
681 "quantized input.");
Teresa Charlind8df0262019-11-11 12:28:15 +0000682 }
683 else
684 {
685 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
686 "Reference DepthwiseConvolution2d: weights is not a supported type.");
687
688 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
689 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
690 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100691
692 if (biases.has_value())
693 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000694 std::array<DataType,4> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100695 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000696 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100697 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100698 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100699 DataType::Signed32
700 };
701 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
702 "Reference DepthwiseConvolution2d: biases is not a supported type.");
703 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100704
705 return supported;
706
arovir011c7c81b2018-10-08 11:34:28 +0100707}
708
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000709bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
710 const TensorInfo& output,
711 Optional<std::string&> reasonIfUnsupported) const
712{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100713 bool supported = true;
714
Ryan OShea9add1202020-02-07 10:06:33 +0000715 std::array<DataType,4> supportedInputTypes = {
716 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000717 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +0000718 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000719 DataType::QSymmS16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100720 };
721
722 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000723 "Reference for Dequantize layer: input type not supported.");
724
Derek Lambertid466a542020-01-22 15:37:29 +0000725 supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
Teresa Charlin1b1950d2021-06-02 20:23:21 +0100726 "Reference for Dequantize layer: per-axis quantized input not supported.");
Derek Lambertid466a542020-01-22 15:37:29 +0000727
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000728 std::array<DataType,3> supportedOutputTypes = {
729 DataType::BFloat16,
Jan Eilersf7107932019-11-01 11:09:36 +0000730 DataType::Float32,
731 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100732 };
733
734 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000735 "Reference for Dequantize layer: output type not supported.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100736
737 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000738 "Reference for Dequantize layer: input/output shapes have different num total "
739 "elements.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100740
741 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000742}
743
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000744bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
745 const TensorInfo& scores,
746 const TensorInfo& anchors,
747 const TensorInfo& detectionBoxes,
748 const TensorInfo& detectionClasses,
749 const TensorInfo& detectionScores,
750 const TensorInfo& numDetections,
751 const DetectionPostProcessDescriptor& descriptor,
752 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000753{
Jan Eilers8eb25602020-03-09 12:13:48 +0000754 IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
Derek Lamberti901ea112019-12-10 22:07:09 +0000755
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100756 bool supported = true;
757
Sadik Armaganaa41d5d2020-11-16 14:27:52 +0000758 std::array<DataType,6> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100759 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000760 DataType::BFloat16,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100761 DataType::Float32,
Sadik Armaganaa41d5d2020-11-16 14:27:52 +0000762 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100763 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000764 DataType::QAsymmU8,
765 DataType::QSymmS16
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100766 };
767
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000768 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100769 "Reference DetectionPostProcess: input 0 is not a supported type.");
770
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000771 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100772 "Reference DetectionPostProcess: input 1 is not a supported type.");
773
774 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000775}
776
Pablo Tellof0bd6832019-04-26 17:58:13 +0100777bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
778 const TensorInfo& output,
779 const DepthwiseConvolution2dDescriptor& descriptor,
780 const TensorInfo& weights,
781 const Optional<TensorInfo>& biases,
782 Optional<std::string&> reasonIfUnsupported) const
783{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100784 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100785}
786
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100787bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100788 const TensorInfo& input1,
789 const TensorInfo& output,
790 Optional<std::string&> reasonIfUnsupported) const
791{
Sadik Armagan2999a022019-04-09 14:20:12 +0100792 bool supported = true;
793
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100794 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000795 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +0100796 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100797 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100798 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000799 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100800 DataType::QSymmS16,
801 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +0100802 };
803
804 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
805 "Reference division: input 0 is not a supported type.");
806
807 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
808 "Reference division: input 1 is not a supported type.");
809
810 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
811 "Reference division: output is not a supported type.");
812
813 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
814 "Reference division: input 0 and Input 1 types are mismatched");
815
816 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
817 "Reference division: input and output types are mismatched");
818
819 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
820 "Reference division: shapes are not suitable for implicit broadcast.");
821
822 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100823}
824
josh minor4a3c6102020-01-06 16:40:46 -0600825bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
826 const TensorInfo& output,
827 const ElementwiseUnaryDescriptor& descriptor,
828 Optional<std::string&> reasonIfUnsupported) const
829{
Jan Eilers8eb25602020-03-09 12:13:48 +0000830 IgnoreUnused(descriptor);
josh minor4a3c6102020-01-06 16:40:46 -0600831
Sadik Armagan303980c2020-04-17 12:45:14 +0100832 std::array<DataType, 7> supportedTypes =
josh minor4a3c6102020-01-06 16:40:46 -0600833 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000834 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -0600835 DataType::Float32,
836 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100837 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -0600838 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +0000839 DataType::QSymmS16,
840 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -0600841 };
842
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +0000843 std::array<DataType, 1> logicalSupportedTypes =
844 {
845 DataType::Boolean
846 };
847
josh minor4a3c6102020-01-06 16:40:46 -0600848 bool supported = true;
849
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +0000850 if (descriptor.m_Operation == UnaryOperation::LogicalNot)
851 {
852 supported &= CheckSupportRule(TypeAnyOf(input, logicalSupportedTypes), reasonIfUnsupported,
853 "Reference elementwise unary: input type not supported");
josh minor4a3c6102020-01-06 16:40:46 -0600854
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +0000855 supported &= CheckSupportRule(TypeAnyOf(output, logicalSupportedTypes), reasonIfUnsupported,
856 "Reference elementwise unary: output type not supported");
857 }
858 else
859 {
860 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
861 "Reference elementwise unary: input type not supported");
862
863 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
864 "Reference elementwise unary: output type not supported");
865 }
josh minor4a3c6102020-01-06 16:40:46 -0600866
867 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
868 "Reference elementwise unary: input and output types not matching");
869
870 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
871 "Reference elementwise unary: input and output shapes"
872 "have different number of total elements");
873
874 return supported;
875}
876
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000877bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
878 const TensorInfo& input1,
879 const TensorInfo& output,
880 Optional<std::string&> reasonIfUnsupported) const
881{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100882 return IsComparisonSupported(input0,
883 input1,
884 output,
885 ComparisonDescriptor(ComparisonOperation::Equal),
886 reasonIfUnsupported);
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000887}
888
arovir011c7c81b2018-10-08 11:34:28 +0100889bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
890 const FakeQuantizationDescriptor& descriptor,
891 Optional<std::string&> reasonIfUnsupported) const
892{
Jan Eilers8eb25602020-03-09 12:13:48 +0000893 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100894 bool supported = true;
895
896 std::array<DataType,1> supportedTypes =
897 {
898 DataType::Float32
899 };
900
901 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
902 "Reference fake quantization: input type not supported.");
903
904 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100905}
906
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +0100907bool RefLayerSupport::IsFillSupported(const TensorInfo& input,
908 const TensorInfo& output,
909 const FillDescriptor& descriptor,
910 Optional<std::string&> reasonIfUnsupported) const
911{
912 IgnoreUnused(descriptor);
913 IgnoreUnused(output);
914
915 bool supported = true;
916
Sadik Armagana792a052020-06-23 16:22:23 +0100917 std::array<DataType,3> supportedTypes =
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +0100918 {
919 DataType::Float32,
Sadik Armagana792a052020-06-23 16:22:23 +0100920 DataType::Float16,
921 DataType::Signed32
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +0100922 };
923
Teresa Charlin4b10fef2020-07-29 09:36:41 +0100924 supported &= CheckSupportRule(TypeIs(input, DataType::Signed32), reasonIfUnsupported,
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +0100925 "Reference Fill: input type not supported.");
926
Teresa Charlin44088502020-07-27 11:27:19 +0100927 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
928 "Reference Fill: output type not supported.");
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +0100929 return supported;
930}
931
arovir011c7c81b2018-10-08 11:34:28 +0100932bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
933 const TensorInfo& output,
934 Optional<std::string&> reasonIfUnsupported) const
935{
Jan Eilers8eb25602020-03-09 12:13:48 +0000936 IgnoreUnused(output);
James Conroy83735b12019-05-30 16:36:59 +0100937 bool supported = true;
938
Francis Murtaghe8ac1332020-07-30 18:03:40 +0100939 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100940 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000941 DataType::BFloat16,
James Conroyb40d7102019-06-04 12:32:09 +0100942 DataType::Float32,
Francis Murtaghe8ac1332020-07-30 18:03:40 +0100943 DataType::Float16
James Conroy83735b12019-05-30 16:36:59 +0100944 };
945
946 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
947 "Reference Floor: input type not supported.");
948
949 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
950 "Reference Floor: output type not supported.");
951
952 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100953}
954
955bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
956 const TensorInfo& output,
957 const TensorInfo& weights,
958 const TensorInfo& biases,
959 const FullyConnectedDescriptor& descriptor,
960 Optional<std::string&> reasonIfUnsupported) const
961{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100962 bool supported = true;
963
964 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000965 std::array<DataType,6> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +0100966 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000967 DataType::BFloat16,
968 DataType::Float32,
969 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000970 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100971 DataType::QAsymmU8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000972 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +0100973 };
974
975 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
976 "Reference Fully Connected: input type not supported.");
977
978 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
979 "Reference Fully Connected: output type not supported.");
980
Francis Murtagh46c09d02019-05-28 08:15:28 +0100981 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
982 "Reference Fully Connected: weights type not supported.");
983
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +0000984 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
985 if (input.GetDataType() == DataType::BFloat16)
986 {
987 if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
988 {
989 reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
990 supported = false;
991 }
992 }
993 else
994 {
995 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
996 "Reference Fully Connected: input and output types mismatched.");
997 }
998
Jan Eilers1f45dc32020-06-15 11:43:03 +0100999 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1000 "Reference Fully Connected: weights is not a supported type.");
Francis Murtaghddb1d062020-03-10 13:51:45 +00001001
Jan Eilers1f45dc32020-06-15 11:43:03 +01001002 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1003 "Reference Fully Connected: input and weights types mismatched.");
Francis Murtagh46c09d02019-05-28 08:15:28 +01001004
1005 if (descriptor.m_BiasEnabled)
1006 {
1007 // Defined supported types for bias
Sadik Armagandb73c982020-04-01 17:35:30 +01001008 std::array<DataType, 5>
Francis Murtagh46c09d02019-05-28 08:15:28 +01001009 supportedBiasTypes =
1010 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001011 DataType::BFloat16,
Francis Murtagh46c09d02019-05-28 08:15:28 +01001012 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001013 DataType::Float16,
Sadik Armagandb73c982020-04-01 17:35:30 +01001014 DataType::Signed32,
1015 DataType::QAsymmS8
Francis Murtagh46c09d02019-05-28 08:15:28 +01001016 };
1017
1018 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
1019 "Reference Fully Connected: bias type not supported.");
1020
1021 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
1022 "Reference Fully Connected: bias and weight types mismatch.");
1023
1024 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
1025 "Reference Fully Connected: bias type inferred from weights is incompatible.");
1026
Narumol Prangnawarat366d7232020-04-29 12:58:17 +01001027 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(biases, 1U), reasonIfUnsupported,
1028 "Reference Fully Connected: bias must have 1 dimension.");
1029
Francis Murtagh46c09d02019-05-28 08:15:28 +01001030 }
1031
1032 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001033}
1034
narpra014951d842019-01-18 16:53:53 +00001035bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
1036 const armnn::TensorInfo& input1,
1037 const armnn::TensorInfo& output,
Teresa Charlin52664732020-06-29 16:27:03 +01001038 const GatherDescriptor& descriptor,
narpra014951d842019-01-18 16:53:53 +00001039 armnn::Optional<std::string&> reasonIfUnsupported) const
1040{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001041 bool supported = true;
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001042 std::array<DataType,7> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001043 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001044 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001045 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001046 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001047 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001048 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001049 DataType::QSymmS16,
1050 DataType::Signed32
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001051 };
1052
Teresa Charlin52664732020-06-29 16:27:03 +01001053 if (descriptor.m_Axis != 0)
1054 {
1055 reasonIfUnsupported.value() += std::string("Reference Gather: axis not supported\n");
1056 supported &= false;
1057 }
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001058 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1059 "Reference Gather: input type not supported");
1060
1061 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1062 "Reference Gather: output type not supported");
1063
1064 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1065 "Reference Gather: indices (input1) type not supported");
1066
1067 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1068 "Reference Gather: input and output types not matching");
1069
1070 return supported;
narpra014951d842019-01-18 16:53:53 +00001071}
1072
FrancisMurtagh878f0232018-12-19 10:56:15 +00001073bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
1074 const TensorInfo& input1,
1075 const TensorInfo& output,
1076 Optional<std::string&> reasonIfUnsupported) const
1077{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01001078 return IsComparisonSupported(input0,
1079 input1,
1080 output,
1081 ComparisonDescriptor(ComparisonOperation::Greater),
1082 reasonIfUnsupported);
FrancisMurtagh878f0232018-12-19 10:56:15 +00001083}
1084
Derek Lamberti901ea112019-12-10 22:07:09 +00001085bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
1086 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +01001087{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001088 return true;
arovir011c7c81b2018-10-08 11:34:28 +01001089}
1090
Kevin May09ca49c2019-10-09 12:37:34 +01001091bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
1092 const TensorInfo& output,
1093 const InstanceNormalizationDescriptor& descriptor,
1094 Optional<std::string&> reasonIfUnsupported) const
1095{
Jan Eilers8eb25602020-03-09 12:13:48 +00001096 IgnoreUnused(descriptor);
Kevin May09ca49c2019-10-09 12:37:34 +01001097 // Define supported types
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001098 std::array<DataType, 3> supportedTypes =
Kevin May09ca49c2019-10-09 12:37:34 +01001099 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001100 DataType::BFloat16,
Kevin May09ca49c2019-10-09 12:37:34 +01001101 DataType::Float32,
1102 DataType::Float16
1103 };
1104
1105 bool supported = true;
1106
1107 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1108 "Reference Instance Normalization: input type not supported.");
1109
1110 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1111 "Reference Instance Normalization: output type not supported.");
1112
1113 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1114 "Reference Instance Normalization: input and output types mismatched.");
1115
1116 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1117 "Reference Instance Normalization: input and output shapes have different "
1118 "num total elements.");
1119
1120 return supported;
1121}
1122
arovir011c7c81b2018-10-08 11:34:28 +01001123bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1124 const TensorInfo& output,
1125 const L2NormalizationDescriptor& descriptor,
1126 Optional<std::string&> reasonIfUnsupported) const
1127{
Jan Eilers8eb25602020-03-09 12:13:48 +00001128 IgnoreUnused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001129 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01001130 std::array<DataType, 6> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001131 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001132 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001133 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001134 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001135 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001136 DataType::QAsymmU8,
1137 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001138 };
1139
1140 bool supported = true;
1141
1142 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1143 "Reference L2normalization: input type not supported.");
1144
1145 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1146 "Reference L2normalization: output type not supported.");
1147
1148 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1149 "Reference L2normalization: input and output types mismatched.");
1150
1151 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1152 "Reference L2normalization: input and output shapes have different "
1153 "num total elements.");
1154
1155 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001156}
1157
James Conroyaba90cd2020-11-06 16:28:18 +00001158bool RefLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0,
1159 const TensorInfo& input1,
1160 const TensorInfo& output,
1161 const LogicalBinaryDescriptor& descriptor,
1162 Optional<std::string&> reasonIfUnsupported) const
1163{
1164 IgnoreUnused(descriptor);
1165
1166 std::array<DataType, 1> supportedTypes =
1167 {
1168 DataType::Boolean
1169 };
1170
1171 bool supported = true;
1172 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1173 "Reference LogicalBinary: input 0 type not supported");
1174 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1175 "Reference LogicalBinary: input 1 type not supported");
1176
1177 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1178 "Reference LogicalBinary: input and output types do not match");
1179
1180 return supported;
1181}
1182
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001183bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1184 const TensorInfo& output,
1185 const LogSoftmaxDescriptor& descriptor,
1186 Optional<std::string&> reasonIfUnsupported) const
1187{
Jan Eilers8eb25602020-03-09 12:13:48 +00001188 IgnoreUnused(descriptor);
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001189
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001190 std::array<DataType, 3> supportedTypes =
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001191 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001192 DataType::BFloat16,
1193 DataType::Float32,
1194 DataType::Float16
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001195 };
1196
1197 bool supported = true;
1198 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1199 "Reference LogSoftmax: input type not supported");
1200
1201 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1202 "Reference LogSoftmax: output type not supported");
1203
1204 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1205 "Reference LogSoftmax: input and output types do not match");
1206
1207 return supported;
1208}
1209
arovir011c7c81b2018-10-08 11:34:28 +01001210bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1211 const TensorInfo& outputStateIn,
1212 const TensorInfo& cellStateIn,
1213 const TensorInfo& scratchBuffer,
1214 const TensorInfo& outputStateOut,
1215 const TensorInfo& cellStateOut,
1216 const TensorInfo& output,
1217 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001218 const LstmInputParamsInfo& paramsInfo,
1219 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +01001220{
Jan Eilers8eb25602020-03-09 12:13:48 +00001221 IgnoreUnused(descriptor);
1222 IgnoreUnused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001223
1224 bool supported = true;
1225
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001226 std::array<DataType,3> supportedTypes = {
1227 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001228 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001229 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001230 };
1231
Jan Eilersd01a83c2019-07-03 18:20:40 +01001232 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001233 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1234 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001235 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1236 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001237 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1238 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001239 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1240 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001241 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1242 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001243 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1244 "Reference Lstm: input and cellStateOut types are mismatched");
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01001245
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001246 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1247 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001248 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001249 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001250 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001251 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001252 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001253 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001254 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001255 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001256 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001257 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001258 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001259 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001260 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001261 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001262 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001263 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001264 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001265 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001266 "Reference Lstm: input and OutputGateBias types are mismatched");
1267 if (!descriptor.m_CifgEnabled)
1268 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001269 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001270 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001271 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001272 reasonIfUnsupported,
1273 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001274 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001275 "Reference Lstm: input and InputGateBias types are mismatched");
1276 if (descriptor.m_PeepholeEnabled)
1277 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001278 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001279 reasonIfUnsupported,
1280 "Reference Lstm: input and CellToInputWeights types are mismatched");
1281 }
1282 }
1283 if (descriptor.m_PeepholeEnabled)
1284 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001285 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001286 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001287 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001288 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1289 }
1290 if (descriptor.m_ProjectionEnabled)
1291 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001292 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001293 "Reference Lstm: input and mProjectionWeights types are mismatched");
1294 if (paramsInfo.m_ProjectionBias != nullptr)
1295 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001296 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001297 "Reference Lstm: input and ProjectionBias types are mismatched");
1298 }
1299 }
1300 if (descriptor.m_LayerNormEnabled)
1301 {
1302 if (!descriptor.m_CifgEnabled)
1303 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001304 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001305 reasonIfUnsupported,
1306 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1307 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001308 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001309 reasonIfUnsupported,
1310 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001311 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001312 reasonIfUnsupported,
1313 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001314 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001315 reasonIfUnsupported,
1316 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1317 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001318
1319 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001320}
1321
saoste012df12b32018-11-28 16:57:20 +00001322bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1323 const TensorInfo& input1,
1324 const TensorInfo& output,
1325 Optional<std::string&> reasonIfUnsupported) const
1326{
Sadik Armagan2999a022019-04-09 14:20:12 +01001327 bool supported = true;
1328
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001329 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001330 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001331 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001332 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001333 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001334 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001335 DataType::QSymmS16,
1336 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001337 };
1338
1339 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1340 "Reference maximum: input 0 is not a supported type.");
1341
1342 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1343 "Reference maximum: input 1 is not a supported type.");
1344
1345 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1346 "Reference maximum: output is not a supported type.");
1347
1348 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1349 "Reference maximum: input 0 and Input 1 types are mismatched");
1350
1351 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1352 "Reference maximum: input and output types are mismatched");
1353
1354 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1355 "Reference maximum: shapes are not suitable for implicit broadcast.");
1356
1357 return supported;
saoste012df12b32018-11-28 16:57:20 +00001358}
1359
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001360bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1361 const TensorInfo& output,
1362 const MeanDescriptor& descriptor,
1363 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001364{
James Conroy4d1ff582019-06-10 17:06:39 +01001365 bool supported = true;
1366 std::string meanLayerStr = "Mean";
1367 std::string outputTensorStr = "output";
1368
Sadik Armagan303980c2020-04-17 12:45:14 +01001369 std::array<DataType,6> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001370 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001371 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01001372 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001373 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001374 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001375 DataType::QAsymmU8,
1376 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01001377 };
1378
1379 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1380 "Reference Mean: input type not supported.");
1381
1382 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1383 "Reference Mean: input and output types are mismatched");
1384
1385 if (descriptor.m_KeepDims)
1386 {
1387 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1388 reasonIfUnsupported,
1389 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1390 output.GetNumDimensions(),
1391 meanLayerStr, outputTensorStr).data());
1392 }
1393 else if (descriptor.m_Axis.empty())
1394 {
1395 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1396 reasonIfUnsupported,
1397 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1398 meanLayerStr, outputTensorStr).data());
1399 }
1400 else
1401 {
Matthew Sloyan171214c2020-09-09 09:07:37 +01001402 auto outputDim = input.GetNumDimensions() - armnn::numeric_cast<unsigned int>(descriptor.m_Axis.size());
James Conroy4d1ff582019-06-10 17:06:39 +01001403
1404 if (outputDim > 0)
1405 {
1406 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1407 reasonIfUnsupported,
1408 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1409 meanLayerStr, outputTensorStr).data());
1410 }
1411 else
1412 {
1413 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1414 reasonIfUnsupported,
1415 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1416 meanLayerStr, outputTensorStr).data());
1417 }
1418 }
1419
1420 return supported;
narpra0132b90462018-09-13 11:07:48 +01001421}
1422
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001423bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +00001424 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +01001425 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001426 Optional<std::string&> reasonIfUnsupported) const
1427{
Jim Flynne242f2d2019-05-22 14:24:13 +01001428 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001429}
1430
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001431bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1432 const TensorInfo &output,
1433 Optional<std::string &> reasonIfUnsupported) const
1434{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001435 bool supported = true;
1436
Sadik Armagan303980c2020-04-17 12:45:14 +01001437 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001438 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001439 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001440 DataType::Float32,
1441 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001442 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001443 DataType::QAsymmU8,
1444 DataType::QSymmS16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001445 DataType::Boolean
1446 };
1447
1448 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1449 "Reference MemCopy: input type not supported");
1450
1451 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1452 "Reference MemCopy: output type not supported");
1453
1454 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1455 "Reference MemCopy: input and output types are mismatched");
1456
1457 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001458}
1459
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001460bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1461 const TensorInfo& input1,
1462 const TensorInfo& output,
1463 Optional<std::string&> reasonIfUnsupported) const
1464{
Sadik Armagan2999a022019-04-09 14:20:12 +01001465 bool supported = true;
1466
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001467 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001468 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001469 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001470 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001471 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001472 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001473 DataType::QSymmS16,
1474 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001475 };
1476
1477 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1478 "Reference minimum: input 0 is not a supported type.");
1479
1480 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1481 "Reference minimum: input 1 is not a supported type.");
1482
1483 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1484 "Reference minimum: output is not a supported type.");
1485
1486 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1487 "Reference minimum: input 0 and Input 1 types are mismatched");
1488
1489 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1490 "Reference minimum: input and output types are mismatched");
1491
1492 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1493 "Reference minimum: shapes are not suitable for implicit broadcast.");
1494
1495 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001496}
1497
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001498bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1499 const TensorInfo& input1,
1500 const TensorInfo& output,
1501 Optional<std::string&> reasonIfUnsupported) const
1502{
Sadik Armagan2999a022019-04-09 14:20:12 +01001503 bool supported = true;
1504
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001505 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001506 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001507 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001508 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001509 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001510 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001511 DataType::QSymmS16,
1512 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001513 };
1514
1515 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1516 "Reference multiplication: input 0 is not a supported type.");
1517
1518 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1519 "Reference multiplication: input 1 is not a supported type.");
1520
1521 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1522 "Reference multiplication: output is not a supported type.");
1523
1524 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1525 "Reference multiplication: input 0 and Input 1 types are mismatched");
1526
1527 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1528 "Reference multiplication: input and output types are mismatched");
1529
1530 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1531 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1532
1533 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001534}
1535
1536bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1537 const TensorInfo& output,
1538 const NormalizationDescriptor& descriptor,
1539 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001540{
Jan Eilers8eb25602020-03-09 12:13:48 +00001541 IgnoreUnused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001542
1543 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01001544 std::array<DataType, 6> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001545 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001546 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001547 DataType::Float16,
1548 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001549 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001550 DataType::QAsymmU8,
1551 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001552 };
1553
1554 bool supported = true;
1555
1556 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1557 "Reference normalization: input type not supported.");
1558
1559 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1560 "Reference normalization: output type not supported.");
1561
1562 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1563 "Reference normalization: input and output shapes have different "
1564 "num total elements.");
1565
1566 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001567}
1568
Derek Lamberti901ea112019-12-10 22:07:09 +00001569bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
1570 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001571{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001572 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001573}
1574
1575bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1576 const TensorInfo& output,
1577 const PadDescriptor& descriptor,
1578 Optional<std::string&> reasonIfUnsupported) const
1579{
Jan Eilers8eb25602020-03-09 12:13:48 +00001580 IgnoreUnused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001581 bool supported = true;
1582
1583 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01001584 std::array<DataType,6> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001585 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001586 DataType::BFloat16,
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001587 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001588 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001589 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001590 DataType::QAsymmU8,
1591 DataType::QSymmS16
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001592 };
1593
1594 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1595 "Reference pad: input is not a supported type.");
1596
1597 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1598 "Reference pad: output is not a supported type.");
1599
1600 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1601 "Reference pad: input and output types are mismatched.");
1602
1603 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01001604}
1605
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001606bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1607 const TensorInfo& output,
1608 const PermuteDescriptor& descriptor,
1609 Optional<std::string&> reasonIfUnsupported) const
1610{
Jan Eilers8eb25602020-03-09 12:13:48 +00001611 IgnoreUnused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001612 bool supported = true;
1613
1614 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01001615 std::array<DataType, 6> supportedTypes =
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001616 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001617 DataType::BFloat16,
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001618 DataType::Float32,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001619 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001620 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001621 DataType::QAsymmU8,
1622 DataType::QSymmS16
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001623 };
1624
1625 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1626 "Reference permute: input is not a supported type.");
1627
1628 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1629 "Reference permute: output is not a supported type.");
1630
1631 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1632 "Reference permute: input and output types are mismatched.");
1633
1634 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001635}
1636
1637bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1638 const TensorInfo& output,
1639 const Pooling2dDescriptor& descriptor,
1640 Optional<std::string&> reasonIfUnsupported) const
1641{
Jan Eilers8eb25602020-03-09 12:13:48 +00001642 IgnoreUnused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001643 bool supported = true;
1644
1645 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001646 std::array<DataType,6> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001647 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001648 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01001649 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001650 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001651 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001652 DataType::QAsymmU8,
1653 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001654 };
1655
1656 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1657 "Reference poolind2d: input is not a supported type.");
1658
1659 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1660 "Reference poolind2d: output is not a supported type.");
1661
1662 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1663 "Reference poolind2d: input and output types are mismatched.");
1664
1665 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001666}
1667
James Conroy4f1f8992020-04-29 20:01:10 +01001668bool RefLayerSupport::IsQLstmSupported(const TensorInfo& input,
1669 const TensorInfo& previousOutputIn,
1670 const TensorInfo& previousCellStateIn,
1671 const TensorInfo& outputStateOut,
1672 const TensorInfo& cellStateOut,
1673 const TensorInfo& output,
1674 const QLstmDescriptor& descriptor,
1675 const LstmInputParamsInfo& paramsInfo,
1676 Optional<std::string&> reasonIfUnsupported) const
1677{
1678 IgnoreUnused(input);
1679 IgnoreUnused(previousOutputIn);
1680 IgnoreUnused(previousCellStateIn);
1681 IgnoreUnused(outputStateOut);
1682 IgnoreUnused(cellStateOut);
1683 IgnoreUnused(output);
1684 IgnoreUnused(descriptor);
1685 IgnoreUnused(paramsInfo);
1686
1687 IgnoreUnused(reasonIfUnsupported);
1688
1689 return true;
1690}
1691
Derek Lamberti5f400d62019-03-25 15:41:58 +00001692bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1693 const TensorInfo& output,
1694 Optional<std::string&> reasonIfUnsupported) const
1695{
1696 bool supported = true;
1697
Finn Williamsfd271062019-12-04 14:27:27 +00001698 // Define supported input types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001699 std::array<DataType,7> supportedInputTypes = {
1700 DataType::BFloat16,
Keith Davis5e51cd82020-01-29 16:52:59 +00001701 DataType::Float32,
Keith Davis3d8bc972020-02-04 09:31:47 +00001702 DataType::Float16,
Ryan OShea9add1202020-02-07 10:06:33 +00001703 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00001704 DataType::QAsymmU8,
1705 DataType::QSymmS8,
1706 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00001707 };
1708
1709 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1710 "Reference quantize: input type not supported.");
1711
1712 // Define supported output types.
Ryan OShea9add1202020-02-07 10:06:33 +00001713 std::array<DataType,4> supportedOutputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00001714 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001715 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00001716 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001717 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00001718 };
1719 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1720 "Reference quantize: output type not supported.");
1721
1722 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1723 "Reference quantize: input and output shapes have different num total elements.");
1724
1725 return supported;
1726}
1727
Finn Williams2605b232020-06-10 15:53:46 +01001728bool RefLayerSupport::IsRankSupported(const TensorInfo& input,
1729 const TensorInfo& output,
1730 Optional<std::string&> reasonIfUnsupported) const
1731{
1732 IgnoreUnused(input);
1733 // Define supported output types.
1734 std::array<DataType,1> supportedOutputTypes =
1735 {
1736 DataType::Signed32,
1737 };
1738
1739 return CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1740 "Reference rank: input type not supported.");
1741}
1742
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001743bool RefLayerSupport::IsReduceSupported(const TensorInfo& input,
1744 const TensorInfo& output,
1745 const ReduceDescriptor& descriptor,
1746 Optional<std::string&> reasonIfUnsupported) const
1747{
1748 IgnoreUnused(descriptor);
1749 bool supported = true;
1750 std::array<DataType,7> supportedTypes =
1751 {
1752 DataType::BFloat16,
1753 DataType::Float32,
1754 DataType::Float16,
1755 DataType::QAsymmS8,
1756 DataType::QAsymmU8,
1757 DataType::QSymmS16,
1758 DataType::Signed32
1759 };
1760
1761 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1762 "Reference Reduce: input type not supported");
1763
1764 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1765 "Reference Reduce: output type not supported");
1766
1767 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1768 "Reference Reduce: input and output types not matching");
1769
1770 return supported;
1771}
1772
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001773bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Kevin Maya023c402019-12-12 17:28:05 +00001774 const TensorInfo& output,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001775 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001776 Optional<std::string&> reasonIfUnsupported) const
1777{
Jan Eilers8eb25602020-03-09 12:13:48 +00001778 IgnoreUnused(output);
1779 IgnoreUnused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001780 // Define supported output types.
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001781 std::array<DataType,8> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001782 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001783 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001784 DataType::Float32,
1785 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001786 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001787 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001788 DataType::QAsymmU8,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001789 DataType::QSymmS16,
1790 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01001791 };
Keith Davis0c2eeac2020-02-11 16:51:50 +00001792
Nina Drozd2f2778f2019-05-27 10:37:05 +01001793 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1794 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001795}
1796
1797bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001798 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001799 Optional<std::string&> reasonIfUnsupported) const
1800{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001801 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01001802 std::array<DataType,6> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001803 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001804 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001805 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001806 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001807 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001808 DataType::QAsymmU8,
1809 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001810 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001811
1812 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1813 "Reference ResizeBilinear: input type not supported");
1814
1815 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1816 "Reference ResizeBilinear: output type not supported");
1817
1818 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1819 "Reference ResizeBilinear: input and output types not matching");
1820
1821 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001822}
1823
Teresa Charlin970f43b2019-07-01 13:51:07 +01001824bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1825 const TensorInfo& output,
1826 const ResizeDescriptor& descriptor,
1827 Optional<std::string&> reasonIfUnsupported) const
1828{
Jan Eilers8eb25602020-03-09 12:13:48 +00001829 IgnoreUnused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01001830 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001831 std::array<DataType,6> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001832 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001833 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001834 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001835 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001836 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001837 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001838 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001839 };
1840
1841 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1842 "Reference Resize: input type not supported");
1843
1844 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1845 "Reference Resize: output type not supported");
1846
1847 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1848 "Reference Resize: input and output types not matching");
1849
1850 return supported;
1851}
1852
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001853bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1854 const TensorInfo& output,
1855 Optional<std::string&> reasonIfUnsupported) const
1856{
josh minor4a3c6102020-01-06 16:40:46 -06001857 return IsElementwiseUnarySupported(input,
1858 output,
1859 ElementwiseUnaryDescriptor(UnaryOperation::Rsqrt),
1860 reasonIfUnsupported);
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001861}
1862
Keith Davis3ae3f972021-05-21 16:33:48 +01001863bool RefLayerSupport::IsShapeSupported(const TensorInfo& input,
1864 const TensorInfo& output,
1865 Optional<std::string&> reasonIfUnsupported) const
1866{
1867 IgnoreUnused(input);
1868 bool supported = true;
1869
1870 std::array<DataType, 1> supportedTypes =
1871 {
1872 DataType::Signed32
1873 };
1874
1875 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1876 "Reference Shape: output type not supported");
1877
1878 return supported;
1879}
1880
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001881bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
1882 const TensorInfo& output,
1883 const SliceDescriptor& descriptor,
1884 Optional<std::string&> reasonIfUnsupported) const
1885{
Jan Eilers8eb25602020-03-09 12:13:48 +00001886 IgnoreUnused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001887 bool supported = true;
1888
Sadik Armagan303980c2020-04-17 12:45:14 +01001889 std::array<DataType, 5> supportedTypes =
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001890 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001891 DataType::BFloat16,
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001892 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001893 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001894 DataType::QAsymmU8,
1895 DataType::QSymmS16
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001896 };
1897
1898 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1899 "Reference Slice: input type not supported");
1900
1901 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1902 "Reference Slice: output type not supported");
1903
1904 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1905 "Reference Slice: input and output types are mismatched");
1906
1907 return supported;
1908}
1909
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001910bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1911 const TensorInfo& output,
1912 const SoftmaxDescriptor& descriptor,
1913 Optional<std::string&> reasonIfUnsupported) const
1914{
Jan Eilers8eb25602020-03-09 12:13:48 +00001915 IgnoreUnused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01001916 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001917 std::array<DataType,7> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01001918 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001919 DataType::BFloat16,
1920 DataType::Float32,
1921 DataType::Float16,
1922 DataType::QSymmS8,
1923 DataType::QAsymmS8,
1924 DataType::QAsymmU8,
1925 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +01001926 };
1927
1928 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001929 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001930
1931 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001932 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001933
1934 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001935 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001936
1937 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001938}
1939
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001940bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1941 const TensorInfo& output,
1942 const SpaceToBatchNdDescriptor& descriptor,
1943 Optional<std::string&> reasonIfUnsupported) const
1944{
Jan Eilers8eb25602020-03-09 12:13:48 +00001945 IgnoreUnused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01001946 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01001947 std::array<DataType,6> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01001948 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001949 DataType::BFloat16,
1950 DataType::Float32,
1951 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001952 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001953 DataType::QAsymmU8,
1954 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001955 };
1956
1957 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1958 "Reference SpaceToBatchNd: input type not supported");
1959
1960 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1961 "Reference SpaceToBatchNd: output type not supported");
1962
1963 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1964 "Reference SpaceToBatchNd: input and output types are mismatched");
1965
1966 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001967}
1968
Keith Davisa57eccb2019-06-14 17:33:22 +01001969bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01001970 const TensorInfo& output,
1971 const SpaceToDepthDescriptor& descriptor,
1972 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01001973{
1974
Jan Eilers8eb25602020-03-09 12:13:48 +00001975 IgnoreUnused(descriptor);
Keith Davisa57eccb2019-06-14 17:33:22 +01001976 bool supported = true;
1977
Sadik Armagan303980c2020-04-17 12:45:14 +01001978 std::array<DataType,6> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01001979 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001980 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001981 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001982 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001983 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001984 DataType::QAsymmU8,
1985 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001986 };
1987
1988 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1989 "Reference SpaceToDepth: input type not supported");
1990
1991 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1992 "Reference SpaceToDepth: output type not supported");
1993
1994 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1995 "Reference SpaceToDepth: input and output types are mismatched");
1996
1997 return supported;
1998}
1999
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002000bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
2001 const ViewsDescriptor& descriptor,
2002 Optional<std::string&> reasonIfUnsupported) const
2003{
Jan Eilers8eb25602020-03-09 12:13:48 +00002004 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002005 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002006 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002007 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002008 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002009 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002010 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002011 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002012 DataType::QAsymmU8,
2013 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002014 };
2015
2016 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2017 "Reference splitter: input type not supported");
2018
2019 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002020}
2021
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002022bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
2023 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
2024 const ViewsDescriptor& descriptor,
2025 Optional<std::string&> reasonIfUnsupported) const
2026{
Jan Eilers8eb25602020-03-09 12:13:48 +00002027 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002028 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002029 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002030 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002031 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002032 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002033 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002034 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002035 DataType::QAsymmU8,
2036 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002037 };
2038
2039 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2040 "Reference splitter: output type not supported");
Derek Lambertieac4adb2020-08-25 13:05:59 +01002041 for (const TensorInfo& output : outputs)
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002042 {
2043 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2044 "Reference splitter: input type not supported");
2045
2046 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2047 "Reference splitter: input and output types mismatched.");
2048 }
2049
2050 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002051}
2052
Matthew Jackson81e601c2019-07-11 12:07:09 +01002053bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
2054 const TensorInfo& output,
2055 const StackDescriptor& descriptor,
2056 Optional<std::string&> reasonIfUnsupported) const
2057{
Jan Eilers8eb25602020-03-09 12:13:48 +00002058 IgnoreUnused(descriptor);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002059
2060 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002061 std::array<DataType,6> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01002062 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002063 DataType::BFloat16,
Matthew Jackson81e601c2019-07-11 12:07:09 +01002064 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01002065 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002066 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002067 DataType::QAsymmU8,
2068 DataType::QSymmS16
Matthew Jackson81e601c2019-07-11 12:07:09 +01002069 };
2070
2071 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2072 "Reference stack: output type not supported");
2073 for (const TensorInfo* input : inputs)
2074 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01002075 ARMNN_ASSERT(input != nullptr);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002076 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
2077 "Reference stack: input type not supported");
2078
2079 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
2080 "Reference stack: input and output types mismatched.");
2081 }
2082
2083 return supported;
2084}
2085
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002086bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
2087 const TensorInfo& output,
2088 const StridedSliceDescriptor& descriptor,
2089 Optional<std::string&> reasonIfUnsupported) const
2090{
Jan Eilers8eb25602020-03-09 12:13:48 +00002091 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002092 bool supported = true;
2093
Sadik Armagan303980c2020-04-17 12:45:14 +01002094 std::array<DataType,5> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002095 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002096 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002097 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002098 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002099 DataType::QAsymmU8,
2100 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002101 };
2102
2103 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2104 "Reference StridedSlice: input type not supported");
2105
2106 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2107 "Reference StridedSlice: output type not supported");
2108
2109 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2110 "Reference StridedSlice: input and output types are mismatched");
2111
2112 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002113}
2114
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002115bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
2116 const TensorInfo& input1,
2117 const TensorInfo& output,
2118 Optional<std::string&> reasonIfUnsupported) const
2119{
Sadik Armagan2999a022019-04-09 14:20:12 +01002120 bool supported = true;
2121
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002122 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002123 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01002124 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002125 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002126 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002127 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002128 DataType::QSymmS16,
2129 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002130 };
2131
2132 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2133 "Reference subtraction: input 0 is not a supported type.");
2134
2135 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2136 "Reference subtraction: input 1 is not a supported type.");
2137
2138 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2139 "Reference subtraction: output is not a supported type.");
2140
2141 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2142 "Reference subtraction: input 0 and Input 1 types are mismatched");
2143
2144 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2145 "Reference subtraction: input and output types are mismatched");
2146
2147 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2148 "Reference subtraction: shapes are not suitable for implicit broadcast.");
2149
2150 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002151}
2152
Matteo Martincighab9e5252019-06-13 17:27:46 +01002153bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
2154 const TensorInfo& alpha,
2155 const TensorInfo& output,
2156 Optional<std::string&> reasonIfUnsupported) const
2157{
2158 bool supported = true;
2159
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002160 std::array<DataType, 6> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01002161 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002162 DataType::BFloat16,
Matteo Martincighab9e5252019-06-13 17:27:46 +01002163 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002164 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002165 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002166 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002167 DataType::QSymmS16
Matteo Martincighab9e5252019-06-13 17:27:46 +01002168 };
2169
2170 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2171 "PReLU: input is not a supported type.");
2172
2173 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
2174 "PReLU: alpha is not a supported type.");
2175
2176 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2177 "PReLU: output is not a supported type.");
2178
2179 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
2180 "PReLU: input, alpha and output types are mismatched");
2181
2182 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
2183 "PReLU: shapes are not suitable for implicit broadcast");
2184
2185 return supported;
2186}
2187
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002188bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
2189 const TensorInfo& output,
2190 const TransposeConvolution2dDescriptor& descriptor,
2191 const TensorInfo& weights,
2192 const Optional<TensorInfo>& biases,
2193 Optional<std::string&> reasonIfUnsupported) const
2194{
Jan Eilers8eb25602020-03-09 12:13:48 +00002195 IgnoreUnused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002196 bool supported = true;
2197
Sadik Armagan303980c2020-04-17 12:45:14 +01002198 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002199 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002200 DataType::BFloat16,
2201 DataType::Float32,
2202 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002203 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002204 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002205 DataType::QSymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002206 DataType::QSymmS16
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002207 };
2208
2209 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2210 "Reference TransposeConvolution2d: input is not a supported type.");
2211
2212 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2213 "Reference TransposeConvolution2d: output is not a supported type.");
2214
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002215 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2216 "Reference TransposeConvolution2d: input and output types mismatched.");
2217
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002218
2219 const DataType inputType = input.GetDataType();
Sadik Armagan303980c2020-04-17 12:45:14 +01002220 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002221 {
Derek Lambertid466a542020-01-22 15:37:29 +00002222 ARMNN_NO_DEPRECATE_WARN_BEGIN
Sadik Armagan303980c2020-04-17 12:45:14 +01002223 std::array<DataType, 4> supportedWeightTypes =
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002224 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002225 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002226 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +00002227 DataType::QSymmS8,
2228 DataType::QuantizedSymm8PerAxis //Deprecated
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002229 };
Derek Lambertid466a542020-01-22 15:37:29 +00002230 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002231
2232 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
2233 "Reference TransposeConvolution2d: weights type not supported for "
2234 "quantized input.");
2235 }
2236 else
2237 {
2238 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
2239 "Reference TransposeConvolution2d: weights is not a supported type.");
2240
2241 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
2242 "Reference TransposeConvolution2d: input and weights types mismatched.");
2243 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002244
2245 if (biases.has_value())
2246 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002247 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01002248 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002249 DataType::BFloat16,
2250 DataType::Float32,
2251 DataType::Float16,
2252 DataType::Signed32
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002253 };
2254 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
2255 "Reference TransposeConvolution2d: biases is not a supported type.");
2256 }
2257
2258 return supported;
2259}
2260
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002261bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
2262 const TensorInfo& output,
2263 const TransposeDescriptor& descriptor,
2264 Optional<std::string&> reasonIfUnsupported) const
2265{
Jan Eilers8eb25602020-03-09 12:13:48 +00002266 IgnoreUnused(descriptor);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002267 bool supported = true;
2268
2269 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002270 std::array<DataType, 6> supportedTypes =
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002271 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002272 DataType::BFloat16,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002273 DataType::Float32,
2274 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002275 DataType::QAsymmS8,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002276 DataType::QAsymmU8,
2277 DataType::QSymmS16
2278 };
2279
2280 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2281 "Reference transpose: input is not a supported type.");
2282
2283 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2284 "Reference transpose: output is not a supported type.");
2285
2286 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2287 "Reference transpose: input and output types are mismatched.");
2288
2289 return supported;
2290}
2291
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01002292bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
2293 const TensorInfo& input,
2294 const TensorInfo& outputStateIn,
2295 const TensorInfo& cellStateIn,
2296 const TensorInfo& output,
2297 const Optional<TensorInfo>& hiddenStateOutput,
2298 const Optional<TensorInfo>& cellStateOutput,
2299 const UnidirectionalSequenceLstmDescriptor& descriptor,
2300 const LstmInputParamsInfo& paramsInfo,
2301 Optional<std::string&> reasonIfUnsupported) const
2302{
2303 IgnoreUnused(descriptor);
2304 IgnoreUnused(paramsInfo);
2305 IgnoreUnused(outputStateIn);
2306 IgnoreUnused(cellStateIn);
2307 bool supported = true;
2308
2309 if (hiddenStateOutput.has_value() || cellStateOutput.has_value())
2310 {
2311 reasonIfUnsupported.value() += "Reference UnidirectionalSequenceLstm: hidden state output "
2312 "and cell state output are not supported at the moment.";
2313 }
2314
2315 std::array<DataType, 1> supportedTypes =
2316 {
2317 DataType::Float32
2318 };
2319
2320 std::array<DataType, 1> supportedWeightTypes =
2321 {
2322 DataType::Float32
2323 };
2324
2325 // check inputs and outputs
2326 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2327 "Reference UnidirectionalSequenceLstm: input is not a supported type.");
2328 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
2329 "Reference UnidirectionalSequenceLstm: input and outputStateIn types are mismatched");
2330 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
2331 "Reference UnidirectionalSequenceLstm: input and cellStateIn types are mismatched");
2332
2333 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2334 "Reference UnidirectionalSequenceLstm: input and output types are mismatched");
2335 // check layer parameters
2336 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes),
2337 reasonIfUnsupported,
2338 "Reference UnidirectionalSequenceLstm: InputToForgetWeights "
2339 "is not a supported type.");
2340 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToCellWeights(), supportedWeightTypes),
2341 reasonIfUnsupported,
2342 "Reference UnidirectionalSequenceLstm: InputToCellWeights is not a supported type.");
2343 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToOutputWeights(), supportedWeightTypes),
2344 reasonIfUnsupported,
2345 "Reference UnidirectionalSequenceLstm: InputToOutputWeights "
2346 "is not a supported type.");
2347 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToForgetWeights(), supportedWeightTypes),
2348 reasonIfUnsupported,
2349 "Reference UnidirectionalSequenceLstm: RecurrentToForgetWeights "
2350 "is not a supported type.");
2351 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToCellWeights(), supportedWeightTypes),
2352 reasonIfUnsupported,
2353 "Reference UnidirectionalSequenceLstm: RecurrentToCellWeights "
2354 "is not a supported type.");
2355 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToOutputWeights(), supportedWeightTypes),
2356 reasonIfUnsupported,
2357 "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights "
2358 "is not a supported type.");
2359 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
2360 "Reference UnidirectionalSequenceLstm: input and ForgetGateBias types "
2361 "are mismatched");
2362 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
2363 "Reference UnidirectionalSequenceLstm: input and CellBias types are mismatched");
2364 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
2365 "Reference UnidirectionalSequenceLstm: input and OutputGateBias types "
2366 "are mismatched");
2367 if (!descriptor.m_CifgEnabled)
2368 {
2369 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes),
2370 reasonIfUnsupported,
2371 "Reference UnidirectionalSequenceLstm: InputToInputWeights "
2372 "is not a supported type.");
2373 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToInputWeights(), supportedWeightTypes),
2374 reasonIfUnsupported,
2375 "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights "
2376 "is not a supported type.");
2377 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
2378 "Reference UnidirectionalSequenceLstm: input and InputGateBias types "
2379 "are mismatched");
2380 if (descriptor.m_PeepholeEnabled)
2381 {
2382 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes),
2383 reasonIfUnsupported,
2384 "Reference UnidirectionalSequenceLstm: CellToInputWeights "
2385 "is not a supported type.");
2386 }
2387 }
2388 if (descriptor.m_PeepholeEnabled)
2389 {
2390 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToForgetWeights(), supportedWeightTypes),
2391 reasonIfUnsupported,
2392 "Reference UnidirectionalSequenceLstm: CellToForgetWeights "
2393 "is not a supported type.");
2394 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToOutputWeights(), supportedWeightTypes),
2395 reasonIfUnsupported,
2396 "Reference UnidirectionalSequenceLstm: CellToOutputWeights "
2397 "is not a supported type.");
2398 }
2399 if (descriptor.m_ProjectionEnabled)
2400 {
2401 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetProjectionWeights(), supportedWeightTypes),
2402 reasonIfUnsupported,
2403 "Reference UnidirectionalSequenceLstm: ProjectionWeights "
2404 "is not a supported type.");
2405 if (paramsInfo.m_ProjectionBias != nullptr)
2406 {
2407 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
2408 "Reference UnidirectionalSequenceLstm: input and ProjectionBias types "
2409 "are mismatched");
2410 }
2411 }
2412 if (descriptor.m_LayerNormEnabled)
2413 {
2414 if (!descriptor.m_CifgEnabled)
2415 {
2416 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputLayerNormWeights(), supportedWeightTypes),
2417 reasonIfUnsupported,
2418 "Reference UnidirectionalSequenceLstm: InputLayerNormWeights "
2419 "is not a supported type.");
2420 }
2421 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetLayerNormWeights(), supportedWeightTypes),
2422 reasonIfUnsupported,
2423 "Reference UnidirectionalSequenceLstm: ForgetLayerNormWeights "
2424 "is not a supported type.");
2425 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellLayerNormWeights(), supportedWeightTypes),
2426 reasonIfUnsupported,
2427 "Reference UnidirectionalSequenceLstm: CellLayerNormWeights "
2428 "is not a supported type.");
2429 supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputLayerNormWeights(), supportedWeightTypes),
2430 reasonIfUnsupported,
2431 "Reference UnidirectionalSequenceLstm: OutputLayerNormWeights "
2432 "is not a supported type.");
2433 }
2434
2435 return supported;
2436}
2437
arovir011c7c81b2018-10-08 11:34:28 +01002438} // namespace armnn