blob: 402bd66f02664e097045d535f61515f5fcaf8f32 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <InternalTypes.hpp>
10#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000012#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
David Beck111b5d92018-11-12 14:59:37 +000014#include <backendsCommon/BackendRegistry.hpp>
Francis Murtagh46c09d02019-05-28 08:15:28 +010015#include <backendsCommon/test/WorkloadTestUtils.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010016
telsoa014fcda012018-03-09 14:13:49 +000017#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
Derek Lamberti50db4e82019-03-13 14:16:15 +000019#include <vector>
20#include <algorithm>
21#include <array>
22
telsoa014fcda012018-03-09 14:13:49 +000023using namespace boost;
24
25namespace armnn
26{
27
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010028namespace
29{
30
31template<typename Float32Func, typename Uint8Func, typename ... Params>
32bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
33 DataType dataType,
34 Float32Func floatFuncPtr,
35 Uint8Func uint8FuncPtr,
36 Params&&... params)
37{
38 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
39 dataType,
40 &FalseFunc<Params...>,
41 floatFuncPtr,
42 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000043 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000044 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010045 std::forward<Params>(params)...);
46}
47
48} // anonymous namespace
49
James Conroy4d1ff582019-06-10 17:06:39 +010050namespace
51{
52
53std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
54 unsigned int actual,
55 std::string& layerStr,
56 std::string& tensorName)
57{
58 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
59 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
60
61 return errorMsg;
62}
63
64} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000065
66namespace
67{
68template<typename F>
69bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
70{
71 bool supported = rule();
72 if (!supported && reason)
73 {
74 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
75 }
76 return supported;
77}
78
79struct Rule
80{
81 bool operator()() const
82 {
83 return m_Res;
84 }
85
86 bool m_Res = true;
87};
88
Derek Lamberti2a434a82019-03-20 13:07:57 +000089template<typename T>
90bool AllTypesAreEqualImpl(T t)
Derek Lamberti50db4e82019-03-13 14:16:15 +000091{
92 return true;
93}
94
95template<typename T, typename... Rest>
96bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
97{
98 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
99
Derek Lamberti2a434a82019-03-20 13:07:57 +0000100 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
Derek Lamberti50db4e82019-03-13 14:16:15 +0000101}
102
103struct TypesAreEqual : public Rule
104{
105 template<typename ... Ts>
106 TypesAreEqual(const Ts&... ts)
107 {
108 m_Res = AllTypesAreEqualImpl(ts...);
109 }
110};
111
112struct QuantizationParametersAreEqual : public Rule
113{
114 QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
115 {
116 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
117 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
118 }
119};
120
121struct TypeAnyOf : public Rule
122{
123 template<typename Container>
124 TypeAnyOf(const TensorInfo& info, const Container& c)
125 {
126 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
Francis Murtagh46c09d02019-05-28 08:15:28 +0100127 {
128 return dt == info.GetDataType();
129 });
130 }
131};
132
133struct BiasAndWeightsTypesMatch : public Rule
134{
135 BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
136 {
137 m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value();
138 }
139};
140
141struct BiasAndWeightsTypesCompatible : public Rule
142{
143 template<typename Container>
144 BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c)
145 {
146 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
147 {
148 return dt == GetBiasTypeFromWeightsType(info.GetDataType()).value();
149 });
Derek Lamberti50db4e82019-03-13 14:16:15 +0000150 }
151};
152
153struct ShapesAreSameRank : public Rule
154{
155 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
156 {
157 m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
158 }
159};
160
Derek Lamberti5f400d62019-03-25 15:41:58 +0000161struct ShapesAreSameTotalSize : public Rule
162{
163 ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
164 {
165 m_Res = info0.GetNumElements() == info1.GetNumElements();
166 }
167};
168
Derek Lamberti50db4e82019-03-13 14:16:15 +0000169struct ShapesAreBroadcastCompatible : public Rule
170{
171 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
172 {
173 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
174 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
175 return sizeIn;
176 }
177
178 ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
179 {
180 const TensorShape& shape0 = in0.GetShape();
181 const TensorShape& shape1 = in1.GetShape();
182 const TensorShape& outShape = out.GetShape();
183
184 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
185 {
186 unsigned int sizeOut = outShape[i];
187 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
188 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
189
190 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
191 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
192 }
193 }
194};
James Conroy4d1ff582019-06-10 17:06:39 +0100195
196struct TensorNumDimensionsAreCorrect : public Rule
197{
198 TensorNumDimensionsAreCorrect(const TensorInfo& info, unsigned int expectedNumDimensions)
199 {
200 m_Res = info.GetNumDimensions() == expectedNumDimensions;
201 }
202};
203
Derek Lamberti50db4e82019-03-13 14:16:15 +0000204} // namespace
205
206
arovir011c7c81b2018-10-08 11:34:28 +0100207bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
208 const TensorInfo& output,
209 const ActivationDescriptor& descriptor,
210 Optional<std::string&> reasonIfUnsupported) const
211{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000212 bool supported = true;
213
214 // Define supported types.
Teresa Charlin18515e22019-04-24 10:17:46 +0100215 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000216 DataType::Float32,
Teresa Charlin18515e22019-04-24 10:17:46 +0100217 DataType::QuantisedAsymm8,
218 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000219 };
220
221 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
222 "Reference activation: input type not supported.");
223
224 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
225 "Reference activation: output type not supported.");
226
227 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
228 "Reference activation: input and output types mismatched.");
229
230 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
231 "Reference activation: input and output shapes are of different rank.");
232
233
234 struct ActivationFunctionSupported : public Rule
235 {
236 ActivationFunctionSupported(const ActivationDescriptor& desc)
237 {
238 switch(desc.m_Function)
239 {
240 case ActivationFunction::Abs:
241 case ActivationFunction::BoundedReLu:
242 case ActivationFunction::LeakyReLu:
243 case ActivationFunction::Linear:
244 case ActivationFunction::ReLu:
245 case ActivationFunction::Sigmoid:
246 case ActivationFunction::SoftReLu:
247 case ActivationFunction::Sqrt:
248 case ActivationFunction::Square:
249 case ActivationFunction::TanH:
250 {
251 m_Res = true;
252 break;
253 }
254 default:
255 {
256 m_Res = false;
257 break;
258 }
259 }
260 }
261 };
262
263 // Function is supported
264 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
265 "Reference activation: function not supported.");
266
267 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100268}
269
270bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
271 const TensorInfo& input1,
272 const TensorInfo& output,
273 Optional<std::string&> reasonIfUnsupported) const
274{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000275 bool supported = true;
276
Sadik Armagan2999a022019-04-09 14:20:12 +0100277 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000278 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100279 DataType::QuantisedAsymm8,
280 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000281 };
282
283 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
284 "Reference addition: input 0 is not a supported type.");
285
286 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
287 "Reference addition: input 1 is not a supported type.");
288
289 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
290 "Reference addition: output is not a supported type.");
291
292 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
293 "Reference addition: input 0 and Input 1 types are mismatched");
294
295 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
296 "Reference addition: input and output types are mismatched");
297
298 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
299 "Reference addition: shapes are not suitable for implicit broadcast.");
300
301 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100302}
303
304bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
305 const TensorInfo& output,
306 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100307 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100308 const TensorInfo& beta,
309 const TensorInfo& gamma,
310 const BatchNormalizationDescriptor& descriptor,
311 Optional<std::string&> reasonIfUnsupported) const
312{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100313 ignore_unused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100314
Matteo Martincighf5507132019-06-04 10:59:47 +0100315 std::array<DataType, 3> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100316 {
317 DataType::Float32,
Matteo Martincighf5507132019-06-04 10:59:47 +0100318 DataType::QuantisedAsymm8,
319 DataType::QuantisedSymm16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100320 };
321
322 bool supported = true;
323
324 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
325 "Reference batch normalization: input is not a supported type.");
326
327 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
328 "Reference batch normalization: output is not a supported type.");
329
330 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
331 "Reference batch normalization: input and output types are mismatched");
332
333 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
334 "Reference batch normalization: mean is not a supported type.");
335
336 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
337 "Reference batch normalization: variance is not a supported type.");
338
339 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
340 "Reference batch normalization: beta is not a supported type.");
341
342 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
343 "Reference batch normalization: gamma is not a supported type.");
344
345 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100346}
347
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000348bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
349 const TensorInfo& output,
350 const BatchToSpaceNdDescriptor& descriptor,
351 Optional<std::string&> reasonIfUnsupported) const
352{
353 ignore_unused(descriptor);
354 return (IsSupportedForDataTypeRef(reasonIfUnsupported,
355 input.GetDataType(),
356 &TrueFunc<>,
357 &TrueFunc<>) &&
358 IsSupportedForDataTypeRef(reasonIfUnsupported,
359 output.GetDataType(),
360 &TrueFunc<>,
361 &TrueFunc<>));
362}
363
Jim Flynn906f9462019-05-10 13:55:21 +0100364bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
365 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100366 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100367 Optional<std::string&> reasonIfUnsupported) const
368{
Jim Flynne242f2d2019-05-22 14:24:13 +0100369 ignore_unused(descriptor);
370
371 bool supported = true;
372 std::array<DataType,3> supportedTypes =
373 {
374 DataType::Float32,
375 DataType::QuantisedAsymm8,
376 DataType::QuantisedSymm16
377 };
378
379 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
380 "Reference concatenation: output type not supported");
381 for (const TensorInfo* input : inputs)
382 {
383 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
384 "Reference concatenation: input type not supported");
385
386 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
387 "Reference concatenation: input and output types mismatched.");
388 }
389
390 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100391}
392
arovir011c7c81b2018-10-08 11:34:28 +0100393bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
394 Optional<std::string&> reasonIfUnsupported) const
395{
Jim Flynne242f2d2019-05-22 14:24:13 +0100396 std::array<DataType,4> supportedTypes =
397 {
Nina Drozd58ef2c62019-05-16 12:09:18 +0100398 DataType::Float32,
399 DataType::Signed32,
400 DataType::QuantisedAsymm8,
401 DataType::QuantisedSymm16
402 };
403
404 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
405 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100406}
407
408bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
409 const TensorInfo& output,
410 Optional<std::string&> reasonIfUnsupported) const
411{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100412 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
413 input.GetDataType(),
414 &TrueFunc<>,
415 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000416 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000417 &FalseFuncI32<>,
418 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100419 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
420 output.GetDataType(),
421 &FalseOutputFuncF16<>,
422 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000423 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000424 &FalseFuncI32<>,
425 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100426}
427
428bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
429 const TensorInfo& output,
430 Optional<std::string&> reasonIfUnsupported) const
431{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100432 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
433 input.GetDataType(),
434 &FalseInputFuncF16<>,
435 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000436 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000437 &FalseFuncI32<>,
438 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100439 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
440 output.GetDataType(),
441 &TrueFunc<>,
442 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000443 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000444 &FalseFuncI32<>,
445 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100446}
447
448bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
449 const TensorInfo& output,
450 const Convolution2dDescriptor& descriptor,
451 const TensorInfo& weights,
452 const Optional<TensorInfo>& biases,
453 Optional<std::string&> reasonIfUnsupported) const
454{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100455 bool supported = true;
456
457 // Define supported types.
458 std::array<DataType,3> supportedTypes = {
459 DataType::Float32,
460 DataType::QuantisedAsymm8,
461 DataType::QuantisedSymm16
462 };
463
464 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100465 "Reference convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100466
467 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100468 "Reference convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100469
470 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100471 "Reference convolution2d: weights is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100472
473 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100474 "Reference convolution2d: input and output types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100475
476 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100477 "Reference convolution2d: input and weights types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100478
479 if (biases.has_value())
480 {
481 std::array<DataType,3> biasesSupportedTypes = {
482 DataType::Float32,
483 DataType::Signed32
484 };
485 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100486 "Reference convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100487 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100488 ignore_unused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100489
490 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100491}
492
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000493bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
494 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000495 Optional<std::string&> reasonIfUnsupported) const
496{
497 ignore_unused(output);
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000498 return IsSupportedForDataTypeRef(reasonIfUnsupported,
499 input.GetDataType(),
500 &TrueFunc<>,
501 &TrueFunc<>);
502}
503
arovir011c7c81b2018-10-08 11:34:28 +0100504bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
505 const TensorInfo& output,
506 const DepthwiseConvolution2dDescriptor& descriptor,
507 const TensorInfo& weights,
508 const Optional<TensorInfo>& biases,
509 Optional<std::string&> reasonIfUnsupported) const
510{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100511 ignore_unused(output);
512 ignore_unused(descriptor);
513 ignore_unused(weights);
514 ignore_unused(biases);
515 return IsSupportedForDataTypeRef(reasonIfUnsupported,
516 input.GetDataType(),
517 &TrueFunc<>,
518 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100519}
520
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000521bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
522 const TensorInfo& output,
523 Optional<std::string&> reasonIfUnsupported) const
524{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100525 bool supported = true;
526
527 std::array<DataType,2> supportedInputTypes = {
528 DataType::QuantisedAsymm8,
529 DataType::QuantisedSymm16
530 };
531
532 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
533 "Reference dequantize: input type not supported.");
534
535 std::array<DataType,2> supportedOutputTypes = {
536 DataType::Float32,
537 };
538
539 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
540 "Reference dequantize: output type not supported.");
541
542 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
543 "Reference dequantize: input and output shapes have different num total elements.");
544
545 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000546}
547
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000548bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
549 const armnn::TensorInfo& input1,
550 const armnn::DetectionPostProcessDescriptor& descriptor,
551 armnn::Optional<std::string&> reasonIfUnsupported) const
552{
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100553 bool supported = true;
554
555 std::vector<DataType> supportedInputTypes =
556 {
557 DataType::Float32,
558 DataType::QuantisedAsymm8,
559 DataType::QuantisedSymm16
560 };
561
562 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
563 "Reference DetectionPostProcess: input 0 is not a supported type.");
564
565 supported &= CheckSupportRule(TypeAnyOf(input1, supportedInputTypes), reasonIfUnsupported,
566 "Reference DetectionPostProcess: input 1 is not a supported type.");
567
568 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000569}
570
Pablo Tellof0bd6832019-04-26 17:58:13 +0100571bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
572 const TensorInfo& output,
573 const DepthwiseConvolution2dDescriptor& descriptor,
574 const TensorInfo& weights,
575 const Optional<TensorInfo>& biases,
576 Optional<std::string&> reasonIfUnsupported) const
577{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100578 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100579}
580
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100581bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100582 const TensorInfo& input1,
583 const TensorInfo& output,
584 Optional<std::string&> reasonIfUnsupported) const
585{
Sadik Armagan2999a022019-04-09 14:20:12 +0100586 bool supported = true;
587
588 std::array<DataType,3> supportedTypes = {
589 DataType::Float32,
590 DataType::QuantisedAsymm8,
591 DataType::QuantisedSymm16
592 };
593
594 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
595 "Reference division: input 0 is not a supported type.");
596
597 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
598 "Reference division: input 1 is not a supported type.");
599
600 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
601 "Reference division: output is not a supported type.");
602
603 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
604 "Reference division: input 0 and Input 1 types are mismatched");
605
606 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
607 "Reference division: input and output types are mismatched");
608
609 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
610 "Reference division: shapes are not suitable for implicit broadcast.");
611
612 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100613}
614
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000615bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
616 const TensorInfo& input1,
617 const TensorInfo& output,
618 Optional<std::string&> reasonIfUnsupported) const
619{
620 ignore_unused(input0);
621 ignore_unused(input1);
622 ignore_unused(output);
623 ignore_unused(reasonIfUnsupported);
624 return IsSupportedForDataTypeRef(reasonIfUnsupported,
625 input0.GetDataType(),
626 &TrueFunc<>,
627 &TrueFunc<>);
628}
629
arovir011c7c81b2018-10-08 11:34:28 +0100630bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
631 const FakeQuantizationDescriptor& descriptor,
632 Optional<std::string&> reasonIfUnsupported) const
633{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100634 ignore_unused(descriptor);
635 return IsSupportedForDataTypeRef(reasonIfUnsupported,
636 input.GetDataType(),
637 &TrueFunc<>,
638 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100639}
640
641bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
642 const TensorInfo& output,
643 Optional<std::string&> reasonIfUnsupported) const
644{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100645 ignore_unused(output);
James Conroy83735b12019-05-30 16:36:59 +0100646 bool supported = true;
647
James Conroyb40d7102019-06-04 12:32:09 +0100648 std::array<DataType,2> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100649 {
James Conroyb40d7102019-06-04 12:32:09 +0100650 DataType::Float32,
651 DataType::QuantisedSymm16
James Conroy83735b12019-05-30 16:36:59 +0100652 };
653
654 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
655 "Reference Floor: input type not supported.");
656
657 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
658 "Reference Floor: output type not supported.");
659
660 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100661}
662
663bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
664 const TensorInfo& output,
665 const TensorInfo& weights,
666 const TensorInfo& biases,
667 const FullyConnectedDescriptor& descriptor,
668 Optional<std::string&> reasonIfUnsupported) const
669{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100670 bool supported = true;
671
672 // Define supported types.
673 std::array<DataType,3> supportedTypes =
674 {
675 DataType::Float32,
676 DataType::QuantisedAsymm8,
677 DataType::QuantisedSymm16
678 };
679
680 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
681 "Reference Fully Connected: input type not supported.");
682
683 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
684 "Reference Fully Connected: output type not supported.");
685
686 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
687 "Reference Fully Connected: input and output types mismatched.");
688
689 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
690 "Reference Fully Connected: weights type not supported.");
691
692 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
693 "Reference Fully Connected: input and weight types mismatched.");
694
695 if (descriptor.m_BiasEnabled)
696 {
697 // Defined supported types for bias
698 std::array<DataType, 2>
699 supportedBiasTypes =
700 {
701 DataType::Float32,
702 DataType::Signed32
703 };
704
705 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
706 "Reference Fully Connected: bias type not supported.");
707
708 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
709 "Reference Fully Connected: bias and weight types mismatch.");
710
711 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
712 "Reference Fully Connected: bias type inferred from weights is incompatible.");
713
714 }
715
716 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100717}
718
narpra014951d842019-01-18 16:53:53 +0000719bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
720 const armnn::TensorInfo& input1,
721 const armnn::TensorInfo& output,
722 armnn::Optional<std::string&> reasonIfUnsupported) const
723{
724 ignore_unused(input1);
725 ignore_unused(output);
726 return IsSupportedForDataTypeRef(reasonIfUnsupported,
727 input0.GetDataType(),
728 &TrueFunc<>,
729 &TrueFunc<>);
730}
731
FrancisMurtagh878f0232018-12-19 10:56:15 +0000732bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
733 const TensorInfo& input1,
734 const TensorInfo& output,
735 Optional<std::string&> reasonIfUnsupported) const
736{
737 ignore_unused(input0);
738 ignore_unused(input1);
739 ignore_unused(output);
740 ignore_unused(reasonIfUnsupported);
741 return IsSupportedForDataTypeRef(reasonIfUnsupported,
742 input0.GetDataType(),
743 &TrueFunc<>,
744 &TrueFunc<>);
745}
746
arovir011c7c81b2018-10-08 11:34:28 +0100747bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
748 Optional<std::string&> reasonIfUnsupported) const
749{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +0100750 return true;
arovir011c7c81b2018-10-08 11:34:28 +0100751}
752
753bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
754 const TensorInfo& output,
755 const L2NormalizationDescriptor& descriptor,
756 Optional<std::string&> reasonIfUnsupported) const
757{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100758 ignore_unused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100759 // Define supported types
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100760 std::array<DataType, 3> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100761 {
762 DataType::Float32,
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100763 DataType::QuantisedAsymm8,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100764 DataType::QuantisedSymm16
765 };
766
767 bool supported = true;
768
769 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
770 "Reference L2normalization: input type not supported.");
771
772 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
773 "Reference L2normalization: output type not supported.");
774
775 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
776 "Reference L2normalization: input and output types mismatched.");
777
778 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
779 "Reference L2normalization: input and output shapes have different "
780 "num total elements.");
781
782 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100783}
784
785bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
786 const TensorInfo& outputStateIn,
787 const TensorInfo& cellStateIn,
788 const TensorInfo& scratchBuffer,
789 const TensorInfo& outputStateOut,
790 const TensorInfo& cellStateOut,
791 const TensorInfo& output,
792 const LstmDescriptor& descriptor,
793 const TensorInfo& inputToForgetWeights,
794 const TensorInfo& inputToCellWeights,
795 const TensorInfo& inputToOutputWeights,
796 const TensorInfo& recurrentToForgetWeights,
797 const TensorInfo& recurrentToCellWeights,
798 const TensorInfo& recurrentToOutputWeights,
799 const TensorInfo& forgetGateBias,
800 const TensorInfo& cellBias,
801 const TensorInfo& outputGateBias,
802 const TensorInfo* inputToInputWeights,
803 const TensorInfo* recurrentToInputWeights,
804 const TensorInfo* cellToInputWeights,
805 const TensorInfo* inputGateBias,
806 const TensorInfo* projectionWeights,
807 const TensorInfo* projectionBias,
808 const TensorInfo* cellToForgetWeights,
809 const TensorInfo* cellToOutputWeights,
810 Optional<std::string&> reasonIfUnsupported) const
811{
telsoa01c577f2c2018-08-31 09:22:23 +0100812 ignore_unused(descriptor);
813 ignore_unused(inputToForgetWeights);
814 ignore_unused(inputToCellWeights);
815 ignore_unused(inputToOutputWeights);
816 ignore_unused(recurrentToForgetWeights);
817 ignore_unused(recurrentToCellWeights);
818 ignore_unused(recurrentToOutputWeights);
819 ignore_unused(forgetGateBias);
820 ignore_unused(cellBias);
821 ignore_unused(outputGateBias);
822 ignore_unused(inputToInputWeights);
823 ignore_unused(recurrentToInputWeights);
824 ignore_unused(cellToInputWeights);
825 ignore_unused(inputGateBias);
826 ignore_unused(projectionWeights);
827 ignore_unused(projectionBias);
828 ignore_unused(cellToForgetWeights);
829 ignore_unused(cellToOutputWeights);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100830
831 bool supported = true;
832
833 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +0100834 DataType::Float32,
835 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100836 };
837
838 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
839 "Reference Lstm: input is not a supported type.");
840
841 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
842 "Reference Lstm: input and outputStateIn types are mismatched");
843
844 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
845 "Reference Lstm: input and cellStateIn types are mismatched");
846
847 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
848 "Reference Lstm: input and scratchBuffer types are mismatched");
849
850 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
851 "Reference Lstm: input and outputStateOut types are mismatched");
852
853 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
854 "Reference Lstm: input and cellStateOut types are mismatched");
855
856 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
857 "Reference Lstm: input and output types are mismatched");
858
859 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +0100860}
861
saoste012df12b32018-11-28 16:57:20 +0000862bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
863 const TensorInfo& input1,
864 const TensorInfo& output,
865 Optional<std::string&> reasonIfUnsupported) const
866{
Sadik Armagan2999a022019-04-09 14:20:12 +0100867 bool supported = true;
868
869 std::array<DataType,3> supportedTypes = {
870 DataType::Float32,
871 DataType::QuantisedAsymm8,
872 DataType::QuantisedSymm16
873 };
874
875 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
876 "Reference maximum: input 0 is not a supported type.");
877
878 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
879 "Reference maximum: input 1 is not a supported type.");
880
881 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
882 "Reference maximum: output is not a supported type.");
883
884 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
885 "Reference maximum: input 0 and Input 1 types are mismatched");
886
887 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
888 "Reference maximum: input and output types are mismatched");
889
890 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
891 "Reference maximum: shapes are not suitable for implicit broadcast.");
892
893 return supported;
saoste012df12b32018-11-28 16:57:20 +0000894}
895
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100896bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
897 const TensorInfo& output,
898 const MeanDescriptor& descriptor,
899 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100900{
James Conroy4d1ff582019-06-10 17:06:39 +0100901 bool supported = true;
902 std::string meanLayerStr = "Mean";
903 std::string outputTensorStr = "output";
904
905 std::array<DataType,2> supportedTypes =
906 {
907 DataType::Float32,
908 DataType::QuantisedAsymm8
909 };
910
911 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
912 "Reference Mean: input type not supported.");
913
914 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
915 "Reference Mean: input and output types are mismatched");
916
917 if (descriptor.m_KeepDims)
918 {
919 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
920 reasonIfUnsupported,
921 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
922 output.GetNumDimensions(),
923 meanLayerStr, outputTensorStr).data());
924 }
925 else if (descriptor.m_Axis.empty())
926 {
927 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
928 reasonIfUnsupported,
929 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
930 meanLayerStr, outputTensorStr).data());
931 }
932 else
933 {
934 auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
935
936 if (outputDim > 0)
937 {
938 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
939 reasonIfUnsupported,
940 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
941 meanLayerStr, outputTensorStr).data());
942 }
943 else
944 {
945 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
946 reasonIfUnsupported,
947 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
948 meanLayerStr, outputTensorStr).data());
949 }
950 }
951
952 return supported;
narpra0132b90462018-09-13 11:07:48 +0100953}
954
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100955bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +0000956 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100957 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100958 Optional<std::string&> reasonIfUnsupported) const
959{
Jim Flynne242f2d2019-05-22 14:24:13 +0100960 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100961}
962
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000963bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
964 const TensorInfo &output,
965 Optional<std::string &> reasonIfUnsupported) const
966{
967 ignore_unused(output);
kevmay012b4d88e2019-01-24 14:05:09 +0000968 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
969 input.GetDataType(),
970 &TrueFunc<>,
971 &TrueFunc<>,
972 &TrueFunc<>,
973 &FalseFuncI32<>,
974 &TrueFunc<>);
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000975}
976
Éanna Ó Catháin20e58802018-12-04 10:29:06 +0000977bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
978 const TensorInfo& input1,
979 const TensorInfo& output,
980 Optional<std::string&> reasonIfUnsupported) const
981{
Sadik Armagan2999a022019-04-09 14:20:12 +0100982 bool supported = true;
983
984 std::array<DataType,3> supportedTypes = {
985 DataType::Float32,
986 DataType::QuantisedAsymm8,
987 DataType::QuantisedSymm16
988 };
989
990 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
991 "Reference minimum: input 0 is not a supported type.");
992
993 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
994 "Reference minimum: input 1 is not a supported type.");
995
996 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
997 "Reference minimum: output is not a supported type.");
998
999 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1000 "Reference minimum: input 0 and Input 1 types are mismatched");
1001
1002 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1003 "Reference minimum: input and output types are mismatched");
1004
1005 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1006 "Reference minimum: shapes are not suitable for implicit broadcast.");
1007
1008 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001009}
1010
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001011bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1012 const TensorInfo& input1,
1013 const TensorInfo& output,
1014 Optional<std::string&> reasonIfUnsupported) const
1015{
Sadik Armagan2999a022019-04-09 14:20:12 +01001016 bool supported = true;
1017
1018 std::array<DataType,3> supportedTypes = {
1019 DataType::Float32,
1020 DataType::QuantisedAsymm8,
1021 DataType::QuantisedSymm16
1022 };
1023
1024 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1025 "Reference multiplication: input 0 is not a supported type.");
1026
1027 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1028 "Reference multiplication: input 1 is not a supported type.");
1029
1030 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1031 "Reference multiplication: output is not a supported type.");
1032
1033 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1034 "Reference multiplication: input 0 and Input 1 types are mismatched");
1035
1036 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1037 "Reference multiplication: input and output types are mismatched");
1038
1039 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1040 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1041
1042 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001043}
1044
1045bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1046 const TensorInfo& output,
1047 const NormalizationDescriptor& descriptor,
1048 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001049{
Nina Drozd661dfa72018-10-02 11:14:17 +01001050 ignore_unused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001051
1052 // Define supported types
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001053 std::array<DataType, 4> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001054 {
1055 DataType::Float16,
1056 DataType::Float32,
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001057 DataType::QuantisedAsymm8,
1058 DataType::QuantisedSymm16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001059 };
1060
1061 bool supported = true;
1062
1063 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1064 "Reference normalization: input type not supported.");
1065
1066 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1067 "Reference normalization: output type not supported.");
1068
1069 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1070 "Reference normalization: input and output shapes have different "
1071 "num total elements.");
1072
1073 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001074}
1075
1076bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
1077 Optional<std::string&> reasonIfUnsupported) const
1078{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001079 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001080}
1081
1082bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1083 const TensorInfo& output,
1084 const PadDescriptor& descriptor,
1085 Optional<std::string&> reasonIfUnsupported) const
1086{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001087 ignore_unused(output);
1088 ignore_unused(descriptor);
jimfly01f6ba7472018-12-04 10:09:52 +00001089 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1090 input.GetDataType(),
1091 &TrueFunc<>,
1092 &TrueFunc<>);
Nina Drozd661dfa72018-10-02 11:14:17 +01001093}
1094
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001095bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1096 const TensorInfo& output,
1097 const PermuteDescriptor& descriptor,
1098 Optional<std::string&> reasonIfUnsupported) const
1099{
1100 ignore_unused(output);
1101 ignore_unused(descriptor);
1102 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1103 input.GetDataType(),
1104 &TrueFunc<>,
1105 &TrueFunc<>);
1106}
1107
1108bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1109 const TensorInfo& output,
1110 const Pooling2dDescriptor& descriptor,
1111 Optional<std::string&> reasonIfUnsupported) const
1112{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001113 ignore_unused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001114 bool supported = true;
1115
1116 // Define supported output and inputs types.
Teresa Charlin0434df62019-06-06 13:40:35 +01001117 std::array<DataType,3> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001118 {
1119 DataType::Float32,
Teresa Charlin0434df62019-06-06 13:40:35 +01001120 DataType::QuantisedAsymm8,
1121 DataType::QuantisedSymm16
Teresa Charlina3b20472019-06-06 11:12:32 +01001122 };
1123
1124 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1125 "Reference poolind2d: input is not a supported type.");
1126
1127 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1128 "Reference poolind2d: output is not a supported type.");
1129
1130 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1131 "Reference poolind2d: input and output types are mismatched.");
1132
1133 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001134}
1135
Derek Lamberti5f400d62019-03-25 15:41:58 +00001136bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1137 const TensorInfo& output,
1138 Optional<std::string&> reasonIfUnsupported) const
1139{
1140 bool supported = true;
1141
1142 // Define supported output types.
1143 std::array<DataType,2> supportedInputTypes = {
1144 DataType::Float32,
1145 };
1146
1147 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1148 "Reference quantize: input type not supported.");
1149
1150 // Define supported output types.
1151 std::array<DataType,2> supportedOutputTypes = {
1152 DataType::QuantisedAsymm8,
1153 DataType::QuantisedSymm16
1154 };
1155 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1156 "Reference quantize: output type not supported.");
1157
1158 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1159 "Reference quantize: input and output shapes have different num total elements.");
1160
1161 return supported;
1162}
1163
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001164bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001165 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001166 Optional<std::string&> reasonIfUnsupported) const
1167{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001168 ignore_unused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001169 // Define supported output types.
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001170 std::array<DataType,4> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001171 {
1172 DataType::Float32,
1173 DataType::Float16,
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001174 DataType::QuantisedAsymm8,
1175 DataType::QuantisedSymm16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001176 };
1177 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1178 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001179}
1180
1181bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001182 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001183 Optional<std::string&> reasonIfUnsupported) const
1184{
Sadik Armaganc625f002018-12-17 11:32:16 +00001185 ignore_unused(output);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001186 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1187 input.GetDataType(),
1188 &TrueFunc<>,
1189 &TrueFunc<>);
1190}
1191
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001192bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1193 const TensorInfo& output,
1194 Optional<std::string&> reasonIfUnsupported) const
1195{
nikraj010421e7f2019-06-14 09:40:34 +01001196 bool supported = true;
nikraj0124d73212019-06-14 14:20:40 +01001197 std::array<DataType,3> supportedTypes =
nikraj010421e7f2019-06-14 09:40:34 +01001198 {
1199 DataType::Float32,
nikraj0124d73212019-06-14 14:20:40 +01001200 DataType::QuantisedAsymm8,
1201 DataType::QuantisedSymm16
nikraj010421e7f2019-06-14 09:40:34 +01001202 };
1203
1204 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1205 "Reference rsqrt: input type not supported");
1206
1207 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1208 "Reference rsqrt: output type not supported");
1209
1210 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1211 "Reference rsqrt: input and output types not matching");
1212
1213 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1214 "Reference Rsqrt: input and output shapes have different number of total elements");
1215
1216 return supported;
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001217}
1218
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001219bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1220 const TensorInfo& output,
1221 const SoftmaxDescriptor& descriptor,
1222 Optional<std::string&> reasonIfUnsupported) const
1223{
1224 ignore_unused(output);
nikraj01248683f2019-05-29 16:46:50 +01001225 bool supported = true;
1226 std::array<DataType,3> supportedTypes =
1227 {
1228 DataType::Float32,
1229 DataType::QuantisedAsymm8,
1230 DataType::QuantisedSymm16
1231 };
1232
1233 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1234 "Reference concatenation: output type not supported");
1235
1236 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1237 "Reference concatenation: input type not supported");
1238
1239 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1240 "Reference concatenation: input type not supported");
1241
1242 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001243}
1244
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001245bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1246 const TensorInfo& output,
1247 const SpaceToBatchNdDescriptor& descriptor,
1248 Optional<std::string&> reasonIfUnsupported) const
1249{
1250 ignore_unused(output);
nikraj01120522a2019-05-31 11:33:07 +01001251 bool supported = true;
1252 std::array<DataType,3> supportedTypes =
1253 {
1254 DataType::Float32,
1255 DataType::QuantisedAsymm8,
1256 DataType::QuantisedSymm16
1257 };
1258
1259 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1260 "Reference SpaceToBatchNd: input type not supported");
1261
1262 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1263 "Reference SpaceToBatchNd: output type not supported");
1264
1265 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1266 "Reference SpaceToBatchNd: input and output types are mismatched");
1267
1268 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001269}
1270
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001271bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1272 const ViewsDescriptor& descriptor,
1273 Optional<std::string&> reasonIfUnsupported) const
1274{
1275 ignore_unused(descriptor);
1276 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1277 input.GetDataType(),
1278 &TrueFunc<>,
1279 &TrueFunc<>);
1280}
1281
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001282bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1283 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1284 const ViewsDescriptor& descriptor,
1285 Optional<std::string&> reasonIfUnsupported) const
1286{
1287 ignore_unused(descriptor);
1288 ignore_unused(outputs);
1289 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1290 input.GetDataType(),
1291 &TrueFunc<>,
1292 &TrueFunc<>);
1293}
1294
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001295bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1296 const TensorInfo& output,
1297 const StridedSliceDescriptor& descriptor,
1298 Optional<std::string&> reasonIfUnsupported) const
1299{
1300 ignore_unused(output);
1301 ignore_unused(descriptor);
1302 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1303 input.GetDataType(),
1304 &TrueFunc<>,
1305 &TrueFunc<>);
1306}
1307
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001308bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1309 const TensorInfo& input1,
1310 const TensorInfo& output,
1311 Optional<std::string&> reasonIfUnsupported) const
1312{
Sadik Armagan2999a022019-04-09 14:20:12 +01001313 bool supported = true;
1314
1315 std::array<DataType,3> supportedTypes = {
1316 DataType::Float32,
1317 DataType::QuantisedAsymm8,
1318 DataType::QuantisedSymm16
1319 };
1320
1321 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1322 "Reference subtraction: input 0 is not a supported type.");
1323
1324 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1325 "Reference subtraction: input 1 is not a supported type.");
1326
1327 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1328 "Reference subtraction: output is not a supported type.");
1329
1330 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1331 "Reference subtraction: input 0 and Input 1 types are mismatched");
1332
1333 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1334 "Reference subtraction: input and output types are mismatched");
1335
1336 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1337 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1338
1339 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001340}
1341
arovir011c7c81b2018-10-08 11:34:28 +01001342} // namespace armnn