blob: 4e9a67879fbc696d6b160de9a35d80fca6367332 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <InternalTypes.hpp>
10#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000012#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
David Beck111b5d92018-11-12 14:59:37 +000014#include <backendsCommon/BackendRegistry.hpp>
Francis Murtagh46c09d02019-05-28 08:15:28 +010015#include <backendsCommon/test/WorkloadTestUtils.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010016
telsoa014fcda012018-03-09 14:13:49 +000017#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
Derek Lamberti50db4e82019-03-13 14:16:15 +000019#include <vector>
20#include <algorithm>
21#include <array>
22
telsoa014fcda012018-03-09 14:13:49 +000023using namespace boost;
24
25namespace armnn
26{
27
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010028namespace
29{
30
31template<typename Float32Func, typename Uint8Func, typename ... Params>
32bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
33 DataType dataType,
34 Float32Func floatFuncPtr,
35 Uint8Func uint8FuncPtr,
36 Params&&... params)
37{
38 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
39 dataType,
40 &FalseFunc<Params...>,
41 floatFuncPtr,
42 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000043 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000044 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010045 std::forward<Params>(params)...);
46}
47
48} // anonymous namespace
49
James Conroy4d1ff582019-06-10 17:06:39 +010050namespace
51{
52
53std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
54 unsigned int actual,
55 std::string& layerStr,
56 std::string& tensorName)
57{
58 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
59 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
60
61 return errorMsg;
62}
63
64} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000065
66namespace
67{
68template<typename F>
69bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
70{
71 bool supported = rule();
72 if (!supported && reason)
73 {
74 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
75 }
76 return supported;
77}
78
79struct Rule
80{
81 bool operator()() const
82 {
83 return m_Res;
84 }
85
86 bool m_Res = true;
87};
88
Derek Lamberti2a434a82019-03-20 13:07:57 +000089template<typename T>
90bool AllTypesAreEqualImpl(T t)
Derek Lamberti50db4e82019-03-13 14:16:15 +000091{
92 return true;
93}
94
95template<typename T, typename... Rest>
96bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
97{
98 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
99
Derek Lamberti2a434a82019-03-20 13:07:57 +0000100 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
Derek Lamberti50db4e82019-03-13 14:16:15 +0000101}
102
103struct TypesAreEqual : public Rule
104{
105 template<typename ... Ts>
106 TypesAreEqual(const Ts&... ts)
107 {
108 m_Res = AllTypesAreEqualImpl(ts...);
109 }
110};
111
112struct QuantizationParametersAreEqual : public Rule
113{
114 QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
115 {
116 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
117 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
118 }
119};
120
121struct TypeAnyOf : public Rule
122{
123 template<typename Container>
124 TypeAnyOf(const TensorInfo& info, const Container& c)
125 {
126 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
Francis Murtagh46c09d02019-05-28 08:15:28 +0100127 {
128 return dt == info.GetDataType();
129 });
130 }
131};
132
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100133struct TypeIs : public Rule
134{
135 TypeIs(const TensorInfo& info, DataType dt)
136 {
137 m_Res = dt == info.GetDataType();
138 }
139};
140
Francis Murtagh46c09d02019-05-28 08:15:28 +0100141struct BiasAndWeightsTypesMatch : public Rule
142{
143 BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
144 {
145 m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value();
146 }
147};
148
149struct BiasAndWeightsTypesCompatible : public Rule
150{
151 template<typename Container>
152 BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c)
153 {
154 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
155 {
156 return dt == GetBiasTypeFromWeightsType(info.GetDataType()).value();
157 });
Derek Lamberti50db4e82019-03-13 14:16:15 +0000158 }
159};
160
161struct ShapesAreSameRank : public Rule
162{
163 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
164 {
165 m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
166 }
167};
168
Derek Lamberti5f400d62019-03-25 15:41:58 +0000169struct ShapesAreSameTotalSize : public Rule
170{
171 ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
172 {
173 m_Res = info0.GetNumElements() == info1.GetNumElements();
174 }
175};
176
Derek Lamberti50db4e82019-03-13 14:16:15 +0000177struct ShapesAreBroadcastCompatible : public Rule
178{
179 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
180 {
181 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
182 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
183 return sizeIn;
184 }
185
186 ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
187 {
188 const TensorShape& shape0 = in0.GetShape();
189 const TensorShape& shape1 = in1.GetShape();
190 const TensorShape& outShape = out.GetShape();
191
192 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
193 {
194 unsigned int sizeOut = outShape[i];
195 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
196 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
197
198 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
199 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
200 }
201 }
202};
James Conroy4d1ff582019-06-10 17:06:39 +0100203
204struct TensorNumDimensionsAreCorrect : public Rule
205{
206 TensorNumDimensionsAreCorrect(const TensorInfo& info, unsigned int expectedNumDimensions)
207 {
208 m_Res = info.GetNumDimensions() == expectedNumDimensions;
209 }
210};
211
Derek Lamberti50db4e82019-03-13 14:16:15 +0000212} // namespace
213
214
arovir011c7c81b2018-10-08 11:34:28 +0100215bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
216 const TensorInfo& output,
217 const ActivationDescriptor& descriptor,
218 Optional<std::string&> reasonIfUnsupported) const
219{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000220 bool supported = true;
221
222 // Define supported types.
Teresa Charlin18515e22019-04-24 10:17:46 +0100223 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000224 DataType::Float32,
Teresa Charlin18515e22019-04-24 10:17:46 +0100225 DataType::QuantisedAsymm8,
226 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000227 };
228
229 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
230 "Reference activation: input type not supported.");
231
232 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
233 "Reference activation: output type not supported.");
234
235 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
236 "Reference activation: input and output types mismatched.");
237
238 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
239 "Reference activation: input and output shapes are of different rank.");
240
241
242 struct ActivationFunctionSupported : public Rule
243 {
244 ActivationFunctionSupported(const ActivationDescriptor& desc)
245 {
246 switch(desc.m_Function)
247 {
248 case ActivationFunction::Abs:
249 case ActivationFunction::BoundedReLu:
250 case ActivationFunction::LeakyReLu:
251 case ActivationFunction::Linear:
252 case ActivationFunction::ReLu:
253 case ActivationFunction::Sigmoid:
254 case ActivationFunction::SoftReLu:
255 case ActivationFunction::Sqrt:
256 case ActivationFunction::Square:
257 case ActivationFunction::TanH:
258 {
259 m_Res = true;
260 break;
261 }
262 default:
263 {
264 m_Res = false;
265 break;
266 }
267 }
268 }
269 };
270
271 // Function is supported
272 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
273 "Reference activation: function not supported.");
274
275 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100276}
277
278bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
279 const TensorInfo& input1,
280 const TensorInfo& output,
281 Optional<std::string&> reasonIfUnsupported) const
282{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000283 bool supported = true;
284
Sadik Armagan2999a022019-04-09 14:20:12 +0100285 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000286 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100287 DataType::QuantisedAsymm8,
288 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000289 };
290
291 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
292 "Reference addition: input 0 is not a supported type.");
293
294 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
295 "Reference addition: input 1 is not a supported type.");
296
297 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
298 "Reference addition: output is not a supported type.");
299
300 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
301 "Reference addition: input 0 and Input 1 types are mismatched");
302
303 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
304 "Reference addition: input and output types are mismatched");
305
306 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
307 "Reference addition: shapes are not suitable for implicit broadcast.");
308
309 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100310}
311
312bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
313 const TensorInfo& output,
314 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100315 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100316 const TensorInfo& beta,
317 const TensorInfo& gamma,
318 const BatchNormalizationDescriptor& descriptor,
319 Optional<std::string&> reasonIfUnsupported) const
320{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100321 ignore_unused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100322
Matteo Martincighf5507132019-06-04 10:59:47 +0100323 std::array<DataType, 3> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100324 {
325 DataType::Float32,
Matteo Martincighf5507132019-06-04 10:59:47 +0100326 DataType::QuantisedAsymm8,
327 DataType::QuantisedSymm16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100328 };
329
330 bool supported = true;
331
332 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
333 "Reference batch normalization: input is not a supported type.");
334
335 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
336 "Reference batch normalization: output is not a supported type.");
337
338 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
339 "Reference batch normalization: input and output types are mismatched");
340
341 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
342 "Reference batch normalization: mean is not a supported type.");
343
344 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
345 "Reference batch normalization: variance is not a supported type.");
346
347 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
348 "Reference batch normalization: beta is not a supported type.");
349
350 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
351 "Reference batch normalization: gamma is not a supported type.");
352
353 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100354}
355
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000356bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
357 const TensorInfo& output,
358 const BatchToSpaceNdDescriptor& descriptor,
359 Optional<std::string&> reasonIfUnsupported) const
360{
361 ignore_unused(descriptor);
362 return (IsSupportedForDataTypeRef(reasonIfUnsupported,
363 input.GetDataType(),
364 &TrueFunc<>,
365 &TrueFunc<>) &&
366 IsSupportedForDataTypeRef(reasonIfUnsupported,
367 output.GetDataType(),
368 &TrueFunc<>,
369 &TrueFunc<>));
370}
371
Jim Flynn906f9462019-05-10 13:55:21 +0100372bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
373 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100374 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100375 Optional<std::string&> reasonIfUnsupported) const
376{
Jim Flynne242f2d2019-05-22 14:24:13 +0100377 ignore_unused(descriptor);
378
379 bool supported = true;
380 std::array<DataType,3> supportedTypes =
381 {
382 DataType::Float32,
383 DataType::QuantisedAsymm8,
384 DataType::QuantisedSymm16
385 };
386
387 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
388 "Reference concatenation: output type not supported");
389 for (const TensorInfo* input : inputs)
390 {
391 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
392 "Reference concatenation: input type not supported");
393
394 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
395 "Reference concatenation: input and output types mismatched.");
396 }
397
398 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100399}
400
arovir011c7c81b2018-10-08 11:34:28 +0100401bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
402 Optional<std::string&> reasonIfUnsupported) const
403{
Jim Flynne242f2d2019-05-22 14:24:13 +0100404 std::array<DataType,4> supportedTypes =
405 {
Nina Drozd58ef2c62019-05-16 12:09:18 +0100406 DataType::Float32,
407 DataType::Signed32,
408 DataType::QuantisedAsymm8,
409 DataType::QuantisedSymm16
410 };
411
412 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
413 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100414}
415
416bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
417 const TensorInfo& output,
418 Optional<std::string&> reasonIfUnsupported) const
419{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100420 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
421 input.GetDataType(),
422 &TrueFunc<>,
423 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000424 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000425 &FalseFuncI32<>,
426 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100427 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
428 output.GetDataType(),
429 &FalseOutputFuncF16<>,
430 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000431 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000432 &FalseFuncI32<>,
433 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100434}
435
436bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
437 const TensorInfo& output,
438 Optional<std::string&> reasonIfUnsupported) const
439{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100440 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
441 input.GetDataType(),
442 &FalseInputFuncF16<>,
443 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000444 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000445 &FalseFuncI32<>,
446 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100447 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
448 output.GetDataType(),
449 &TrueFunc<>,
450 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000451 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000452 &FalseFuncI32<>,
453 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100454}
455
456bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
457 const TensorInfo& output,
458 const Convolution2dDescriptor& descriptor,
459 const TensorInfo& weights,
460 const Optional<TensorInfo>& biases,
461 Optional<std::string&> reasonIfUnsupported) const
462{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100463 bool supported = true;
464
465 // Define supported types.
466 std::array<DataType,3> supportedTypes = {
467 DataType::Float32,
468 DataType::QuantisedAsymm8,
469 DataType::QuantisedSymm16
470 };
471
472 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100473 "Reference convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100474
475 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100476 "Reference convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100477
478 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100479 "Reference convolution2d: weights is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100480
481 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100482 "Reference convolution2d: input and output types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100483
484 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100485 "Reference convolution2d: input and weights types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100486
487 if (biases.has_value())
488 {
489 std::array<DataType,3> biasesSupportedTypes = {
490 DataType::Float32,
491 DataType::Signed32
492 };
493 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100494 "Reference convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100495 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100496 ignore_unused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100497
498 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100499}
500
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000501bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
502 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000503 Optional<std::string&> reasonIfUnsupported) const
504{
505 ignore_unused(output);
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000506 return IsSupportedForDataTypeRef(reasonIfUnsupported,
507 input.GetDataType(),
508 &TrueFunc<>,
509 &TrueFunc<>);
510}
511
arovir011c7c81b2018-10-08 11:34:28 +0100512bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
513 const TensorInfo& output,
514 const DepthwiseConvolution2dDescriptor& descriptor,
515 const TensorInfo& weights,
516 const Optional<TensorInfo>& biases,
517 Optional<std::string&> reasonIfUnsupported) const
518{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100519 ignore_unused(output);
520 ignore_unused(descriptor);
521 ignore_unused(weights);
522 ignore_unused(biases);
523 return IsSupportedForDataTypeRef(reasonIfUnsupported,
524 input.GetDataType(),
525 &TrueFunc<>,
526 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100527}
528
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000529bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
530 const TensorInfo& output,
531 Optional<std::string&> reasonIfUnsupported) const
532{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100533 bool supported = true;
534
535 std::array<DataType,2> supportedInputTypes = {
536 DataType::QuantisedAsymm8,
537 DataType::QuantisedSymm16
538 };
539
540 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
541 "Reference dequantize: input type not supported.");
542
543 std::array<DataType,2> supportedOutputTypes = {
544 DataType::Float32,
545 };
546
547 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
548 "Reference dequantize: output type not supported.");
549
550 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
551 "Reference dequantize: input and output shapes have different num total elements.");
552
553 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000554}
555
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000556bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
557 const armnn::TensorInfo& input1,
558 const armnn::DetectionPostProcessDescriptor& descriptor,
559 armnn::Optional<std::string&> reasonIfUnsupported) const
560{
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100561 bool supported = true;
562
563 std::vector<DataType> supportedInputTypes =
564 {
565 DataType::Float32,
566 DataType::QuantisedAsymm8,
567 DataType::QuantisedSymm16
568 };
569
570 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
571 "Reference DetectionPostProcess: input 0 is not a supported type.");
572
573 supported &= CheckSupportRule(TypeAnyOf(input1, supportedInputTypes), reasonIfUnsupported,
574 "Reference DetectionPostProcess: input 1 is not a supported type.");
575
576 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000577}
578
Pablo Tellof0bd6832019-04-26 17:58:13 +0100579bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
580 const TensorInfo& output,
581 const DepthwiseConvolution2dDescriptor& descriptor,
582 const TensorInfo& weights,
583 const Optional<TensorInfo>& biases,
584 Optional<std::string&> reasonIfUnsupported) const
585{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100586 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100587}
588
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100589bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100590 const TensorInfo& input1,
591 const TensorInfo& output,
592 Optional<std::string&> reasonIfUnsupported) const
593{
Sadik Armagan2999a022019-04-09 14:20:12 +0100594 bool supported = true;
595
596 std::array<DataType,3> supportedTypes = {
597 DataType::Float32,
598 DataType::QuantisedAsymm8,
599 DataType::QuantisedSymm16
600 };
601
602 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
603 "Reference division: input 0 is not a supported type.");
604
605 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
606 "Reference division: input 1 is not a supported type.");
607
608 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
609 "Reference division: output is not a supported type.");
610
611 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
612 "Reference division: input 0 and Input 1 types are mismatched");
613
614 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
615 "Reference division: input and output types are mismatched");
616
617 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
618 "Reference division: shapes are not suitable for implicit broadcast.");
619
620 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100621}
622
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000623bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
624 const TensorInfo& input1,
625 const TensorInfo& output,
626 Optional<std::string&> reasonIfUnsupported) const
627{
628 ignore_unused(input0);
629 ignore_unused(input1);
630 ignore_unused(output);
631 ignore_unused(reasonIfUnsupported);
632 return IsSupportedForDataTypeRef(reasonIfUnsupported,
633 input0.GetDataType(),
634 &TrueFunc<>,
635 &TrueFunc<>);
636}
637
arovir011c7c81b2018-10-08 11:34:28 +0100638bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
639 const FakeQuantizationDescriptor& descriptor,
640 Optional<std::string&> reasonIfUnsupported) const
641{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100642 ignore_unused(descriptor);
643 return IsSupportedForDataTypeRef(reasonIfUnsupported,
644 input.GetDataType(),
645 &TrueFunc<>,
646 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100647}
648
649bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
650 const TensorInfo& output,
651 Optional<std::string&> reasonIfUnsupported) const
652{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100653 ignore_unused(output);
James Conroy83735b12019-05-30 16:36:59 +0100654 bool supported = true;
655
James Conroyb40d7102019-06-04 12:32:09 +0100656 std::array<DataType,2> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100657 {
James Conroyb40d7102019-06-04 12:32:09 +0100658 DataType::Float32,
659 DataType::QuantisedSymm16
James Conroy83735b12019-05-30 16:36:59 +0100660 };
661
662 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
663 "Reference Floor: input type not supported.");
664
665 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
666 "Reference Floor: output type not supported.");
667
668 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100669}
670
671bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
672 const TensorInfo& output,
673 const TensorInfo& weights,
674 const TensorInfo& biases,
675 const FullyConnectedDescriptor& descriptor,
676 Optional<std::string&> reasonIfUnsupported) const
677{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100678 bool supported = true;
679
680 // Define supported types.
681 std::array<DataType,3> supportedTypes =
682 {
683 DataType::Float32,
684 DataType::QuantisedAsymm8,
685 DataType::QuantisedSymm16
686 };
687
688 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
689 "Reference Fully Connected: input type not supported.");
690
691 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
692 "Reference Fully Connected: output type not supported.");
693
694 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
695 "Reference Fully Connected: input and output types mismatched.");
696
697 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
698 "Reference Fully Connected: weights type not supported.");
699
700 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
701 "Reference Fully Connected: input and weight types mismatched.");
702
703 if (descriptor.m_BiasEnabled)
704 {
705 // Defined supported types for bias
706 std::array<DataType, 2>
707 supportedBiasTypes =
708 {
709 DataType::Float32,
710 DataType::Signed32
711 };
712
713 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
714 "Reference Fully Connected: bias type not supported.");
715
716 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
717 "Reference Fully Connected: bias and weight types mismatch.");
718
719 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
720 "Reference Fully Connected: bias type inferred from weights is incompatible.");
721
722 }
723
724 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100725}
726
narpra014951d842019-01-18 16:53:53 +0000727bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
728 const armnn::TensorInfo& input1,
729 const armnn::TensorInfo& output,
730 armnn::Optional<std::string&> reasonIfUnsupported) const
731{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100732 bool supported = true;
733 std::array<DataType,3> supportedTypes =
734 {
735 DataType::Float32,
736 DataType::QuantisedAsymm8,
737 DataType::QuantisedSymm16
738 };
739
740 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
741 "Reference Gather: input type not supported");
742
743 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
744 "Reference Gather: output type not supported");
745
746 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
747 "Reference Gather: indices (input1) type not supported");
748
749 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
750 "Reference Gather: input and output types not matching");
751
752 return supported;
narpra014951d842019-01-18 16:53:53 +0000753}
754
FrancisMurtagh878f0232018-12-19 10:56:15 +0000755bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
756 const TensorInfo& input1,
757 const TensorInfo& output,
758 Optional<std::string&> reasonIfUnsupported) const
759{
760 ignore_unused(input0);
761 ignore_unused(input1);
762 ignore_unused(output);
763 ignore_unused(reasonIfUnsupported);
764 return IsSupportedForDataTypeRef(reasonIfUnsupported,
765 input0.GetDataType(),
766 &TrueFunc<>,
767 &TrueFunc<>);
768}
769
arovir011c7c81b2018-10-08 11:34:28 +0100770bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
771 Optional<std::string&> reasonIfUnsupported) const
772{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +0100773 return true;
arovir011c7c81b2018-10-08 11:34:28 +0100774}
775
776bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
777 const TensorInfo& output,
778 const L2NormalizationDescriptor& descriptor,
779 Optional<std::string&> reasonIfUnsupported) const
780{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100781 ignore_unused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100782 // Define supported types
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100783 std::array<DataType, 3> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100784 {
785 DataType::Float32,
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100786 DataType::QuantisedAsymm8,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100787 DataType::QuantisedSymm16
788 };
789
790 bool supported = true;
791
792 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
793 "Reference L2normalization: input type not supported.");
794
795 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
796 "Reference L2normalization: output type not supported.");
797
798 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
799 "Reference L2normalization: input and output types mismatched.");
800
801 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
802 "Reference L2normalization: input and output shapes have different "
803 "num total elements.");
804
805 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100806}
807
808bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
809 const TensorInfo& outputStateIn,
810 const TensorInfo& cellStateIn,
811 const TensorInfo& scratchBuffer,
812 const TensorInfo& outputStateOut,
813 const TensorInfo& cellStateOut,
814 const TensorInfo& output,
815 const LstmDescriptor& descriptor,
816 const TensorInfo& inputToForgetWeights,
817 const TensorInfo& inputToCellWeights,
818 const TensorInfo& inputToOutputWeights,
819 const TensorInfo& recurrentToForgetWeights,
820 const TensorInfo& recurrentToCellWeights,
821 const TensorInfo& recurrentToOutputWeights,
822 const TensorInfo& forgetGateBias,
823 const TensorInfo& cellBias,
824 const TensorInfo& outputGateBias,
825 const TensorInfo* inputToInputWeights,
826 const TensorInfo* recurrentToInputWeights,
827 const TensorInfo* cellToInputWeights,
828 const TensorInfo* inputGateBias,
829 const TensorInfo* projectionWeights,
830 const TensorInfo* projectionBias,
831 const TensorInfo* cellToForgetWeights,
832 const TensorInfo* cellToOutputWeights,
833 Optional<std::string&> reasonIfUnsupported) const
834{
telsoa01c577f2c2018-08-31 09:22:23 +0100835 ignore_unused(descriptor);
836 ignore_unused(inputToForgetWeights);
837 ignore_unused(inputToCellWeights);
838 ignore_unused(inputToOutputWeights);
839 ignore_unused(recurrentToForgetWeights);
840 ignore_unused(recurrentToCellWeights);
841 ignore_unused(recurrentToOutputWeights);
842 ignore_unused(forgetGateBias);
843 ignore_unused(cellBias);
844 ignore_unused(outputGateBias);
845 ignore_unused(inputToInputWeights);
846 ignore_unused(recurrentToInputWeights);
847 ignore_unused(cellToInputWeights);
848 ignore_unused(inputGateBias);
849 ignore_unused(projectionWeights);
850 ignore_unused(projectionBias);
851 ignore_unused(cellToForgetWeights);
852 ignore_unused(cellToOutputWeights);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100853
854 bool supported = true;
855
856 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +0100857 DataType::Float32,
858 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100859 };
860
861 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
862 "Reference Lstm: input is not a supported type.");
863
864 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
865 "Reference Lstm: input and outputStateIn types are mismatched");
866
867 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
868 "Reference Lstm: input and cellStateIn types are mismatched");
869
870 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
871 "Reference Lstm: input and scratchBuffer types are mismatched");
872
873 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
874 "Reference Lstm: input and outputStateOut types are mismatched");
875
876 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
877 "Reference Lstm: input and cellStateOut types are mismatched");
878
879 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
880 "Reference Lstm: input and output types are mismatched");
881
882 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +0100883}
884
saoste012df12b32018-11-28 16:57:20 +0000885bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
886 const TensorInfo& input1,
887 const TensorInfo& output,
888 Optional<std::string&> reasonIfUnsupported) const
889{
Sadik Armagan2999a022019-04-09 14:20:12 +0100890 bool supported = true;
891
892 std::array<DataType,3> supportedTypes = {
893 DataType::Float32,
894 DataType::QuantisedAsymm8,
895 DataType::QuantisedSymm16
896 };
897
898 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
899 "Reference maximum: input 0 is not a supported type.");
900
901 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
902 "Reference maximum: input 1 is not a supported type.");
903
904 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
905 "Reference maximum: output is not a supported type.");
906
907 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
908 "Reference maximum: input 0 and Input 1 types are mismatched");
909
910 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
911 "Reference maximum: input and output types are mismatched");
912
913 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
914 "Reference maximum: shapes are not suitable for implicit broadcast.");
915
916 return supported;
saoste012df12b32018-11-28 16:57:20 +0000917}
918
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100919bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
920 const TensorInfo& output,
921 const MeanDescriptor& descriptor,
922 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100923{
James Conroy4d1ff582019-06-10 17:06:39 +0100924 bool supported = true;
925 std::string meanLayerStr = "Mean";
926 std::string outputTensorStr = "output";
927
James Conroyb80775f2019-06-11 11:25:30 +0100928 std::array<DataType,3> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +0100929 {
930 DataType::Float32,
James Conroyb80775f2019-06-11 11:25:30 +0100931 DataType::QuantisedAsymm8,
932 DataType::QuantisedSymm16
James Conroy4d1ff582019-06-10 17:06:39 +0100933 };
934
935 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
936 "Reference Mean: input type not supported.");
937
938 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
939 "Reference Mean: input and output types are mismatched");
940
941 if (descriptor.m_KeepDims)
942 {
943 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
944 reasonIfUnsupported,
945 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
946 output.GetNumDimensions(),
947 meanLayerStr, outputTensorStr).data());
948 }
949 else if (descriptor.m_Axis.empty())
950 {
951 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
952 reasonIfUnsupported,
953 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
954 meanLayerStr, outputTensorStr).data());
955 }
956 else
957 {
958 auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
959
960 if (outputDim > 0)
961 {
962 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
963 reasonIfUnsupported,
964 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
965 meanLayerStr, outputTensorStr).data());
966 }
967 else
968 {
969 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
970 reasonIfUnsupported,
971 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
972 meanLayerStr, outputTensorStr).data());
973 }
974 }
975
976 return supported;
narpra0132b90462018-09-13 11:07:48 +0100977}
978
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100979bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +0000980 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100981 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100982 Optional<std::string&> reasonIfUnsupported) const
983{
Jim Flynne242f2d2019-05-22 14:24:13 +0100984 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100985}
986
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000987bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
988 const TensorInfo &output,
989 Optional<std::string &> reasonIfUnsupported) const
990{
991 ignore_unused(output);
kevmay012b4d88e2019-01-24 14:05:09 +0000992 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
993 input.GetDataType(),
994 &TrueFunc<>,
995 &TrueFunc<>,
996 &TrueFunc<>,
997 &FalseFuncI32<>,
998 &TrueFunc<>);
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000999}
1000
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001001bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1002 const TensorInfo& input1,
1003 const TensorInfo& output,
1004 Optional<std::string&> reasonIfUnsupported) const
1005{
Sadik Armagan2999a022019-04-09 14:20:12 +01001006 bool supported = true;
1007
1008 std::array<DataType,3> supportedTypes = {
1009 DataType::Float32,
1010 DataType::QuantisedAsymm8,
1011 DataType::QuantisedSymm16
1012 };
1013
1014 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1015 "Reference minimum: input 0 is not a supported type.");
1016
1017 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1018 "Reference minimum: input 1 is not a supported type.");
1019
1020 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1021 "Reference minimum: output is not a supported type.");
1022
1023 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1024 "Reference minimum: input 0 and Input 1 types are mismatched");
1025
1026 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1027 "Reference minimum: input and output types are mismatched");
1028
1029 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1030 "Reference minimum: shapes are not suitable for implicit broadcast.");
1031
1032 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001033}
1034
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001035bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1036 const TensorInfo& input1,
1037 const TensorInfo& output,
1038 Optional<std::string&> reasonIfUnsupported) const
1039{
Sadik Armagan2999a022019-04-09 14:20:12 +01001040 bool supported = true;
1041
1042 std::array<DataType,3> supportedTypes = {
1043 DataType::Float32,
1044 DataType::QuantisedAsymm8,
1045 DataType::QuantisedSymm16
1046 };
1047
1048 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1049 "Reference multiplication: input 0 is not a supported type.");
1050
1051 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1052 "Reference multiplication: input 1 is not a supported type.");
1053
1054 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1055 "Reference multiplication: output is not a supported type.");
1056
1057 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1058 "Reference multiplication: input 0 and Input 1 types are mismatched");
1059
1060 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1061 "Reference multiplication: input and output types are mismatched");
1062
1063 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1064 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1065
1066 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001067}
1068
1069bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1070 const TensorInfo& output,
1071 const NormalizationDescriptor& descriptor,
1072 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001073{
Nina Drozd661dfa72018-10-02 11:14:17 +01001074 ignore_unused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001075
1076 // Define supported types
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001077 std::array<DataType, 4> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001078 {
1079 DataType::Float16,
1080 DataType::Float32,
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001081 DataType::QuantisedAsymm8,
1082 DataType::QuantisedSymm16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001083 };
1084
1085 bool supported = true;
1086
1087 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1088 "Reference normalization: input type not supported.");
1089
1090 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1091 "Reference normalization: output type not supported.");
1092
1093 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1094 "Reference normalization: input and output shapes have different "
1095 "num total elements.");
1096
1097 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001098}
1099
1100bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
1101 Optional<std::string&> reasonIfUnsupported) const
1102{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001103 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001104}
1105
1106bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1107 const TensorInfo& output,
1108 const PadDescriptor& descriptor,
1109 Optional<std::string&> reasonIfUnsupported) const
1110{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001111 ignore_unused(output);
1112 ignore_unused(descriptor);
jimfly01f6ba7472018-12-04 10:09:52 +00001113 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1114 input.GetDataType(),
1115 &TrueFunc<>,
1116 &TrueFunc<>);
Nina Drozd661dfa72018-10-02 11:14:17 +01001117}
1118
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001119bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1120 const TensorInfo& output,
1121 const PermuteDescriptor& descriptor,
1122 Optional<std::string&> reasonIfUnsupported) const
1123{
1124 ignore_unused(output);
1125 ignore_unused(descriptor);
1126 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1127 input.GetDataType(),
1128 &TrueFunc<>,
1129 &TrueFunc<>);
1130}
1131
1132bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1133 const TensorInfo& output,
1134 const Pooling2dDescriptor& descriptor,
1135 Optional<std::string&> reasonIfUnsupported) const
1136{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001137 ignore_unused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001138 bool supported = true;
1139
1140 // Define supported output and inputs types.
Teresa Charlin0434df62019-06-06 13:40:35 +01001141 std::array<DataType,3> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001142 {
1143 DataType::Float32,
Teresa Charlin0434df62019-06-06 13:40:35 +01001144 DataType::QuantisedAsymm8,
1145 DataType::QuantisedSymm16
Teresa Charlina3b20472019-06-06 11:12:32 +01001146 };
1147
1148 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1149 "Reference poolind2d: input is not a supported type.");
1150
1151 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1152 "Reference poolind2d: output is not a supported type.");
1153
1154 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1155 "Reference poolind2d: input and output types are mismatched.");
1156
1157 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001158}
1159
Derek Lamberti5f400d62019-03-25 15:41:58 +00001160bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1161 const TensorInfo& output,
1162 Optional<std::string&> reasonIfUnsupported) const
1163{
1164 bool supported = true;
1165
1166 // Define supported output types.
1167 std::array<DataType,2> supportedInputTypes = {
1168 DataType::Float32,
1169 };
1170
1171 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1172 "Reference quantize: input type not supported.");
1173
1174 // Define supported output types.
1175 std::array<DataType,2> supportedOutputTypes = {
1176 DataType::QuantisedAsymm8,
1177 DataType::QuantisedSymm16
1178 };
1179 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1180 "Reference quantize: output type not supported.");
1181
1182 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1183 "Reference quantize: input and output shapes have different num total elements.");
1184
1185 return supported;
1186}
1187
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001188bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001189 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001190 Optional<std::string&> reasonIfUnsupported) const
1191{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001192 ignore_unused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001193 // Define supported output types.
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001194 std::array<DataType,4> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001195 {
1196 DataType::Float32,
1197 DataType::Float16,
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001198 DataType::QuantisedAsymm8,
1199 DataType::QuantisedSymm16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001200 };
1201 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1202 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001203}
1204
1205bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001206 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001207 Optional<std::string&> reasonIfUnsupported) const
1208{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001209 bool supported = true;
1210 std::array<DataType,3> supportedTypes =
1211 {
1212 DataType::Float32,
1213 DataType::QuantisedAsymm8,
1214 DataType::QuantisedSymm16
1215 };
1216
1217 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1218 "Reference ResizeBilinear: input type not supported");
1219
1220 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1221 "Reference ResizeBilinear: output type not supported");
1222
1223 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1224 "Reference ResizeBilinear: input and output types not matching");
1225
1226 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001227}
1228
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001229bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1230 const TensorInfo& output,
1231 Optional<std::string&> reasonIfUnsupported) const
1232{
nikraj010421e7f2019-06-14 09:40:34 +01001233 bool supported = true;
nikraj0124d73212019-06-14 14:20:40 +01001234 std::array<DataType,3> supportedTypes =
nikraj010421e7f2019-06-14 09:40:34 +01001235 {
1236 DataType::Float32,
nikraj0124d73212019-06-14 14:20:40 +01001237 DataType::QuantisedAsymm8,
1238 DataType::QuantisedSymm16
nikraj010421e7f2019-06-14 09:40:34 +01001239 };
1240
1241 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1242 "Reference rsqrt: input type not supported");
1243
1244 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1245 "Reference rsqrt: output type not supported");
1246
1247 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1248 "Reference rsqrt: input and output types not matching");
1249
1250 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1251 "Reference Rsqrt: input and output shapes have different number of total elements");
1252
1253 return supported;
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001254}
1255
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001256bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1257 const TensorInfo& output,
1258 const SoftmaxDescriptor& descriptor,
1259 Optional<std::string&> reasonIfUnsupported) const
1260{
1261 ignore_unused(output);
nikraj01248683f2019-05-29 16:46:50 +01001262 bool supported = true;
1263 std::array<DataType,3> supportedTypes =
1264 {
1265 DataType::Float32,
1266 DataType::QuantisedAsymm8,
1267 DataType::QuantisedSymm16
1268 };
1269
1270 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1271 "Reference concatenation: output type not supported");
1272
1273 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1274 "Reference concatenation: input type not supported");
1275
1276 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1277 "Reference concatenation: input type not supported");
1278
1279 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001280}
1281
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001282bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1283 const TensorInfo& output,
1284 const SpaceToBatchNdDescriptor& descriptor,
1285 Optional<std::string&> reasonIfUnsupported) const
1286{
1287 ignore_unused(output);
nikraj01120522a2019-05-31 11:33:07 +01001288 bool supported = true;
1289 std::array<DataType,3> supportedTypes =
1290 {
1291 DataType::Float32,
1292 DataType::QuantisedAsymm8,
1293 DataType::QuantisedSymm16
1294 };
1295
1296 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1297 "Reference SpaceToBatchNd: input type not supported");
1298
1299 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1300 "Reference SpaceToBatchNd: output type not supported");
1301
1302 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1303 "Reference SpaceToBatchNd: input and output types are mismatched");
1304
1305 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001306}
1307
Keith Davisa57eccb2019-06-14 17:33:22 +01001308bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
1309 const TensorInfo& output,
1310 const SpaceToDepthDescriptor& descriptor,
1311 Optional<std::string&> reasonIfUnsupported) const
1312{
1313
1314 ignore_unused(descriptor);
1315 bool supported = true;
1316
1317 std::array<DataType,2> supportedTypes =
1318 {
1319 DataType::Float32,
1320 DataType::QuantisedAsymm8,
1321 };
1322
1323 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1324 "Reference SpaceToDepth: input type not supported");
1325
1326 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1327 "Reference SpaceToDepth: output type not supported");
1328
1329 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1330 "Reference SpaceToDepth: input and output types are mismatched");
1331
1332 return supported;
1333}
1334
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001335bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1336 const ViewsDescriptor& descriptor,
1337 Optional<std::string&> reasonIfUnsupported) const
1338{
1339 ignore_unused(descriptor);
1340 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1341 input.GetDataType(),
1342 &TrueFunc<>,
1343 &TrueFunc<>);
1344}
1345
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001346bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1347 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1348 const ViewsDescriptor& descriptor,
1349 Optional<std::string&> reasonIfUnsupported) const
1350{
1351 ignore_unused(descriptor);
1352 ignore_unused(outputs);
1353 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1354 input.GetDataType(),
1355 &TrueFunc<>,
1356 &TrueFunc<>);
1357}
1358
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001359bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1360 const TensorInfo& output,
1361 const StridedSliceDescriptor& descriptor,
1362 Optional<std::string&> reasonIfUnsupported) const
1363{
1364 ignore_unused(output);
1365 ignore_unused(descriptor);
1366 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1367 input.GetDataType(),
1368 &TrueFunc<>,
1369 &TrueFunc<>);
1370}
1371
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001372bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1373 const TensorInfo& input1,
1374 const TensorInfo& output,
1375 Optional<std::string&> reasonIfUnsupported) const
1376{
Sadik Armagan2999a022019-04-09 14:20:12 +01001377 bool supported = true;
1378
1379 std::array<DataType,3> supportedTypes = {
1380 DataType::Float32,
1381 DataType::QuantisedAsymm8,
1382 DataType::QuantisedSymm16
1383 };
1384
1385 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1386 "Reference subtraction: input 0 is not a supported type.");
1387
1388 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1389 "Reference subtraction: input 1 is not a supported type.");
1390
1391 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1392 "Reference subtraction: output is not a supported type.");
1393
1394 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1395 "Reference subtraction: input 0 and Input 1 types are mismatched");
1396
1397 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1398 "Reference subtraction: input and output types are mismatched");
1399
1400 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1401 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1402
1403 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001404}
1405
Matteo Martincighab9e5252019-06-13 17:27:46 +01001406bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
1407 const TensorInfo& alpha,
1408 const TensorInfo& output,
1409 Optional<std::string&> reasonIfUnsupported) const
1410{
1411 bool supported = true;
1412
1413 std::array<DataType, 3> supportedTypes
1414 {
1415 DataType::Float32,
1416 DataType::QuantisedAsymm8,
1417 DataType::QuantisedSymm16
1418 };
1419
1420 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1421 "PReLU: input is not a supported type.");
1422
1423 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
1424 "PReLU: alpha is not a supported type.");
1425
1426 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1427 "PReLU: output is not a supported type.");
1428
1429 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
1430 "PReLU: input, alpha and output types are mismatched");
1431
1432 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
1433 "PReLU: shapes are not suitable for implicit broadcast");
1434
1435 return supported;
1436}
1437
arovir011c7c81b2018-10-08 11:34:28 +01001438} // namespace armnn