blob: 03c8633dce86372f4c7b1766f7e9157d74741b5c [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <InternalTypes.hpp>
10#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000012#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
David Beck111b5d92018-11-12 14:59:37 +000014#include <backendsCommon/BackendRegistry.hpp>
Francis Murtagh46c09d02019-05-28 08:15:28 +010015#include <backendsCommon/test/WorkloadTestUtils.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010016
telsoa014fcda012018-03-09 14:13:49 +000017#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
Derek Lamberti50db4e82019-03-13 14:16:15 +000019#include <vector>
20#include <algorithm>
21#include <array>
22
telsoa014fcda012018-03-09 14:13:49 +000023using namespace boost;
24
25namespace armnn
26{
27
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010028namespace
29{
30
31template<typename Float32Func, typename Uint8Func, typename ... Params>
32bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
33 DataType dataType,
34 Float32Func floatFuncPtr,
35 Uint8Func uint8FuncPtr,
36 Params&&... params)
37{
38 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
39 dataType,
40 &FalseFunc<Params...>,
41 floatFuncPtr,
42 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000043 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000044 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010045 std::forward<Params>(params)...);
46}
47
48} // anonymous namespace
49
Derek Lamberti50db4e82019-03-13 14:16:15 +000050
51namespace
52{
53template<typename F>
54bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
55{
56 bool supported = rule();
57 if (!supported && reason)
58 {
59 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
60 }
61 return supported;
62}
63
64struct Rule
65{
66 bool operator()() const
67 {
68 return m_Res;
69 }
70
71 bool m_Res = true;
72};
73
Derek Lamberti2a434a82019-03-20 13:07:57 +000074template<typename T>
75bool AllTypesAreEqualImpl(T t)
Derek Lamberti50db4e82019-03-13 14:16:15 +000076{
77 return true;
78}
79
80template<typename T, typename... Rest>
81bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
82{
83 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
84
Derek Lamberti2a434a82019-03-20 13:07:57 +000085 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
Derek Lamberti50db4e82019-03-13 14:16:15 +000086}
87
88struct TypesAreEqual : public Rule
89{
90 template<typename ... Ts>
91 TypesAreEqual(const Ts&... ts)
92 {
93 m_Res = AllTypesAreEqualImpl(ts...);
94 }
95};
96
97struct QuantizationParametersAreEqual : public Rule
98{
99 QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
100 {
101 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
102 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
103 }
104};
105
106struct TypeAnyOf : public Rule
107{
108 template<typename Container>
109 TypeAnyOf(const TensorInfo& info, const Container& c)
110 {
111 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
Francis Murtagh46c09d02019-05-28 08:15:28 +0100112 {
113 return dt == info.GetDataType();
114 });
115 }
116};
117
118struct BiasAndWeightsTypesMatch : public Rule
119{
120 BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
121 {
122 m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value();
123 }
124};
125
126struct BiasAndWeightsTypesCompatible : public Rule
127{
128 template<typename Container>
129 BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c)
130 {
131 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
132 {
133 return dt == GetBiasTypeFromWeightsType(info.GetDataType()).value();
134 });
Derek Lamberti50db4e82019-03-13 14:16:15 +0000135 }
136};
137
138struct ShapesAreSameRank : public Rule
139{
140 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
141 {
142 m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
143 }
144};
145
Derek Lamberti5f400d62019-03-25 15:41:58 +0000146struct ShapesAreSameTotalSize : public Rule
147{
148 ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
149 {
150 m_Res = info0.GetNumElements() == info1.GetNumElements();
151 }
152};
153
Derek Lamberti50db4e82019-03-13 14:16:15 +0000154struct ShapesAreBroadcastCompatible : public Rule
155{
156 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
157 {
158 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
159 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
160 return sizeIn;
161 }
162
163 ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
164 {
165 const TensorShape& shape0 = in0.GetShape();
166 const TensorShape& shape1 = in1.GetShape();
167 const TensorShape& outShape = out.GetShape();
168
169 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
170 {
171 unsigned int sizeOut = outShape[i];
172 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
173 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
174
175 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
176 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
177 }
178 }
179};
180} // namespace
181
182
arovir011c7c81b2018-10-08 11:34:28 +0100183bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
184 const TensorInfo& output,
185 const ActivationDescriptor& descriptor,
186 Optional<std::string&> reasonIfUnsupported) const
187{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000188 bool supported = true;
189
190 // Define supported types.
Teresa Charlin18515e22019-04-24 10:17:46 +0100191 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000192 DataType::Float32,
Teresa Charlin18515e22019-04-24 10:17:46 +0100193 DataType::QuantisedAsymm8,
194 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000195 };
196
197 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
198 "Reference activation: input type not supported.");
199
200 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
201 "Reference activation: output type not supported.");
202
203 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
204 "Reference activation: input and output types mismatched.");
205
206 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
207 "Reference activation: input and output shapes are of different rank.");
208
209
210 struct ActivationFunctionSupported : public Rule
211 {
212 ActivationFunctionSupported(const ActivationDescriptor& desc)
213 {
214 switch(desc.m_Function)
215 {
216 case ActivationFunction::Abs:
217 case ActivationFunction::BoundedReLu:
218 case ActivationFunction::LeakyReLu:
219 case ActivationFunction::Linear:
220 case ActivationFunction::ReLu:
221 case ActivationFunction::Sigmoid:
222 case ActivationFunction::SoftReLu:
223 case ActivationFunction::Sqrt:
224 case ActivationFunction::Square:
225 case ActivationFunction::TanH:
226 {
227 m_Res = true;
228 break;
229 }
230 default:
231 {
232 m_Res = false;
233 break;
234 }
235 }
236 }
237 };
238
239 // Function is supported
240 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
241 "Reference activation: function not supported.");
242
243 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100244}
245
246bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
247 const TensorInfo& input1,
248 const TensorInfo& output,
249 Optional<std::string&> reasonIfUnsupported) const
250{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000251 bool supported = true;
252
Sadik Armagan2999a022019-04-09 14:20:12 +0100253 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000254 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100255 DataType::QuantisedAsymm8,
256 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000257 };
258
259 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
260 "Reference addition: input 0 is not a supported type.");
261
262 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
263 "Reference addition: input 1 is not a supported type.");
264
265 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
266 "Reference addition: output is not a supported type.");
267
268 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
269 "Reference addition: input 0 and Input 1 types are mismatched");
270
271 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
272 "Reference addition: input and output types are mismatched");
273
274 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
275 "Reference addition: shapes are not suitable for implicit broadcast.");
276
277 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100278}
279
280bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
281 const TensorInfo& output,
282 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100283 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100284 const TensorInfo& beta,
285 const TensorInfo& gamma,
286 const BatchNormalizationDescriptor& descriptor,
287 Optional<std::string&> reasonIfUnsupported) const
288{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100289 ignore_unused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100290
Matteo Martincighf5507132019-06-04 10:59:47 +0100291 std::array<DataType, 3> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100292 {
293 DataType::Float32,
Matteo Martincighf5507132019-06-04 10:59:47 +0100294 DataType::QuantisedAsymm8,
295 DataType::QuantisedSymm16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100296 };
297
298 bool supported = true;
299
300 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
301 "Reference batch normalization: input is not a supported type.");
302
303 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
304 "Reference batch normalization: output is not a supported type.");
305
306 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
307 "Reference batch normalization: input and output types are mismatched");
308
309 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
310 "Reference batch normalization: mean is not a supported type.");
311
312 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
313 "Reference batch normalization: variance is not a supported type.");
314
315 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
316 "Reference batch normalization: beta is not a supported type.");
317
318 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
319 "Reference batch normalization: gamma is not a supported type.");
320
321 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100322}
323
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000324bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
325 const TensorInfo& output,
326 const BatchToSpaceNdDescriptor& descriptor,
327 Optional<std::string&> reasonIfUnsupported) const
328{
329 ignore_unused(descriptor);
330 return (IsSupportedForDataTypeRef(reasonIfUnsupported,
331 input.GetDataType(),
332 &TrueFunc<>,
333 &TrueFunc<>) &&
334 IsSupportedForDataTypeRef(reasonIfUnsupported,
335 output.GetDataType(),
336 &TrueFunc<>,
337 &TrueFunc<>));
338}
339
Jim Flynn906f9462019-05-10 13:55:21 +0100340bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
341 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100342 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100343 Optional<std::string&> reasonIfUnsupported) const
344{
Jim Flynne242f2d2019-05-22 14:24:13 +0100345 ignore_unused(descriptor);
346
347 bool supported = true;
348 std::array<DataType,3> supportedTypes =
349 {
350 DataType::Float32,
351 DataType::QuantisedAsymm8,
352 DataType::QuantisedSymm16
353 };
354
355 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
356 "Reference concatenation: output type not supported");
357 for (const TensorInfo* input : inputs)
358 {
359 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
360 "Reference concatenation: input type not supported");
361
362 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
363 "Reference concatenation: input and output types mismatched.");
364 }
365
366 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100367}
368
arovir011c7c81b2018-10-08 11:34:28 +0100369bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
370 Optional<std::string&> reasonIfUnsupported) const
371{
Jim Flynne242f2d2019-05-22 14:24:13 +0100372 std::array<DataType,4> supportedTypes =
373 {
Nina Drozd58ef2c62019-05-16 12:09:18 +0100374 DataType::Float32,
375 DataType::Signed32,
376 DataType::QuantisedAsymm8,
377 DataType::QuantisedSymm16
378 };
379
380 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
381 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100382}
383
384bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
385 const TensorInfo& output,
386 Optional<std::string&> reasonIfUnsupported) const
387{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100388 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
389 input.GetDataType(),
390 &TrueFunc<>,
391 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000392 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000393 &FalseFuncI32<>,
394 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100395 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
396 output.GetDataType(),
397 &FalseOutputFuncF16<>,
398 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000399 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000400 &FalseFuncI32<>,
401 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100402}
403
404bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
405 const TensorInfo& output,
406 Optional<std::string&> reasonIfUnsupported) const
407{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100408 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
409 input.GetDataType(),
410 &FalseInputFuncF16<>,
411 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000412 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000413 &FalseFuncI32<>,
414 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100415 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
416 output.GetDataType(),
417 &TrueFunc<>,
418 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000419 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000420 &FalseFuncI32<>,
421 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100422}
423
424bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
425 const TensorInfo& output,
426 const Convolution2dDescriptor& descriptor,
427 const TensorInfo& weights,
428 const Optional<TensorInfo>& biases,
429 Optional<std::string&> reasonIfUnsupported) const
430{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100431 bool supported = true;
432
433 // Define supported types.
434 std::array<DataType,3> supportedTypes = {
435 DataType::Float32,
436 DataType::QuantisedAsymm8,
437 DataType::QuantisedSymm16
438 };
439
440 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100441 "Reference convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100442
443 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100444 "Reference convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100445
446 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100447 "Reference convolution2d: weights is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100448
449 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100450 "Reference convolution2d: input and output types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100451
452 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100453 "Reference convolution2d: input and weights types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100454
455 if (biases.has_value())
456 {
457 std::array<DataType,3> biasesSupportedTypes = {
458 DataType::Float32,
459 DataType::Signed32
460 };
461 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100462 "Reference convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100463 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100464 ignore_unused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100465
466 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100467}
468
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000469bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
470 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000471 Optional<std::string&> reasonIfUnsupported) const
472{
473 ignore_unused(output);
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000474 return IsSupportedForDataTypeRef(reasonIfUnsupported,
475 input.GetDataType(),
476 &TrueFunc<>,
477 &TrueFunc<>);
478}
479
arovir011c7c81b2018-10-08 11:34:28 +0100480bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
481 const TensorInfo& output,
482 const DepthwiseConvolution2dDescriptor& descriptor,
483 const TensorInfo& weights,
484 const Optional<TensorInfo>& biases,
485 Optional<std::string&> reasonIfUnsupported) const
486{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100487 ignore_unused(output);
488 ignore_unused(descriptor);
489 ignore_unused(weights);
490 ignore_unused(biases);
491 return IsSupportedForDataTypeRef(reasonIfUnsupported,
492 input.GetDataType(),
493 &TrueFunc<>,
494 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100495}
496
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000497bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
498 const TensorInfo& output,
499 Optional<std::string&> reasonIfUnsupported) const
500{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100501 bool supported = true;
502
503 std::array<DataType,2> supportedInputTypes = {
504 DataType::QuantisedAsymm8,
505 DataType::QuantisedSymm16
506 };
507
508 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
509 "Reference dequantize: input type not supported.");
510
511 std::array<DataType,2> supportedOutputTypes = {
512 DataType::Float32,
513 };
514
515 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
516 "Reference dequantize: output type not supported.");
517
518 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
519 "Reference dequantize: input and output shapes have different num total elements.");
520
521 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000522}
523
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000524bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
525 const armnn::TensorInfo& input1,
526 const armnn::DetectionPostProcessDescriptor& descriptor,
527 armnn::Optional<std::string&> reasonIfUnsupported) const
528{
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100529 bool supported = true;
530
531 std::vector<DataType> supportedInputTypes =
532 {
533 DataType::Float32,
534 DataType::QuantisedAsymm8,
535 DataType::QuantisedSymm16
536 };
537
538 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
539 "Reference DetectionPostProcess: input 0 is not a supported type.");
540
541 supported &= CheckSupportRule(TypeAnyOf(input1, supportedInputTypes), reasonIfUnsupported,
542 "Reference DetectionPostProcess: input 1 is not a supported type.");
543
544 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000545}
546
Pablo Tellof0bd6832019-04-26 17:58:13 +0100547bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
548 const TensorInfo& output,
549 const DepthwiseConvolution2dDescriptor& descriptor,
550 const TensorInfo& weights,
551 const Optional<TensorInfo>& biases,
552 Optional<std::string&> reasonIfUnsupported) const
553{
554 if (descriptor.m_DilationY == 1 && descriptor.m_DilationY == 1)
555 {
556 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
557 }
558 else
559 {
560 if (reasonIfUnsupported)
561 {
562 reasonIfUnsupported.value() = "Reference Depthwise Convolution: Dilation parameters must be 1";
563 }
564 return false;
565 }
566}
567
568
569 bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100570 const TensorInfo& input1,
571 const TensorInfo& output,
572 Optional<std::string&> reasonIfUnsupported) const
573{
Sadik Armagan2999a022019-04-09 14:20:12 +0100574 bool supported = true;
575
576 std::array<DataType,3> supportedTypes = {
577 DataType::Float32,
578 DataType::QuantisedAsymm8,
579 DataType::QuantisedSymm16
580 };
581
582 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
583 "Reference division: input 0 is not a supported type.");
584
585 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
586 "Reference division: input 1 is not a supported type.");
587
588 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
589 "Reference division: output is not a supported type.");
590
591 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
592 "Reference division: input 0 and Input 1 types are mismatched");
593
594 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
595 "Reference division: input and output types are mismatched");
596
597 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
598 "Reference division: shapes are not suitable for implicit broadcast.");
599
600 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100601}
602
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000603bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
604 const TensorInfo& input1,
605 const TensorInfo& output,
606 Optional<std::string&> reasonIfUnsupported) const
607{
608 ignore_unused(input0);
609 ignore_unused(input1);
610 ignore_unused(output);
611 ignore_unused(reasonIfUnsupported);
612 return IsSupportedForDataTypeRef(reasonIfUnsupported,
613 input0.GetDataType(),
614 &TrueFunc<>,
615 &TrueFunc<>);
616}
617
arovir011c7c81b2018-10-08 11:34:28 +0100618bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
619 const FakeQuantizationDescriptor& descriptor,
620 Optional<std::string&> reasonIfUnsupported) const
621{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100622 ignore_unused(descriptor);
623 return IsSupportedForDataTypeRef(reasonIfUnsupported,
624 input.GetDataType(),
625 &TrueFunc<>,
626 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100627}
628
629bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
630 const TensorInfo& output,
631 Optional<std::string&> reasonIfUnsupported) const
632{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100633 ignore_unused(output);
James Conroy83735b12019-05-30 16:36:59 +0100634 bool supported = true;
635
James Conroyb40d7102019-06-04 12:32:09 +0100636 std::array<DataType,2> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100637 {
James Conroyb40d7102019-06-04 12:32:09 +0100638 DataType::Float32,
639 DataType::QuantisedSymm16
James Conroy83735b12019-05-30 16:36:59 +0100640 };
641
642 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
643 "Reference Floor: input type not supported.");
644
645 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
646 "Reference Floor: output type not supported.");
647
648 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100649}
650
651bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
652 const TensorInfo& output,
653 const TensorInfo& weights,
654 const TensorInfo& biases,
655 const FullyConnectedDescriptor& descriptor,
656 Optional<std::string&> reasonIfUnsupported) const
657{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100658 bool supported = true;
659
660 // Define supported types.
661 std::array<DataType,3> supportedTypes =
662 {
663 DataType::Float32,
664 DataType::QuantisedAsymm8,
665 DataType::QuantisedSymm16
666 };
667
668 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
669 "Reference Fully Connected: input type not supported.");
670
671 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
672 "Reference Fully Connected: output type not supported.");
673
674 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
675 "Reference Fully Connected: input and output types mismatched.");
676
677 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
678 "Reference Fully Connected: weights type not supported.");
679
680 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
681 "Reference Fully Connected: input and weight types mismatched.");
682
683 if (descriptor.m_BiasEnabled)
684 {
685 // Defined supported types for bias
686 std::array<DataType, 2>
687 supportedBiasTypes =
688 {
689 DataType::Float32,
690 DataType::Signed32
691 };
692
693 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
694 "Reference Fully Connected: bias type not supported.");
695
696 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
697 "Reference Fully Connected: bias and weight types mismatch.");
698
699 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
700 "Reference Fully Connected: bias type inferred from weights is incompatible.");
701
702 }
703
704 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100705}
706
narpra014951d842019-01-18 16:53:53 +0000707bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
708 const armnn::TensorInfo& input1,
709 const armnn::TensorInfo& output,
710 armnn::Optional<std::string&> reasonIfUnsupported) const
711{
712 ignore_unused(input1);
713 ignore_unused(output);
714 return IsSupportedForDataTypeRef(reasonIfUnsupported,
715 input0.GetDataType(),
716 &TrueFunc<>,
717 &TrueFunc<>);
718}
719
FrancisMurtagh878f0232018-12-19 10:56:15 +0000720bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
721 const TensorInfo& input1,
722 const TensorInfo& output,
723 Optional<std::string&> reasonIfUnsupported) const
724{
725 ignore_unused(input0);
726 ignore_unused(input1);
727 ignore_unused(output);
728 ignore_unused(reasonIfUnsupported);
729 return IsSupportedForDataTypeRef(reasonIfUnsupported,
730 input0.GetDataType(),
731 &TrueFunc<>,
732 &TrueFunc<>);
733}
734
arovir011c7c81b2018-10-08 11:34:28 +0100735bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
736 Optional<std::string&> reasonIfUnsupported) const
737{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +0100738 return true;
arovir011c7c81b2018-10-08 11:34:28 +0100739}
740
741bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
742 const TensorInfo& output,
743 const L2NormalizationDescriptor& descriptor,
744 Optional<std::string&> reasonIfUnsupported) const
745{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100746 ignore_unused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100747 // Define supported types
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100748 std::array<DataType, 3> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100749 {
750 DataType::Float32,
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100751 DataType::QuantisedAsymm8,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100752 DataType::QuantisedSymm16
753 };
754
755 bool supported = true;
756
757 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
758 "Reference L2normalization: input type not supported.");
759
760 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
761 "Reference L2normalization: output type not supported.");
762
763 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
764 "Reference L2normalization: input and output types mismatched.");
765
766 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
767 "Reference L2normalization: input and output shapes have different "
768 "num total elements.");
769
770 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100771}
772
773bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
774 const TensorInfo& outputStateIn,
775 const TensorInfo& cellStateIn,
776 const TensorInfo& scratchBuffer,
777 const TensorInfo& outputStateOut,
778 const TensorInfo& cellStateOut,
779 const TensorInfo& output,
780 const LstmDescriptor& descriptor,
781 const TensorInfo& inputToForgetWeights,
782 const TensorInfo& inputToCellWeights,
783 const TensorInfo& inputToOutputWeights,
784 const TensorInfo& recurrentToForgetWeights,
785 const TensorInfo& recurrentToCellWeights,
786 const TensorInfo& recurrentToOutputWeights,
787 const TensorInfo& forgetGateBias,
788 const TensorInfo& cellBias,
789 const TensorInfo& outputGateBias,
790 const TensorInfo* inputToInputWeights,
791 const TensorInfo* recurrentToInputWeights,
792 const TensorInfo* cellToInputWeights,
793 const TensorInfo* inputGateBias,
794 const TensorInfo* projectionWeights,
795 const TensorInfo* projectionBias,
796 const TensorInfo* cellToForgetWeights,
797 const TensorInfo* cellToOutputWeights,
798 Optional<std::string&> reasonIfUnsupported) const
799{
telsoa01c577f2c2018-08-31 09:22:23 +0100800 ignore_unused(descriptor);
801 ignore_unused(inputToForgetWeights);
802 ignore_unused(inputToCellWeights);
803 ignore_unused(inputToOutputWeights);
804 ignore_unused(recurrentToForgetWeights);
805 ignore_unused(recurrentToCellWeights);
806 ignore_unused(recurrentToOutputWeights);
807 ignore_unused(forgetGateBias);
808 ignore_unused(cellBias);
809 ignore_unused(outputGateBias);
810 ignore_unused(inputToInputWeights);
811 ignore_unused(recurrentToInputWeights);
812 ignore_unused(cellToInputWeights);
813 ignore_unused(inputGateBias);
814 ignore_unused(projectionWeights);
815 ignore_unused(projectionBias);
816 ignore_unused(cellToForgetWeights);
817 ignore_unused(cellToOutputWeights);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100818
819 bool supported = true;
820
821 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +0100822 DataType::Float32,
823 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100824 };
825
826 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
827 "Reference Lstm: input is not a supported type.");
828
829 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
830 "Reference Lstm: input and outputStateIn types are mismatched");
831
832 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
833 "Reference Lstm: input and cellStateIn types are mismatched");
834
835 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
836 "Reference Lstm: input and scratchBuffer types are mismatched");
837
838 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
839 "Reference Lstm: input and outputStateOut types are mismatched");
840
841 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
842 "Reference Lstm: input and cellStateOut types are mismatched");
843
844 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
845 "Reference Lstm: input and output types are mismatched");
846
847 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +0100848}
849
saoste012df12b32018-11-28 16:57:20 +0000850bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
851 const TensorInfo& input1,
852 const TensorInfo& output,
853 Optional<std::string&> reasonIfUnsupported) const
854{
Sadik Armagan2999a022019-04-09 14:20:12 +0100855 bool supported = true;
856
857 std::array<DataType,3> supportedTypes = {
858 DataType::Float32,
859 DataType::QuantisedAsymm8,
860 DataType::QuantisedSymm16
861 };
862
863 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
864 "Reference maximum: input 0 is not a supported type.");
865
866 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
867 "Reference maximum: input 1 is not a supported type.");
868
869 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
870 "Reference maximum: output is not a supported type.");
871
872 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
873 "Reference maximum: input 0 and Input 1 types are mismatched");
874
875 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
876 "Reference maximum: input and output types are mismatched");
877
878 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
879 "Reference maximum: shapes are not suitable for implicit broadcast.");
880
881 return supported;
saoste012df12b32018-11-28 16:57:20 +0000882}
883
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100884bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
885 const TensorInfo& output,
886 const MeanDescriptor& descriptor,
887 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100888{
narpra011e4c31d2018-09-28 11:07:51 +0100889 ignore_unused(output);
890 ignore_unused(descriptor);
891 return IsSupportedForDataTypeRef(reasonIfUnsupported,
892 input.GetDataType(),
893 &TrueFunc<>,
894 &TrueFunc<>);
narpra0132b90462018-09-13 11:07:48 +0100895}
896
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100897bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +0000898 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100899 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100900 Optional<std::string&> reasonIfUnsupported) const
901{
Jim Flynne242f2d2019-05-22 14:24:13 +0100902 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100903}
904
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000905bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
906 const TensorInfo &output,
907 Optional<std::string &> reasonIfUnsupported) const
908{
909 ignore_unused(output);
kevmay012b4d88e2019-01-24 14:05:09 +0000910 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
911 input.GetDataType(),
912 &TrueFunc<>,
913 &TrueFunc<>,
914 &TrueFunc<>,
915 &FalseFuncI32<>,
916 &TrueFunc<>);
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000917}
918
Éanna Ó Catháin20e58802018-12-04 10:29:06 +0000919bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
920 const TensorInfo& input1,
921 const TensorInfo& output,
922 Optional<std::string&> reasonIfUnsupported) const
923{
Sadik Armagan2999a022019-04-09 14:20:12 +0100924 bool supported = true;
925
926 std::array<DataType,3> supportedTypes = {
927 DataType::Float32,
928 DataType::QuantisedAsymm8,
929 DataType::QuantisedSymm16
930 };
931
932 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
933 "Reference minimum: input 0 is not a supported type.");
934
935 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
936 "Reference minimum: input 1 is not a supported type.");
937
938 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
939 "Reference minimum: output is not a supported type.");
940
941 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
942 "Reference minimum: input 0 and Input 1 types are mismatched");
943
944 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
945 "Reference minimum: input and output types are mismatched");
946
947 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
948 "Reference minimum: shapes are not suitable for implicit broadcast.");
949
950 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +0000951}
952
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100953bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
954 const TensorInfo& input1,
955 const TensorInfo& output,
956 Optional<std::string&> reasonIfUnsupported) const
957{
Sadik Armagan2999a022019-04-09 14:20:12 +0100958 bool supported = true;
959
960 std::array<DataType,3> supportedTypes = {
961 DataType::Float32,
962 DataType::QuantisedAsymm8,
963 DataType::QuantisedSymm16
964 };
965
966 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
967 "Reference multiplication: input 0 is not a supported type.");
968
969 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
970 "Reference multiplication: input 1 is not a supported type.");
971
972 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
973 "Reference multiplication: output is not a supported type.");
974
975 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
976 "Reference multiplication: input 0 and Input 1 types are mismatched");
977
978 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
979 "Reference multiplication: input and output types are mismatched");
980
981 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
982 "Reference multiplication: shapes are not suitable for implicit broadcast.");
983
984 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100985}
986
987bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
988 const TensorInfo& output,
989 const NormalizationDescriptor& descriptor,
990 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +0100991{
Nina Drozd661dfa72018-10-02 11:14:17 +0100992 ignore_unused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100993
994 // Define supported types
Matteo Martincigh6aeb7712019-06-05 17:23:29 +0100995 std::array<DataType, 4> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100996 {
997 DataType::Float16,
998 DataType::Float32,
Matteo Martincigh6aeb7712019-06-05 17:23:29 +0100999 DataType::QuantisedAsymm8,
1000 DataType::QuantisedSymm16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001001 };
1002
1003 bool supported = true;
1004
1005 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1006 "Reference normalization: input type not supported.");
1007
1008 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1009 "Reference normalization: output type not supported.");
1010
1011 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1012 "Reference normalization: input and output shapes have different "
1013 "num total elements.");
1014
1015 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001016}
1017
1018bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
1019 Optional<std::string&> reasonIfUnsupported) const
1020{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001021 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001022}
1023
1024bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1025 const TensorInfo& output,
1026 const PadDescriptor& descriptor,
1027 Optional<std::string&> reasonIfUnsupported) const
1028{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001029 ignore_unused(output);
1030 ignore_unused(descriptor);
jimfly01f6ba7472018-12-04 10:09:52 +00001031 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1032 input.GetDataType(),
1033 &TrueFunc<>,
1034 &TrueFunc<>);
Nina Drozd661dfa72018-10-02 11:14:17 +01001035}
1036
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001037bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1038 const TensorInfo& output,
1039 const PermuteDescriptor& descriptor,
1040 Optional<std::string&> reasonIfUnsupported) const
1041{
1042 ignore_unused(output);
1043 ignore_unused(descriptor);
1044 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1045 input.GetDataType(),
1046 &TrueFunc<>,
1047 &TrueFunc<>);
1048}
1049
1050bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1051 const TensorInfo& output,
1052 const Pooling2dDescriptor& descriptor,
1053 Optional<std::string&> reasonIfUnsupported) const
1054{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001055 ignore_unused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001056 bool supported = true;
1057
1058 // Define supported output and inputs types.
Teresa Charlin0434df62019-06-06 13:40:35 +01001059 std::array<DataType,3> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001060 {
1061 DataType::Float32,
Teresa Charlin0434df62019-06-06 13:40:35 +01001062 DataType::QuantisedAsymm8,
1063 DataType::QuantisedSymm16
Teresa Charlina3b20472019-06-06 11:12:32 +01001064 };
1065
1066 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1067 "Reference poolind2d: input is not a supported type.");
1068
1069 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1070 "Reference poolind2d: output is not a supported type.");
1071
1072 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1073 "Reference poolind2d: input and output types are mismatched.");
1074
1075 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001076}
1077
Derek Lamberti5f400d62019-03-25 15:41:58 +00001078bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1079 const TensorInfo& output,
1080 Optional<std::string&> reasonIfUnsupported) const
1081{
1082 bool supported = true;
1083
1084 // Define supported output types.
1085 std::array<DataType,2> supportedInputTypes = {
1086 DataType::Float32,
1087 };
1088
1089 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1090 "Reference quantize: input type not supported.");
1091
1092 // Define supported output types.
1093 std::array<DataType,2> supportedOutputTypes = {
1094 DataType::QuantisedAsymm8,
1095 DataType::QuantisedSymm16
1096 };
1097 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1098 "Reference quantize: output type not supported.");
1099
1100 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1101 "Reference quantize: input and output shapes have different num total elements.");
1102
1103 return supported;
1104}
1105
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001106bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001107 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001108 Optional<std::string&> reasonIfUnsupported) const
1109{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001110 ignore_unused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001111 // Define supported output types.
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001112 std::array<DataType,4> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001113 {
1114 DataType::Float32,
1115 DataType::Float16,
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001116 DataType::QuantisedAsymm8,
1117 DataType::QuantisedSymm16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001118 };
1119 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1120 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001121}
1122
1123bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001124 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001125 Optional<std::string&> reasonIfUnsupported) const
1126{
Sadik Armaganc625f002018-12-17 11:32:16 +00001127 ignore_unused(output);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001128 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1129 input.GetDataType(),
1130 &TrueFunc<>,
1131 &TrueFunc<>);
1132}
1133
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001134bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1135 const TensorInfo& output,
1136 Optional<std::string&> reasonIfUnsupported) const
1137{
nikraj010421e7f2019-06-14 09:40:34 +01001138 bool supported = true;
nikraj0124d73212019-06-14 14:20:40 +01001139 std::array<DataType,3> supportedTypes =
nikraj010421e7f2019-06-14 09:40:34 +01001140 {
1141 DataType::Float32,
nikraj0124d73212019-06-14 14:20:40 +01001142 DataType::QuantisedAsymm8,
1143 DataType::QuantisedSymm16
nikraj010421e7f2019-06-14 09:40:34 +01001144 };
1145
1146 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1147 "Reference rsqrt: input type not supported");
1148
1149 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1150 "Reference rsqrt: output type not supported");
1151
1152 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1153 "Reference rsqrt: input and output types not matching");
1154
1155 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1156 "Reference Rsqrt: input and output shapes have different number of total elements");
1157
1158 return supported;
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001159}
1160
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001161bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1162 const TensorInfo& output,
1163 const SoftmaxDescriptor& descriptor,
1164 Optional<std::string&> reasonIfUnsupported) const
1165{
1166 ignore_unused(output);
nikraj01248683f2019-05-29 16:46:50 +01001167 bool supported = true;
1168 std::array<DataType,3> supportedTypes =
1169 {
1170 DataType::Float32,
1171 DataType::QuantisedAsymm8,
1172 DataType::QuantisedSymm16
1173 };
1174
1175 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1176 "Reference concatenation: output type not supported");
1177
1178 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1179 "Reference concatenation: input type not supported");
1180
1181 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1182 "Reference concatenation: input type not supported");
1183
1184 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001185}
1186
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001187bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1188 const TensorInfo& output,
1189 const SpaceToBatchNdDescriptor& descriptor,
1190 Optional<std::string&> reasonIfUnsupported) const
1191{
1192 ignore_unused(output);
nikraj01120522a2019-05-31 11:33:07 +01001193 bool supported = true;
1194 std::array<DataType,3> supportedTypes =
1195 {
1196 DataType::Float32,
1197 DataType::QuantisedAsymm8,
1198 DataType::QuantisedSymm16
1199 };
1200
1201 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1202 "Reference SpaceToBatchNd: input type not supported");
1203
1204 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1205 "Reference SpaceToBatchNd: output type not supported");
1206
1207 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1208 "Reference SpaceToBatchNd: input and output types are mismatched");
1209
1210 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001211}
1212
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001213bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1214 const ViewsDescriptor& descriptor,
1215 Optional<std::string&> reasonIfUnsupported) const
1216{
1217 ignore_unused(descriptor);
1218 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1219 input.GetDataType(),
1220 &TrueFunc<>,
1221 &TrueFunc<>);
1222}
1223
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001224bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1225 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1226 const ViewsDescriptor& descriptor,
1227 Optional<std::string&> reasonIfUnsupported) const
1228{
1229 ignore_unused(descriptor);
1230 ignore_unused(outputs);
1231 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1232 input.GetDataType(),
1233 &TrueFunc<>,
1234 &TrueFunc<>);
1235}
1236
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001237bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1238 const TensorInfo& output,
1239 const StridedSliceDescriptor& descriptor,
1240 Optional<std::string&> reasonIfUnsupported) const
1241{
1242 ignore_unused(output);
1243 ignore_unused(descriptor);
1244 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1245 input.GetDataType(),
1246 &TrueFunc<>,
1247 &TrueFunc<>);
1248}
1249
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001250bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1251 const TensorInfo& input1,
1252 const TensorInfo& output,
1253 Optional<std::string&> reasonIfUnsupported) const
1254{
Sadik Armagan2999a022019-04-09 14:20:12 +01001255 bool supported = true;
1256
1257 std::array<DataType,3> supportedTypes = {
1258 DataType::Float32,
1259 DataType::QuantisedAsymm8,
1260 DataType::QuantisedSymm16
1261 };
1262
1263 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1264 "Reference subtraction: input 0 is not a supported type.");
1265
1266 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1267 "Reference subtraction: input 1 is not a supported type.");
1268
1269 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1270 "Reference subtraction: output is not a supported type.");
1271
1272 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1273 "Reference subtraction: input 0 and Input 1 types are mismatched");
1274
1275 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1276 "Reference subtraction: input and output types are mismatched");
1277
1278 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1279 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1280
1281 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001282}
1283
arovir011c7c81b2018-10-08 11:34:28 +01001284} // namespace armnn