blob: 14a40f9d5d7179c07582b0a6bb93f5d624b79d15 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
Teresa Charlin52664732020-06-29 16:27:03 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01007
Keith Davis0c2eeac2020-02-11 16:51:50 +00008#include <armnn/TypesUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000010#include <armnn/Descriptors.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000011#include <armnn/utility/IgnoreUnused.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010012#include <armnn/utility/NumericCast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
Matteo Martincighe011d202019-11-28 11:35:47 +000014#include <LayerSupportCommon.hpp>
Derek Lambertif674aa02019-08-01 15:56:25 +010015#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000016
Derek Lamberti50db4e82019-03-13 14:16:15 +000017#include <vector>
Derek Lamberti50db4e82019-03-13 14:16:15 +000018#include <array>
19
telsoa014fcda012018-03-09 14:13:49 +000020namespace armnn
21{
22
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010023namespace
24{
25
26template<typename Float32Func, typename Uint8Func, typename ... Params>
27bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
28 DataType dataType,
29 Float32Func floatFuncPtr,
30 Uint8Func uint8FuncPtr,
31 Params&&... params)
32{
33 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
34 dataType,
35 &FalseFunc<Params...>,
36 floatFuncPtr,
37 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000038 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000039 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010040 std::forward<Params>(params)...);
41}
42
43} // anonymous namespace
44
James Conroy4d1ff582019-06-10 17:06:39 +010045namespace
46{
47
48std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
49 unsigned int actual,
50 std::string& layerStr,
51 std::string& tensorName)
52{
53 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
54 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
55
56 return errorMsg;
57}
58
59} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000060
Sadik Armagan9199e582019-09-05 17:35:31 +010061bool RefLayerSupport::IsAbsSupported(const TensorInfo& input, const TensorInfo& output,
62 Optional<std::string&> reasonIfUnsupported) const
63{
josh minor4a3c6102020-01-06 16:40:46 -060064 return IsElementwiseUnarySupported(input,
65 output,
66 ElementwiseUnaryDescriptor(UnaryOperation::Abs),
67 reasonIfUnsupported);
Sadik Armagan9199e582019-09-05 17:35:31 +010068}
69
arovir011c7c81b2018-10-08 11:34:28 +010070bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
71 const TensorInfo& output,
72 const ActivationDescriptor& descriptor,
73 Optional<std::string&> reasonIfUnsupported) const
74{
Derek Lamberti50db4e82019-03-13 14:16:15 +000075 bool supported = true;
76
77 // Define supported types.
Keith Davis0c2eeac2020-02-11 16:51:50 +000078 std::array<DataType,6> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +000079 DataType::BFloat16,
Derek Lamberti50db4e82019-03-13 14:16:15 +000080 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +010081 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +000082 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +000083 DataType::QAsymmU8,
84 DataType::QSymmS16
Derek Lamberti50db4e82019-03-13 14:16:15 +000085 };
86
87 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
88 "Reference activation: input type not supported.");
89
90 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
91 "Reference activation: output type not supported.");
92
93 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
94 "Reference activation: input and output types mismatched.");
95
96 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
97 "Reference activation: input and output shapes are of different rank.");
98
99
100 struct ActivationFunctionSupported : public Rule
101 {
102 ActivationFunctionSupported(const ActivationDescriptor& desc)
103 {
104 switch(desc.m_Function)
105 {
106 case ActivationFunction::Abs:
107 case ActivationFunction::BoundedReLu:
David Monahan3b3c3812020-02-25 09:03:29 +0000108 case ActivationFunction::Elu:
Colm Donelan03fbeaf2020-02-26 15:39:23 +0000109 case ActivationFunction::HardSwish:
Derek Lamberti50db4e82019-03-13 14:16:15 +0000110 case ActivationFunction::LeakyReLu:
111 case ActivationFunction::Linear:
112 case ActivationFunction::ReLu:
113 case ActivationFunction::Sigmoid:
114 case ActivationFunction::SoftReLu:
115 case ActivationFunction::Sqrt:
116 case ActivationFunction::Square:
117 case ActivationFunction::TanH:
118 {
119 m_Res = true;
120 break;
121 }
122 default:
123 {
124 m_Res = false;
125 break;
126 }
127 }
128 }
129 };
130
131 // Function is supported
132 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
133 "Reference activation: function not supported.");
134
135 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100136}
137
138bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
139 const TensorInfo& input1,
140 const TensorInfo& output,
141 Optional<std::string&> reasonIfUnsupported) const
142{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000143 bool supported = true;
144
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100145 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000146 DataType::BFloat16,
Derek Lamberti50db4e82019-03-13 14:16:15 +0000147 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100148 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000149 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000150 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100151 DataType::QSymmS16,
152 DataType::Signed32
Derek Lamberti50db4e82019-03-13 14:16:15 +0000153 };
154
155 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
156 "Reference addition: input 0 is not a supported type.");
157
158 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
159 "Reference addition: input 1 is not a supported type.");
160
161 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
162 "Reference addition: output is not a supported type.");
163
164 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
165 "Reference addition: input 0 and Input 1 types are mismatched");
166
167 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
168 "Reference addition: input and output types are mismatched");
169
170 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
171 "Reference addition: shapes are not suitable for implicit broadcast.");
172
173 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100174}
175
Nikhil Raj68c2c902019-09-19 11:21:11 +0100176bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
177 const armnn::ArgMinMaxDescriptor &descriptor,
178 armnn::Optional<std::string &> reasonIfUnsupported) const
179{
Jan Eilers8eb25602020-03-09 12:13:48 +0000180 IgnoreUnused(descriptor);
Nikhil Raj68c2c902019-09-19 11:21:11 +0100181
Mike Kelly1f140f72021-04-06 12:25:55 +0100182 std::array<DataType, 8> supportedInputTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100183 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000184 DataType::BFloat16,
Teresa Charline300b362020-05-25 10:01:03 +0100185 DataType::Float16,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100186 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +0100187 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000188 DataType::QAsymmU8,
189 DataType::QSymmS16,
Mike Kelly1f140f72021-04-06 12:25:55 +0100190 DataType::Signed32,
191 DataType::Signed64
192 };
193
194 std::array<DataType,2> supportedOutputTypes = {
195 DataType::Signed32,
196 DataType::Signed64
Nikhil Raj68c2c902019-09-19 11:21:11 +0100197 };
198
199 bool supported = true;
200
Mike Kelly1f140f72021-04-06 12:25:55 +0100201 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100202 "Reference ArgMinMax: input is not a supported type.");
Mike Kelly1f140f72021-04-06 12:25:55 +0100203 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Nikhil Raj68c2c902019-09-19 11:21:11 +0100204 "Reference ArgMinMax: output type not supported");
205
206 return supported;
207}
208
arovir011c7c81b2018-10-08 11:34:28 +0100209bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
210 const TensorInfo& output,
211 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100212 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100213 const TensorInfo& beta,
214 const TensorInfo& gamma,
215 const BatchNormalizationDescriptor& descriptor,
216 Optional<std::string&> reasonIfUnsupported) const
217{
Jan Eilers8eb25602020-03-09 12:13:48 +0000218 IgnoreUnused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100219
Sadik Armagan303980c2020-04-17 12:45:14 +0100220 std::array<DataType, 6> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100221 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000222 DataType::BFloat16,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100223 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100224 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100225 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000226 DataType::QAsymmU8,
227 DataType::QSymmS16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100228 };
229
230 bool supported = true;
231
232 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
233 "Reference batch normalization: input is not a supported type.");
234
235 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
236 "Reference batch normalization: output is not a supported type.");
237
238 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
239 "Reference batch normalization: input and output types are mismatched");
240
241 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
242 "Reference batch normalization: mean is not a supported type.");
243
244 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
245 "Reference batch normalization: variance is not a supported type.");
246
247 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
248 "Reference batch normalization: beta is not a supported type.");
249
250 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
251 "Reference batch normalization: gamma is not a supported type.");
252
253 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100254}
255
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000256bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
257 const TensorInfo& output,
258 const BatchToSpaceNdDescriptor& descriptor,
259 Optional<std::string&> reasonIfUnsupported) const
260{
Jan Eilers8eb25602020-03-09 12:13:48 +0000261 IgnoreUnused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100262
263 bool supported = true;
264
265 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
266 std::string inputTensorStr = "input";
267 std::string outputTensorStr = "output";
268
269 // Define supported types.
Sadik Armagan303980c2020-04-17 12:45:14 +0100270 std::array<DataType,6> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100271 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000272 DataType::BFloat16,
273 DataType::Float32,
274 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100275 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000276 DataType::QAsymmU8,
277 DataType::QSymmS16
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100278 };
279
280 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
281 "Reference BatchToSpaceNd: input type not supported.");
282
283 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
284 "Reference BatchToSpaceNd: output type not supported.");
285
286 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
287 "Reference BatchToSpaceNd: input and output types mismatched.");
288
289 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
290 reasonIfUnsupported,
291 CreateIncorrectDimensionsErrorMsg(4,
292 output.GetNumDimensions(),
293 batchToSpaceNdLayerStr,
294 outputTensorStr).data());
295
296 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
297 reasonIfUnsupported,
298 CreateIncorrectDimensionsErrorMsg(4,
299 input.GetNumDimensions(),
300 batchToSpaceNdLayerStr,
301 inputTensorStr).data());
302
303 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000304}
305
mathad01b392e982021-04-07 12:07:30 +0100306bool RefLayerSupport::IsCastSupported(const TensorInfo& input,
307 const TensorInfo& output,
308 Optional<std::string&> reasonIfUnsupported) const
309{
310 std::array<DataType, 9> supportedInputTypes =
311 {
312 DataType::BFloat16,
313 DataType::Float32,
314 DataType::Float16,
315 DataType::QSymmS8,
316 DataType::QAsymmS8,
317 DataType::QAsymmU8,
318 DataType::QSymmS16,
319 DataType::Signed32
320 };
321
322 bool supported = true;
323 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
324 "Reference cast: input is not a supported type");
325
326
327 supported &= CheckSupportRule(TypeAnyOf(output, supportedInputTypes), reasonIfUnsupported,
328 "Reference cast: output is not a supported type");
329
330 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
331 "Reference cast: input and output shapes have different number of total elements");
332
333 return supported;
334}
335
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100336bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
337 const TensorInfo& input1,
338 const TensorInfo& output,
339 const ComparisonDescriptor& descriptor,
340 Optional<std::string&> reasonIfUnsupported) const
341{
Jan Eilers8eb25602020-03-09 12:13:48 +0000342 IgnoreUnused(descriptor);
Sadik Armagan303980c2020-04-17 12:45:14 +0100343 std::array<DataType, 8> supportedInputTypes =
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100344 {
Sadik Armaganb60dd242020-03-19 13:53:16 +0000345 DataType::Boolean,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000346 DataType::BFloat16,
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100347 DataType::Float32,
348 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100349 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000350 DataType::QAsymmU8,
Sadik Armaganb60dd242020-03-19 13:53:16 +0000351 DataType::QSymmS16,
352 DataType::Signed32
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100353 };
354
355 bool supported = true;
356 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
357 "Reference comparison: input 0 is not a supported type");
358
359 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
360 "Reference comparison: input 0 and Input 1 types are mismatched");
361
362 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
363 "Reference comparison: output is not of type Boolean");
364
365 return supported;
366}
367
Jim Flynn906f9462019-05-10 13:55:21 +0100368bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
369 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100370 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100371 Optional<std::string&> reasonIfUnsupported) const
372{
Jan Eilers8eb25602020-03-09 12:13:48 +0000373 IgnoreUnused(descriptor);
Jim Flynne242f2d2019-05-22 14:24:13 +0100374
375 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000376 std::array<DataType,6> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100377 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000378 DataType::BFloat16,
379 DataType::Float32,
380 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000381 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100382 DataType::QAsymmU8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000383 DataType::QSymmS16
Jim Flynne242f2d2019-05-22 14:24:13 +0100384 };
385
386 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
387 "Reference concatenation: output type not supported");
388 for (const TensorInfo* input : inputs)
389 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100390 ARMNN_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100391 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
392 "Reference concatenation: input type not supported");
393
394 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
395 "Reference concatenation: input and output types mismatched.");
396 }
397
398 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100399}
400
arovir011c7c81b2018-10-08 11:34:28 +0100401bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
402 Optional<std::string&> reasonIfUnsupported) const
403{
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100404 std::array<DataType,8> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100405 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000406 DataType::BFloat16,
Teresa Charlin6fa8ce62020-05-25 16:16:44 +0100407 DataType::Float16,
Nina Drozd58ef2c62019-05-16 12:09:18 +0100408 DataType::Float32,
Keith Davis67e6c542020-02-19 10:08:33 +0000409 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100410 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000411 DataType::QSymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100412 DataType::QSymmS16,
413 DataType::Signed32
Nina Drozd58ef2c62019-05-16 12:09:18 +0100414 };
415
416 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
417 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100418}
419
Narumol Prangnawarat7ddbbae2020-03-13 10:26:05 +0000420bool RefLayerSupport::IsConvertBf16ToFp32Supported(const TensorInfo& input,
421 const TensorInfo& output,
422 Optional<std::string&> reasonIfUnsupported) const
423{
424 bool supported = true;
425
426 supported &= CheckSupportRule(TypeIs(input, DataType::BFloat16), reasonIfUnsupported,
427 "Reference for ConvertBf16ToFp32 layer: input type not supported");
428
429 supported &= CheckSupportRule(TypeIs(output, DataType::Float32), reasonIfUnsupported,
430 "Reference for ConvertBf16ToFp32 layer: output type not supported");
431
432 return supported;
433}
434
arovir011c7c81b2018-10-08 11:34:28 +0100435bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
436 const TensorInfo& output,
437 Optional<std::string&> reasonIfUnsupported) const
438{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100439 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
440 input.GetDataType(),
441 &TrueFunc<>,
442 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000443 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000444 &FalseFuncI32<>,
445 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100446 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
447 output.GetDataType(),
448 &FalseOutputFuncF16<>,
449 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000450 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000451 &FalseFuncI32<>,
452 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100453}
454
Narumol Prangnawaratea54a012020-03-16 16:36:10 +0000455bool RefLayerSupport::IsConvertFp32ToBf16Supported(const TensorInfo& input,
456 const TensorInfo& output,
457 Optional<std::string&> reasonIfUnsupported) const
458{
459 bool supported = true;
460
461 supported &= CheckSupportRule(TypeIs(input, DataType::Float32), reasonIfUnsupported,
462 "Reference for ConvertFp32ToBf16 layer: input type not supported");
463
464 supported &= CheckSupportRule(TypeIs(output, DataType::BFloat16), reasonIfUnsupported,
465 "Reference for ConvertFp32ToBf16 layer: output type not supported");
466
467 return supported;
468}
469
arovir011c7c81b2018-10-08 11:34:28 +0100470bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
471 const TensorInfo& output,
472 Optional<std::string&> reasonIfUnsupported) const
473{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100474 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
475 input.GetDataType(),
476 &FalseInputFuncF16<>,
477 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000478 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000479 &FalseFuncI32<>,
480 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100481 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
482 output.GetDataType(),
483 &TrueFunc<>,
484 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000485 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000486 &FalseFuncI32<>,
487 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100488}
489
490bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
491 const TensorInfo& output,
492 const Convolution2dDescriptor& descriptor,
493 const TensorInfo& weights,
494 const Optional<TensorInfo>& biases,
495 Optional<std::string&> reasonIfUnsupported) const
496{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100497 bool supported = true;
498
499 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000500 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000501 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000502 DataType::BFloat16,
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000503 DataType::Float32,
504 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000505 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100506 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000507 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000508 DataType::QSymmS16
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100509 };
510
511 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000512 "Reference Convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100513
514 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000515 "Reference Convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100516
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +0000517 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
518 if (input.GetDataType() == DataType::BFloat16)
519 {
520 if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
521 {
522 reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
523 supported = false;
524 }
525 }
526 else
527 {
528 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000529 "Reference Convolution2d: input and output types mismatched.");
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +0000530 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100531
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000532 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000533 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000534 {
Derek Lambertid466a542020-01-22 15:37:29 +0000535 ARMNN_NO_DEPRECATE_WARN_BEGIN
Keith Davis0c2eeac2020-02-11 16:51:50 +0000536 std::array<DataType, 4> supportedWeightTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000537 {
Sadik Armagan303980c2020-04-17 12:45:14 +0100538 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000539 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +0000540 DataType::QSymmS8,
541 DataType::QuantizedSymm8PerAxis // deprecated
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000542 };
Derek Lambertid466a542020-01-22 15:37:29 +0000543 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000544
545 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000546 "Reference Convolution2d: weights type not supported for quantized input.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000547 }
548 else
549 {
550 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000551 "Reference Convolution2d: weights is not a supported type.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000552
553 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000554 "Reference Convolution2d: input and weights types mismatched.");
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000555 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100556
557 if (biases.has_value())
558 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000559 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000560 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000561 DataType::BFloat16,
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000562 DataType::Float32,
563 DataType::Float16,
564 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100565 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000566
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100567 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000568 "Reference Convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100569 }
Jan Eilers8eb25602020-03-09 12:13:48 +0000570 IgnoreUnused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100571
572 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100573}
574
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000575bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
576 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000577 Optional<std::string&> reasonIfUnsupported) const
578{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100579 bool supported = true;
580
Narumol Prangnawarat403a1852020-03-12 14:24:13 +0000581 std::array<DataType, 8> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100582 {
Narumol Prangnawarat403a1852020-03-12 14:24:13 +0000583 DataType::BFloat16,
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +0000584 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100585 DataType::Float32,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000586 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100587 DataType::QAsymmU8,
Keith Davis5204aa82020-01-27 15:24:59 +0000588 DataType::QSymmS8,
Narumol Prangnawaratd2d917d2020-01-09 10:16:39 +0000589 DataType::QSymmS16,
590 DataType::Signed32
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100591 };
592
593 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000594 "Reference for Debug layer: input type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100595
596 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000597 "Reference for Debug layer: output type not supported");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100598
599 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000600 "Reference for Debug layer: input and output types are mismatched");
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100601
602 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000603}
604
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100605bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
606 const TensorInfo& output,
607 const DepthToSpaceDescriptor& descriptor,
608 Optional<std::string&> reasonIfUnsupported) const
609{
Jan Eilers8eb25602020-03-09 12:13:48 +0000610 IgnoreUnused(descriptor);
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100611 bool supported = true;
612
Sadik Armagan303980c2020-04-17 12:45:14 +0100613 std::array<DataType,6> supportedTypes =
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100614 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000615 DataType::BFloat16,
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100616 DataType::Float32,
617 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100618 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000619 DataType::QAsymmU8,
620 DataType::QSymmS16
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100621 };
622
623 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
624 "Reference DepthToSpace: input type not supported");
625
626 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
627 "Reference DepthToSpace: output type not supported");
628
629 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
630 "Reference DepthToSpace: input and output types are mismatched");
631
632 return supported;
633}
634
arovir011c7c81b2018-10-08 11:34:28 +0100635bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
636 const TensorInfo& output,
637 const DepthwiseConvolution2dDescriptor& descriptor,
638 const TensorInfo& weights,
639 const Optional<TensorInfo>& biases,
640 Optional<std::string&> reasonIfUnsupported) const
641{
Sadik Armagan303980c2020-04-17 12:45:14 +0100642 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100643 bool supported = true;
644
645 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000646 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100647 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000648 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100649 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100650 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +0000651 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000652 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100653 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000654 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100655 };
656
657 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
658 "Reference DepthwiseConvolution2d: input is not a supported type.");
659
660 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
661 "Reference DepthwiseConvolution2d: output is not a supported type.");
662
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100663 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
664 "Reference DepthwiseConvolution2d: input and output types mismatched.");
665
Teresa Charlind8df0262019-11-11 12:28:15 +0000666 const DataType inputType = input.GetDataType();
Keith Davis0c2eeac2020-02-11 16:51:50 +0000667 if (IsQuantized8BitType(inputType))
Teresa Charlind8df0262019-11-11 12:28:15 +0000668 {
Sadik Armagan303980c2020-04-17 12:45:14 +0100669 ARMNN_NO_DEPRECATE_WARN_BEGIN
670 std::array<DataType, 4> supportedWeightTypes =
671 {
672 DataType::QAsymmS8,
673 DataType::QAsymmU8,
674 DataType::QSymmS8,
675 DataType::QuantizedSymm8PerAxis // deprecated
676 };
677 ARMNN_NO_DEPRECATE_WARN_END
Teresa Charlind8df0262019-11-11 12:28:15 +0000678
679 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
Sadik Armagan303980c2020-04-17 12:45:14 +0100680 "Reference DepthwiseConvolution2d: weights type not supported for "
681 "quantized input.");
Teresa Charlind8df0262019-11-11 12:28:15 +0000682 }
683 else
684 {
685 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
686 "Reference DepthwiseConvolution2d: weights is not a supported type.");
687
688 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
689 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
690 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100691
692 if (biases.has_value())
693 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000694 std::array<DataType,4> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100695 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000696 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100697 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100698 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100699 DataType::Signed32
700 };
701 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
702 "Reference DepthwiseConvolution2d: biases is not a supported type.");
703 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100704
705 return supported;
706
arovir011c7c81b2018-10-08 11:34:28 +0100707}
708
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000709bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
710 const TensorInfo& output,
711 Optional<std::string&> reasonIfUnsupported) const
712{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100713 bool supported = true;
714
Ryan OShea9add1202020-02-07 10:06:33 +0000715 std::array<DataType,4> supportedInputTypes = {
716 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000717 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +0000718 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000719 DataType::QSymmS16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100720 };
721
722 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000723 "Reference for Dequantize layer: input type not supported.");
724
Derek Lambertid466a542020-01-22 15:37:29 +0000725 supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
Teresa Charlin1b1950d2021-06-02 20:23:21 +0100726 "Reference for Dequantize layer: per-axis quantized input not supported.");
Derek Lambertid466a542020-01-22 15:37:29 +0000727
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000728 std::array<DataType,3> supportedOutputTypes = {
729 DataType::BFloat16,
Jan Eilersf7107932019-11-01 11:09:36 +0000730 DataType::Float32,
731 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100732 };
733
734 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000735 "Reference for Dequantize layer: output type not supported.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100736
737 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
Keith Davis5204aa82020-01-27 15:24:59 +0000738 "Reference for Dequantize layer: input/output shapes have different num total "
739 "elements.");
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100740
741 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000742}
743
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000744bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
745 const TensorInfo& scores,
746 const TensorInfo& anchors,
747 const TensorInfo& detectionBoxes,
748 const TensorInfo& detectionClasses,
749 const TensorInfo& detectionScores,
750 const TensorInfo& numDetections,
751 const DetectionPostProcessDescriptor& descriptor,
752 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000753{
Jan Eilers8eb25602020-03-09 12:13:48 +0000754 IgnoreUnused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
Derek Lamberti901ea112019-12-10 22:07:09 +0000755
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100756 bool supported = true;
757
Sadik Armaganaa41d5d2020-11-16 14:27:52 +0000758 std::array<DataType,6> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100759 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000760 DataType::BFloat16,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100761 DataType::Float32,
Sadik Armaganaa41d5d2020-11-16 14:27:52 +0000762 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100763 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000764 DataType::QAsymmU8,
765 DataType::QSymmS16
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100766 };
767
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000768 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100769 "Reference DetectionPostProcess: input 0 is not a supported type.");
770
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000771 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100772 "Reference DetectionPostProcess: input 1 is not a supported type.");
773
774 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000775}
776
Pablo Tellof0bd6832019-04-26 17:58:13 +0100777bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
778 const TensorInfo& output,
779 const DepthwiseConvolution2dDescriptor& descriptor,
780 const TensorInfo& weights,
781 const Optional<TensorInfo>& biases,
782 Optional<std::string&> reasonIfUnsupported) const
783{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100784 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100785}
786
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100787bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100788 const TensorInfo& input1,
789 const TensorInfo& output,
790 Optional<std::string&> reasonIfUnsupported) const
791{
Sadik Armagan2999a022019-04-09 14:20:12 +0100792 bool supported = true;
793
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100794 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000795 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +0100796 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100797 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100798 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000799 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +0100800 DataType::QSymmS16,
801 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +0100802 };
803
804 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
805 "Reference division: input 0 is not a supported type.");
806
807 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
808 "Reference division: input 1 is not a supported type.");
809
810 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
811 "Reference division: output is not a supported type.");
812
813 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
814 "Reference division: input 0 and Input 1 types are mismatched");
815
816 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
817 "Reference division: input and output types are mismatched");
818
819 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
820 "Reference division: shapes are not suitable for implicit broadcast.");
821
822 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100823}
824
josh minor4a3c6102020-01-06 16:40:46 -0600825bool RefLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
826 const TensorInfo& output,
827 const ElementwiseUnaryDescriptor& descriptor,
828 Optional<std::string&> reasonIfUnsupported) const
829{
Jan Eilers8eb25602020-03-09 12:13:48 +0000830 IgnoreUnused(descriptor);
josh minor4a3c6102020-01-06 16:40:46 -0600831
Sadik Armagan303980c2020-04-17 12:45:14 +0100832 std::array<DataType, 7> supportedTypes =
josh minor4a3c6102020-01-06 16:40:46 -0600833 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000834 DataType::BFloat16,
josh minor4a3c6102020-01-06 16:40:46 -0600835 DataType::Float32,
836 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +0100837 DataType::QAsymmS8,
josh minor4a3c6102020-01-06 16:40:46 -0600838 DataType::QAsymmU8,
Sadik Armaganac472102020-03-24 09:54:36 +0000839 DataType::QSymmS16,
840 DataType::Signed32
josh minor4a3c6102020-01-06 16:40:46 -0600841 };
842
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +0000843 std::array<DataType, 1> logicalSupportedTypes =
844 {
845 DataType::Boolean
846 };
847
josh minor4a3c6102020-01-06 16:40:46 -0600848 bool supported = true;
849
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +0000850 if (descriptor.m_Operation == UnaryOperation::LogicalNot)
851 {
852 supported &= CheckSupportRule(TypeAnyOf(input, logicalSupportedTypes), reasonIfUnsupported,
853 "Reference elementwise unary: input type not supported");
josh minor4a3c6102020-01-06 16:40:46 -0600854
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +0000855 supported &= CheckSupportRule(TypeAnyOf(output, logicalSupportedTypes), reasonIfUnsupported,
856 "Reference elementwise unary: output type not supported");
857 }
858 else
859 {
860 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
861 "Reference elementwise unary: input type not supported");
862
863 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
864 "Reference elementwise unary: output type not supported");
865 }
josh minor4a3c6102020-01-06 16:40:46 -0600866
867 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
868 "Reference elementwise unary: input and output types not matching");
869
870 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
871 "Reference elementwise unary: input and output shapes"
872 "have different number of total elements");
873
874 return supported;
875}
876
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000877bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
878 const TensorInfo& input1,
879 const TensorInfo& output,
880 Optional<std::string&> reasonIfUnsupported) const
881{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100882 return IsComparisonSupported(input0,
883 input1,
884 output,
885 ComparisonDescriptor(ComparisonOperation::Equal),
886 reasonIfUnsupported);
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000887}
888
arovir011c7c81b2018-10-08 11:34:28 +0100889bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
890 const FakeQuantizationDescriptor& descriptor,
891 Optional<std::string&> reasonIfUnsupported) const
892{
Jan Eilers8eb25602020-03-09 12:13:48 +0000893 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100894 bool supported = true;
895
896 std::array<DataType,1> supportedTypes =
897 {
898 DataType::Float32
899 };
900
901 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
902 "Reference fake quantization: input type not supported.");
903
904 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100905}
906
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +0100907bool RefLayerSupport::IsFillSupported(const TensorInfo& input,
908 const TensorInfo& output,
909 const FillDescriptor& descriptor,
910 Optional<std::string&> reasonIfUnsupported) const
911{
912 IgnoreUnused(descriptor);
913 IgnoreUnused(output);
914
915 bool supported = true;
916
Sadik Armagana792a052020-06-23 16:22:23 +0100917 std::array<DataType,3> supportedTypes =
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +0100918 {
919 DataType::Float32,
Sadik Armagana792a052020-06-23 16:22:23 +0100920 DataType::Float16,
921 DataType::Signed32
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +0100922 };
923
Teresa Charlin4b10fef2020-07-29 09:36:41 +0100924 supported &= CheckSupportRule(TypeIs(input, DataType::Signed32), reasonIfUnsupported,
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +0100925 "Reference Fill: input type not supported.");
926
Teresa Charlin44088502020-07-27 11:27:19 +0100927 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
928 "Reference Fill: output type not supported.");
Ryan OSheaf4bfa6a2020-06-10 11:33:37 +0100929 return supported;
930}
931
arovir011c7c81b2018-10-08 11:34:28 +0100932bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
933 const TensorInfo& output,
934 Optional<std::string&> reasonIfUnsupported) const
935{
Jan Eilers8eb25602020-03-09 12:13:48 +0000936 IgnoreUnused(output);
James Conroy83735b12019-05-30 16:36:59 +0100937 bool supported = true;
938
Francis Murtaghe8ac1332020-07-30 18:03:40 +0100939 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100940 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000941 DataType::BFloat16,
James Conroyb40d7102019-06-04 12:32:09 +0100942 DataType::Float32,
Francis Murtaghe8ac1332020-07-30 18:03:40 +0100943 DataType::Float16
James Conroy83735b12019-05-30 16:36:59 +0100944 };
945
946 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
947 "Reference Floor: input type not supported.");
948
949 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
950 "Reference Floor: output type not supported.");
951
952 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100953}
954
955bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
956 const TensorInfo& output,
957 const TensorInfo& weights,
958 const TensorInfo& biases,
959 const FullyConnectedDescriptor& descriptor,
960 Optional<std::string&> reasonIfUnsupported) const
961{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100962 bool supported = true;
963
964 // Define supported types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000965 std::array<DataType,6> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +0100966 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000967 DataType::BFloat16,
968 DataType::Float32,
969 DataType::Float16,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000970 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +0100971 DataType::QAsymmU8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +0000972 DataType::QSymmS16
Francis Murtagh46c09d02019-05-28 08:15:28 +0100973 };
974
975 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
976 "Reference Fully Connected: input type not supported.");
977
978 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
979 "Reference Fully Connected: output type not supported.");
980
Francis Murtagh46c09d02019-05-28 08:15:28 +0100981 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
982 "Reference Fully Connected: weights type not supported.");
983
Narumol Prangnawarat57ef0082020-03-26 09:20:43 +0000984 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
985 if (input.GetDataType() == DataType::BFloat16)
986 {
987 if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
988 {
989 reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
990 supported = false;
991 }
992 }
993 else
994 {
995 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
996 "Reference Fully Connected: input and output types mismatched.");
997 }
998
Jan Eilers1f45dc32020-06-15 11:43:03 +0100999 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1000 "Reference Fully Connected: weights is not a supported type.");
Francis Murtaghddb1d062020-03-10 13:51:45 +00001001
Jan Eilers1f45dc32020-06-15 11:43:03 +01001002 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1003 "Reference Fully Connected: input and weights types mismatched.");
Francis Murtagh46c09d02019-05-28 08:15:28 +01001004
1005 if (descriptor.m_BiasEnabled)
1006 {
1007 // Defined supported types for bias
Sadik Armagandb73c982020-04-01 17:35:30 +01001008 std::array<DataType, 5>
Francis Murtagh46c09d02019-05-28 08:15:28 +01001009 supportedBiasTypes =
1010 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001011 DataType::BFloat16,
Francis Murtagh46c09d02019-05-28 08:15:28 +01001012 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001013 DataType::Float16,
Sadik Armagandb73c982020-04-01 17:35:30 +01001014 DataType::Signed32,
1015 DataType::QAsymmS8
Francis Murtagh46c09d02019-05-28 08:15:28 +01001016 };
1017
1018 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
1019 "Reference Fully Connected: bias type not supported.");
1020
1021 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
1022 "Reference Fully Connected: bias and weight types mismatch.");
1023
1024 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
1025 "Reference Fully Connected: bias type inferred from weights is incompatible.");
1026
Narumol Prangnawarat366d7232020-04-29 12:58:17 +01001027 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(biases, 1U), reasonIfUnsupported,
1028 "Reference Fully Connected: bias must have 1 dimension.");
1029
Francis Murtagh46c09d02019-05-28 08:15:28 +01001030 }
1031
1032 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001033}
1034
narpra014951d842019-01-18 16:53:53 +00001035bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
1036 const armnn::TensorInfo& input1,
1037 const armnn::TensorInfo& output,
Teresa Charlin52664732020-06-29 16:27:03 +01001038 const GatherDescriptor& descriptor,
narpra014951d842019-01-18 16:53:53 +00001039 armnn::Optional<std::string&> reasonIfUnsupported) const
1040{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001041 bool supported = true;
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001042 std::array<DataType,7> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001043 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001044 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001045 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001046 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001047 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001048 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01001049 DataType::QSymmS16,
1050 DataType::Signed32
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001051 };
1052
Teresa Charlin52664732020-06-29 16:27:03 +01001053 if (descriptor.m_Axis != 0)
1054 {
1055 reasonIfUnsupported.value() += std::string("Reference Gather: axis not supported\n");
1056 supported &= false;
1057 }
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +01001058 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1059 "Reference Gather: input type not supported");
1060
1061 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1062 "Reference Gather: output type not supported");
1063
1064 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
1065 "Reference Gather: indices (input1) type not supported");
1066
1067 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1068 "Reference Gather: input and output types not matching");
1069
1070 return supported;
narpra014951d842019-01-18 16:53:53 +00001071}
1072
FrancisMurtagh878f0232018-12-19 10:56:15 +00001073bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
1074 const TensorInfo& input1,
1075 const TensorInfo& output,
1076 Optional<std::string&> reasonIfUnsupported) const
1077{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +01001078 return IsComparisonSupported(input0,
1079 input1,
1080 output,
1081 ComparisonDescriptor(ComparisonOperation::Greater),
1082 reasonIfUnsupported);
FrancisMurtagh878f0232018-12-19 10:56:15 +00001083}
1084
Derek Lamberti901ea112019-12-10 22:07:09 +00001085bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
1086 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +01001087{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001088 return true;
arovir011c7c81b2018-10-08 11:34:28 +01001089}
1090
Kevin May09ca49c2019-10-09 12:37:34 +01001091bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
1092 const TensorInfo& output,
1093 const InstanceNormalizationDescriptor& descriptor,
1094 Optional<std::string&> reasonIfUnsupported) const
1095{
Jan Eilers8eb25602020-03-09 12:13:48 +00001096 IgnoreUnused(descriptor);
Kevin May09ca49c2019-10-09 12:37:34 +01001097 // Define supported types
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001098 std::array<DataType, 3> supportedTypes =
Kevin May09ca49c2019-10-09 12:37:34 +01001099 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001100 DataType::BFloat16,
Kevin May09ca49c2019-10-09 12:37:34 +01001101 DataType::Float32,
1102 DataType::Float16
1103 };
1104
1105 bool supported = true;
1106
1107 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1108 "Reference Instance Normalization: input type not supported.");
1109
1110 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1111 "Reference Instance Normalization: output type not supported.");
1112
1113 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1114 "Reference Instance Normalization: input and output types mismatched.");
1115
1116 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1117 "Reference Instance Normalization: input and output shapes have different "
1118 "num total elements.");
1119
1120 return supported;
1121}
1122
arovir011c7c81b2018-10-08 11:34:28 +01001123bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1124 const TensorInfo& output,
1125 const L2NormalizationDescriptor& descriptor,
1126 Optional<std::string&> reasonIfUnsupported) const
1127{
Jan Eilers8eb25602020-03-09 12:13:48 +00001128 IgnoreUnused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001129 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01001130 std::array<DataType, 6> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001131 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001132 DataType::BFloat16,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001133 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001134 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001135 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001136 DataType::QAsymmU8,
1137 DataType::QSymmS16
Ferran Balaguerd73d14f2019-06-10 10:29:54 +01001138 };
1139
1140 bool supported = true;
1141
1142 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1143 "Reference L2normalization: input type not supported.");
1144
1145 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1146 "Reference L2normalization: output type not supported.");
1147
1148 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1149 "Reference L2normalization: input and output types mismatched.");
1150
1151 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1152 "Reference L2normalization: input and output shapes have different "
1153 "num total elements.");
1154
1155 return supported;
arovir011c7c81b2018-10-08 11:34:28 +01001156}
1157
James Conroyaba90cd2020-11-06 16:28:18 +00001158bool RefLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0,
1159 const TensorInfo& input1,
1160 const TensorInfo& output,
1161 const LogicalBinaryDescriptor& descriptor,
1162 Optional<std::string&> reasonIfUnsupported) const
1163{
1164 IgnoreUnused(descriptor);
1165
1166 std::array<DataType, 1> supportedTypes =
1167 {
1168 DataType::Boolean
1169 };
1170
1171 bool supported = true;
1172 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1173 "Reference LogicalBinary: input 0 type not supported");
1174 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1175 "Reference LogicalBinary: input 1 type not supported");
1176
1177 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1178 "Reference LogicalBinary: input and output types do not match");
1179
1180 return supported;
1181}
1182
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001183bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1184 const TensorInfo& output,
1185 const LogSoftmaxDescriptor& descriptor,
1186 Optional<std::string&> reasonIfUnsupported) const
1187{
Jan Eilers8eb25602020-03-09 12:13:48 +00001188 IgnoreUnused(descriptor);
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001189
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001190 std::array<DataType, 3> supportedTypes =
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001191 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001192 DataType::BFloat16,
1193 DataType::Float32,
1194 DataType::Float16
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001195 };
1196
1197 bool supported = true;
1198 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1199 "Reference LogSoftmax: input type not supported");
1200
1201 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1202 "Reference LogSoftmax: output type not supported");
1203
1204 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1205 "Reference LogSoftmax: input and output types do not match");
1206
1207 return supported;
1208}
1209
arovir011c7c81b2018-10-08 11:34:28 +01001210bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
1211 const TensorInfo& outputStateIn,
1212 const TensorInfo& cellStateIn,
1213 const TensorInfo& scratchBuffer,
1214 const TensorInfo& outputStateOut,
1215 const TensorInfo& cellStateOut,
1216 const TensorInfo& output,
1217 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001218 const LstmInputParamsInfo& paramsInfo,
1219 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +01001220{
Jan Eilers8eb25602020-03-09 12:13:48 +00001221 IgnoreUnused(descriptor);
1222 IgnoreUnused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001223
1224 bool supported = true;
1225
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001226 std::array<DataType,3> supportedTypes = {
1227 DataType::BFloat16,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001228 DataType::Float32,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001229 DataType::QSymmS16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001230 };
1231
Jan Eilersd01a83c2019-07-03 18:20:40 +01001232 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001233 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1234 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001235 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
1236 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001237 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
1238 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001239 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
1240 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001241 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
1242 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001243 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
1244 "Reference Lstm: input and cellStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001245 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1246 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001247 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001248 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001249 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001250 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001251 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001252 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001253 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001254 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001255 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001256 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001257 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001258 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001259 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001260 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001261 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001262 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001263 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001264 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001265 "Reference Lstm: input and OutputGateBias types are mismatched");
1266 if (!descriptor.m_CifgEnabled)
1267 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001268 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001269 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001270 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001271 reasonIfUnsupported,
1272 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001273 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001274 "Reference Lstm: input and InputGateBias types are mismatched");
1275 if (descriptor.m_PeepholeEnabled)
1276 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001277 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001278 reasonIfUnsupported,
1279 "Reference Lstm: input and CellToInputWeights types are mismatched");
1280 }
1281 }
1282 if (descriptor.m_PeepholeEnabled)
1283 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001284 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001285 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001286 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001287 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1288 }
1289 if (descriptor.m_ProjectionEnabled)
1290 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001291 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001292 "Reference Lstm: input and mProjectionWeights types are mismatched");
1293 if (paramsInfo.m_ProjectionBias != nullptr)
1294 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001295 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001296 "Reference Lstm: input and ProjectionBias types are mismatched");
1297 }
1298 }
1299 if (descriptor.m_LayerNormEnabled)
1300 {
1301 if (!descriptor.m_CifgEnabled)
1302 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001303 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001304 reasonIfUnsupported,
1305 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1306 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001307 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001308 reasonIfUnsupported,
1309 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001310 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001311 reasonIfUnsupported,
1312 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001313 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001314 reasonIfUnsupported,
1315 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1316 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001317
1318 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001319}
1320
saoste012df12b32018-11-28 16:57:20 +00001321bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1322 const TensorInfo& input1,
1323 const TensorInfo& output,
1324 Optional<std::string&> reasonIfUnsupported) const
1325{
Sadik Armagan2999a022019-04-09 14:20:12 +01001326 bool supported = true;
1327
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001328 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001329 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001330 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001331 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001332 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001333 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001334 DataType::QSymmS16,
1335 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001336 };
1337
1338 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1339 "Reference maximum: input 0 is not a supported type.");
1340
1341 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1342 "Reference maximum: input 1 is not a supported type.");
1343
1344 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1345 "Reference maximum: output is not a supported type.");
1346
1347 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1348 "Reference maximum: input 0 and Input 1 types are mismatched");
1349
1350 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1351 "Reference maximum: input and output types are mismatched");
1352
1353 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1354 "Reference maximum: shapes are not suitable for implicit broadcast.");
1355
1356 return supported;
saoste012df12b32018-11-28 16:57:20 +00001357}
1358
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001359bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1360 const TensorInfo& output,
1361 const MeanDescriptor& descriptor,
1362 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001363{
James Conroy4d1ff582019-06-10 17:06:39 +01001364 bool supported = true;
1365 std::string meanLayerStr = "Mean";
1366 std::string outputTensorStr = "output";
1367
Sadik Armagan303980c2020-04-17 12:45:14 +01001368 std::array<DataType,6> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001369 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001370 DataType::BFloat16,
James Conroy4d1ff582019-06-10 17:06:39 +01001371 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001372 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001373 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001374 DataType::QAsymmU8,
1375 DataType::QSymmS16
James Conroy4d1ff582019-06-10 17:06:39 +01001376 };
1377
1378 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1379 "Reference Mean: input type not supported.");
1380
1381 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1382 "Reference Mean: input and output types are mismatched");
1383
1384 if (descriptor.m_KeepDims)
1385 {
1386 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1387 reasonIfUnsupported,
1388 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1389 output.GetNumDimensions(),
1390 meanLayerStr, outputTensorStr).data());
1391 }
1392 else if (descriptor.m_Axis.empty())
1393 {
1394 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1395 reasonIfUnsupported,
1396 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1397 meanLayerStr, outputTensorStr).data());
1398 }
1399 else
1400 {
Matthew Sloyan171214c2020-09-09 09:07:37 +01001401 auto outputDim = input.GetNumDimensions() - armnn::numeric_cast<unsigned int>(descriptor.m_Axis.size());
James Conroy4d1ff582019-06-10 17:06:39 +01001402
1403 if (outputDim > 0)
1404 {
1405 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1406 reasonIfUnsupported,
1407 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1408 meanLayerStr, outputTensorStr).data());
1409 }
1410 else
1411 {
1412 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1413 reasonIfUnsupported,
1414 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1415 meanLayerStr, outputTensorStr).data());
1416 }
1417 }
1418
1419 return supported;
narpra0132b90462018-09-13 11:07:48 +01001420}
1421
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001422bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +00001423 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +01001424 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001425 Optional<std::string&> reasonIfUnsupported) const
1426{
Jim Flynne242f2d2019-05-22 14:24:13 +01001427 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001428}
1429
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001430bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1431 const TensorInfo &output,
1432 Optional<std::string &> reasonIfUnsupported) const
1433{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001434 bool supported = true;
1435
Sadik Armagan303980c2020-04-17 12:45:14 +01001436 std::array<DataType,7> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001437 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001438 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001439 DataType::Float32,
1440 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001441 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001442 DataType::QAsymmU8,
1443 DataType::QSymmS16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001444 DataType::Boolean
1445 };
1446
1447 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1448 "Reference MemCopy: input type not supported");
1449
1450 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1451 "Reference MemCopy: output type not supported");
1452
1453 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1454 "Reference MemCopy: input and output types are mismatched");
1455
1456 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001457}
1458
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001459bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1460 const TensorInfo& input1,
1461 const TensorInfo& output,
1462 Optional<std::string&> reasonIfUnsupported) const
1463{
Sadik Armagan2999a022019-04-09 14:20:12 +01001464 bool supported = true;
1465
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001466 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001467 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001468 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001469 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001470 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001471 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001472 DataType::QSymmS16,
1473 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001474 };
1475
1476 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1477 "Reference minimum: input 0 is not a supported type.");
1478
1479 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1480 "Reference minimum: input 1 is not a supported type.");
1481
1482 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1483 "Reference minimum: output is not a supported type.");
1484
1485 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1486 "Reference minimum: input 0 and Input 1 types are mismatched");
1487
1488 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1489 "Reference minimum: input and output types are mismatched");
1490
1491 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1492 "Reference minimum: shapes are not suitable for implicit broadcast.");
1493
1494 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001495}
1496
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001497bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1498 const TensorInfo& input1,
1499 const TensorInfo& output,
1500 Optional<std::string&> reasonIfUnsupported) const
1501{
Sadik Armagan2999a022019-04-09 14:20:12 +01001502 bool supported = true;
1503
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001504 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001505 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001506 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001507 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001508 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001509 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01001510 DataType::QSymmS16,
1511 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01001512 };
1513
1514 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1515 "Reference multiplication: input 0 is not a supported type.");
1516
1517 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1518 "Reference multiplication: input 1 is not a supported type.");
1519
1520 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1521 "Reference multiplication: output is not a supported type.");
1522
1523 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1524 "Reference multiplication: input 0 and Input 1 types are mismatched");
1525
1526 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1527 "Reference multiplication: input and output types are mismatched");
1528
1529 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1530 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1531
1532 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001533}
1534
1535bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1536 const TensorInfo& output,
1537 const NormalizationDescriptor& descriptor,
1538 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001539{
Jan Eilers8eb25602020-03-09 12:13:48 +00001540 IgnoreUnused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001541
1542 // Define supported types
Sadik Armagan303980c2020-04-17 12:45:14 +01001543 std::array<DataType, 6> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001544 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001545 DataType::BFloat16,
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001546 DataType::Float16,
1547 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001548 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001549 DataType::QAsymmU8,
1550 DataType::QSymmS16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001551 };
1552
1553 bool supported = true;
1554
1555 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1556 "Reference normalization: input type not supported.");
1557
1558 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1559 "Reference normalization: output type not supported.");
1560
1561 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1562 "Reference normalization: input and output shapes have different "
1563 "num total elements.");
1564
1565 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001566}
1567
Derek Lamberti901ea112019-12-10 22:07:09 +00001568bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
1569 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001570{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001571 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001572}
1573
1574bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1575 const TensorInfo& output,
1576 const PadDescriptor& descriptor,
1577 Optional<std::string&> reasonIfUnsupported) const
1578{
Jan Eilers8eb25602020-03-09 12:13:48 +00001579 IgnoreUnused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001580 bool supported = true;
1581
1582 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01001583 std::array<DataType,6> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001584 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001585 DataType::BFloat16,
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001586 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001587 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001588 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001589 DataType::QAsymmU8,
1590 DataType::QSymmS16
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001591 };
1592
1593 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1594 "Reference pad: input is not a supported type.");
1595
1596 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1597 "Reference pad: output is not a supported type.");
1598
1599 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1600 "Reference pad: input and output types are mismatched.");
1601
1602 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01001603}
1604
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001605bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1606 const TensorInfo& output,
1607 const PermuteDescriptor& descriptor,
1608 Optional<std::string&> reasonIfUnsupported) const
1609{
Jan Eilers8eb25602020-03-09 12:13:48 +00001610 IgnoreUnused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001611 bool supported = true;
1612
1613 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01001614 std::array<DataType, 6> supportedTypes =
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001615 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001616 DataType::BFloat16,
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001617 DataType::Float32,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00001618 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001619 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001620 DataType::QAsymmU8,
1621 DataType::QSymmS16
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001622 };
1623
1624 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1625 "Reference permute: input is not a supported type.");
1626
1627 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1628 "Reference permute: output is not a supported type.");
1629
1630 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1631 "Reference permute: input and output types are mismatched.");
1632
1633 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001634}
1635
1636bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1637 const TensorInfo& output,
1638 const Pooling2dDescriptor& descriptor,
1639 Optional<std::string&> reasonIfUnsupported) const
1640{
Jan Eilers8eb25602020-03-09 12:13:48 +00001641 IgnoreUnused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001642 bool supported = true;
1643
1644 // Define supported output and inputs types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001645 std::array<DataType,6> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001646 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001647 DataType::BFloat16,
Teresa Charlina3b20472019-06-06 11:12:32 +01001648 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001649 DataType::Float16,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001650 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001651 DataType::QAsymmU8,
1652 DataType::QSymmS16
Teresa Charlina3b20472019-06-06 11:12:32 +01001653 };
1654
1655 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1656 "Reference poolind2d: input is not a supported type.");
1657
1658 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1659 "Reference poolind2d: output is not a supported type.");
1660
1661 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1662 "Reference poolind2d: input and output types are mismatched.");
1663
1664 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001665}
1666
James Conroy4f1f8992020-04-29 20:01:10 +01001667bool RefLayerSupport::IsQLstmSupported(const TensorInfo& input,
1668 const TensorInfo& previousOutputIn,
1669 const TensorInfo& previousCellStateIn,
1670 const TensorInfo& outputStateOut,
1671 const TensorInfo& cellStateOut,
1672 const TensorInfo& output,
1673 const QLstmDescriptor& descriptor,
1674 const LstmInputParamsInfo& paramsInfo,
1675 Optional<std::string&> reasonIfUnsupported) const
1676{
1677 IgnoreUnused(input);
1678 IgnoreUnused(previousOutputIn);
1679 IgnoreUnused(previousCellStateIn);
1680 IgnoreUnused(outputStateOut);
1681 IgnoreUnused(cellStateOut);
1682 IgnoreUnused(output);
1683 IgnoreUnused(descriptor);
1684 IgnoreUnused(paramsInfo);
1685
1686 IgnoreUnused(reasonIfUnsupported);
1687
1688 return true;
1689}
1690
Derek Lamberti5f400d62019-03-25 15:41:58 +00001691bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1692 const TensorInfo& output,
1693 Optional<std::string&> reasonIfUnsupported) const
1694{
1695 bool supported = true;
1696
Finn Williamsfd271062019-12-04 14:27:27 +00001697 // Define supported input types.
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001698 std::array<DataType,7> supportedInputTypes = {
1699 DataType::BFloat16,
Keith Davis5e51cd82020-01-29 16:52:59 +00001700 DataType::Float32,
Keith Davis3d8bc972020-02-04 09:31:47 +00001701 DataType::Float16,
Ryan OShea9add1202020-02-07 10:06:33 +00001702 DataType::QAsymmS8,
Keith Davis5e51cd82020-01-29 16:52:59 +00001703 DataType::QAsymmU8,
1704 DataType::QSymmS8,
1705 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00001706 };
1707
1708 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1709 "Reference quantize: input type not supported.");
1710
1711 // Define supported output types.
Ryan OShea9add1202020-02-07 10:06:33 +00001712 std::array<DataType,4> supportedOutputTypes = {
Ryan OShea9add1202020-02-07 10:06:33 +00001713 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001714 DataType::QAsymmU8,
Finn Williamsfd271062019-12-04 14:27:27 +00001715 DataType::QSymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001716 DataType::QSymmS16
Derek Lamberti5f400d62019-03-25 15:41:58 +00001717 };
1718 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1719 "Reference quantize: output type not supported.");
1720
1721 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1722 "Reference quantize: input and output shapes have different num total elements.");
1723
1724 return supported;
1725}
1726
Finn Williams2605b232020-06-10 15:53:46 +01001727bool RefLayerSupport::IsRankSupported(const TensorInfo& input,
1728 const TensorInfo& output,
1729 Optional<std::string&> reasonIfUnsupported) const
1730{
1731 IgnoreUnused(input);
1732 // Define supported output types.
1733 std::array<DataType,1> supportedOutputTypes =
1734 {
1735 DataType::Signed32,
1736 };
1737
1738 return CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1739 "Reference rank: input type not supported.");
1740}
1741
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +00001742bool RefLayerSupport::IsReduceSupported(const TensorInfo& input,
1743 const TensorInfo& output,
1744 const ReduceDescriptor& descriptor,
1745 Optional<std::string&> reasonIfUnsupported) const
1746{
1747 IgnoreUnused(descriptor);
1748 bool supported = true;
1749 std::array<DataType,7> supportedTypes =
1750 {
1751 DataType::BFloat16,
1752 DataType::Float32,
1753 DataType::Float16,
1754 DataType::QAsymmS8,
1755 DataType::QAsymmU8,
1756 DataType::QSymmS16,
1757 DataType::Signed32
1758 };
1759
1760 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1761 "Reference Reduce: input type not supported");
1762
1763 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1764 "Reference Reduce: output type not supported");
1765
1766 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1767 "Reference Reduce: input and output types not matching");
1768
1769 return supported;
1770}
1771
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001772bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Kevin Maya023c402019-12-12 17:28:05 +00001773 const TensorInfo& output,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001774 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001775 Optional<std::string&> reasonIfUnsupported) const
1776{
Jan Eilers8eb25602020-03-09 12:13:48 +00001777 IgnoreUnused(output);
1778 IgnoreUnused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001779 // Define supported output types.
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001780 std::array<DataType,8> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001781 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001782 DataType::BFloat16,
Nina Drozd2f2778f2019-05-27 10:37:05 +01001783 DataType::Float32,
1784 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001785 DataType::Signed32,
Keith Davis0c2eeac2020-02-11 16:51:50 +00001786 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001787 DataType::QAsymmU8,
Narumol Prangnawarat0c95f4c2020-11-18 16:52:07 +00001788 DataType::QSymmS16,
1789 DataType::Boolean
Nina Drozd2f2778f2019-05-27 10:37:05 +01001790 };
Keith Davis0c2eeac2020-02-11 16:51:50 +00001791
Nina Drozd2f2778f2019-05-27 10:37:05 +01001792 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1793 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001794}
1795
1796bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001797 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001798 Optional<std::string&> reasonIfUnsupported) const
1799{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001800 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01001801 std::array<DataType,6> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001802 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001803 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001804 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001805 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001806 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001807 DataType::QAsymmU8,
1808 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001809 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001810
1811 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1812 "Reference ResizeBilinear: input type not supported");
1813
1814 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1815 "Reference ResizeBilinear: output type not supported");
1816
1817 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1818 "Reference ResizeBilinear: input and output types not matching");
1819
1820 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001821}
1822
Teresa Charlin970f43b2019-07-01 13:51:07 +01001823bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1824 const TensorInfo& output,
1825 const ResizeDescriptor& descriptor,
1826 Optional<std::string&> reasonIfUnsupported) const
1827{
Jan Eilers8eb25602020-03-09 12:13:48 +00001828 IgnoreUnused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01001829 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001830 std::array<DataType,6> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001831 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001832 DataType::BFloat16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001833 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001834 DataType::Float16,
Keith Davis67e6c542020-02-19 10:08:33 +00001835 DataType::QAsymmS8,
Sadik Armagan303980c2020-04-17 12:45:14 +01001836 DataType::QAsymmU8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001837 DataType::QSymmS16
Teresa Charlin970f43b2019-07-01 13:51:07 +01001838 };
1839
1840 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1841 "Reference Resize: input type not supported");
1842
1843 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1844 "Reference Resize: output type not supported");
1845
1846 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1847 "Reference Resize: input and output types not matching");
1848
1849 return supported;
1850}
1851
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001852bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1853 const TensorInfo& output,
1854 Optional<std::string&> reasonIfUnsupported) const
1855{
josh minor4a3c6102020-01-06 16:40:46 -06001856 return IsElementwiseUnarySupported(input,
1857 output,
1858 ElementwiseUnaryDescriptor(UnaryOperation::Rsqrt),
1859 reasonIfUnsupported);
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001860}
1861
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001862bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
1863 const TensorInfo& output,
1864 const SliceDescriptor& descriptor,
1865 Optional<std::string&> reasonIfUnsupported) const
1866{
Jan Eilers8eb25602020-03-09 12:13:48 +00001867 IgnoreUnused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001868 bool supported = true;
1869
Sadik Armagan303980c2020-04-17 12:45:14 +01001870 std::array<DataType, 5> supportedTypes =
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001871 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001872 DataType::BFloat16,
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001873 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01001874 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001875 DataType::QAsymmU8,
1876 DataType::QSymmS16
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001877 };
1878
1879 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1880 "Reference Slice: input type not supported");
1881
1882 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1883 "Reference Slice: output type not supported");
1884
1885 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1886 "Reference Slice: input and output types are mismatched");
1887
1888 return supported;
1889}
1890
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001891bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1892 const TensorInfo& output,
1893 const SoftmaxDescriptor& descriptor,
1894 Optional<std::string&> reasonIfUnsupported) const
1895{
Jan Eilers8eb25602020-03-09 12:13:48 +00001896 IgnoreUnused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01001897 bool supported = true;
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001898 std::array<DataType,7> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01001899 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001900 DataType::BFloat16,
1901 DataType::Float32,
1902 DataType::Float16,
1903 DataType::QSymmS8,
1904 DataType::QAsymmS8,
1905 DataType::QAsymmU8,
1906 DataType::QSymmS16
nikraj01248683f2019-05-29 16:46:50 +01001907 };
1908
1909 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001910 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001911
1912 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001913 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001914
1915 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001916 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001917
1918 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001919}
1920
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001921bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1922 const TensorInfo& output,
1923 const SpaceToBatchNdDescriptor& descriptor,
1924 Optional<std::string&> reasonIfUnsupported) const
1925{
Jan Eilers8eb25602020-03-09 12:13:48 +00001926 IgnoreUnused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01001927 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01001928 std::array<DataType,6> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01001929 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001930 DataType::BFloat16,
1931 DataType::Float32,
1932 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001933 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001934 DataType::QAsymmU8,
1935 DataType::QSymmS16
nikraj01120522a2019-05-31 11:33:07 +01001936 };
1937
1938 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1939 "Reference SpaceToBatchNd: input type not supported");
1940
1941 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1942 "Reference SpaceToBatchNd: output type not supported");
1943
1944 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1945 "Reference SpaceToBatchNd: input and output types are mismatched");
1946
1947 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001948}
1949
Keith Davisa57eccb2019-06-14 17:33:22 +01001950bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01001951 const TensorInfo& output,
1952 const SpaceToDepthDescriptor& descriptor,
1953 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01001954{
1955
Jan Eilers8eb25602020-03-09 12:13:48 +00001956 IgnoreUnused(descriptor);
Keith Davisa57eccb2019-06-14 17:33:22 +01001957 bool supported = true;
1958
Sadik Armagan303980c2020-04-17 12:45:14 +01001959 std::array<DataType,6> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01001960 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001961 DataType::BFloat16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001962 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001963 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001964 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001965 DataType::QAsymmU8,
1966 DataType::QSymmS16
Keith Davisa57eccb2019-06-14 17:33:22 +01001967 };
1968
1969 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1970 "Reference SpaceToDepth: input type not supported");
1971
1972 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1973 "Reference SpaceToDepth: output type not supported");
1974
1975 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1976 "Reference SpaceToDepth: input and output types are mismatched");
1977
1978 return supported;
1979}
1980
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001981bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1982 const ViewsDescriptor& descriptor,
1983 Optional<std::string&> reasonIfUnsupported) const
1984{
Jan Eilers8eb25602020-03-09 12:13:48 +00001985 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001986 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01001987 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001988 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00001989 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001990 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001991 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01001992 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00001993 DataType::QAsymmU8,
1994 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001995 };
1996
1997 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1998 "Reference splitter: input type not supported");
1999
2000 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002001}
2002
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002003bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
2004 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
2005 const ViewsDescriptor& descriptor,
2006 Optional<std::string&> reasonIfUnsupported) const
2007{
Jan Eilers8eb25602020-03-09 12:13:48 +00002008 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002009 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002010 std::array<DataType,6> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002011 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002012 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002013 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002014 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002015 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002016 DataType::QAsymmU8,
2017 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002018 };
2019
2020 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2021 "Reference splitter: output type not supported");
Derek Lambertieac4adb2020-08-25 13:05:59 +01002022 for (const TensorInfo& output : outputs)
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002023 {
2024 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2025 "Reference splitter: input type not supported");
2026
2027 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2028 "Reference splitter: input and output types mismatched.");
2029 }
2030
2031 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01002032}
2033
Matthew Jackson81e601c2019-07-11 12:07:09 +01002034bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
2035 const TensorInfo& output,
2036 const StackDescriptor& descriptor,
2037 Optional<std::string&> reasonIfUnsupported) const
2038{
Jan Eilers8eb25602020-03-09 12:13:48 +00002039 IgnoreUnused(descriptor);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002040
2041 bool supported = true;
Sadik Armagan303980c2020-04-17 12:45:14 +01002042 std::array<DataType,6> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01002043 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002044 DataType::BFloat16,
Matthew Jackson81e601c2019-07-11 12:07:09 +01002045 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01002046 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002047 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002048 DataType::QAsymmU8,
2049 DataType::QSymmS16
Matthew Jackson81e601c2019-07-11 12:07:09 +01002050 };
2051
2052 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2053 "Reference stack: output type not supported");
2054 for (const TensorInfo* input : inputs)
2055 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01002056 ARMNN_ASSERT(input != nullptr);
Matthew Jackson81e601c2019-07-11 12:07:09 +01002057 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
2058 "Reference stack: input type not supported");
2059
2060 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
2061 "Reference stack: input and output types mismatched.");
2062 }
2063
2064 return supported;
2065}
2066
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002067bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
2068 const TensorInfo& output,
2069 const StridedSliceDescriptor& descriptor,
2070 Optional<std::string&> reasonIfUnsupported) const
2071{
Jan Eilers8eb25602020-03-09 12:13:48 +00002072 IgnoreUnused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002073 bool supported = true;
2074
Sadik Armagan303980c2020-04-17 12:45:14 +01002075 std::array<DataType,5> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002076 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002077 DataType::BFloat16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002078 DataType::Float32,
Sadik Armagan303980c2020-04-17 12:45:14 +01002079 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002080 DataType::QAsymmU8,
2081 DataType::QSymmS16
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01002082 };
2083
2084 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2085 "Reference StridedSlice: input type not supported");
2086
2087 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2088 "Reference StridedSlice: output type not supported");
2089
2090 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2091 "Reference StridedSlice: input and output types are mismatched");
2092
2093 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00002094}
2095
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002096bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
2097 const TensorInfo& input1,
2098 const TensorInfo& output,
2099 Optional<std::string&> reasonIfUnsupported) const
2100{
Sadik Armagan2999a022019-04-09 14:20:12 +01002101 bool supported = true;
2102
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002103 std::array<DataType,7> supportedTypes = {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002104 DataType::BFloat16,
Sadik Armagan2999a022019-04-09 14:20:12 +01002105 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002106 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002107 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002108 DataType::QAsymmU8,
Teresa Charlinecb6b8e2020-05-22 18:08:23 +01002109 DataType::QSymmS16,
2110 DataType::Signed32
Sadik Armagan2999a022019-04-09 14:20:12 +01002111 };
2112
2113 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
2114 "Reference subtraction: input 0 is not a supported type.");
2115
2116 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
2117 "Reference subtraction: input 1 is not a supported type.");
2118
2119 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2120 "Reference subtraction: output is not a supported type.");
2121
2122 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
2123 "Reference subtraction: input 0 and Input 1 types are mismatched");
2124
2125 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
2126 "Reference subtraction: input and output types are mismatched");
2127
2128 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
2129 "Reference subtraction: shapes are not suitable for implicit broadcast.");
2130
2131 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01002132}
2133
Matteo Martincighab9e5252019-06-13 17:27:46 +01002134bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
2135 const TensorInfo& alpha,
2136 const TensorInfo& output,
2137 Optional<std::string&> reasonIfUnsupported) const
2138{
2139 bool supported = true;
2140
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002141 std::array<DataType, 6> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01002142 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002143 DataType::BFloat16,
Matteo Martincighab9e5252019-06-13 17:27:46 +01002144 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01002145 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002146 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002147 DataType::QAsymmU8,
Teresa Charlin3940d8b2020-05-29 16:47:23 +01002148 DataType::QSymmS16
Matteo Martincighab9e5252019-06-13 17:27:46 +01002149 };
2150
2151 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2152 "PReLU: input is not a supported type.");
2153
2154 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
2155 "PReLU: alpha is not a supported type.");
2156
2157 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2158 "PReLU: output is not a supported type.");
2159
2160 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
2161 "PReLU: input, alpha and output types are mismatched");
2162
2163 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
2164 "PReLU: shapes are not suitable for implicit broadcast");
2165
2166 return supported;
2167}
2168
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002169bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
2170 const TensorInfo& output,
2171 const TransposeConvolution2dDescriptor& descriptor,
2172 const TensorInfo& weights,
2173 const Optional<TensorInfo>& biases,
2174 Optional<std::string&> reasonIfUnsupported) const
2175{
Jan Eilers8eb25602020-03-09 12:13:48 +00002176 IgnoreUnused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002177 bool supported = true;
2178
Sadik Armagan303980c2020-04-17 12:45:14 +01002179 std::array<DataType,7> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002180 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002181 DataType::BFloat16,
2182 DataType::Float32,
2183 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002184 DataType::QAsymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002185 DataType::QAsymmU8,
Sadik Armagan303980c2020-04-17 12:45:14 +01002186 DataType::QSymmS8,
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002187 DataType::QSymmS16
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002188 };
2189
2190 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2191 "Reference TransposeConvolution2d: input is not a supported type.");
2192
2193 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2194 "Reference TransposeConvolution2d: output is not a supported type.");
2195
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002196 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2197 "Reference TransposeConvolution2d: input and output types mismatched.");
2198
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002199
2200 const DataType inputType = input.GetDataType();
Sadik Armagan303980c2020-04-17 12:45:14 +01002201 if (IsQuantized8BitType(inputType))
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002202 {
Derek Lambertid466a542020-01-22 15:37:29 +00002203 ARMNN_NO_DEPRECATE_WARN_BEGIN
Sadik Armagan303980c2020-04-17 12:45:14 +01002204 std::array<DataType, 4> supportedWeightTypes =
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002205 {
Sadik Armagan303980c2020-04-17 12:45:14 +01002206 DataType::QAsymmS8,
Derek Lambertif90c56d2020-01-10 17:14:08 +00002207 DataType::QAsymmU8,
Derek Lambertid466a542020-01-22 15:37:29 +00002208 DataType::QSymmS8,
2209 DataType::QuantizedSymm8PerAxis //Deprecated
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002210 };
Derek Lambertid466a542020-01-22 15:37:29 +00002211 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00002212
2213 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
2214 "Reference TransposeConvolution2d: weights type not supported for "
2215 "quantized input.");
2216 }
2217 else
2218 {
2219 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
2220 "Reference TransposeConvolution2d: weights is not a supported type.");
2221
2222 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
2223 "Reference TransposeConvolution2d: input and weights types mismatched.");
2224 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002225
2226 if (biases.has_value())
2227 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002228 std::array<DataType,4> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01002229 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002230 DataType::BFloat16,
2231 DataType::Float32,
2232 DataType::Float16,
2233 DataType::Signed32
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01002234 };
2235 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
2236 "Reference TransposeConvolution2d: biases is not a supported type.");
2237 }
2238
2239 return supported;
2240}
2241
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002242bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
2243 const TensorInfo& output,
2244 const TransposeDescriptor& descriptor,
2245 Optional<std::string&> reasonIfUnsupported) const
2246{
Jan Eilers8eb25602020-03-09 12:13:48 +00002247 IgnoreUnused(descriptor);
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002248 bool supported = true;
2249
2250 // Define supported output and inputs types.
Sadik Armagan303980c2020-04-17 12:45:14 +01002251 std::array<DataType, 6> supportedTypes =
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002252 {
Narumol Prangnawarat44179c32020-03-11 14:51:27 +00002253 DataType::BFloat16,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002254 DataType::Float32,
2255 DataType::Float16,
Sadik Armagan303980c2020-04-17 12:45:14 +01002256 DataType::QAsymmS8,
Mike Kellyc9ea45a2020-02-28 18:11:58 +00002257 DataType::QAsymmU8,
2258 DataType::QSymmS16
2259 };
2260
2261 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
2262 "Reference transpose: input is not a supported type.");
2263
2264 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
2265 "Reference transpose: output is not a supported type.");
2266
2267 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
2268 "Reference transpose: input and output types are mismatched.");
2269
2270 return supported;
2271}
2272
arovir011c7c81b2018-10-08 11:34:28 +01002273} // namespace armnn