blob: 077aa1ce3a40a6321e10a4e6815db9f2d83d2cc9 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <InternalTypes.hpp>
10#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000012#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
David Beck111b5d92018-11-12 14:59:37 +000014#include <backendsCommon/BackendRegistry.hpp>
Francis Murtagh46c09d02019-05-28 08:15:28 +010015#include <backendsCommon/test/WorkloadTestUtils.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010016
telsoa014fcda012018-03-09 14:13:49 +000017#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
Derek Lamberti50db4e82019-03-13 14:16:15 +000019#include <vector>
20#include <algorithm>
21#include <array>
22
telsoa014fcda012018-03-09 14:13:49 +000023using namespace boost;
24
25namespace armnn
26{
27
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010028namespace
29{
30
31template<typename Float32Func, typename Uint8Func, typename ... Params>
32bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
33 DataType dataType,
34 Float32Func floatFuncPtr,
35 Uint8Func uint8FuncPtr,
36 Params&&... params)
37{
38 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
39 dataType,
40 &FalseFunc<Params...>,
41 floatFuncPtr,
42 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000043 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000044 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010045 std::forward<Params>(params)...);
46}
47
48} // anonymous namespace
49
James Conroy4d1ff582019-06-10 17:06:39 +010050namespace
51{
52
53std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
54 unsigned int actual,
55 std::string& layerStr,
56 std::string& tensorName)
57{
58 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
59 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
60
61 return errorMsg;
62}
63
64} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000065
66namespace
67{
68template<typename F>
69bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
70{
71 bool supported = rule();
72 if (!supported && reason)
73 {
74 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
75 }
76 return supported;
77}
78
79struct Rule
80{
81 bool operator()() const
82 {
83 return m_Res;
84 }
85
86 bool m_Res = true;
87};
88
Derek Lamberti2a434a82019-03-20 13:07:57 +000089template<typename T>
90bool AllTypesAreEqualImpl(T t)
Derek Lamberti50db4e82019-03-13 14:16:15 +000091{
92 return true;
93}
94
95template<typename T, typename... Rest>
96bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
97{
98 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
99
Derek Lamberti2a434a82019-03-20 13:07:57 +0000100 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
Derek Lamberti50db4e82019-03-13 14:16:15 +0000101}
102
103struct TypesAreEqual : public Rule
104{
105 template<typename ... Ts>
106 TypesAreEqual(const Ts&... ts)
107 {
108 m_Res = AllTypesAreEqualImpl(ts...);
109 }
110};
111
112struct QuantizationParametersAreEqual : public Rule
113{
114 QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
115 {
116 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
117 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
118 }
119};
120
121struct TypeAnyOf : public Rule
122{
123 template<typename Container>
124 TypeAnyOf(const TensorInfo& info, const Container& c)
125 {
126 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
Francis Murtagh46c09d02019-05-28 08:15:28 +0100127 {
128 return dt == info.GetDataType();
129 });
130 }
131};
132
133struct BiasAndWeightsTypesMatch : public Rule
134{
135 BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
136 {
137 m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value();
138 }
139};
140
141struct BiasAndWeightsTypesCompatible : public Rule
142{
143 template<typename Container>
144 BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c)
145 {
146 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
147 {
148 return dt == GetBiasTypeFromWeightsType(info.GetDataType()).value();
149 });
Derek Lamberti50db4e82019-03-13 14:16:15 +0000150 }
151};
152
153struct ShapesAreSameRank : public Rule
154{
155 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
156 {
157 m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
158 }
159};
160
Derek Lamberti5f400d62019-03-25 15:41:58 +0000161struct ShapesAreSameTotalSize : public Rule
162{
163 ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
164 {
165 m_Res = info0.GetNumElements() == info1.GetNumElements();
166 }
167};
168
Derek Lamberti50db4e82019-03-13 14:16:15 +0000169struct ShapesAreBroadcastCompatible : public Rule
170{
171 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
172 {
173 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
174 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
175 return sizeIn;
176 }
177
178 ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
179 {
180 const TensorShape& shape0 = in0.GetShape();
181 const TensorShape& shape1 = in1.GetShape();
182 const TensorShape& outShape = out.GetShape();
183
184 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
185 {
186 unsigned int sizeOut = outShape[i];
187 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
188 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
189
190 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
191 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
192 }
193 }
194};
James Conroy4d1ff582019-06-10 17:06:39 +0100195
196struct TensorNumDimensionsAreCorrect : public Rule
197{
198 TensorNumDimensionsAreCorrect(const TensorInfo& info, unsigned int expectedNumDimensions)
199 {
200 m_Res = info.GetNumDimensions() == expectedNumDimensions;
201 }
202};
203
Derek Lamberti50db4e82019-03-13 14:16:15 +0000204} // namespace
205
206
arovir011c7c81b2018-10-08 11:34:28 +0100207bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
208 const TensorInfo& output,
209 const ActivationDescriptor& descriptor,
210 Optional<std::string&> reasonIfUnsupported) const
211{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000212 bool supported = true;
213
214 // Define supported types.
Teresa Charlin18515e22019-04-24 10:17:46 +0100215 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000216 DataType::Float32,
Teresa Charlin18515e22019-04-24 10:17:46 +0100217 DataType::QuantisedAsymm8,
218 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000219 };
220
221 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
222 "Reference activation: input type not supported.");
223
224 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
225 "Reference activation: output type not supported.");
226
227 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
228 "Reference activation: input and output types mismatched.");
229
230 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
231 "Reference activation: input and output shapes are of different rank.");
232
233
234 struct ActivationFunctionSupported : public Rule
235 {
236 ActivationFunctionSupported(const ActivationDescriptor& desc)
237 {
238 switch(desc.m_Function)
239 {
240 case ActivationFunction::Abs:
241 case ActivationFunction::BoundedReLu:
242 case ActivationFunction::LeakyReLu:
243 case ActivationFunction::Linear:
244 case ActivationFunction::ReLu:
245 case ActivationFunction::Sigmoid:
246 case ActivationFunction::SoftReLu:
247 case ActivationFunction::Sqrt:
248 case ActivationFunction::Square:
249 case ActivationFunction::TanH:
250 {
251 m_Res = true;
252 break;
253 }
254 default:
255 {
256 m_Res = false;
257 break;
258 }
259 }
260 }
261 };
262
263 // Function is supported
264 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
265 "Reference activation: function not supported.");
266
267 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100268}
269
270bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
271 const TensorInfo& input1,
272 const TensorInfo& output,
273 Optional<std::string&> reasonIfUnsupported) const
274{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000275 bool supported = true;
276
Sadik Armagan2999a022019-04-09 14:20:12 +0100277 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000278 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100279 DataType::QuantisedAsymm8,
280 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000281 };
282
283 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
284 "Reference addition: input 0 is not a supported type.");
285
286 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
287 "Reference addition: input 1 is not a supported type.");
288
289 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
290 "Reference addition: output is not a supported type.");
291
292 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
293 "Reference addition: input 0 and Input 1 types are mismatched");
294
295 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
296 "Reference addition: input and output types are mismatched");
297
298 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
299 "Reference addition: shapes are not suitable for implicit broadcast.");
300
301 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100302}
303
304bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
305 const TensorInfo& output,
306 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100307 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100308 const TensorInfo& beta,
309 const TensorInfo& gamma,
310 const BatchNormalizationDescriptor& descriptor,
311 Optional<std::string&> reasonIfUnsupported) const
312{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100313 ignore_unused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100314
Matteo Martincighf5507132019-06-04 10:59:47 +0100315 std::array<DataType, 3> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100316 {
317 DataType::Float32,
Matteo Martincighf5507132019-06-04 10:59:47 +0100318 DataType::QuantisedAsymm8,
319 DataType::QuantisedSymm16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100320 };
321
322 bool supported = true;
323
324 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
325 "Reference batch normalization: input is not a supported type.");
326
327 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
328 "Reference batch normalization: output is not a supported type.");
329
330 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
331 "Reference batch normalization: input and output types are mismatched");
332
333 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
334 "Reference batch normalization: mean is not a supported type.");
335
336 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
337 "Reference batch normalization: variance is not a supported type.");
338
339 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
340 "Reference batch normalization: beta is not a supported type.");
341
342 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
343 "Reference batch normalization: gamma is not a supported type.");
344
345 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100346}
347
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000348bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
349 const TensorInfo& output,
350 const BatchToSpaceNdDescriptor& descriptor,
351 Optional<std::string&> reasonIfUnsupported) const
352{
353 ignore_unused(descriptor);
354 return (IsSupportedForDataTypeRef(reasonIfUnsupported,
355 input.GetDataType(),
356 &TrueFunc<>,
357 &TrueFunc<>) &&
358 IsSupportedForDataTypeRef(reasonIfUnsupported,
359 output.GetDataType(),
360 &TrueFunc<>,
361 &TrueFunc<>));
362}
363
Jim Flynn906f9462019-05-10 13:55:21 +0100364bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
365 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100366 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100367 Optional<std::string&> reasonIfUnsupported) const
368{
Jim Flynne242f2d2019-05-22 14:24:13 +0100369 ignore_unused(descriptor);
370
371 bool supported = true;
372 std::array<DataType,3> supportedTypes =
373 {
374 DataType::Float32,
375 DataType::QuantisedAsymm8,
376 DataType::QuantisedSymm16
377 };
378
379 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
380 "Reference concatenation: output type not supported");
381 for (const TensorInfo* input : inputs)
382 {
383 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
384 "Reference concatenation: input type not supported");
385
386 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
387 "Reference concatenation: input and output types mismatched.");
388 }
389
390 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100391}
392
arovir011c7c81b2018-10-08 11:34:28 +0100393bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
394 Optional<std::string&> reasonIfUnsupported) const
395{
Jim Flynne242f2d2019-05-22 14:24:13 +0100396 std::array<DataType,4> supportedTypes =
397 {
Nina Drozd58ef2c62019-05-16 12:09:18 +0100398 DataType::Float32,
399 DataType::Signed32,
400 DataType::QuantisedAsymm8,
401 DataType::QuantisedSymm16
402 };
403
404 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
405 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100406}
407
408bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
409 const TensorInfo& output,
410 Optional<std::string&> reasonIfUnsupported) const
411{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100412 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
413 input.GetDataType(),
414 &TrueFunc<>,
415 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000416 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000417 &FalseFuncI32<>,
418 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100419 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
420 output.GetDataType(),
421 &FalseOutputFuncF16<>,
422 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000423 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000424 &FalseFuncI32<>,
425 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100426}
427
428bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
429 const TensorInfo& output,
430 Optional<std::string&> reasonIfUnsupported) const
431{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100432 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
433 input.GetDataType(),
434 &FalseInputFuncF16<>,
435 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000436 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000437 &FalseFuncI32<>,
438 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100439 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
440 output.GetDataType(),
441 &TrueFunc<>,
442 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000443 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000444 &FalseFuncI32<>,
445 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100446}
447
448bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
449 const TensorInfo& output,
450 const Convolution2dDescriptor& descriptor,
451 const TensorInfo& weights,
452 const Optional<TensorInfo>& biases,
453 Optional<std::string&> reasonIfUnsupported) const
454{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100455 bool supported = true;
456
457 // Define supported types.
458 std::array<DataType,3> supportedTypes = {
459 DataType::Float32,
460 DataType::QuantisedAsymm8,
461 DataType::QuantisedSymm16
462 };
463
464 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100465 "Reference convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100466
467 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100468 "Reference convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100469
470 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100471 "Reference convolution2d: weights is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100472
473 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100474 "Reference convolution2d: input and output types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100475
476 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100477 "Reference convolution2d: input and weights types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100478
479 if (biases.has_value())
480 {
481 std::array<DataType,3> biasesSupportedTypes = {
482 DataType::Float32,
483 DataType::Signed32
484 };
485 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100486 "Reference convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100487 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100488 ignore_unused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100489
490 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100491}
492
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000493bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
494 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000495 Optional<std::string&> reasonIfUnsupported) const
496{
497 ignore_unused(output);
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000498 return IsSupportedForDataTypeRef(reasonIfUnsupported,
499 input.GetDataType(),
500 &TrueFunc<>,
501 &TrueFunc<>);
502}
503
arovir011c7c81b2018-10-08 11:34:28 +0100504bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
505 const TensorInfo& output,
506 const DepthwiseConvolution2dDescriptor& descriptor,
507 const TensorInfo& weights,
508 const Optional<TensorInfo>& biases,
509 Optional<std::string&> reasonIfUnsupported) const
510{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100511 ignore_unused(output);
512 ignore_unused(descriptor);
513 ignore_unused(weights);
514 ignore_unused(biases);
515 return IsSupportedForDataTypeRef(reasonIfUnsupported,
516 input.GetDataType(),
517 &TrueFunc<>,
518 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100519}
520
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000521bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
522 const TensorInfo& output,
523 Optional<std::string&> reasonIfUnsupported) const
524{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100525 bool supported = true;
526
527 std::array<DataType,2> supportedInputTypes = {
528 DataType::QuantisedAsymm8,
529 DataType::QuantisedSymm16
530 };
531
532 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
533 "Reference dequantize: input type not supported.");
534
535 std::array<DataType,2> supportedOutputTypes = {
536 DataType::Float32,
537 };
538
539 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
540 "Reference dequantize: output type not supported.");
541
542 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
543 "Reference dequantize: input and output shapes have different num total elements.");
544
545 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000546}
547
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000548bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
549 const armnn::TensorInfo& input1,
550 const armnn::DetectionPostProcessDescriptor& descriptor,
551 armnn::Optional<std::string&> reasonIfUnsupported) const
552{
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100553 bool supported = true;
554
555 std::vector<DataType> supportedInputTypes =
556 {
557 DataType::Float32,
558 DataType::QuantisedAsymm8,
559 DataType::QuantisedSymm16
560 };
561
562 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
563 "Reference DetectionPostProcess: input 0 is not a supported type.");
564
565 supported &= CheckSupportRule(TypeAnyOf(input1, supportedInputTypes), reasonIfUnsupported,
566 "Reference DetectionPostProcess: input 1 is not a supported type.");
567
568 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000569}
570
Pablo Tellof0bd6832019-04-26 17:58:13 +0100571bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
572 const TensorInfo& output,
573 const DepthwiseConvolution2dDescriptor& descriptor,
574 const TensorInfo& weights,
575 const Optional<TensorInfo>& biases,
576 Optional<std::string&> reasonIfUnsupported) const
577{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100578 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100579}
580
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100581bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100582 const TensorInfo& input1,
583 const TensorInfo& output,
584 Optional<std::string&> reasonIfUnsupported) const
585{
Sadik Armagan2999a022019-04-09 14:20:12 +0100586 bool supported = true;
587
588 std::array<DataType,3> supportedTypes = {
589 DataType::Float32,
590 DataType::QuantisedAsymm8,
591 DataType::QuantisedSymm16
592 };
593
594 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
595 "Reference division: input 0 is not a supported type.");
596
597 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
598 "Reference division: input 1 is not a supported type.");
599
600 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
601 "Reference division: output is not a supported type.");
602
603 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
604 "Reference division: input 0 and Input 1 types are mismatched");
605
606 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
607 "Reference division: input and output types are mismatched");
608
609 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
610 "Reference division: shapes are not suitable for implicit broadcast.");
611
612 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100613}
614
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000615bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
616 const TensorInfo& input1,
617 const TensorInfo& output,
618 Optional<std::string&> reasonIfUnsupported) const
619{
620 ignore_unused(input0);
621 ignore_unused(input1);
622 ignore_unused(output);
623 ignore_unused(reasonIfUnsupported);
624 return IsSupportedForDataTypeRef(reasonIfUnsupported,
625 input0.GetDataType(),
626 &TrueFunc<>,
627 &TrueFunc<>);
628}
629
arovir011c7c81b2018-10-08 11:34:28 +0100630bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
631 const FakeQuantizationDescriptor& descriptor,
632 Optional<std::string&> reasonIfUnsupported) const
633{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100634 ignore_unused(descriptor);
635 return IsSupportedForDataTypeRef(reasonIfUnsupported,
636 input.GetDataType(),
637 &TrueFunc<>,
638 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100639}
640
641bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
642 const TensorInfo& output,
643 Optional<std::string&> reasonIfUnsupported) const
644{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100645 ignore_unused(output);
James Conroy83735b12019-05-30 16:36:59 +0100646 bool supported = true;
647
James Conroyb40d7102019-06-04 12:32:09 +0100648 std::array<DataType,2> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100649 {
James Conroyb40d7102019-06-04 12:32:09 +0100650 DataType::Float32,
651 DataType::QuantisedSymm16
James Conroy83735b12019-05-30 16:36:59 +0100652 };
653
654 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
655 "Reference Floor: input type not supported.");
656
657 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
658 "Reference Floor: output type not supported.");
659
660 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100661}
662
663bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
664 const TensorInfo& output,
665 const TensorInfo& weights,
666 const TensorInfo& biases,
667 const FullyConnectedDescriptor& descriptor,
668 Optional<std::string&> reasonIfUnsupported) const
669{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100670 bool supported = true;
671
672 // Define supported types.
673 std::array<DataType,3> supportedTypes =
674 {
675 DataType::Float32,
676 DataType::QuantisedAsymm8,
677 DataType::QuantisedSymm16
678 };
679
680 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
681 "Reference Fully Connected: input type not supported.");
682
683 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
684 "Reference Fully Connected: output type not supported.");
685
686 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
687 "Reference Fully Connected: input and output types mismatched.");
688
689 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
690 "Reference Fully Connected: weights type not supported.");
691
692 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
693 "Reference Fully Connected: input and weight types mismatched.");
694
695 if (descriptor.m_BiasEnabled)
696 {
697 // Defined supported types for bias
698 std::array<DataType, 2>
699 supportedBiasTypes =
700 {
701 DataType::Float32,
702 DataType::Signed32
703 };
704
705 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
706 "Reference Fully Connected: bias type not supported.");
707
708 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
709 "Reference Fully Connected: bias and weight types mismatch.");
710
711 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
712 "Reference Fully Connected: bias type inferred from weights is incompatible.");
713
714 }
715
716 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100717}
718
narpra014951d842019-01-18 16:53:53 +0000719bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
720 const armnn::TensorInfo& input1,
721 const armnn::TensorInfo& output,
722 armnn::Optional<std::string&> reasonIfUnsupported) const
723{
724 ignore_unused(input1);
725 ignore_unused(output);
726 return IsSupportedForDataTypeRef(reasonIfUnsupported,
727 input0.GetDataType(),
728 &TrueFunc<>,
729 &TrueFunc<>);
730}
731
FrancisMurtagh878f0232018-12-19 10:56:15 +0000732bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
733 const TensorInfo& input1,
734 const TensorInfo& output,
735 Optional<std::string&> reasonIfUnsupported) const
736{
737 ignore_unused(input0);
738 ignore_unused(input1);
739 ignore_unused(output);
740 ignore_unused(reasonIfUnsupported);
741 return IsSupportedForDataTypeRef(reasonIfUnsupported,
742 input0.GetDataType(),
743 &TrueFunc<>,
744 &TrueFunc<>);
745}
746
arovir011c7c81b2018-10-08 11:34:28 +0100747bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
748 Optional<std::string&> reasonIfUnsupported) const
749{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +0100750 return true;
arovir011c7c81b2018-10-08 11:34:28 +0100751}
752
753bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
754 const TensorInfo& output,
755 const L2NormalizationDescriptor& descriptor,
756 Optional<std::string&> reasonIfUnsupported) const
757{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100758 ignore_unused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100759 // Define supported types
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100760 std::array<DataType, 3> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100761 {
762 DataType::Float32,
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100763 DataType::QuantisedAsymm8,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100764 DataType::QuantisedSymm16
765 };
766
767 bool supported = true;
768
769 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
770 "Reference L2normalization: input type not supported.");
771
772 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
773 "Reference L2normalization: output type not supported.");
774
775 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
776 "Reference L2normalization: input and output types mismatched.");
777
778 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
779 "Reference L2normalization: input and output shapes have different "
780 "num total elements.");
781
782 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100783}
784
785bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
786 const TensorInfo& outputStateIn,
787 const TensorInfo& cellStateIn,
788 const TensorInfo& scratchBuffer,
789 const TensorInfo& outputStateOut,
790 const TensorInfo& cellStateOut,
791 const TensorInfo& output,
792 const LstmDescriptor& descriptor,
793 const TensorInfo& inputToForgetWeights,
794 const TensorInfo& inputToCellWeights,
795 const TensorInfo& inputToOutputWeights,
796 const TensorInfo& recurrentToForgetWeights,
797 const TensorInfo& recurrentToCellWeights,
798 const TensorInfo& recurrentToOutputWeights,
799 const TensorInfo& forgetGateBias,
800 const TensorInfo& cellBias,
801 const TensorInfo& outputGateBias,
802 const TensorInfo* inputToInputWeights,
803 const TensorInfo* recurrentToInputWeights,
804 const TensorInfo* cellToInputWeights,
805 const TensorInfo* inputGateBias,
806 const TensorInfo* projectionWeights,
807 const TensorInfo* projectionBias,
808 const TensorInfo* cellToForgetWeights,
809 const TensorInfo* cellToOutputWeights,
810 Optional<std::string&> reasonIfUnsupported) const
811{
telsoa01c577f2c2018-08-31 09:22:23 +0100812 ignore_unused(descriptor);
813 ignore_unused(inputToForgetWeights);
814 ignore_unused(inputToCellWeights);
815 ignore_unused(inputToOutputWeights);
816 ignore_unused(recurrentToForgetWeights);
817 ignore_unused(recurrentToCellWeights);
818 ignore_unused(recurrentToOutputWeights);
819 ignore_unused(forgetGateBias);
820 ignore_unused(cellBias);
821 ignore_unused(outputGateBias);
822 ignore_unused(inputToInputWeights);
823 ignore_unused(recurrentToInputWeights);
824 ignore_unused(cellToInputWeights);
825 ignore_unused(inputGateBias);
826 ignore_unused(projectionWeights);
827 ignore_unused(projectionBias);
828 ignore_unused(cellToForgetWeights);
829 ignore_unused(cellToOutputWeights);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100830
831 bool supported = true;
832
833 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +0100834 DataType::Float32,
835 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100836 };
837
838 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
839 "Reference Lstm: input is not a supported type.");
840
841 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
842 "Reference Lstm: input and outputStateIn types are mismatched");
843
844 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
845 "Reference Lstm: input and cellStateIn types are mismatched");
846
847 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
848 "Reference Lstm: input and scratchBuffer types are mismatched");
849
850 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
851 "Reference Lstm: input and outputStateOut types are mismatched");
852
853 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
854 "Reference Lstm: input and cellStateOut types are mismatched");
855
856 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
857 "Reference Lstm: input and output types are mismatched");
858
859 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +0100860}
861
saoste012df12b32018-11-28 16:57:20 +0000862bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
863 const TensorInfo& input1,
864 const TensorInfo& output,
865 Optional<std::string&> reasonIfUnsupported) const
866{
Sadik Armagan2999a022019-04-09 14:20:12 +0100867 bool supported = true;
868
869 std::array<DataType,3> supportedTypes = {
870 DataType::Float32,
871 DataType::QuantisedAsymm8,
872 DataType::QuantisedSymm16
873 };
874
875 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
876 "Reference maximum: input 0 is not a supported type.");
877
878 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
879 "Reference maximum: input 1 is not a supported type.");
880
881 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
882 "Reference maximum: output is not a supported type.");
883
884 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
885 "Reference maximum: input 0 and Input 1 types are mismatched");
886
887 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
888 "Reference maximum: input and output types are mismatched");
889
890 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
891 "Reference maximum: shapes are not suitable for implicit broadcast.");
892
893 return supported;
saoste012df12b32018-11-28 16:57:20 +0000894}
895
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100896bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
897 const TensorInfo& output,
898 const MeanDescriptor& descriptor,
899 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100900{
James Conroy4d1ff582019-06-10 17:06:39 +0100901 bool supported = true;
902 std::string meanLayerStr = "Mean";
903 std::string outputTensorStr = "output";
904
James Conroyb80775f2019-06-11 11:25:30 +0100905 std::array<DataType,3> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +0100906 {
907 DataType::Float32,
James Conroyb80775f2019-06-11 11:25:30 +0100908 DataType::QuantisedAsymm8,
909 DataType::QuantisedSymm16
James Conroy4d1ff582019-06-10 17:06:39 +0100910 };
911
912 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
913 "Reference Mean: input type not supported.");
914
915 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
916 "Reference Mean: input and output types are mismatched");
917
918 if (descriptor.m_KeepDims)
919 {
920 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
921 reasonIfUnsupported,
922 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
923 output.GetNumDimensions(),
924 meanLayerStr, outputTensorStr).data());
925 }
926 else if (descriptor.m_Axis.empty())
927 {
928 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
929 reasonIfUnsupported,
930 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
931 meanLayerStr, outputTensorStr).data());
932 }
933 else
934 {
935 auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
936
937 if (outputDim > 0)
938 {
939 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
940 reasonIfUnsupported,
941 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
942 meanLayerStr, outputTensorStr).data());
943 }
944 else
945 {
946 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
947 reasonIfUnsupported,
948 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
949 meanLayerStr, outputTensorStr).data());
950 }
951 }
952
953 return supported;
narpra0132b90462018-09-13 11:07:48 +0100954}
955
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100956bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +0000957 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100958 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100959 Optional<std::string&> reasonIfUnsupported) const
960{
Jim Flynne242f2d2019-05-22 14:24:13 +0100961 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100962}
963
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000964bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
965 const TensorInfo &output,
966 Optional<std::string &> reasonIfUnsupported) const
967{
968 ignore_unused(output);
kevmay012b4d88e2019-01-24 14:05:09 +0000969 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
970 input.GetDataType(),
971 &TrueFunc<>,
972 &TrueFunc<>,
973 &TrueFunc<>,
974 &FalseFuncI32<>,
975 &TrueFunc<>);
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000976}
977
Éanna Ó Catháin20e58802018-12-04 10:29:06 +0000978bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
979 const TensorInfo& input1,
980 const TensorInfo& output,
981 Optional<std::string&> reasonIfUnsupported) const
982{
Sadik Armagan2999a022019-04-09 14:20:12 +0100983 bool supported = true;
984
985 std::array<DataType,3> supportedTypes = {
986 DataType::Float32,
987 DataType::QuantisedAsymm8,
988 DataType::QuantisedSymm16
989 };
990
991 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
992 "Reference minimum: input 0 is not a supported type.");
993
994 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
995 "Reference minimum: input 1 is not a supported type.");
996
997 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
998 "Reference minimum: output is not a supported type.");
999
1000 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1001 "Reference minimum: input 0 and Input 1 types are mismatched");
1002
1003 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1004 "Reference minimum: input and output types are mismatched");
1005
1006 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1007 "Reference minimum: shapes are not suitable for implicit broadcast.");
1008
1009 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001010}
1011
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001012bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1013 const TensorInfo& input1,
1014 const TensorInfo& output,
1015 Optional<std::string&> reasonIfUnsupported) const
1016{
Sadik Armagan2999a022019-04-09 14:20:12 +01001017 bool supported = true;
1018
1019 std::array<DataType,3> supportedTypes = {
1020 DataType::Float32,
1021 DataType::QuantisedAsymm8,
1022 DataType::QuantisedSymm16
1023 };
1024
1025 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1026 "Reference multiplication: input 0 is not a supported type.");
1027
1028 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1029 "Reference multiplication: input 1 is not a supported type.");
1030
1031 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1032 "Reference multiplication: output is not a supported type.");
1033
1034 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1035 "Reference multiplication: input 0 and Input 1 types are mismatched");
1036
1037 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1038 "Reference multiplication: input and output types are mismatched");
1039
1040 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1041 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1042
1043 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001044}
1045
1046bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1047 const TensorInfo& output,
1048 const NormalizationDescriptor& descriptor,
1049 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001050{
Nina Drozd661dfa72018-10-02 11:14:17 +01001051 ignore_unused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001052
1053 // Define supported types
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001054 std::array<DataType, 4> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001055 {
1056 DataType::Float16,
1057 DataType::Float32,
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001058 DataType::QuantisedAsymm8,
1059 DataType::QuantisedSymm16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001060 };
1061
1062 bool supported = true;
1063
1064 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1065 "Reference normalization: input type not supported.");
1066
1067 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1068 "Reference normalization: output type not supported.");
1069
1070 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1071 "Reference normalization: input and output shapes have different "
1072 "num total elements.");
1073
1074 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001075}
1076
1077bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
1078 Optional<std::string&> reasonIfUnsupported) const
1079{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001080 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001081}
1082
1083bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1084 const TensorInfo& output,
1085 const PadDescriptor& descriptor,
1086 Optional<std::string&> reasonIfUnsupported) const
1087{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001088 ignore_unused(output);
1089 ignore_unused(descriptor);
jimfly01f6ba7472018-12-04 10:09:52 +00001090 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1091 input.GetDataType(),
1092 &TrueFunc<>,
1093 &TrueFunc<>);
Nina Drozd661dfa72018-10-02 11:14:17 +01001094}
1095
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001096bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1097 const TensorInfo& output,
1098 const PermuteDescriptor& descriptor,
1099 Optional<std::string&> reasonIfUnsupported) const
1100{
1101 ignore_unused(output);
1102 ignore_unused(descriptor);
1103 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1104 input.GetDataType(),
1105 &TrueFunc<>,
1106 &TrueFunc<>);
1107}
1108
1109bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1110 const TensorInfo& output,
1111 const Pooling2dDescriptor& descriptor,
1112 Optional<std::string&> reasonIfUnsupported) const
1113{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001114 ignore_unused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001115 bool supported = true;
1116
1117 // Define supported output and inputs types.
Teresa Charlin0434df62019-06-06 13:40:35 +01001118 std::array<DataType,3> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001119 {
1120 DataType::Float32,
Teresa Charlin0434df62019-06-06 13:40:35 +01001121 DataType::QuantisedAsymm8,
1122 DataType::QuantisedSymm16
Teresa Charlina3b20472019-06-06 11:12:32 +01001123 };
1124
1125 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1126 "Reference poolind2d: input is not a supported type.");
1127
1128 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1129 "Reference poolind2d: output is not a supported type.");
1130
1131 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1132 "Reference poolind2d: input and output types are mismatched.");
1133
1134 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001135}
1136
Derek Lamberti5f400d62019-03-25 15:41:58 +00001137bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1138 const TensorInfo& output,
1139 Optional<std::string&> reasonIfUnsupported) const
1140{
1141 bool supported = true;
1142
1143 // Define supported output types.
1144 std::array<DataType,2> supportedInputTypes = {
1145 DataType::Float32,
1146 };
1147
1148 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1149 "Reference quantize: input type not supported.");
1150
1151 // Define supported output types.
1152 std::array<DataType,2> supportedOutputTypes = {
1153 DataType::QuantisedAsymm8,
1154 DataType::QuantisedSymm16
1155 };
1156 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1157 "Reference quantize: output type not supported.");
1158
1159 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1160 "Reference quantize: input and output shapes have different num total elements.");
1161
1162 return supported;
1163}
1164
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001165bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001166 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001167 Optional<std::string&> reasonIfUnsupported) const
1168{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001169 ignore_unused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001170 // Define supported output types.
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001171 std::array<DataType,4> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001172 {
1173 DataType::Float32,
1174 DataType::Float16,
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001175 DataType::QuantisedAsymm8,
1176 DataType::QuantisedSymm16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001177 };
1178 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1179 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001180}
1181
1182bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001183 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001184 Optional<std::string&> reasonIfUnsupported) const
1185{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001186 bool supported = true;
1187 std::array<DataType,3> supportedTypes =
1188 {
1189 DataType::Float32,
1190 DataType::QuantisedAsymm8,
1191 DataType::QuantisedSymm16
1192 };
1193
1194 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1195 "Reference ResizeBilinear: input type not supported");
1196
1197 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1198 "Reference ResizeBilinear: output type not supported");
1199
1200 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1201 "Reference ResizeBilinear: input and output types not matching");
1202
1203 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001204}
1205
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001206bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1207 const TensorInfo& output,
1208 Optional<std::string&> reasonIfUnsupported) const
1209{
nikraj010421e7f2019-06-14 09:40:34 +01001210 bool supported = true;
nikraj0124d73212019-06-14 14:20:40 +01001211 std::array<DataType,3> supportedTypes =
nikraj010421e7f2019-06-14 09:40:34 +01001212 {
1213 DataType::Float32,
nikraj0124d73212019-06-14 14:20:40 +01001214 DataType::QuantisedAsymm8,
1215 DataType::QuantisedSymm16
nikraj010421e7f2019-06-14 09:40:34 +01001216 };
1217
1218 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1219 "Reference rsqrt: input type not supported");
1220
1221 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1222 "Reference rsqrt: output type not supported");
1223
1224 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1225 "Reference rsqrt: input and output types not matching");
1226
1227 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1228 "Reference Rsqrt: input and output shapes have different number of total elements");
1229
1230 return supported;
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001231}
1232
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001233bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1234 const TensorInfo& output,
1235 const SoftmaxDescriptor& descriptor,
1236 Optional<std::string&> reasonIfUnsupported) const
1237{
1238 ignore_unused(output);
nikraj01248683f2019-05-29 16:46:50 +01001239 bool supported = true;
1240 std::array<DataType,3> supportedTypes =
1241 {
1242 DataType::Float32,
1243 DataType::QuantisedAsymm8,
1244 DataType::QuantisedSymm16
1245 };
1246
1247 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1248 "Reference concatenation: output type not supported");
1249
1250 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1251 "Reference concatenation: input type not supported");
1252
1253 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1254 "Reference concatenation: input type not supported");
1255
1256 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001257}
1258
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001259bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1260 const TensorInfo& output,
1261 const SpaceToBatchNdDescriptor& descriptor,
1262 Optional<std::string&> reasonIfUnsupported) const
1263{
1264 ignore_unused(output);
nikraj01120522a2019-05-31 11:33:07 +01001265 bool supported = true;
1266 std::array<DataType,3> supportedTypes =
1267 {
1268 DataType::Float32,
1269 DataType::QuantisedAsymm8,
1270 DataType::QuantisedSymm16
1271 };
1272
1273 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1274 "Reference SpaceToBatchNd: input type not supported");
1275
1276 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1277 "Reference SpaceToBatchNd: output type not supported");
1278
1279 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1280 "Reference SpaceToBatchNd: input and output types are mismatched");
1281
1282 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001283}
1284
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001285bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1286 const ViewsDescriptor& descriptor,
1287 Optional<std::string&> reasonIfUnsupported) const
1288{
1289 ignore_unused(descriptor);
1290 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1291 input.GetDataType(),
1292 &TrueFunc<>,
1293 &TrueFunc<>);
1294}
1295
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001296bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1297 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1298 const ViewsDescriptor& descriptor,
1299 Optional<std::string&> reasonIfUnsupported) const
1300{
1301 ignore_unused(descriptor);
1302 ignore_unused(outputs);
1303 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1304 input.GetDataType(),
1305 &TrueFunc<>,
1306 &TrueFunc<>);
1307}
1308
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001309bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1310 const TensorInfo& output,
1311 const StridedSliceDescriptor& descriptor,
1312 Optional<std::string&> reasonIfUnsupported) const
1313{
1314 ignore_unused(output);
1315 ignore_unused(descriptor);
1316 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1317 input.GetDataType(),
1318 &TrueFunc<>,
1319 &TrueFunc<>);
1320}
1321
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001322bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1323 const TensorInfo& input1,
1324 const TensorInfo& output,
1325 Optional<std::string&> reasonIfUnsupported) const
1326{
Sadik Armagan2999a022019-04-09 14:20:12 +01001327 bool supported = true;
1328
1329 std::array<DataType,3> supportedTypes = {
1330 DataType::Float32,
1331 DataType::QuantisedAsymm8,
1332 DataType::QuantisedSymm16
1333 };
1334
1335 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1336 "Reference subtraction: input 0 is not a supported type.");
1337
1338 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1339 "Reference subtraction: input 1 is not a supported type.");
1340
1341 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1342 "Reference subtraction: output is not a supported type.");
1343
1344 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1345 "Reference subtraction: input 0 and Input 1 types are mismatched");
1346
1347 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1348 "Reference subtraction: input and output types are mismatched");
1349
1350 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1351 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1352
1353 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001354}
1355
Matteo Martincighab9e5252019-06-13 17:27:46 +01001356bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
1357 const TensorInfo& alpha,
1358 const TensorInfo& output,
1359 Optional<std::string&> reasonIfUnsupported) const
1360{
1361 bool supported = true;
1362
1363 std::array<DataType, 3> supportedTypes
1364 {
1365 DataType::Float32,
1366 DataType::QuantisedAsymm8,
1367 DataType::QuantisedSymm16
1368 };
1369
1370 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1371 "PReLU: input is not a supported type.");
1372
1373 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
1374 "PReLU: alpha is not a supported type.");
1375
1376 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1377 "PReLU: output is not a supported type.");
1378
1379 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
1380 "PReLU: input, alpha and output types are mismatched");
1381
1382 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
1383 "PReLU: shapes are not suitable for implicit broadcast");
1384
1385 return supported;
1386}
1387
arovir011c7c81b2018-10-08 11:34:28 +01001388} // namespace armnn