blob: e278e45f0150b26a6827a2016bfa55430fcfcef1 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
Teresa Charlin52664732020-06-29 16:27:03 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01007
Keith Davis0c2eeac2020-02-11 16:51:50 +00008#include <armnn/TypesUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000010#include <armnn/Descriptors.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000011#include <armnn/utility/IgnoreUnused.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010012#include <armnn/utility/NumericCast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
Matteo Martincighe011d202019-11-28 11:35:47 +000014#include <LayerSupportCommon.hpp>
Derek Lambertif674aa02019-08-01 15:56:25 +010015#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000016
Derek Lamberti50db4e82019-03-13 14:16:15 +000017#include <vector>
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <array>
19
telsoa014fcda012018-03-09 14:13:49 +000020namespace armnn
21{
22
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010023namespace
24{
25
26template<typename Float32Func, typename Uint8Func, typename ... Params>
27bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
28 DataType dataType,
29 Float32Func floatFuncPtr,
30 Uint8Func uint8FuncPtr,
31 Params&&... params)
32{
33 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
34 dataType,
35 &FalseFunc<Params...>,
36 floatFuncPtr,
37 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000038 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000039 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010040 std::forward<Params>(params)...);
41}
42
43} // anonymous namespace
44
James Conroy4d1ff582019-06-10 17:06:39 +010045namespace
46{
47
48std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
49 unsigned int actual,
50 std::string& layerStr,
51 std::string& tensorName)
52{
53 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
54 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
55
56 return errorMsg;
57}
58
59} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000060
Sadik Armagan9199e582019-09-05 17:35:31 +010061bool RefLayerSupport::IsAbsSupported(const TensorInfo& input, const TensorInfo& output,
62 Optional<std::string&> reasonIfUnsupported) const
63{
josh minor4a3c6102020-01-06 16:40:46 -060064 return IsElementwiseUnarySupported(input,
65 output,
66 ElementwiseUnaryDescriptor(UnaryOperation::Abs),
67 reasonIfUnsupported);
Sadik Armagan9199e582019-09-05 17:35:31 +010068}
69
arovir011c7c81b2018-10-08 11:34:28 +010070bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
71 const TensorInfo& output,
72 const ActivationDescriptor& descriptor,
73 Optional<std::string&> reasonIfUnsupported) const
74{
Derek Lamberti50db4e82019-03-13 14:16:15 +000075 bool supported = true;
76
77 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +000078 std::array<DataType,6> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +000079 DataType::BFloat16,
Derek Lamberti50db4e82019-03-13 14:16:15 +000080 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +010081 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +000082 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +000083 DataType::QAsymmU8,
84 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +000085 };
86
87 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
88 "Reference activation: input type not supported.");
89
90 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
91 "Reference activation: output type not supported.");
92
93 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
94 "Reference activation: input and output types mismatched.");
95
96 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
97 "Reference activation: input and output shapes are of different rank.");
98
99
100 struct ActivationFunctionSupported : public Rule
101 {
102 ActivationFunctionSupported(const ActivationDescriptor& desc)
103 {
104 switch(desc.m_Function)
105 {
106 case ActivationFunction::Abs:
107 case ActivationFunction::BoundedReLu:
David Monahan3b3c3812020-02-25 09:03:29 +0000108 case ActivationFunction::Elu:
Colm Donelan03fbeaf2020-02-26 15:39:23 +0000109 case ActivationFunction::HardSwish:
Derek Lamberti50db4e82019-03-13 14:16:15 +0000110 case ActivationFunction::LeakyReLu:
111 case ActivationFunction::Linear:
112 case ActivationFunction::ReLu:
113 case ActivationFunction::Sigmoid:
114 case ActivationFunction::SoftReLu:
115 case ActivationFunction::Sqrt:
116 case ActivationFunction::Square:
117 case ActivationFunction::TanH:
118 {
119 m_Res = true;
120 break;
121 }
122 default:
123 {
124 m_Res = false;
125 break;
126 }
127 }
128 }
129 };
130
131 // Function is supported
132 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
133 "Reference activation: function not supported.");
134
135 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100136}
137
138bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
139 const TensorInfo& input1,
140 const TensorInfo& output,
141 Optional<std::string&> reasonIfUnsupported) const
142{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000143 bool supported = true;
144
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100145 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000146 DataType::BFloat16,
Derek Lamberti50db4e82019-03-13 14:16:15 +0000147 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100148 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000149 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000150 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100151 DataType::QSymmS16,
152 DataType::Signed32
Derek Lamberti50db4e82019-03-13 14:16:15 +0000153 };
154
155 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
156 "Reference addition: input 0 is not a supported type.");
157
158 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
159 "Reference addition: input 1 is not a supported type.");
160
161 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
162 "Reference addition: output is not a supported type.");
163
164 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
165 "Reference addition: input 0 and Input 1 types are mismatched");
166
167 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
168 "Reference addition: input and output types are mismatched");
169
170 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
171 "Reference addition: shapes are not suitable for implicit broadcast.");
172
173 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100174}
175
Nikhil Raj68c2c902019-09-19 11:21:11 +0100176bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
177 const armnn::ArgMinMaxDescriptor &descriptor,
178 armnn::Optional<std::string &> reasonIfUnsupported) const
179{
Jan Eilers8eb25602020-03-09 12:13:48 +0000180 IgnoreUnused(descriptor);
Nikhil Raj68c2c902019-09-19 11:21:11 +0100181
Mike Kelly1f140f72021-04-06 12:25:55 +0100182 std::array<DataType, 8> supportedInputTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100183 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000184 DataType::BFloat16,
Teresa Charline300b362020-05-25 10:01:03 +0100185 DataType::Float16,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100186 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100187 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000188 DataType::QAsymmU8,
189 DataType::QSymmS16,
Mike Kelly1f140f72021-04-06 12:25:55 +0100190 DataType::Signed32,
191 DataType::Signed64
192 };
193
194 std::array<DataType,2> supportedOutputTypes = {
195 DataType::Signed32,
196 DataType::Signed64
Nikhil Raj68c2c902019-09-19 11:21:11 +0100197 };
198
199 bool supported = true;
200
Mike Kelly1f140f72021-04-06 12:25:55 +0100201 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100202 "Reference ArgMinMax: input is not a supported type.");
Mike Kelly1f140f72021-04-06 12:25:55 +0100203 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100204 "Reference ArgMinMax: output type not supported");
205
206 return supported;
207}
208
arovir011c7c81b2018-10-08 11:34:28 +0100209bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
210 const TensorInfo& output,
211 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100212 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100213 const TensorInfo& beta,
214 const TensorInfo& gamma,
215 const BatchNormalizationDescriptor& descriptor,
216 Optional<std::string&> reasonIfUnsupported) const
217{
Jan Eilers8eb25602020-03-09 12:13:48 +0000218 IgnoreUnused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100219
Sadik Armagan303980c2020-04-17 12:45:14 +0100220 std::array<DataType, 6> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100221 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000222 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100223 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100224 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100225 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000226 DataType::QAsymmU8,
227 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100228 };
229
230 bool supported = true;
231
232 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
233 "Reference batch normalization: input is not a supported type.");
234
235 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
236 "Reference batch normalization: output is not a supported type.");
237
238 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
239 "Reference batch normalization: input and output types are mismatched");
240
241 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
242 "Reference batch normalization: mean is not a supported type.");
243
244 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
245 "Reference batch normalization: variance is not a supported type.");
246
247 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
248 "Reference batch normalization: beta is not a supported type.");
249
250 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
251 "Reference batch normalization: gamma is not a supported type.");
252
253 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100254}
255
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000256bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
257 const TensorInfo& output,
258 const BatchToSpaceNdDescriptor& descriptor,
259 Optional<std::string&> reasonIfUnsupported) const
260{
Jan Eilers8eb25602020-03-09 12:13:48 +0000261 IgnoreUnused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100262
263 bool supported = true;
264
265 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
266 std::string inputTensorStr = "input";
267 std::string outputTensorStr = "output";
268
269 // Define supported types.
Sadik Armagan303980c2020-04-17 12:45:14 +0100270 std::array<DataType,6> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100271 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000272 DataType::BFloat16,
273 DataType::Float32,
274 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100275 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000276 DataType::QAsymmU8,
277 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100278 };
279
280 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
281 "Reference BatchToSpaceNd: input type not supported.");
282
283 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
284 "Reference BatchToSpaceNd: output type not supported.");
285
286 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
287 "Reference BatchToSpaceNd: input and output types mismatched.");
288
289 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
290 reasonIfUnsupported,
291 CreateIncorrectDimensionsErrorMsg(4,
292 output.GetNumDimensions(),
293 batchToSpaceNdLayerStr,
294 outputTensorStr).data());
295
296 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
297 reasonIfUnsupported,
298 CreateIncorrectDimensionsErrorMsg(4,
299 input.GetNumDimensions(),
300 batchToSpaceNdLayerStr,
301 inputTensorStr).data());
302
303 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000304}
305
mathad01b392e982021-04-07 12:07:30 +0100306bool RefLayerSupport::IsCastSupported(const TensorInfo& input,
307 const TensorInfo& output,
308 Optional<std::string&> reasonIfUnsupported) const
309{
310 std::array<DataType, 9> supportedInputTypes =
311 {
312 DataType::BFloat16,
313 DataType::Float32,
314 DataType::Float16,
315 DataType::QSymmS8,
316 DataType::QAsymmS8,
317 DataType::QAsymmU8,
318 DataType::QSymmS16,
319 DataType::Signed32
320 };
321
322 bool supported = true;
323 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
324 "Reference cast: input is not a supported type");
325
326
327 supported &= CheckSupportRule(TypeAnyOf(output, supportedInputTypes), reasonIfUnsupported,
328 "Reference cast: output is not a supported type");
329
330 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
331 "Reference cast: input and output shapes have different number of total elements");
332
333 return supported;
334}
335
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100336bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
337 const TensorInfo& input1,
338 const TensorInfo& output,
339 const ComparisonDescriptor& descriptor,
340 Optional<std::string&> reasonIfUnsupported) const
341{
Jan Eilers8eb25602020-03-09 12:13:48 +0000342 IgnoreUnused(descriptor);
Sadik Armagan303980c2020-04-17 12:45:14 +0100343 std::array<DataType, 8> supportedInputTypes =
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100344 {
Sadik Armaganb60dd242020-03-19 13:53:16 +0000345 DataType::Boolean,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000346 DataType::BFloat16,
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100347 DataType::Float32,
348 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100349 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000350 DataType::QAsymmU8,
Sadik Armaganb60dd242020-03-19 13:53:16 +0000351 DataType::QSymmS16,
352 DataType::Signed32
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100353 };
354
355 bool supported = true;
356 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
357 "Reference comparison: input 0 is not a supported type");
358
359 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
360 "Reference comparison: input 0 and Input 1 types are mismatched");
361
362 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
363 "Reference comparison: output is not of type Boolean");
364
365 return supported;
366}
367
Jim Flynn906f9462019-05-10 13:55:21 +0100368bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
369 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100370 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100371 Optional<std::string&> reasonIfUnsupported) const
372{
Jan Eilers8eb25602020-03-09 12:13:48 +0000373 IgnoreUnused(descriptor);
Jim Flynne242f2d2019-05-22 14:24:13 +0100374
375 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000376 std::array<DataType,6> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100377 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000378 DataType::BFloat16,
379 DataType::Float32,
380 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000381 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100382 DataType::QAsymmU8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000383 DataType::QSymmS16
Jim Flynne242f2d2019-05-22 14:24:13 +0100384 };
385
386 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
387 "Reference concatenation: output type not supported");
388 for (const TensorInfo* input : inputs)
389 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100390 ARMNN_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100391 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
392 "Reference concatenation: input type not supported");
393
394 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
395 "Reference concatenation: input and output types mismatched.");
396 }
397
398 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100399}
400
arovir011c7c81b2018-10-08 11:34:28 +0100401bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
402 Optional<std::string&> reasonIfUnsupported) const
403{
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100404 std::array<DataType,8> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100405 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000406 DataType::BFloat16,
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100407 DataType::Float16,
Nina Drozd58ef2c62019-05-16 12:09:18 +0100408 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +0000409 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100410 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000411 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100412 DataType::QSymmS16,
413 DataType::Signed32
Nina Drozd58ef2c62019-05-16 12:09:18 +0100414 };
415
416 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
417 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100418}
419
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000420bool RefLayerSupport::IsConvertBf16ToFp32Supported(const TensorInfo& input,
421 const TensorInfo& output,
422 Optional<std::string&> reasonIfUnsupported) const
423{
424 bool supported = true;
425
426 supported &= CheckSupportRule(TypeIs(input, DataType::BFloat16), reasonIfUnsupported,
427 "Reference for ConvertBf16ToFp32 layer: input type not supported");
428
429 supported &= CheckSupportRule(TypeIs(output, DataType::Float32), reasonIfUnsupported,
430 "Reference for ConvertBf16ToFp32 layer: output type not supported");
431
432 return supported;
433}
434
arovir011c7c81b2018-10-08 11:34:28 +0100435bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
436 const TensorInfo& output,
437 Optional<std::string&> reasonIfUnsupported) const
438{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100439 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
440 input.GetDataType(),
441 &TrueFunc<>,
442 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000443 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000444 &FalseFuncI32<>,
445 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100446 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
447 output.GetDataType(),
448 &FalseOutputFuncF16<>,
449 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000450 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000451 &FalseFuncI32<>,
452 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100453}
454
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000455bool RefLayerSupport::IsConvertFp32ToBf16Supported(const TensorInfo& input,
456 const TensorInfo& output,
457 Optional<std::string&> reasonIfUnsupported) const
458{
459 bool supported = true;
460
461 supported &= CheckSupportRule(TypeIs(input, DataType::Float32), reasonIfUnsupported,
462 "Reference for ConvertFp32ToBf16 layer: input type not supported");
463
464 supported &= CheckSupportRule(TypeIs(output, DataType::BFloat16), reasonIfUnsupported,
465 "Reference for ConvertFp32ToBf16 layer: output type not supported");
466
467 return supported;
468}
469
arovir011c7c81b2018-10-08 11:34:28 +0100470bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
471 const TensorInfo& output,
472 Optional<std::string&> reasonIfUnsupported) const
473{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100474 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
475 input.GetDataType(),
476 &FalseInputFuncF16<>,
477 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000478 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000479 &FalseFuncI32<>,
480 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100481 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
482 output.GetDataType(),
483 &TrueFunc<>,
484 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000485 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000486 &FalseFuncI32<>,
487 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100488}
489
490bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
491 const TensorInfo& output,
492 const Convolution2dDescriptor& descriptor,
493 const TensorInfo& weights,
494 const Optional<TensorInfo>& biases,
495 Optional<std::string&> reasonIfUnsupported) const
496{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100497 bool supported = true;
498
499 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000500 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000501 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000502 DataType::BFloat16,
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000503 DataType::Float32,
504 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000505 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100506 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000507 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000508 DataType::QSymmS16
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100509 };
510
511 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000512 "Reference Convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100513
514 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000515 "Reference Convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100516
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +0000517 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
518 if (input.GetDataType() == DataType::BFloat16)
519 {
520 if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
521 {
522 reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
523 supported = false;
524 }
525 }
526 else
527 {
528 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000529 "Reference Convolution2d: input and output types mismatched.");
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +0000530 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100531
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000532 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000533 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000534 {
Derek Lambertid466a542020-01-22 15:37:29 +0000535 ARMNN_NO_DEPRECATE_WARN_BEGIN
Keith Davis0c2eeac2020-02-11 16:51:50 +0000536 std::array<DataType, 4> supportedWeightTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000537 {
Sadik Armagan303980c2020-04-17 12:45:14 +0100538 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000539 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +0000540 DataType::QSymmS8,
541 DataType::QuantizedSymm8PerAxis // deprecated
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000542 };
Derek Lambertid466a542020-01-22 15:37:29 +0000543 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000544
545 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000546 "Reference Convolution2d: weights type not supported for quantized input.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000547 }
548 else
549 {
550 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000551 "Reference Convolution2d: weights is not a supported type.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000552
553 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000554 "Reference Convolution2d: input and weights types mismatched.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000555 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100556
557 if (biases.has_value())
558 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000559 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000560 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000561 DataType::BFloat16,
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000562 DataType::Float32,
563 DataType::Float16,
564 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100565 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000566
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100567 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000568 "Reference Convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100569 }
Jan Eilers8eb25602020-03-09 12:13:48 +0000570 IgnoreUnused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100571
572 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100573}
574
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000575bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
576 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000577 Optional<std::string&> reasonIfUnsupported) const
578{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100579 bool supported = true;
580
Narumol Prangnawarat403a1852020-03-12 14:24:13 +0000581 std::array<DataType, 8> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100582 {
Narumol Prangnawarat403a1852020-03-12 14:24:13 +0000583 DataType::BFloat16,
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +0000584 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100585 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000586 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100587 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000588 DataType::QSymmS8,
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +0000589 DataType::QSymmS16,
590 DataType::Signed32
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100591 };
592
593 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000594 "Reference for Debug layer: input type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100595
596 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000597 "Reference for Debug layer: output type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100598
599 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000600 "Reference for Debug layer: input and output types are mismatched");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100601
602 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000603}
604
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100605bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
606 const TensorInfo& output,
607 const DepthToSpaceDescriptor& descriptor,
608 Optional<std::string&> reasonIfUnsupported) const
609{
Jan Eilers8eb25602020-03-09 12:13:48 +0000610 IgnoreUnused(descriptor);
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100611 bool supported = true;
612
Sadik Armagan303980c2020-04-17 12:45:14 +0100613 std::array<DataType,6> supportedTypes =
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100614 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000615 DataType::BFloat16,
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100616 DataType::Float32,
617 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100618 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000619 DataType::QAsymmU8,
620 DataType::QSymmS16
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100621 };
622
623 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
624 "Reference DepthToSpace: input type not supported");
625
626 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
627 "Reference DepthToSpace: output type not supported");
628
629 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
630 "Reference DepthToSpace: input and output types are mismatched");
631
632 return supported;
633}
634
arovir011c7c81b2018-10-08 11:34:28 +0100635bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
636 const TensorInfo& output,
637 const DepthwiseConvolution2dDescriptor& descriptor,
638 const TensorInfo& weights,
639 const Optional<TensorInfo>& biases,
640 Optional<std::string&> reasonIfUnsupported) const
641{
Sadik Armagan303980c2020-04-17 12:45:14 +0100642 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100643 bool supported = true;
644
645 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000646 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100647 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000648 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100649 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100650 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000651 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000652 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100653 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000654 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100655 };
656
657 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
658 "Reference DepthwiseConvolution2d: input is not a supported type.");
659
660 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
661 "Reference DepthwiseConvolution2d: output is not a supported type.");
662
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100663 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
664 "Reference DepthwiseConvolution2d: input and output types mismatched.");
665
Teresa Charlind8df0262019-11-11 12:28:15 +0000666 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000667 if (IsQuantized8BitType(inputType))
Teresa Charlind8df0262019-11-11 12:28:15 +0000668 {
Sadik Armagan303980c2020-04-17 12:45:14 +0100669 ARMNN_NO_DEPRECATE_WARN_BEGIN
670 std::array<DataType, 4> supportedWeightTypes =
671 {
672 DataType::QAsymmS8,
673 DataType::QAsymmU8,
674 DataType::QSymmS8,
675 DataType::QuantizedSymm8PerAxis // deprecated
676 };
677 ARMNN_NO_DEPRECATE_WARN_END
Teresa Charlind8df0262019-11-11 12:28:15 +0000678
679 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Sadik Armagan303980c2020-04-17 12:45:14 +0100680 "Reference DepthwiseConvolution2d: weights type not supported for "
681 "quantized input.");
Teresa Charlind8df0262019-11-11 12:28:15 +0000682 }
683 else
684 {
685 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
686 "Reference DepthwiseConvolution2d: weights is not a supported type.");
687
688 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
689 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
690 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100691
692 if (biases.has_value())
693 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000694 std::array<DataType,4> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100695 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000696 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100697 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100698 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100699 DataType::Signed32
700 };
701 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
702 "Reference DepthwiseConvolution2d: biases is not a supported type.");
703 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100704
705 return supported;
706
arovir011c7c81b2018-10-08 11:34:28 +0100707}
708
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000709bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
710 const TensorInfo& output,
711 Optional<std::string&> reasonIfUnsupported) const
712{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100713 bool supported = true;
714
Ryan OShea9add1202020-02-07 10:06:33 +0000715 std::array<DataType,4> supportedInputTypes = {
716 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000717 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +0000718 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000719 DataType::QSymmS16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100720 };
721
722 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000723 "Reference for Dequantize layer: input type not supported.");
724
725 supported &= CheckSupportRule( TypeNotPerAxisQuantized(input), reasonIfUnsupported,
726 "Reference for Dequantize layer: per-axis quantized input not support .");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100727
Derek Lambertid466a542020-01-22 15:37:29 +0000728 supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
729 "Reference dequantize: per-axis quantized input not support .");
730
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000731 std::array<DataType,3> supportedOutputTypes = {
732 DataType::BFloat16,
Jan Eilersf7107932019-11-01 11:09:36 +0000733 DataType::Float32,
734 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100735 };
736
737 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000738 "Reference for Dequantize layer: output type not supported.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100739
740 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000741 "Reference for Dequantize layer: input/output shapes have different num total "
742 "elements.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100743
744 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000745}
746
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000747bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
748 const TensorInfo& scores,
749 const TensorInfo& anchors,
750 const TensorInfo& detectionBoxes,
751 const TensorInfo& detectionClasses,
752 const TensorInfo& detectionScores,
753 const TensorInfo& numDetections,
754 const DetectionPostProcessDescriptor& descriptor,
755 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000756{
Jan Eilers8eb25602020-03-09 12:13:48 +0000757 IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
Derek Lamberti901ea112019-12-10 22:07:09 +0000758
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100759 bool supported = true;
760
Sadik Armaganaa41d5d2020-11-16 14:27:52 +0000761 std::array<DataType,6> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100762 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000763 DataType::BFloat16,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100764 DataType::Float32,
Sadik Armaganaa41d5d2020-11-16 14:27:52 +0000765 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100766 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000767 DataType::QAsymmU8,
768 DataType::QSymmS16
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100769 };
770
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000771 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100772 "Reference DetectionPostProcess: input 0 is not a supported type.");
773
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000774 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100775 "Reference DetectionPostProcess: input 1 is not a supported type.");
776
777 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000778}
779
Pablo Tellof0bd6832019-04-26 17:58:13 +0100780bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
781 const TensorInfo& output,
782 const DepthwiseConvolution2dDescriptor& descriptor,
783 const TensorInfo& weights,
784 const Optional<TensorInfo>& biases,
785 Optional<std::string&> reasonIfUnsupported) const
786{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100787 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100788}
789
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100790bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100791 const TensorInfo& input1,
792 const TensorInfo& output,
793 Optional<std::string&> reasonIfUnsupported) const
794{
Sadik Armagan2999a022019-04-09 14:20:12 +0100795 bool supported = true;
796
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100797 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000798 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +0100799 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100800 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100801 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000802 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100803 DataType::QSymmS16,
804 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +0100805 };
806
807 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
808 "Reference division: input 0 is not a supported type.");
809
810 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
811 "Reference division: input 1 is not a supported type.");
812
813 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
814 "Reference division: output is not a supported type.");
815
816 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
817 "Reference division: input 0 and Input 1 types are mismatched");
818
819 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
820 "Reference division: input and output types are mismatched");
821
822 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
823 "Reference division: shapes are not suitable for implicit broadcast.");
824
825 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100826}
827
josh minor4a3c6102020-01-06 16:40:46 -0600828bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
829 const TensorInfo& output,
830 const ElementwiseUnaryDescriptor& descriptor,
831 Optional<std::string&> reasonIfUnsupported) const
832{
Jan Eilers8eb25602020-03-09 12:13:48 +0000833 IgnoreUnused(descriptor);
josh minor4a3c6102020-01-06 16:40:46 -0600834
Sadik Armagan303980c2020-04-17 12:45:14 +0100835 std::array<DataType, 7> supportedTypes =
josh minor4a3c6102020-01-06 16:40:46 -0600836 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000837 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -0600838 DataType::Float32,
839 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100840 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -0600841 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +0000842 DataType::QSymmS16,
843 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -0600844 };
845
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +0000846 std::array<DataType, 1> logicalSupportedTypes =
847 {
848 DataType::Boolean
849 };
850
josh minor4a3c6102020-01-06 16:40:46 -0600851 bool supported = true;
852
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +0000853 if (descriptor.m_Operation == UnaryOperation::LogicalNot)
854 {
855 supported &= CheckSupportRule(TypeAnyOf(input, logicalSupportedTypes), reasonIfUnsupported,
856 "Reference elementwise unary: input type not supported");
josh minor4a3c6102020-01-06 16:40:46 -0600857
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +0000858 supported &= CheckSupportRule(TypeAnyOf(output, logicalSupportedTypes), reasonIfUnsupported,
859 "Reference elementwise unary: output type not supported");
860 }
861 else
862 {
863 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
864 "Reference elementwise unary: input type not supported");
865
866 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
867 "Reference elementwise unary: output type not supported");
868 }
josh minor4a3c6102020-01-06 16:40:46 -0600869
870 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
871 "Reference elementwise unary: input and output types not matching");
872
873 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
874 "Reference elementwise unary: input and output shapes"
875 "have different number of total elements");
876
877 return supported;
878}
879
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000880bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
881 const TensorInfo& input1,
882 const TensorInfo& output,
883 Optional<std::string&> reasonIfUnsupported) const
884{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100885 return IsComparisonSupported(input0,
886 input1,
887 output,
888 ComparisonDescriptor(ComparisonOperation::Equal),
889 reasonIfUnsupported);
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000890}
891
arovir011c7c81b2018-10-08 11:34:28 +0100892bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
893 const FakeQuantizationDescriptor& descriptor,
894 Optional<std::string&> reasonIfUnsupported) const
895{
Jan Eilers8eb25602020-03-09 12:13:48 +0000896 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100897 bool supported = true;
898
899 std::array<DataType,1> supportedTypes =
900 {
901 DataType::Float32
902 };
903
904 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
905 "Reference fake quantization: input type not supported.");
906
907 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100908}
909
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +0100910bool RefLayerSupport::IsFillSupported(const TensorInfo& input,
911 const TensorInfo& output,
912 const FillDescriptor& descriptor,
913 Optional<std::string&> reasonIfUnsupported) const
914{
915 IgnoreUnused(descriptor);
916 IgnoreUnused(output);
917
918 bool supported = true;
919
Sadik Armagana792a052020-06-23 16:22:23 +0100920 std::array<DataType,3> supportedTypes =
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +0100921 {
922 DataType::Float32,
Sadik Armagana792a052020-06-23 16:22:23 +0100923 DataType::Float16,
924 DataType::Signed32
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +0100925 };
926
Teresa Charlin4b10fef2020-07-29 09:36:41 +0100927 supported &= CheckSupportRule(TypeIs(input, DataType::Signed32), reasonIfUnsupported,
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +0100928 "Reference Fill: input type not supported.");
929
Teresa Charlin44088502020-07-27 11:27:19 +0100930 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
931 "Reference Fill: output type not supported.");
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +0100932 return supported;
933}
934
arovir011c7c81b2018-10-08 11:34:28 +0100935bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
936 const TensorInfo& output,
937 Optional<std::string&> reasonIfUnsupported) const
938{
Jan Eilers8eb25602020-03-09 12:13:48 +0000939 IgnoreUnused(output);
James Conroy83735b12019-05-30 16:36:59 +0100940 bool supported = true;
941
Francis Murtaghe8ac1332020-07-30 18:03:40 +0100942 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100943 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000944 DataType::BFloat16,
James Conroyb40d7102019-06-04 12:32:09 +0100945 DataType::Float32,
Francis Murtaghe8ac1332020-07-30 18:03:40 +0100946 DataType::Float16
James Conroy83735b12019-05-30 16:36:59 +0100947 };
948
949 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
950 "Reference Floor: input type not supported.");
951
952 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
953 "Reference Floor: output type not supported.");
954
955 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100956}
957
958bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
959 const TensorInfo& output,
960 const TensorInfo& weights,
961 const TensorInfo& biases,
962 const FullyConnectedDescriptor& descriptor,
963 Optional<std::string&> reasonIfUnsupported) const
964{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100965 bool supported = true;
966
967 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000968 std::array<DataType,6> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +0100969 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000970 DataType::BFloat16,
971 DataType::Float32,
972 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000973 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100974 DataType::QAsymmU8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000975 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +0100976 };
977
978 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
979 "Reference Fully Connected: input type not supported.");
980
981 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
982 "Reference Fully Connected: output type not supported.");
983
Francis Murtagh46c09d02019-05-28 08:15:28 +0100984 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
985 "Reference Fully Connected: weights type not supported.");
986
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +0000987 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
988 if (input.GetDataType() == DataType::BFloat16)
989 {
990 if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
991 {
992 reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
993 supported = false;
994 }
995 }
996 else
997 {
998 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
999 "Reference Fully Connected: input and output types mismatched.");
1000 }
1001
Jan Eilers1f45dc32020-06-15 11:43:03 +01001002 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1003 "Reference Fully Connected: weights is not a supported type.");
Francis Murtaghddb1d062020-03-10 13:51:45 +00001004
Jan Eilers1f45dc32020-06-15 11:43:03 +01001005 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1006 "Reference Fully Connected: input and weights types mismatched.");
Francis Murtagh46c09d02019-05-28 08:15:28 +01001007
1008 if (descriptor.m_BiasEnabled)
1009 {
1010 // Defined supported types for bias
Sadik Armagandb73c982020-04-01 17:35:30 +01001011 std::array<DataType, 5>
Francis Murtagh46c09d02019-05-28 08:15:28 +01001012 supportedBiasTypes =
1013 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001014 DataType::BFloat16,
Francis Murtagh46c09d02019-05-28 08:15:28 +01001015 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001016 DataType::Float16,
Sadik Armagandb73c982020-04-01 17:35:30 +01001017 DataType::Signed32,
1018 DataType::QAsymmS8
Francis Murtagh46c09d02019-05-28 08:15:28 +01001019 };
1020
1021 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
1022 "Reference Fully Connected: bias type not supported.");
1023
1024 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
1025 "Reference Fully Connected: bias and weight types mismatch.");
1026
1027 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
1028 "Reference Fully Connected: bias type inferred from weights is incompatible.");
1029
Narumol Prangnawarat366d7232020-04-29 12:58:17 +01001030 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(biases, 1U), reasonIfUnsupported,
1031 "Reference Fully Connected: bias must have 1 dimension.");
1032
Francis Murtagh46c09d02019-05-28 08:15:28 +01001033 }
1034
1035 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001036}
1037
narpra014951d842019-01-18 16:53:53 +00001038bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
1039 const armnn::TensorInfo& input1,
1040 const armnn::TensorInfo& output,
Teresa Charlin52664732020-06-29 16:27:03 +01001041 const GatherDescriptor& descriptor,
narpra014951d842019-01-18 16:53:53 +00001042 armnn::Optional<std::string&> reasonIfUnsupported) const
1043{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001044 bool supported = true;
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001045 std::array<DataType,7> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001046 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001047 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001048 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001049 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001050 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001051 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001052 DataType::QSymmS16,
1053 DataType::Signed32
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001054 };
1055
Teresa Charlin52664732020-06-29 16:27:03 +01001056 if (descriptor.m_Axis != 0)
1057 {
1058 reasonIfUnsupported.value() += std::string("Reference Gather: axis not supported\n");
1059 supported &= false;
1060 }
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001061 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1062 "Reference Gather: input type not supported");
1063
1064 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1065 "Reference Gather: output type not supported");
1066
1067 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1068 "Reference Gather: indices (input1) type not supported");
1069
1070 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1071 "Reference Gather: input and output types not matching");
1072
1073 return supported;
narpra014951d842019-01-18 16:53:53 +00001074}
1075
FrancisMurtagh878f0232018-12-19 10:56:15 +00001076bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
1077 const TensorInfo& input1,
1078 const TensorInfo& output,
1079 Optional<std::string&> reasonIfUnsupported) const
1080{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01001081 return IsComparisonSupported(input0,
1082 input1,
1083 output,
1084 ComparisonDescriptor(ComparisonOperation::Greater),
1085 reasonIfUnsupported);
FrancisMurtagh878f0232018-12-19 10:56:15 +00001086}
1087
Derek Lamberti901ea112019-12-10 22:07:09 +00001088bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
1089 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +01001090{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001091 return true;
arovir011c7c81b2018-10-08 11:34:28 +01001092}
1093
Kevin May09ca49c2019-10-09 12:37:34 +01001094bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
1095 const TensorInfo& output,
1096 const InstanceNormalizationDescriptor& descriptor,
1097 Optional<std::string&> reasonIfUnsupported) const
1098{
Jan Eilers8eb25602020-03-09 12:13:48 +00001099 IgnoreUnused(descriptor);
Kevin May09ca49c2019-10-09 12:37:34 +01001100 // Define supported types
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001101 std::array<DataType, 3> supportedTypes =
Kevin May09ca49c2019-10-09 12:37:34 +01001102 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001103 DataType::BFloat16,
Kevin May09ca49c2019-10-09 12:37:34 +01001104 DataType::Float32,
1105 DataType::Float16
1106 };
1107
1108 bool supported = true;
1109
1110 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1111 "Reference Instance Normalization: input type not supported.");
1112
1113 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1114 "Reference Instance Normalization: output type not supported.");
1115
1116 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1117 "Reference Instance Normalization: input and output types mismatched.");
1118
1119 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1120 "Reference Instance Normalization: input and output shapes have different "
1121 "num total elements.");
1122
1123 return supported;
1124}
1125
arovir011c7c81b2018-10-08 11:34:28 +01001126bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1127 const TensorInfo& output,
1128 const L2NormalizationDescriptor& descriptor,
1129 Optional<std::string&> reasonIfUnsupported) const
1130{
Jan Eilers8eb25602020-03-09 12:13:48 +00001131 IgnoreUnused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001132 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01001133 std::array<DataType, 6> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001134 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001135 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001136 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001137 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001138 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001139 DataType::QAsymmU8,
1140 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001141 };
1142
1143 bool supported = true;
1144
1145 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1146 "Reference L2normalization: input type not supported.");
1147
1148 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1149 "Reference L2normalization: output type not supported.");
1150
1151 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1152 "Reference L2normalization: input and output types mismatched.");
1153
1154 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1155 "Reference L2normalization: input and output shapes have different "
1156 "num total elements.");
1157
1158 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001159}
1160
James Conroyaba90cd2020-11-06 16:28:18 +00001161bool RefLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0,
1162 const TensorInfo& input1,
1163 const TensorInfo& output,
1164 const LogicalBinaryDescriptor& descriptor,
1165 Optional<std::string&> reasonIfUnsupported) const
1166{
1167 IgnoreUnused(descriptor);
1168
1169 std::array<DataType, 1> supportedTypes =
1170 {
1171 DataType::Boolean
1172 };
1173
1174 bool supported = true;
1175 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1176 "Reference LogicalBinary: input 0 type not supported");
1177 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1178 "Reference LogicalBinary: input 1 type not supported");
1179
1180 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1181 "Reference LogicalBinary: input and output types do not match");
1182
1183 return supported;
1184}
1185
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001186bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1187 const TensorInfo& output,
1188 const LogSoftmaxDescriptor& descriptor,
1189 Optional<std::string&> reasonIfUnsupported) const
1190{
Jan Eilers8eb25602020-03-09 12:13:48 +00001191 IgnoreUnused(descriptor);
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001192
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001193 std::array<DataType, 3> supportedTypes =
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001194 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001195 DataType::BFloat16,
1196 DataType::Float32,
1197 DataType::Float16
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001198 };
1199
1200 bool supported = true;
1201 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1202 "Reference LogSoftmax: input type not supported");
1203
1204 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1205 "Reference LogSoftmax: output type not supported");
1206
1207 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1208 "Reference LogSoftmax: input and output types do not match");
1209
1210 return supported;
1211}
1212
arovir011c7c81b2018-10-08 11:34:28 +01001213bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1214 const TensorInfo& outputStateIn,
1215 const TensorInfo& cellStateIn,
1216 const TensorInfo& scratchBuffer,
1217 const TensorInfo& outputStateOut,
1218 const TensorInfo& cellStateOut,
1219 const TensorInfo& output,
1220 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001221 const LstmInputParamsInfo& paramsInfo,
1222 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +01001223{
Jan Eilers8eb25602020-03-09 12:13:48 +00001224 IgnoreUnused(descriptor);
1225 IgnoreUnused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001226
1227 bool supported = true;
1228
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001229 std::array<DataType,3> supportedTypes = {
1230 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001231 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001232 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001233 };
1234
Jan Eilersd01a83c2019-07-03 18:20:40 +01001235 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001236 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1237 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001238 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1239 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001240 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1241 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001242 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1243 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001244 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1245 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001246 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1247 "Reference Lstm: input and cellStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001248 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1249 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001250 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001251 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001252 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001253 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001254 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001255 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001256 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001257 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001258 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001259 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001260 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001261 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001262 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001263 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001264 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001265 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001266 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001267 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001268 "Reference Lstm: input and OutputGateBias types are mismatched");
1269 if (!descriptor.m_CifgEnabled)
1270 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001271 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001272 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001273 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001274 reasonIfUnsupported,
1275 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001276 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001277 "Reference Lstm: input and InputGateBias types are mismatched");
1278 if (descriptor.m_PeepholeEnabled)
1279 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001280 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001281 reasonIfUnsupported,
1282 "Reference Lstm: input and CellToInputWeights types are mismatched");
1283 }
1284 }
1285 if (descriptor.m_PeepholeEnabled)
1286 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001287 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001288 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001289 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001290 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1291 }
1292 if (descriptor.m_ProjectionEnabled)
1293 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001294 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001295 "Reference Lstm: input and mProjectionWeights types are mismatched");
1296 if (paramsInfo.m_ProjectionBias != nullptr)
1297 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001298 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001299 "Reference Lstm: input and ProjectionBias types are mismatched");
1300 }
1301 }
1302 if (descriptor.m_LayerNormEnabled)
1303 {
1304 if (!descriptor.m_CifgEnabled)
1305 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001306 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001307 reasonIfUnsupported,
1308 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1309 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001310 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001311 reasonIfUnsupported,
1312 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001313 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001314 reasonIfUnsupported,
1315 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001316 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001317 reasonIfUnsupported,
1318 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1319 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001320
1321 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001322}
1323
saoste012df12b32018-11-28 16:57:20 +00001324bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1325 const TensorInfo& input1,
1326 const TensorInfo& output,
1327 Optional<std::string&> reasonIfUnsupported) const
1328{
Sadik Armagan2999a022019-04-09 14:20:12 +01001329 bool supported = true;
1330
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001331 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001332 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001333 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001334 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001335 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001336 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001337 DataType::QSymmS16,
1338 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001339 };
1340
1341 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1342 "Reference maximum: input 0 is not a supported type.");
1343
1344 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1345 "Reference maximum: input 1 is not a supported type.");
1346
1347 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1348 "Reference maximum: output is not a supported type.");
1349
1350 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1351 "Reference maximum: input 0 and Input 1 types are mismatched");
1352
1353 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1354 "Reference maximum: input and output types are mismatched");
1355
1356 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1357 "Reference maximum: shapes are not suitable for implicit broadcast.");
1358
1359 return supported;
saoste012df12b32018-11-28 16:57:20 +00001360}
1361
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001362bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1363 const TensorInfo& output,
1364 const MeanDescriptor& descriptor,
1365 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001366{
James Conroy4d1ff582019-06-10 17:06:39 +01001367 bool supported = true;
1368 std::string meanLayerStr = "Mean";
1369 std::string outputTensorStr = "output";
1370
Sadik Armagan303980c2020-04-17 12:45:14 +01001371 std::array<DataType,6> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001372 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001373 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01001374 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001375 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001376 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001377 DataType::QAsymmU8,
1378 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01001379 };
1380
1381 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1382 "Reference Mean: input type not supported.");
1383
1384 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1385 "Reference Mean: input and output types are mismatched");
1386
1387 if (descriptor.m_KeepDims)
1388 {
1389 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1390 reasonIfUnsupported,
1391 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1392 output.GetNumDimensions(),
1393 meanLayerStr, outputTensorStr).data());
1394 }
1395 else if (descriptor.m_Axis.empty())
1396 {
1397 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1398 reasonIfUnsupported,
1399 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1400 meanLayerStr, outputTensorStr).data());
1401 }
1402 else
1403 {
Matthew Sloyan171214c2020-09-09 09:07:37 +01001404 auto outputDim = input.GetNumDimensions() - armnn::numeric_cast<unsigned int>(descriptor.m_Axis.size());
James Conroy4d1ff582019-06-10 17:06:39 +01001405
1406 if (outputDim > 0)
1407 {
1408 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1409 reasonIfUnsupported,
1410 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1411 meanLayerStr, outputTensorStr).data());
1412 }
1413 else
1414 {
1415 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1416 reasonIfUnsupported,
1417 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1418 meanLayerStr, outputTensorStr).data());
1419 }
1420 }
1421
1422 return supported;
narpra0132b90462018-09-13 11:07:48 +01001423}
1424
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001425bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +00001426 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +01001427 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001428 Optional<std::string&> reasonIfUnsupported) const
1429{
Jim Flynne242f2d2019-05-22 14:24:13 +01001430 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001431}
1432
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001433bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1434 const TensorInfo &output,
1435 Optional<std::string &> reasonIfUnsupported) const
1436{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001437 bool supported = true;
1438
Sadik Armagan303980c2020-04-17 12:45:14 +01001439 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001440 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001441 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001442 DataType::Float32,
1443 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001444 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001445 DataType::QAsymmU8,
1446 DataType::QSymmS16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001447 DataType::Boolean
1448 };
1449
1450 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1451 "Reference MemCopy: input type not supported");
1452
1453 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1454 "Reference MemCopy: output type not supported");
1455
1456 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1457 "Reference MemCopy: input and output types are mismatched");
1458
1459 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001460}
1461
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001462bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1463 const TensorInfo& input1,
1464 const TensorInfo& output,
1465 Optional<std::string&> reasonIfUnsupported) const
1466{
Sadik Armagan2999a022019-04-09 14:20:12 +01001467 bool supported = true;
1468
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001469 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001470 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001471 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001472 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001473 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001474 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001475 DataType::QSymmS16,
1476 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001477 };
1478
1479 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1480 "Reference minimum: input 0 is not a supported type.");
1481
1482 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1483 "Reference minimum: input 1 is not a supported type.");
1484
1485 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1486 "Reference minimum: output is not a supported type.");
1487
1488 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1489 "Reference minimum: input 0 and Input 1 types are mismatched");
1490
1491 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1492 "Reference minimum: input and output types are mismatched");
1493
1494 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1495 "Reference minimum: shapes are not suitable for implicit broadcast.");
1496
1497 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001498}
1499
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001500bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1501 const TensorInfo& input1,
1502 const TensorInfo& output,
1503 Optional<std::string&> reasonIfUnsupported) const
1504{
Sadik Armagan2999a022019-04-09 14:20:12 +01001505 bool supported = true;
1506
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001507 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001508 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001509 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001510 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001511 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001512 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001513 DataType::QSymmS16,
1514 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001515 };
1516
1517 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1518 "Reference multiplication: input 0 is not a supported type.");
1519
1520 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1521 "Reference multiplication: input 1 is not a supported type.");
1522
1523 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1524 "Reference multiplication: output is not a supported type.");
1525
1526 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1527 "Reference multiplication: input 0 and Input 1 types are mismatched");
1528
1529 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1530 "Reference multiplication: input and output types are mismatched");
1531
1532 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1533 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1534
1535 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001536}
1537
1538bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1539 const TensorInfo& output,
1540 const NormalizationDescriptor& descriptor,
1541 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001542{
Jan Eilers8eb25602020-03-09 12:13:48 +00001543 IgnoreUnused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001544
1545 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01001546 std::array<DataType, 6> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001547 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001548 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001549 DataType::Float16,
1550 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001551 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001552 DataType::QAsymmU8,
1553 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001554 };
1555
1556 bool supported = true;
1557
1558 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1559 "Reference normalization: input type not supported.");
1560
1561 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1562 "Reference normalization: output type not supported.");
1563
1564 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1565 "Reference normalization: input and output shapes have different "
1566 "num total elements.");
1567
1568 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001569}
1570
Derek Lamberti901ea112019-12-10 22:07:09 +00001571bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
1572 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001573{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001574 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001575}
1576
1577bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1578 const TensorInfo& output,
1579 const PadDescriptor& descriptor,
1580 Optional<std::string&> reasonIfUnsupported) const
1581{
Jan Eilers8eb25602020-03-09 12:13:48 +00001582 IgnoreUnused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001583 bool supported = true;
1584
1585 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01001586 std::array<DataType,6> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001587 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001588 DataType::BFloat16,
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001589 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001590 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001591 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001592 DataType::QAsymmU8,
1593 DataType::QSymmS16
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001594 };
1595
1596 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1597 "Reference pad: input is not a supported type.");
1598
1599 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1600 "Reference pad: output is not a supported type.");
1601
1602 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1603 "Reference pad: input and output types are mismatched.");
1604
1605 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01001606}
1607
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001608bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1609 const TensorInfo& output,
1610 const PermuteDescriptor& descriptor,
1611 Optional<std::string&> reasonIfUnsupported) const
1612{
Jan Eilers8eb25602020-03-09 12:13:48 +00001613 IgnoreUnused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001614 bool supported = true;
1615
1616 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01001617 std::array<DataType, 6> supportedTypes =
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001618 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001619 DataType::BFloat16,
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001620 DataType::Float32,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001621 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001622 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001623 DataType::QAsymmU8,
1624 DataType::QSymmS16
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001625 };
1626
1627 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1628 "Reference permute: input is not a supported type.");
1629
1630 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1631 "Reference permute: output is not a supported type.");
1632
1633 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1634 "Reference permute: input and output types are mismatched.");
1635
1636 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001637}
1638
1639bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1640 const TensorInfo& output,
1641 const Pooling2dDescriptor& descriptor,
1642 Optional<std::string&> reasonIfUnsupported) const
1643{
Jan Eilers8eb25602020-03-09 12:13:48 +00001644 IgnoreUnused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001645 bool supported = true;
1646
1647 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001648 std::array<DataType,6> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001649 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001650 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01001651 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001652 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001653 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001654 DataType::QAsymmU8,
1655 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001656 };
1657
1658 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1659 "Reference poolind2d: input is not a supported type.");
1660
1661 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1662 "Reference poolind2d: output is not a supported type.");
1663
1664 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1665 "Reference poolind2d: input and output types are mismatched.");
1666
1667 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001668}
1669
James Conroy4f1f8992020-04-29 20:01:10 +01001670bool RefLayerSupport::IsQLstmSupported(const TensorInfo& input,
1671 const TensorInfo& previousOutputIn,
1672 const TensorInfo& previousCellStateIn,
1673 const TensorInfo& outputStateOut,
1674 const TensorInfo& cellStateOut,
1675 const TensorInfo& output,
1676 const QLstmDescriptor& descriptor,
1677 const LstmInputParamsInfo& paramsInfo,
1678 Optional<std::string&> reasonIfUnsupported) const
1679{
1680 IgnoreUnused(input);
1681 IgnoreUnused(previousOutputIn);
1682 IgnoreUnused(previousCellStateIn);
1683 IgnoreUnused(outputStateOut);
1684 IgnoreUnused(cellStateOut);
1685 IgnoreUnused(output);
1686 IgnoreUnused(descriptor);
1687 IgnoreUnused(paramsInfo);
1688
1689 IgnoreUnused(reasonIfUnsupported);
1690
1691 return true;
1692}
1693
Derek Lamberti5f400d62019-03-25 15:41:58 +00001694bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1695 const TensorInfo& output,
1696 Optional<std::string&> reasonIfUnsupported) const
1697{
1698 bool supported = true;
1699
Finn Williamsfd271062019-12-04 14:27:27 +00001700 // Define supported input types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001701 std::array<DataType,7> supportedInputTypes = {
1702 DataType::BFloat16,
Keith Davis5e51cd82020-01-29 16:52:59 +00001703 DataType::Float32,
Keith Davis3d8bc972020-02-04 09:31:47 +00001704 DataType::Float16,
Ryan OShea9add1202020-02-07 10:06:33 +00001705 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00001706 DataType::QAsymmU8,
1707 DataType::QSymmS8,
1708 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00001709 };
1710
1711 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1712 "Reference quantize: input type not supported.");
1713
1714 // Define supported output types.
Ryan OShea9add1202020-02-07 10:06:33 +00001715 std::array<DataType,4> supportedOutputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00001716 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001717 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00001718 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001719 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00001720 };
1721 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1722 "Reference quantize: output type not supported.");
1723
1724 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1725 "Reference quantize: input and output shapes have different num total elements.");
1726
1727 return supported;
1728}
1729
Finn Williams2605b232020-06-10 15:53:46 +01001730bool RefLayerSupport::IsRankSupported(const TensorInfo& input,
1731 const TensorInfo& output,
1732 Optional<std::string&> reasonIfUnsupported) const
1733{
1734 IgnoreUnused(input);
1735 // Define supported output types.
1736 std::array<DataType,1> supportedOutputTypes =
1737 {
1738 DataType::Signed32,
1739 };
1740
1741 return CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1742 "Reference rank: input type not supported.");
1743}
1744
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001745bool RefLayerSupport::IsReduceSupported(const TensorInfo& input,
1746 const TensorInfo& output,
1747 const ReduceDescriptor& descriptor,
1748 Optional<std::string&> reasonIfUnsupported) const
1749{
1750 IgnoreUnused(descriptor);
1751 bool supported = true;
1752 std::array<DataType,7> supportedTypes =
1753 {
1754 DataType::BFloat16,
1755 DataType::Float32,
1756 DataType::Float16,
1757 DataType::QAsymmS8,
1758 DataType::QAsymmU8,
1759 DataType::QSymmS16,
1760 DataType::Signed32
1761 };
1762
1763 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1764 "Reference Reduce: input type not supported");
1765
1766 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1767 "Reference Reduce: output type not supported");
1768
1769 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1770 "Reference Reduce: input and output types not matching");
1771
1772 return supported;
1773}
1774
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001775bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Kevin Maya023c402019-12-12 17:28:05 +00001776 const TensorInfo& output,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001777 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001778 Optional<std::string&> reasonIfUnsupported) const
1779{
Jan Eilers8eb25602020-03-09 12:13:48 +00001780 IgnoreUnused(output);
1781 IgnoreUnused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001782 // Define supported output types.
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001783 std::array<DataType,8> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001784 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001785 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001786 DataType::Float32,
1787 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001788 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001789 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001790 DataType::QAsymmU8,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001791 DataType::QSymmS16,
1792 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01001793 };
Keith Davis0c2eeac2020-02-11 16:51:50 +00001794
Nina Drozd2f2778f2019-05-27 10:37:05 +01001795 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1796 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001797}
1798
1799bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001800 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001801 Optional<std::string&> reasonIfUnsupported) const
1802{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001803 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01001804 std::array<DataType,6> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001805 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001806 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001807 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001808 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001809 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001810 DataType::QAsymmU8,
1811 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001812 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001813
1814 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1815 "Reference ResizeBilinear: input type not supported");
1816
1817 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1818 "Reference ResizeBilinear: output type not supported");
1819
1820 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1821 "Reference ResizeBilinear: input and output types not matching");
1822
1823 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001824}
1825
Teresa Charlin970f43b2019-07-01 13:51:07 +01001826bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1827 const TensorInfo& output,
1828 const ResizeDescriptor& descriptor,
1829 Optional<std::string&> reasonIfUnsupported) const
1830{
Jan Eilers8eb25602020-03-09 12:13:48 +00001831 IgnoreUnused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01001832 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001833 std::array<DataType,6> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001834 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001835 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001836 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001837 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001838 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001839 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001840 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001841 };
1842
1843 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1844 "Reference Resize: input type not supported");
1845
1846 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1847 "Reference Resize: output type not supported");
1848
1849 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1850 "Reference Resize: input and output types not matching");
1851
1852 return supported;
1853}
1854
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001855bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1856 const TensorInfo& output,
1857 Optional<std::string&> reasonIfUnsupported) const
1858{
josh minor4a3c6102020-01-06 16:40:46 -06001859 return IsElementwiseUnarySupported(input,
1860 output,
1861 ElementwiseUnaryDescriptor(UnaryOperation::Rsqrt),
1862 reasonIfUnsupported);
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001863}
1864
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001865bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
1866 const TensorInfo& output,
1867 const SliceDescriptor& descriptor,
1868 Optional<std::string&> reasonIfUnsupported) const
1869{
Jan Eilers8eb25602020-03-09 12:13:48 +00001870 IgnoreUnused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001871 bool supported = true;
1872
Sadik Armagan303980c2020-04-17 12:45:14 +01001873 std::array<DataType, 5> supportedTypes =
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001874 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001875 DataType::BFloat16,
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001876 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001877 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001878 DataType::QAsymmU8,
1879 DataType::QSymmS16
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001880 };
1881
1882 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1883 "Reference Slice: input type not supported");
1884
1885 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1886 "Reference Slice: output type not supported");
1887
1888 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1889 "Reference Slice: input and output types are mismatched");
1890
1891 return supported;
1892}
1893
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001894bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1895 const TensorInfo& output,
1896 const SoftmaxDescriptor& descriptor,
1897 Optional<std::string&> reasonIfUnsupported) const
1898{
Jan Eilers8eb25602020-03-09 12:13:48 +00001899 IgnoreUnused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01001900 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001901 std::array<DataType,7> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01001902 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001903 DataType::BFloat16,
1904 DataType::Float32,
1905 DataType::Float16,
1906 DataType::QSymmS8,
1907 DataType::QAsymmS8,
1908 DataType::QAsymmU8,
1909 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +01001910 };
1911
1912 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001913 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001914
1915 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001916 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001917
1918 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001919 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001920
1921 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001922}
1923
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001924bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1925 const TensorInfo& output,
1926 const SpaceToBatchNdDescriptor& descriptor,
1927 Optional<std::string&> reasonIfUnsupported) const
1928{
Jan Eilers8eb25602020-03-09 12:13:48 +00001929 IgnoreUnused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01001930 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01001931 std::array<DataType,6> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01001932 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001933 DataType::BFloat16,
1934 DataType::Float32,
1935 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001936 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001937 DataType::QAsymmU8,
1938 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001939 };
1940
1941 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1942 "Reference SpaceToBatchNd: input type not supported");
1943
1944 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1945 "Reference SpaceToBatchNd: output type not supported");
1946
1947 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1948 "Reference SpaceToBatchNd: input and output types are mismatched");
1949
1950 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001951}
1952
Keith Davisa57eccb2019-06-14 17:33:22 +01001953bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01001954 const TensorInfo& output,
1955 const SpaceToDepthDescriptor& descriptor,
1956 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01001957{
1958
Jan Eilers8eb25602020-03-09 12:13:48 +00001959 IgnoreUnused(descriptor);
Keith Davisa57eccb2019-06-14 17:33:22 +01001960 bool supported = true;
1961
Sadik Armagan303980c2020-04-17 12:45:14 +01001962 std::array<DataType,6> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01001963 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001964 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001965 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001966 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001967 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001968 DataType::QAsymmU8,
1969 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001970 };
1971
1972 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1973 "Reference SpaceToDepth: input type not supported");
1974
1975 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1976 "Reference SpaceToDepth: output type not supported");
1977
1978 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1979 "Reference SpaceToDepth: input and output types are mismatched");
1980
1981 return supported;
1982}
1983
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001984bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1985 const ViewsDescriptor& descriptor,
1986 Optional<std::string&> reasonIfUnsupported) const
1987{
Jan Eilers8eb25602020-03-09 12:13:48 +00001988 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001989 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01001990 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001991 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001992 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001993 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001994 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001995 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001996 DataType::QAsymmU8,
1997 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001998 };
1999
2000 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2001 "Reference splitter: input type not supported");
2002
2003 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002004}
2005
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002006bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
2007 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
2008 const ViewsDescriptor& descriptor,
2009 Optional<std::string&> reasonIfUnsupported) const
2010{
Jan Eilers8eb25602020-03-09 12:13:48 +00002011 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002012 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002013 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002014 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002015 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002016 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002017 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002018 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002019 DataType::QAsymmU8,
2020 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002021 };
2022
2023 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2024 "Reference splitter: output type not supported");
Derek Lambertieac4adb2020-08-25 13:05:59 +01002025 for (const TensorInfo& output : outputs)
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002026 {
2027 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2028 "Reference splitter: input type not supported");
2029
2030 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2031 "Reference splitter: input and output types mismatched.");
2032 }
2033
2034 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002035}
2036
Matthew Jackson81e601c2019-07-11 12:07:09 +01002037bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
2038 const TensorInfo& output,
2039 const StackDescriptor& descriptor,
2040 Optional<std::string&> reasonIfUnsupported) const
2041{
Jan Eilers8eb25602020-03-09 12:13:48 +00002042 IgnoreUnused(descriptor);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002043
2044 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002045 std::array<DataType,6> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01002046 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002047 DataType::BFloat16,
Matthew Jackson81e601c2019-07-11 12:07:09 +01002048 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01002049 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002050 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002051 DataType::QAsymmU8,
2052 DataType::QSymmS16
Matthew Jackson81e601c2019-07-11 12:07:09 +01002053 };
2054
2055 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2056 "Reference stack: output type not supported");
2057 for (const TensorInfo* input : inputs)
2058 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01002059 ARMNN_ASSERT(input != nullptr);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002060 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
2061 "Reference stack: input type not supported");
2062
2063 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
2064 "Reference stack: input and output types mismatched.");
2065 }
2066
2067 return supported;
2068}
2069
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002070bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
2071 const TensorInfo& output,
2072 const StridedSliceDescriptor& descriptor,
2073 Optional<std::string&> reasonIfUnsupported) const
2074{
Jan Eilers8eb25602020-03-09 12:13:48 +00002075 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002076 bool supported = true;
2077
Sadik Armagan303980c2020-04-17 12:45:14 +01002078 std::array<DataType,5> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002079 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002080 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002081 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002082 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002083 DataType::QAsymmU8,
2084 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002085 };
2086
2087 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2088 "Reference StridedSlice: input type not supported");
2089
2090 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2091 "Reference StridedSlice: output type not supported");
2092
2093 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2094 "Reference StridedSlice: input and output types are mismatched");
2095
2096 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002097}
2098
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002099bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
2100 const TensorInfo& input1,
2101 const TensorInfo& output,
2102 Optional<std::string&> reasonIfUnsupported) const
2103{
Sadik Armagan2999a022019-04-09 14:20:12 +01002104 bool supported = true;
2105
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002106 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002107 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01002108 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002109 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002110 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002111 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002112 DataType::QSymmS16,
2113 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002114 };
2115
2116 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2117 "Reference subtraction: input 0 is not a supported type.");
2118
2119 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2120 "Reference subtraction: input 1 is not a supported type.");
2121
2122 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2123 "Reference subtraction: output is not a supported type.");
2124
2125 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2126 "Reference subtraction: input 0 and Input 1 types are mismatched");
2127
2128 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2129 "Reference subtraction: input and output types are mismatched");
2130
2131 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2132 "Reference subtraction: shapes are not suitable for implicit broadcast.");
2133
2134 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002135}
2136
Matteo Martincighab9e5252019-06-13 17:27:46 +01002137bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
2138 const TensorInfo& alpha,
2139 const TensorInfo& output,
2140 Optional<std::string&> reasonIfUnsupported) const
2141{
2142 bool supported = true;
2143
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002144 std::array<DataType, 6> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01002145 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002146 DataType::BFloat16,
Matteo Martincighab9e5252019-06-13 17:27:46 +01002147 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002148 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002149 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002150 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002151 DataType::QSymmS16
Matteo Martincighab9e5252019-06-13 17:27:46 +01002152 };
2153
2154 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2155 "PReLU: input is not a supported type.");
2156
2157 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
2158 "PReLU: alpha is not a supported type.");
2159
2160 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2161 "PReLU: output is not a supported type.");
2162
2163 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
2164 "PReLU: input, alpha and output types are mismatched");
2165
2166 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
2167 "PReLU: shapes are not suitable for implicit broadcast");
2168
2169 return supported;
2170}
2171
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002172bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
2173 const TensorInfo& output,
2174 const TransposeConvolution2dDescriptor& descriptor,
2175 const TensorInfo& weights,
2176 const Optional<TensorInfo>& biases,
2177 Optional<std::string&> reasonIfUnsupported) const
2178{
Jan Eilers8eb25602020-03-09 12:13:48 +00002179 IgnoreUnused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002180 bool supported = true;
2181
Sadik Armagan303980c2020-04-17 12:45:14 +01002182 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002183 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002184 DataType::BFloat16,
2185 DataType::Float32,
2186 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002187 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002188 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002189 DataType::QSymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002190 DataType::QSymmS16
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002191 };
2192
2193 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2194 "Reference TransposeConvolution2d: input is not a supported type.");
2195
2196 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2197 "Reference TransposeConvolution2d: output is not a supported type.");
2198
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002199 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2200 "Reference TransposeConvolution2d: input and output types mismatched.");
2201
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002202
2203 const DataType inputType = input.GetDataType();
Sadik Armagan303980c2020-04-17 12:45:14 +01002204 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002205 {
Derek Lambertid466a542020-01-22 15:37:29 +00002206 ARMNN_NO_DEPRECATE_WARN_BEGIN
Sadik Armagan303980c2020-04-17 12:45:14 +01002207 std::array<DataType, 4> supportedWeightTypes =
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002208 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002209 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002210 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +00002211 DataType::QSymmS8,
2212 DataType::QuantizedSymm8PerAxis //Deprecated
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002213 };
Derek Lambertid466a542020-01-22 15:37:29 +00002214 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002215
2216 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
2217 "Reference TransposeConvolution2d: weights type not supported for "
2218 "quantized input.");
2219 }
2220 else
2221 {
2222 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
2223 "Reference TransposeConvolution2d: weights is not a supported type.");
2224
2225 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
2226 "Reference TransposeConvolution2d: input and weights types mismatched.");
2227 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002228
2229 if (biases.has_value())
2230 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002231 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01002232 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002233 DataType::BFloat16,
2234 DataType::Float32,
2235 DataType::Float16,
2236 DataType::Signed32
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002237 };
2238 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
2239 "Reference TransposeConvolution2d: biases is not a supported type.");
2240 }
2241
2242 return supported;
2243}
2244
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002245bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
2246 const TensorInfo& output,
2247 const TransposeDescriptor& descriptor,
2248 Optional<std::string&> reasonIfUnsupported) const
2249{
Jan Eilers8eb25602020-03-09 12:13:48 +00002250 IgnoreUnused(descriptor);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002251 bool supported = true;
2252
2253 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002254 std::array<DataType, 6> supportedTypes =
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002255 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002256 DataType::BFloat16,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002257 DataType::Float32,
2258 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002259 DataType::QAsymmS8,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002260 DataType::QAsymmU8,
2261 DataType::QSymmS16
2262 };
2263
2264 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2265 "Reference transpose: input is not a supported type.");
2266
2267 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2268 "Reference transpose: output is not a supported type.");
2269
2270 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2271 "Reference transpose: input and output types are mismatched.");
2272
2273 return supported;
2274}
2275
arovir011c7c81b2018-10-08 11:34:28 +01002276} // namespace armnn