blob: aeff51d853a964e5950c26397dc2958cc8312483 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <InternalTypes.hpp>
10#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000012#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
David Beck111b5d92018-11-12 14:59:37 +000014#include <backendsCommon/BackendRegistry.hpp>
Francis Murtagh46c09d02019-05-28 08:15:28 +010015#include <backendsCommon/test/WorkloadTestUtils.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010016
telsoa014fcda012018-03-09 14:13:49 +000017#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
Derek Lamberti50db4e82019-03-13 14:16:15 +000019#include <vector>
20#include <algorithm>
21#include <array>
22
telsoa014fcda012018-03-09 14:13:49 +000023using namespace boost;
24
25namespace armnn
26{
27
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010028namespace
29{
30
31template<typename Float32Func, typename Uint8Func, typename ... Params>
32bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
33 DataType dataType,
34 Float32Func floatFuncPtr,
35 Uint8Func uint8FuncPtr,
36 Params&&... params)
37{
38 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
39 dataType,
40 &FalseFunc<Params...>,
41 floatFuncPtr,
42 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000043 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000044 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010045 std::forward<Params>(params)...);
46}
47
48} // anonymous namespace
49
Derek Lamberti50db4e82019-03-13 14:16:15 +000050
51namespace
52{
53template<typename F>
54bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
55{
56 bool supported = rule();
57 if (!supported && reason)
58 {
59 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
60 }
61 return supported;
62}
63
64struct Rule
65{
66 bool operator()() const
67 {
68 return m_Res;
69 }
70
71 bool m_Res = true;
72};
73
Derek Lamberti2a434a82019-03-20 13:07:57 +000074template<typename T>
75bool AllTypesAreEqualImpl(T t)
Derek Lamberti50db4e82019-03-13 14:16:15 +000076{
77 return true;
78}
79
80template<typename T, typename... Rest>
81bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
82{
83 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
84
Derek Lamberti2a434a82019-03-20 13:07:57 +000085 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
Derek Lamberti50db4e82019-03-13 14:16:15 +000086}
87
88struct TypesAreEqual : public Rule
89{
90 template<typename ... Ts>
91 TypesAreEqual(const Ts&... ts)
92 {
93 m_Res = AllTypesAreEqualImpl(ts...);
94 }
95};
96
97struct QuantizationParametersAreEqual : public Rule
98{
99 QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
100 {
101 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
102 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
103 }
104};
105
106struct TypeAnyOf : public Rule
107{
108 template<typename Container>
109 TypeAnyOf(const TensorInfo& info, const Container& c)
110 {
111 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
Francis Murtagh46c09d02019-05-28 08:15:28 +0100112 {
113 return dt == info.GetDataType();
114 });
115 }
116};
117
118struct BiasAndWeightsTypesMatch : public Rule
119{
120 BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
121 {
122 m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value();
123 }
124};
125
126struct BiasAndWeightsTypesCompatible : public Rule
127{
128 template<typename Container>
129 BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c)
130 {
131 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
132 {
133 return dt == GetBiasTypeFromWeightsType(info.GetDataType()).value();
134 });
Derek Lamberti50db4e82019-03-13 14:16:15 +0000135 }
136};
137
138struct ShapesAreSameRank : public Rule
139{
140 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
141 {
142 m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
143 }
144};
145
Derek Lamberti5f400d62019-03-25 15:41:58 +0000146struct ShapesAreSameTotalSize : public Rule
147{
148 ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
149 {
150 m_Res = info0.GetNumElements() == info1.GetNumElements();
151 }
152};
153
Derek Lamberti50db4e82019-03-13 14:16:15 +0000154struct ShapesAreBroadcastCompatible : public Rule
155{
156 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
157 {
158 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
159 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
160 return sizeIn;
161 }
162
163 ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
164 {
165 const TensorShape& shape0 = in0.GetShape();
166 const TensorShape& shape1 = in1.GetShape();
167 const TensorShape& outShape = out.GetShape();
168
169 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
170 {
171 unsigned int sizeOut = outShape[i];
172 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
173 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
174
175 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
176 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
177 }
178 }
179};
180} // namespace
181
182
arovir011c7c81b2018-10-08 11:34:28 +0100183bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
184 const TensorInfo& output,
185 const ActivationDescriptor& descriptor,
186 Optional<std::string&> reasonIfUnsupported) const
187{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000188 bool supported = true;
189
190 // Define supported types.
Teresa Charlin18515e22019-04-24 10:17:46 +0100191 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000192 DataType::Float32,
Teresa Charlin18515e22019-04-24 10:17:46 +0100193 DataType::QuantisedAsymm8,
194 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000195 };
196
197 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
198 "Reference activation: input type not supported.");
199
200 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
201 "Reference activation: output type not supported.");
202
203 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
204 "Reference activation: input and output types mismatched.");
205
206 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
207 "Reference activation: input and output shapes are of different rank.");
208
209
210 struct ActivationFunctionSupported : public Rule
211 {
212 ActivationFunctionSupported(const ActivationDescriptor& desc)
213 {
214 switch(desc.m_Function)
215 {
216 case ActivationFunction::Abs:
217 case ActivationFunction::BoundedReLu:
218 case ActivationFunction::LeakyReLu:
219 case ActivationFunction::Linear:
220 case ActivationFunction::ReLu:
221 case ActivationFunction::Sigmoid:
222 case ActivationFunction::SoftReLu:
223 case ActivationFunction::Sqrt:
224 case ActivationFunction::Square:
225 case ActivationFunction::TanH:
226 {
227 m_Res = true;
228 break;
229 }
230 default:
231 {
232 m_Res = false;
233 break;
234 }
235 }
236 }
237 };
238
239 // Function is supported
240 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
241 "Reference activation: function not supported.");
242
243 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100244}
245
246bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
247 const TensorInfo& input1,
248 const TensorInfo& output,
249 Optional<std::string&> reasonIfUnsupported) const
250{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000251 bool supported = true;
252
Sadik Armagan2999a022019-04-09 14:20:12 +0100253 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000254 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100255 DataType::QuantisedAsymm8,
256 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000257 };
258
259 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
260 "Reference addition: input 0 is not a supported type.");
261
262 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
263 "Reference addition: input 1 is not a supported type.");
264
265 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
266 "Reference addition: output is not a supported type.");
267
268 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
269 "Reference addition: input 0 and Input 1 types are mismatched");
270
271 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
272 "Reference addition: input and output types are mismatched");
273
274 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
275 "Reference addition: shapes are not suitable for implicit broadcast.");
276
277 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100278}
279
280bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
281 const TensorInfo& output,
282 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100283 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100284 const TensorInfo& beta,
285 const TensorInfo& gamma,
286 const BatchNormalizationDescriptor& descriptor,
287 Optional<std::string&> reasonIfUnsupported) const
288{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100289 ignore_unused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100290
291 std::array<DataType, 2> supportedTypes =
292 {
293 DataType::Float32,
294 DataType::QuantisedAsymm8
295 };
296
297 bool supported = true;
298
299 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
300 "Reference batch normalization: input is not a supported type.");
301
302 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
303 "Reference batch normalization: output is not a supported type.");
304
305 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
306 "Reference batch normalization: input and output types are mismatched");
307
308 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
309 "Reference batch normalization: mean is not a supported type.");
310
311 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
312 "Reference batch normalization: variance is not a supported type.");
313
314 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
315 "Reference batch normalization: beta is not a supported type.");
316
317 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
318 "Reference batch normalization: gamma is not a supported type.");
319
320 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100321}
322
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000323bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
324 const TensorInfo& output,
325 const BatchToSpaceNdDescriptor& descriptor,
326 Optional<std::string&> reasonIfUnsupported) const
327{
328 ignore_unused(descriptor);
329 return (IsSupportedForDataTypeRef(reasonIfUnsupported,
330 input.GetDataType(),
331 &TrueFunc<>,
332 &TrueFunc<>) &&
333 IsSupportedForDataTypeRef(reasonIfUnsupported,
334 output.GetDataType(),
335 &TrueFunc<>,
336 &TrueFunc<>));
337}
338
Jim Flynn906f9462019-05-10 13:55:21 +0100339bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
340 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100341 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100342 Optional<std::string&> reasonIfUnsupported) const
343{
Jim Flynne242f2d2019-05-22 14:24:13 +0100344 ignore_unused(descriptor);
345
346 bool supported = true;
347 std::array<DataType,3> supportedTypes =
348 {
349 DataType::Float32,
350 DataType::QuantisedAsymm8,
351 DataType::QuantisedSymm16
352 };
353
354 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
355 "Reference concatenation: output type not supported");
356 for (const TensorInfo* input : inputs)
357 {
358 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
359 "Reference concatenation: input type not supported");
360
361 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
362 "Reference concatenation: input and output types mismatched.");
363 }
364
365 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100366}
367
arovir011c7c81b2018-10-08 11:34:28 +0100368bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
369 Optional<std::string&> reasonIfUnsupported) const
370{
Jim Flynne242f2d2019-05-22 14:24:13 +0100371 std::array<DataType,4> supportedTypes =
372 {
Nina Drozd58ef2c62019-05-16 12:09:18 +0100373 DataType::Float32,
374 DataType::Signed32,
375 DataType::QuantisedAsymm8,
376 DataType::QuantisedSymm16
377 };
378
379 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
380 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100381}
382
383bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
384 const TensorInfo& output,
385 Optional<std::string&> reasonIfUnsupported) const
386{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100387 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
388 input.GetDataType(),
389 &TrueFunc<>,
390 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000391 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000392 &FalseFuncI32<>,
393 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100394 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
395 output.GetDataType(),
396 &FalseOutputFuncF16<>,
397 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000398 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000399 &FalseFuncI32<>,
400 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100401}
402
403bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
404 const TensorInfo& output,
405 Optional<std::string&> reasonIfUnsupported) const
406{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100407 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
408 input.GetDataType(),
409 &FalseInputFuncF16<>,
410 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000411 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000412 &FalseFuncI32<>,
413 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100414 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
415 output.GetDataType(),
416 &TrueFunc<>,
417 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000418 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000419 &FalseFuncI32<>,
420 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100421}
422
423bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
424 const TensorInfo& output,
425 const Convolution2dDescriptor& descriptor,
426 const TensorInfo& weights,
427 const Optional<TensorInfo>& biases,
428 Optional<std::string&> reasonIfUnsupported) const
429{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100430 bool supported = true;
431
432 // Define supported types.
433 std::array<DataType,3> supportedTypes = {
434 DataType::Float32,
435 DataType::QuantisedAsymm8,
436 DataType::QuantisedSymm16
437 };
438
439 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100440 "Reference convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100441
442 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100443 "Reference convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100444
445 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100446 "Reference convolution2d: weights is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100447
448 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100449 "Reference convolution2d: input and output types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100450
451 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100452 "Reference convolution2d: input and weights types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100453
454 if (biases.has_value())
455 {
456 std::array<DataType,3> biasesSupportedTypes = {
457 DataType::Float32,
458 DataType::Signed32
459 };
460 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100461 "Reference convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100462 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100463 ignore_unused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100464
465 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100466}
467
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000468bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
469 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000470 Optional<std::string&> reasonIfUnsupported) const
471{
472 ignore_unused(output);
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000473 return IsSupportedForDataTypeRef(reasonIfUnsupported,
474 input.GetDataType(),
475 &TrueFunc<>,
476 &TrueFunc<>);
477}
478
arovir011c7c81b2018-10-08 11:34:28 +0100479bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
480 const TensorInfo& output,
481 const DepthwiseConvolution2dDescriptor& descriptor,
482 const TensorInfo& weights,
483 const Optional<TensorInfo>& biases,
484 Optional<std::string&> reasonIfUnsupported) const
485{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100486 ignore_unused(output);
487 ignore_unused(descriptor);
488 ignore_unused(weights);
489 ignore_unused(biases);
490 return IsSupportedForDataTypeRef(reasonIfUnsupported,
491 input.GetDataType(),
492 &TrueFunc<>,
493 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100494}
495
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000496bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
497 const TensorInfo& output,
498 Optional<std::string&> reasonIfUnsupported) const
499{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100500 bool supported = true;
501
502 std::array<DataType,2> supportedInputTypes = {
503 DataType::QuantisedAsymm8,
504 DataType::QuantisedSymm16
505 };
506
507 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
508 "Reference dequantize: input type not supported.");
509
510 std::array<DataType,2> supportedOutputTypes = {
511 DataType::Float32,
512 };
513
514 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
515 "Reference dequantize: output type not supported.");
516
517 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
518 "Reference dequantize: input and output shapes have different num total elements.");
519
520 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000521}
522
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000523bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
524 const armnn::TensorInfo& input1,
525 const armnn::DetectionPostProcessDescriptor& descriptor,
526 armnn::Optional<std::string&> reasonIfUnsupported) const
527{
528 ignore_unused(input1);
529 return IsSupportedForDataTypeRef(reasonIfUnsupported,
530 input0.GetDataType(),
531 &TrueFunc<>,
532 &TrueFunc<>);
533}
534
Pablo Tellof0bd6832019-04-26 17:58:13 +0100535bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
536 const TensorInfo& output,
537 const DepthwiseConvolution2dDescriptor& descriptor,
538 const TensorInfo& weights,
539 const Optional<TensorInfo>& biases,
540 Optional<std::string&> reasonIfUnsupported) const
541{
542 if (descriptor.m_DilationY == 1 && descriptor.m_DilationY == 1)
543 {
544 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
545 }
546 else
547 {
548 if (reasonIfUnsupported)
549 {
550 reasonIfUnsupported.value() = "Reference Depthwise Convolution: Dilation parameters must be 1";
551 }
552 return false;
553 }
554}
555
556
557 bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100558 const TensorInfo& input1,
559 const TensorInfo& output,
560 Optional<std::string&> reasonIfUnsupported) const
561{
Sadik Armagan2999a022019-04-09 14:20:12 +0100562 bool supported = true;
563
564 std::array<DataType,3> supportedTypes = {
565 DataType::Float32,
566 DataType::QuantisedAsymm8,
567 DataType::QuantisedSymm16
568 };
569
570 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
571 "Reference division: input 0 is not a supported type.");
572
573 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
574 "Reference division: input 1 is not a supported type.");
575
576 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
577 "Reference division: output is not a supported type.");
578
579 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
580 "Reference division: input 0 and Input 1 types are mismatched");
581
582 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
583 "Reference division: input and output types are mismatched");
584
585 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
586 "Reference division: shapes are not suitable for implicit broadcast.");
587
588 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100589}
590
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000591bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
592 const TensorInfo& input1,
593 const TensorInfo& output,
594 Optional<std::string&> reasonIfUnsupported) const
595{
596 ignore_unused(input0);
597 ignore_unused(input1);
598 ignore_unused(output);
599 ignore_unused(reasonIfUnsupported);
600 return IsSupportedForDataTypeRef(reasonIfUnsupported,
601 input0.GetDataType(),
602 &TrueFunc<>,
603 &TrueFunc<>);
604}
605
arovir011c7c81b2018-10-08 11:34:28 +0100606bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
607 const FakeQuantizationDescriptor& descriptor,
608 Optional<std::string&> reasonIfUnsupported) const
609{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100610 ignore_unused(descriptor);
611 return IsSupportedForDataTypeRef(reasonIfUnsupported,
612 input.GetDataType(),
613 &TrueFunc<>,
614 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100615}
616
617bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
618 const TensorInfo& output,
619 Optional<std::string&> reasonIfUnsupported) const
620{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100621 ignore_unused(output);
James Conroy83735b12019-05-30 16:36:59 +0100622 bool supported = true;
623
James Conroyb40d7102019-06-04 12:32:09 +0100624 std::array<DataType,2> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100625 {
James Conroyb40d7102019-06-04 12:32:09 +0100626 DataType::Float32,
627 DataType::QuantisedSymm16
James Conroy83735b12019-05-30 16:36:59 +0100628 };
629
630 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
631 "Reference Floor: input type not supported.");
632
633 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
634 "Reference Floor: output type not supported.");
635
636 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100637}
638
639bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
640 const TensorInfo& output,
641 const TensorInfo& weights,
642 const TensorInfo& biases,
643 const FullyConnectedDescriptor& descriptor,
644 Optional<std::string&> reasonIfUnsupported) const
645{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100646 bool supported = true;
647
648 // Define supported types.
649 std::array<DataType,3> supportedTypes =
650 {
651 DataType::Float32,
652 DataType::QuantisedAsymm8,
653 DataType::QuantisedSymm16
654 };
655
656 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
657 "Reference Fully Connected: input type not supported.");
658
659 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
660 "Reference Fully Connected: output type not supported.");
661
662 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
663 "Reference Fully Connected: input and output types mismatched.");
664
665 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
666 "Reference Fully Connected: weights type not supported.");
667
668 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
669 "Reference Fully Connected: input and weight types mismatched.");
670
671 if (descriptor.m_BiasEnabled)
672 {
673 // Defined supported types for bias
674 std::array<DataType, 2>
675 supportedBiasTypes =
676 {
677 DataType::Float32,
678 DataType::Signed32
679 };
680
681 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
682 "Reference Fully Connected: bias type not supported.");
683
684 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
685 "Reference Fully Connected: bias and weight types mismatch.");
686
687 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
688 "Reference Fully Connected: bias type inferred from weights is incompatible.");
689
690 }
691
692 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100693}
694
narpra014951d842019-01-18 16:53:53 +0000695bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
696 const armnn::TensorInfo& input1,
697 const armnn::TensorInfo& output,
698 armnn::Optional<std::string&> reasonIfUnsupported) const
699{
700 ignore_unused(input1);
701 ignore_unused(output);
702 return IsSupportedForDataTypeRef(reasonIfUnsupported,
703 input0.GetDataType(),
704 &TrueFunc<>,
705 &TrueFunc<>);
706}
707
FrancisMurtagh878f0232018-12-19 10:56:15 +0000708bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
709 const TensorInfo& input1,
710 const TensorInfo& output,
711 Optional<std::string&> reasonIfUnsupported) const
712{
713 ignore_unused(input0);
714 ignore_unused(input1);
715 ignore_unused(output);
716 ignore_unused(reasonIfUnsupported);
717 return IsSupportedForDataTypeRef(reasonIfUnsupported,
718 input0.GetDataType(),
719 &TrueFunc<>,
720 &TrueFunc<>);
721}
722
arovir011c7c81b2018-10-08 11:34:28 +0100723bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
724 Optional<std::string&> reasonIfUnsupported) const
725{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +0100726 return true;
arovir011c7c81b2018-10-08 11:34:28 +0100727}
728
729bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
730 const TensorInfo& output,
731 const L2NormalizationDescriptor& descriptor,
732 Optional<std::string&> reasonIfUnsupported) const
733{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100734 ignore_unused(output);
735 ignore_unused(descriptor);
736 return IsSupportedForDataTypeRef(reasonIfUnsupported,
737 input.GetDataType(),
738 &TrueFunc<>,
739 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100740}
741
742bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
743 const TensorInfo& outputStateIn,
744 const TensorInfo& cellStateIn,
745 const TensorInfo& scratchBuffer,
746 const TensorInfo& outputStateOut,
747 const TensorInfo& cellStateOut,
748 const TensorInfo& output,
749 const LstmDescriptor& descriptor,
750 const TensorInfo& inputToForgetWeights,
751 const TensorInfo& inputToCellWeights,
752 const TensorInfo& inputToOutputWeights,
753 const TensorInfo& recurrentToForgetWeights,
754 const TensorInfo& recurrentToCellWeights,
755 const TensorInfo& recurrentToOutputWeights,
756 const TensorInfo& forgetGateBias,
757 const TensorInfo& cellBias,
758 const TensorInfo& outputGateBias,
759 const TensorInfo* inputToInputWeights,
760 const TensorInfo* recurrentToInputWeights,
761 const TensorInfo* cellToInputWeights,
762 const TensorInfo* inputGateBias,
763 const TensorInfo* projectionWeights,
764 const TensorInfo* projectionBias,
765 const TensorInfo* cellToForgetWeights,
766 const TensorInfo* cellToOutputWeights,
767 Optional<std::string&> reasonIfUnsupported) const
768{
telsoa01c577f2c2018-08-31 09:22:23 +0100769 ignore_unused(descriptor);
770 ignore_unused(inputToForgetWeights);
771 ignore_unused(inputToCellWeights);
772 ignore_unused(inputToOutputWeights);
773 ignore_unused(recurrentToForgetWeights);
774 ignore_unused(recurrentToCellWeights);
775 ignore_unused(recurrentToOutputWeights);
776 ignore_unused(forgetGateBias);
777 ignore_unused(cellBias);
778 ignore_unused(outputGateBias);
779 ignore_unused(inputToInputWeights);
780 ignore_unused(recurrentToInputWeights);
781 ignore_unused(cellToInputWeights);
782 ignore_unused(inputGateBias);
783 ignore_unused(projectionWeights);
784 ignore_unused(projectionBias);
785 ignore_unused(cellToForgetWeights);
786 ignore_unused(cellToOutputWeights);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100787
788 bool supported = true;
789
790 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +0100791 DataType::Float32,
792 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100793 };
794
795 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
796 "Reference Lstm: input is not a supported type.");
797
798 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
799 "Reference Lstm: input and outputStateIn types are mismatched");
800
801 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
802 "Reference Lstm: input and cellStateIn types are mismatched");
803
804 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
805 "Reference Lstm: input and scratchBuffer types are mismatched");
806
807 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
808 "Reference Lstm: input and outputStateOut types are mismatched");
809
810 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
811 "Reference Lstm: input and cellStateOut types are mismatched");
812
813 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
814 "Reference Lstm: input and output types are mismatched");
815
816 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +0100817}
818
saoste012df12b32018-11-28 16:57:20 +0000819bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
820 const TensorInfo& input1,
821 const TensorInfo& output,
822 Optional<std::string&> reasonIfUnsupported) const
823{
Sadik Armagan2999a022019-04-09 14:20:12 +0100824 bool supported = true;
825
826 std::array<DataType,3> supportedTypes = {
827 DataType::Float32,
828 DataType::QuantisedAsymm8,
829 DataType::QuantisedSymm16
830 };
831
832 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
833 "Reference maximum: input 0 is not a supported type.");
834
835 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
836 "Reference maximum: input 1 is not a supported type.");
837
838 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
839 "Reference maximum: output is not a supported type.");
840
841 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
842 "Reference maximum: input 0 and Input 1 types are mismatched");
843
844 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
845 "Reference maximum: input and output types are mismatched");
846
847 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
848 "Reference maximum: shapes are not suitable for implicit broadcast.");
849
850 return supported;
saoste012df12b32018-11-28 16:57:20 +0000851}
852
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100853bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
854 const TensorInfo& output,
855 const MeanDescriptor& descriptor,
856 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100857{
narpra011e4c31d2018-09-28 11:07:51 +0100858 ignore_unused(output);
859 ignore_unused(descriptor);
860 return IsSupportedForDataTypeRef(reasonIfUnsupported,
861 input.GetDataType(),
862 &TrueFunc<>,
863 &TrueFunc<>);
narpra0132b90462018-09-13 11:07:48 +0100864}
865
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100866bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +0000867 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100868 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100869 Optional<std::string&> reasonIfUnsupported) const
870{
Jim Flynne242f2d2019-05-22 14:24:13 +0100871 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100872}
873
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000874bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
875 const TensorInfo &output,
876 Optional<std::string &> reasonIfUnsupported) const
877{
878 ignore_unused(output);
kevmay012b4d88e2019-01-24 14:05:09 +0000879 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
880 input.GetDataType(),
881 &TrueFunc<>,
882 &TrueFunc<>,
883 &TrueFunc<>,
884 &FalseFuncI32<>,
885 &TrueFunc<>);
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000886}
887
Éanna Ó Catháin20e58802018-12-04 10:29:06 +0000888bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
889 const TensorInfo& input1,
890 const TensorInfo& output,
891 Optional<std::string&> reasonIfUnsupported) const
892{
Sadik Armagan2999a022019-04-09 14:20:12 +0100893 bool supported = true;
894
895 std::array<DataType,3> supportedTypes = {
896 DataType::Float32,
897 DataType::QuantisedAsymm8,
898 DataType::QuantisedSymm16
899 };
900
901 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
902 "Reference minimum: input 0 is not a supported type.");
903
904 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
905 "Reference minimum: input 1 is not a supported type.");
906
907 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
908 "Reference minimum: output is not a supported type.");
909
910 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
911 "Reference minimum: input 0 and Input 1 types are mismatched");
912
913 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
914 "Reference minimum: input and output types are mismatched");
915
916 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
917 "Reference minimum: shapes are not suitable for implicit broadcast.");
918
919 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +0000920}
921
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100922bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
923 const TensorInfo& input1,
924 const TensorInfo& output,
925 Optional<std::string&> reasonIfUnsupported) const
926{
Sadik Armagan2999a022019-04-09 14:20:12 +0100927 bool supported = true;
928
929 std::array<DataType,3> supportedTypes = {
930 DataType::Float32,
931 DataType::QuantisedAsymm8,
932 DataType::QuantisedSymm16
933 };
934
935 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
936 "Reference multiplication: input 0 is not a supported type.");
937
938 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
939 "Reference multiplication: input 1 is not a supported type.");
940
941 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
942 "Reference multiplication: output is not a supported type.");
943
944 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
945 "Reference multiplication: input 0 and Input 1 types are mismatched");
946
947 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
948 "Reference multiplication: input and output types are mismatched");
949
950 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
951 "Reference multiplication: shapes are not suitable for implicit broadcast.");
952
953 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100954}
955
956bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
957 const TensorInfo& output,
958 const NormalizationDescriptor& descriptor,
959 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +0100960{
961 ignore_unused(output);
962 ignore_unused(descriptor);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100963 return IsSupportedForDataTypeRef(reasonIfUnsupported,
964 input.GetDataType(),
965 &TrueFunc<>,
966 &FalseFuncU8<>);
967}
968
969bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
970 Optional<std::string&> reasonIfUnsupported) const
971{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +0100972 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100973}
974
975bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
976 const TensorInfo& output,
977 const PadDescriptor& descriptor,
978 Optional<std::string&> reasonIfUnsupported) const
979{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100980 ignore_unused(output);
981 ignore_unused(descriptor);
jimfly01f6ba7472018-12-04 10:09:52 +0000982 return IsSupportedForDataTypeRef(reasonIfUnsupported,
983 input.GetDataType(),
984 &TrueFunc<>,
985 &TrueFunc<>);
Nina Drozd661dfa72018-10-02 11:14:17 +0100986}
987
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100988bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
989 const TensorInfo& output,
990 const PermuteDescriptor& descriptor,
991 Optional<std::string&> reasonIfUnsupported) const
992{
993 ignore_unused(output);
994 ignore_unused(descriptor);
995 return IsSupportedForDataTypeRef(reasonIfUnsupported,
996 input.GetDataType(),
997 &TrueFunc<>,
998 &TrueFunc<>);
999}
1000
1001bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1002 const TensorInfo& output,
1003 const Pooling2dDescriptor& descriptor,
1004 Optional<std::string&> reasonIfUnsupported) const
1005{
1006 ignore_unused(output);
1007 ignore_unused(descriptor);
1008 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1009 input.GetDataType(),
1010 &TrueFunc<>,
1011 &TrueFunc<>);
1012}
1013
Derek Lamberti5f400d62019-03-25 15:41:58 +00001014bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1015 const TensorInfo& output,
1016 Optional<std::string&> reasonIfUnsupported) const
1017{
1018 bool supported = true;
1019
1020 // Define supported output types.
1021 std::array<DataType,2> supportedInputTypes = {
1022 DataType::Float32,
1023 };
1024
1025 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1026 "Reference quantize: input type not supported.");
1027
1028 // Define supported output types.
1029 std::array<DataType,2> supportedOutputTypes = {
1030 DataType::QuantisedAsymm8,
1031 DataType::QuantisedSymm16
1032 };
1033 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1034 "Reference quantize: output type not supported.");
1035
1036 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1037 "Reference quantize: input and output shapes have different num total elements.");
1038
1039 return supported;
1040}
1041
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001042bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001043 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001044 Optional<std::string&> reasonIfUnsupported) const
1045{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001046 ignore_unused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001047 // Define supported output types.
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001048 std::array<DataType,4> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001049 {
1050 DataType::Float32,
1051 DataType::Float16,
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001052 DataType::QuantisedAsymm8,
1053 DataType::QuantisedSymm16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001054 };
1055 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1056 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001057}
1058
1059bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001060 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001061 Optional<std::string&> reasonIfUnsupported) const
1062{
Sadik Armaganc625f002018-12-17 11:32:16 +00001063 ignore_unused(output);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001064 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1065 input.GetDataType(),
1066 &TrueFunc<>,
1067 &TrueFunc<>);
1068}
1069
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001070bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1071 const TensorInfo& output,
1072 Optional<std::string&> reasonIfUnsupported) const
1073{
1074 ignore_unused(output);
1075 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1076 input.GetDataType(),
1077 &TrueFunc<>,
1078 &FalseFuncU8<>);
1079}
1080
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001081bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1082 const TensorInfo& output,
1083 const SoftmaxDescriptor& descriptor,
1084 Optional<std::string&> reasonIfUnsupported) const
1085{
1086 ignore_unused(output);
nikraj01248683f2019-05-29 16:46:50 +01001087 bool supported = true;
1088 std::array<DataType,3> supportedTypes =
1089 {
1090 DataType::Float32,
1091 DataType::QuantisedAsymm8,
1092 DataType::QuantisedSymm16
1093 };
1094
1095 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1096 "Reference concatenation: output type not supported");
1097
1098 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1099 "Reference concatenation: input type not supported");
1100
1101 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1102 "Reference concatenation: input type not supported");
1103
1104 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001105}
1106
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001107bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1108 const TensorInfo& output,
1109 const SpaceToBatchNdDescriptor& descriptor,
1110 Optional<std::string&> reasonIfUnsupported) const
1111{
1112 ignore_unused(output);
nikraj01120522a2019-05-31 11:33:07 +01001113 bool supported = true;
1114 std::array<DataType,3> supportedTypes =
1115 {
1116 DataType::Float32,
1117 DataType::QuantisedAsymm8,
1118 DataType::QuantisedSymm16
1119 };
1120
1121 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1122 "Reference SpaceToBatchNd: input type not supported");
1123
1124 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1125 "Reference SpaceToBatchNd: output type not supported");
1126
1127 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1128 "Reference SpaceToBatchNd: input and output types are mismatched");
1129
1130 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001131}
1132
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001133bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1134 const ViewsDescriptor& descriptor,
1135 Optional<std::string&> reasonIfUnsupported) const
1136{
1137 ignore_unused(descriptor);
1138 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1139 input.GetDataType(),
1140 &TrueFunc<>,
1141 &TrueFunc<>);
1142}
1143
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001144bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1145 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1146 const ViewsDescriptor& descriptor,
1147 Optional<std::string&> reasonIfUnsupported) const
1148{
1149 ignore_unused(descriptor);
1150 ignore_unused(outputs);
1151 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1152 input.GetDataType(),
1153 &TrueFunc<>,
1154 &TrueFunc<>);
1155}
1156
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001157bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1158 const TensorInfo& output,
1159 const StridedSliceDescriptor& descriptor,
1160 Optional<std::string&> reasonIfUnsupported) const
1161{
1162 ignore_unused(output);
1163 ignore_unused(descriptor);
1164 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1165 input.GetDataType(),
1166 &TrueFunc<>,
1167 &TrueFunc<>);
1168}
1169
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001170bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1171 const TensorInfo& input1,
1172 const TensorInfo& output,
1173 Optional<std::string&> reasonIfUnsupported) const
1174{
Sadik Armagan2999a022019-04-09 14:20:12 +01001175 bool supported = true;
1176
1177 std::array<DataType,3> supportedTypes = {
1178 DataType::Float32,
1179 DataType::QuantisedAsymm8,
1180 DataType::QuantisedSymm16
1181 };
1182
1183 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1184 "Reference subtraction: input 0 is not a supported type.");
1185
1186 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1187 "Reference subtraction: input 1 is not a supported type.");
1188
1189 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1190 "Reference subtraction: output is not a supported type.");
1191
1192 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1193 "Reference subtraction: input 0 and Input 1 types are mismatched");
1194
1195 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1196 "Reference subtraction: input and output types are mismatched");
1197
1198 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1199 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1200
1201 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001202}
1203
arovir011c7c81b2018-10-08 11:34:28 +01001204} // namespace armnn