blob: 1d0b230c777e8b2ea86db8ce137c8a0b9d547c73 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <InternalTypes.hpp>
10#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000012#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
David Beck111b5d92018-11-12 14:59:37 +000014#include <backendsCommon/BackendRegistry.hpp>
Francis Murtagh46c09d02019-05-28 08:15:28 +010015#include <backendsCommon/test/WorkloadTestUtils.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010016
telsoa014fcda012018-03-09 14:13:49 +000017#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
Derek Lamberti50db4e82019-03-13 14:16:15 +000019#include <vector>
20#include <algorithm>
21#include <array>
22
telsoa014fcda012018-03-09 14:13:49 +000023using namespace boost;
24
25namespace armnn
26{
27
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010028namespace
29{
30
31template<typename Float32Func, typename Uint8Func, typename ... Params>
32bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
33 DataType dataType,
34 Float32Func floatFuncPtr,
35 Uint8Func uint8FuncPtr,
36 Params&&... params)
37{
38 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
39 dataType,
40 &FalseFunc<Params...>,
41 floatFuncPtr,
42 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000043 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000044 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010045 std::forward<Params>(params)...);
46}
47
48} // anonymous namespace
49
Derek Lamberti50db4e82019-03-13 14:16:15 +000050
51namespace
52{
53template<typename F>
54bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
55{
56 bool supported = rule();
57 if (!supported && reason)
58 {
59 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
60 }
61 return supported;
62}
63
64struct Rule
65{
66 bool operator()() const
67 {
68 return m_Res;
69 }
70
71 bool m_Res = true;
72};
73
Derek Lamberti2a434a82019-03-20 13:07:57 +000074template<typename T>
75bool AllTypesAreEqualImpl(T t)
Derek Lamberti50db4e82019-03-13 14:16:15 +000076{
77 return true;
78}
79
80template<typename T, typename... Rest>
81bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
82{
83 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
84
Derek Lamberti2a434a82019-03-20 13:07:57 +000085 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
Derek Lamberti50db4e82019-03-13 14:16:15 +000086}
87
88struct TypesAreEqual : public Rule
89{
90 template<typename ... Ts>
91 TypesAreEqual(const Ts&... ts)
92 {
93 m_Res = AllTypesAreEqualImpl(ts...);
94 }
95};
96
97struct QuantizationParametersAreEqual : public Rule
98{
99 QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
100 {
101 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
102 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
103 }
104};
105
106struct TypeAnyOf : public Rule
107{
108 template<typename Container>
109 TypeAnyOf(const TensorInfo& info, const Container& c)
110 {
111 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
Francis Murtagh46c09d02019-05-28 08:15:28 +0100112 {
113 return dt == info.GetDataType();
114 });
115 }
116};
117
118struct BiasAndWeightsTypesMatch : public Rule
119{
120 BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
121 {
122 m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value();
123 }
124};
125
126struct BiasAndWeightsTypesCompatible : public Rule
127{
128 template<typename Container>
129 BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c)
130 {
131 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
132 {
133 return dt == GetBiasTypeFromWeightsType(info.GetDataType()).value();
134 });
Derek Lamberti50db4e82019-03-13 14:16:15 +0000135 }
136};
137
138struct ShapesAreSameRank : public Rule
139{
140 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
141 {
142 m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
143 }
144};
145
Derek Lamberti5f400d62019-03-25 15:41:58 +0000146struct ShapesAreSameTotalSize : public Rule
147{
148 ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
149 {
150 m_Res = info0.GetNumElements() == info1.GetNumElements();
151 }
152};
153
Derek Lamberti50db4e82019-03-13 14:16:15 +0000154struct ShapesAreBroadcastCompatible : public Rule
155{
156 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
157 {
158 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
159 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
160 return sizeIn;
161 }
162
163 ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
164 {
165 const TensorShape& shape0 = in0.GetShape();
166 const TensorShape& shape1 = in1.GetShape();
167 const TensorShape& outShape = out.GetShape();
168
169 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
170 {
171 unsigned int sizeOut = outShape[i];
172 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
173 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
174
175 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
176 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
177 }
178 }
179};
180} // namespace
181
182
arovir011c7c81b2018-10-08 11:34:28 +0100183bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
184 const TensorInfo& output,
185 const ActivationDescriptor& descriptor,
186 Optional<std::string&> reasonIfUnsupported) const
187{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000188 bool supported = true;
189
190 // Define supported types.
Teresa Charlin18515e22019-04-24 10:17:46 +0100191 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000192 DataType::Float32,
Teresa Charlin18515e22019-04-24 10:17:46 +0100193 DataType::QuantisedAsymm8,
194 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000195 };
196
197 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
198 "Reference activation: input type not supported.");
199
200 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
201 "Reference activation: output type not supported.");
202
203 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
204 "Reference activation: input and output types mismatched.");
205
206 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
207 "Reference activation: input and output shapes are of different rank.");
208
209
210 struct ActivationFunctionSupported : public Rule
211 {
212 ActivationFunctionSupported(const ActivationDescriptor& desc)
213 {
214 switch(desc.m_Function)
215 {
216 case ActivationFunction::Abs:
217 case ActivationFunction::BoundedReLu:
218 case ActivationFunction::LeakyReLu:
219 case ActivationFunction::Linear:
220 case ActivationFunction::ReLu:
221 case ActivationFunction::Sigmoid:
222 case ActivationFunction::SoftReLu:
223 case ActivationFunction::Sqrt:
224 case ActivationFunction::Square:
225 case ActivationFunction::TanH:
226 {
227 m_Res = true;
228 break;
229 }
230 default:
231 {
232 m_Res = false;
233 break;
234 }
235 }
236 }
237 };
238
239 // Function is supported
240 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
241 "Reference activation: function not supported.");
242
243 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100244}
245
246bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
247 const TensorInfo& input1,
248 const TensorInfo& output,
249 Optional<std::string&> reasonIfUnsupported) const
250{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000251 bool supported = true;
252
Sadik Armagan2999a022019-04-09 14:20:12 +0100253 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000254 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100255 DataType::QuantisedAsymm8,
256 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000257 };
258
259 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
260 "Reference addition: input 0 is not a supported type.");
261
262 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
263 "Reference addition: input 1 is not a supported type.");
264
265 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
266 "Reference addition: output is not a supported type.");
267
268 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
269 "Reference addition: input 0 and Input 1 types are mismatched");
270
271 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
272 "Reference addition: input and output types are mismatched");
273
274 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
275 "Reference addition: shapes are not suitable for implicit broadcast.");
276
277 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100278}
279
280bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
281 const TensorInfo& output,
282 const TensorInfo& mean,
283 const TensorInfo& var,
284 const TensorInfo& beta,
285 const TensorInfo& gamma,
286 const BatchNormalizationDescriptor& descriptor,
287 Optional<std::string&> reasonIfUnsupported) const
288{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100289 ignore_unused(output);
290 ignore_unused(mean);
291 ignore_unused(var);
292 ignore_unused(beta);
293 ignore_unused(gamma);
294 ignore_unused(descriptor);
295 return IsSupportedForDataTypeRef(reasonIfUnsupported,
296 input.GetDataType(),
297 &TrueFunc<>,
298 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100299}
300
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000301bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
302 const TensorInfo& output,
303 const BatchToSpaceNdDescriptor& descriptor,
304 Optional<std::string&> reasonIfUnsupported) const
305{
306 ignore_unused(descriptor);
307 return (IsSupportedForDataTypeRef(reasonIfUnsupported,
308 input.GetDataType(),
309 &TrueFunc<>,
310 &TrueFunc<>) &&
311 IsSupportedForDataTypeRef(reasonIfUnsupported,
312 output.GetDataType(),
313 &TrueFunc<>,
314 &TrueFunc<>));
315}
316
Jim Flynn906f9462019-05-10 13:55:21 +0100317bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
318 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100319 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100320 Optional<std::string&> reasonIfUnsupported) const
321{
Jim Flynne242f2d2019-05-22 14:24:13 +0100322 ignore_unused(descriptor);
323
324 bool supported = true;
325 std::array<DataType,3> supportedTypes =
326 {
327 DataType::Float32,
328 DataType::QuantisedAsymm8,
329 DataType::QuantisedSymm16
330 };
331
332 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
333 "Reference concatenation: output type not supported");
334 for (const TensorInfo* input : inputs)
335 {
336 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
337 "Reference concatenation: input type not supported");
338
339 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
340 "Reference concatenation: input and output types mismatched.");
341 }
342
343 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100344}
345
arovir011c7c81b2018-10-08 11:34:28 +0100346bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
347 Optional<std::string&> reasonIfUnsupported) const
348{
Jim Flynne242f2d2019-05-22 14:24:13 +0100349 std::array<DataType,4> supportedTypes =
350 {
Nina Drozd58ef2c62019-05-16 12:09:18 +0100351 DataType::Float32,
352 DataType::Signed32,
353 DataType::QuantisedAsymm8,
354 DataType::QuantisedSymm16
355 };
356
357 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
358 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100359}
360
361bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
362 const TensorInfo& output,
363 Optional<std::string&> reasonIfUnsupported) const
364{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100365 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
366 input.GetDataType(),
367 &TrueFunc<>,
368 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000369 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000370 &FalseFuncI32<>,
371 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100372 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
373 output.GetDataType(),
374 &FalseOutputFuncF16<>,
375 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000376 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000377 &FalseFuncI32<>,
378 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100379}
380
381bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
382 const TensorInfo& output,
383 Optional<std::string&> reasonIfUnsupported) const
384{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100385 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
386 input.GetDataType(),
387 &FalseInputFuncF16<>,
388 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000389 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000390 &FalseFuncI32<>,
391 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100392 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
393 output.GetDataType(),
394 &TrueFunc<>,
395 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000396 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000397 &FalseFuncI32<>,
398 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100399}
400
401bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
402 const TensorInfo& output,
403 const Convolution2dDescriptor& descriptor,
404 const TensorInfo& weights,
405 const Optional<TensorInfo>& biases,
406 Optional<std::string&> reasonIfUnsupported) const
407{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100408 bool supported = true;
409
410 // Define supported types.
411 std::array<DataType,3> supportedTypes = {
412 DataType::Float32,
413 DataType::QuantisedAsymm8,
414 DataType::QuantisedSymm16
415 };
416
417 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100418 "Reference convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100419
420 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100421 "Reference convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100422
423 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100424 "Reference convolution2d: weights is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100425
426 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100427 "Reference convolution2d: input and output types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100428
429 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100430 "Reference convolution2d: input and weights types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100431
432 if (biases.has_value())
433 {
434 std::array<DataType,3> biasesSupportedTypes = {
435 DataType::Float32,
436 DataType::Signed32
437 };
438 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100439 "Reference convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100440 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100441 ignore_unused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100442
443 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100444}
445
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000446bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
447 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000448 Optional<std::string&> reasonIfUnsupported) const
449{
450 ignore_unused(output);
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000451 return IsSupportedForDataTypeRef(reasonIfUnsupported,
452 input.GetDataType(),
453 &TrueFunc<>,
454 &TrueFunc<>);
455}
456
arovir011c7c81b2018-10-08 11:34:28 +0100457bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
458 const TensorInfo& output,
459 const DepthwiseConvolution2dDescriptor& descriptor,
460 const TensorInfo& weights,
461 const Optional<TensorInfo>& biases,
462 Optional<std::string&> reasonIfUnsupported) const
463{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100464 ignore_unused(output);
465 ignore_unused(descriptor);
466 ignore_unused(weights);
467 ignore_unused(biases);
468 return IsSupportedForDataTypeRef(reasonIfUnsupported,
469 input.GetDataType(),
470 &TrueFunc<>,
471 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100472}
473
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000474bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
475 const TensorInfo& output,
476 Optional<std::string&> reasonIfUnsupported) const
477{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100478 bool supported = true;
479
480 std::array<DataType,2> supportedInputTypes = {
481 DataType::QuantisedAsymm8,
482 DataType::QuantisedSymm16
483 };
484
485 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
486 "Reference dequantize: input type not supported.");
487
488 std::array<DataType,2> supportedOutputTypes = {
489 DataType::Float32,
490 };
491
492 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
493 "Reference dequantize: output type not supported.");
494
495 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
496 "Reference dequantize: input and output shapes have different num total elements.");
497
498 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000499}
500
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000501bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
502 const armnn::TensorInfo& input1,
503 const armnn::DetectionPostProcessDescriptor& descriptor,
504 armnn::Optional<std::string&> reasonIfUnsupported) const
505{
506 ignore_unused(input1);
507 return IsSupportedForDataTypeRef(reasonIfUnsupported,
508 input0.GetDataType(),
509 &TrueFunc<>,
510 &TrueFunc<>);
511}
512
Pablo Tellof0bd6832019-04-26 17:58:13 +0100513bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
514 const TensorInfo& output,
515 const DepthwiseConvolution2dDescriptor& descriptor,
516 const TensorInfo& weights,
517 const Optional<TensorInfo>& biases,
518 Optional<std::string&> reasonIfUnsupported) const
519{
520 if (descriptor.m_DilationY == 1 && descriptor.m_DilationY == 1)
521 {
522 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
523 }
524 else
525 {
526 if (reasonIfUnsupported)
527 {
528 reasonIfUnsupported.value() = "Reference Depthwise Convolution: Dilation parameters must be 1";
529 }
530 return false;
531 }
532}
533
534
535 bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100536 const TensorInfo& input1,
537 const TensorInfo& output,
538 Optional<std::string&> reasonIfUnsupported) const
539{
Sadik Armagan2999a022019-04-09 14:20:12 +0100540 bool supported = true;
541
542 std::array<DataType,3> supportedTypes = {
543 DataType::Float32,
544 DataType::QuantisedAsymm8,
545 DataType::QuantisedSymm16
546 };
547
548 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
549 "Reference division: input 0 is not a supported type.");
550
551 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
552 "Reference division: input 1 is not a supported type.");
553
554 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
555 "Reference division: output is not a supported type.");
556
557 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
558 "Reference division: input 0 and Input 1 types are mismatched");
559
560 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
561 "Reference division: input and output types are mismatched");
562
563 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
564 "Reference division: shapes are not suitable for implicit broadcast.");
565
566 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100567}
568
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000569bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
570 const TensorInfo& input1,
571 const TensorInfo& output,
572 Optional<std::string&> reasonIfUnsupported) const
573{
574 ignore_unused(input0);
575 ignore_unused(input1);
576 ignore_unused(output);
577 ignore_unused(reasonIfUnsupported);
578 return IsSupportedForDataTypeRef(reasonIfUnsupported,
579 input0.GetDataType(),
580 &TrueFunc<>,
581 &TrueFunc<>);
582}
583
arovir011c7c81b2018-10-08 11:34:28 +0100584bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
585 const FakeQuantizationDescriptor& descriptor,
586 Optional<std::string&> reasonIfUnsupported) const
587{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100588 ignore_unused(descriptor);
589 return IsSupportedForDataTypeRef(reasonIfUnsupported,
590 input.GetDataType(),
591 &TrueFunc<>,
592 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100593}
594
595bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
596 const TensorInfo& output,
597 Optional<std::string&> reasonIfUnsupported) const
598{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100599 ignore_unused(output);
James Conroy83735b12019-05-30 16:36:59 +0100600 bool supported = true;
601
James Conroyb40d7102019-06-04 12:32:09 +0100602 std::array<DataType,2> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100603 {
James Conroyb40d7102019-06-04 12:32:09 +0100604 DataType::Float32,
605 DataType::QuantisedSymm16
James Conroy83735b12019-05-30 16:36:59 +0100606 };
607
608 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
609 "Reference Floor: input type not supported.");
610
611 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
612 "Reference Floor: output type not supported.");
613
614 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100615}
616
617bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
618 const TensorInfo& output,
619 const TensorInfo& weights,
620 const TensorInfo& biases,
621 const FullyConnectedDescriptor& descriptor,
622 Optional<std::string&> reasonIfUnsupported) const
623{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100624 bool supported = true;
625
626 // Define supported types.
627 std::array<DataType,3> supportedTypes =
628 {
629 DataType::Float32,
630 DataType::QuantisedAsymm8,
631 DataType::QuantisedSymm16
632 };
633
634 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
635 "Reference Fully Connected: input type not supported.");
636
637 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
638 "Reference Fully Connected: output type not supported.");
639
640 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
641 "Reference Fully Connected: input and output types mismatched.");
642
643 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
644 "Reference Fully Connected: weights type not supported.");
645
646 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
647 "Reference Fully Connected: input and weight types mismatched.");
648
649 if (descriptor.m_BiasEnabled)
650 {
651 // Defined supported types for bias
652 std::array<DataType, 2>
653 supportedBiasTypes =
654 {
655 DataType::Float32,
656 DataType::Signed32
657 };
658
659 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
660 "Reference Fully Connected: bias type not supported.");
661
662 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
663 "Reference Fully Connected: bias and weight types mismatch.");
664
665 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
666 "Reference Fully Connected: bias type inferred from weights is incompatible.");
667
668 }
669
670 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100671}
672
narpra014951d842019-01-18 16:53:53 +0000673bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
674 const armnn::TensorInfo& input1,
675 const armnn::TensorInfo& output,
676 armnn::Optional<std::string&> reasonIfUnsupported) const
677{
678 ignore_unused(input1);
679 ignore_unused(output);
680 return IsSupportedForDataTypeRef(reasonIfUnsupported,
681 input0.GetDataType(),
682 &TrueFunc<>,
683 &TrueFunc<>);
684}
685
FrancisMurtagh878f0232018-12-19 10:56:15 +0000686bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
687 const TensorInfo& input1,
688 const TensorInfo& output,
689 Optional<std::string&> reasonIfUnsupported) const
690{
691 ignore_unused(input0);
692 ignore_unused(input1);
693 ignore_unused(output);
694 ignore_unused(reasonIfUnsupported);
695 return IsSupportedForDataTypeRef(reasonIfUnsupported,
696 input0.GetDataType(),
697 &TrueFunc<>,
698 &TrueFunc<>);
699}
700
arovir011c7c81b2018-10-08 11:34:28 +0100701bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
702 Optional<std::string&> reasonIfUnsupported) const
703{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100704 return IsSupportedForDataTypeRef(reasonIfUnsupported,
705 input.GetDataType(),
706 &TrueFunc<>,
707 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100708}
709
710bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
711 const TensorInfo& output,
712 const L2NormalizationDescriptor& descriptor,
713 Optional<std::string&> reasonIfUnsupported) const
714{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100715 ignore_unused(output);
716 ignore_unused(descriptor);
717 return IsSupportedForDataTypeRef(reasonIfUnsupported,
718 input.GetDataType(),
719 &TrueFunc<>,
720 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100721}
722
723bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
724 const TensorInfo& outputStateIn,
725 const TensorInfo& cellStateIn,
726 const TensorInfo& scratchBuffer,
727 const TensorInfo& outputStateOut,
728 const TensorInfo& cellStateOut,
729 const TensorInfo& output,
730 const LstmDescriptor& descriptor,
731 const TensorInfo& inputToForgetWeights,
732 const TensorInfo& inputToCellWeights,
733 const TensorInfo& inputToOutputWeights,
734 const TensorInfo& recurrentToForgetWeights,
735 const TensorInfo& recurrentToCellWeights,
736 const TensorInfo& recurrentToOutputWeights,
737 const TensorInfo& forgetGateBias,
738 const TensorInfo& cellBias,
739 const TensorInfo& outputGateBias,
740 const TensorInfo* inputToInputWeights,
741 const TensorInfo* recurrentToInputWeights,
742 const TensorInfo* cellToInputWeights,
743 const TensorInfo* inputGateBias,
744 const TensorInfo* projectionWeights,
745 const TensorInfo* projectionBias,
746 const TensorInfo* cellToForgetWeights,
747 const TensorInfo* cellToOutputWeights,
748 Optional<std::string&> reasonIfUnsupported) const
749{
telsoa01c577f2c2018-08-31 09:22:23 +0100750 ignore_unused(descriptor);
751 ignore_unused(inputToForgetWeights);
752 ignore_unused(inputToCellWeights);
753 ignore_unused(inputToOutputWeights);
754 ignore_unused(recurrentToForgetWeights);
755 ignore_unused(recurrentToCellWeights);
756 ignore_unused(recurrentToOutputWeights);
757 ignore_unused(forgetGateBias);
758 ignore_unused(cellBias);
759 ignore_unused(outputGateBias);
760 ignore_unused(inputToInputWeights);
761 ignore_unused(recurrentToInputWeights);
762 ignore_unused(cellToInputWeights);
763 ignore_unused(inputGateBias);
764 ignore_unused(projectionWeights);
765 ignore_unused(projectionBias);
766 ignore_unused(cellToForgetWeights);
767 ignore_unused(cellToOutputWeights);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100768
769 bool supported = true;
770
771 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +0100772 DataType::Float32,
773 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100774 };
775
776 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
777 "Reference Lstm: input is not a supported type.");
778
779 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
780 "Reference Lstm: input and outputStateIn types are mismatched");
781
782 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
783 "Reference Lstm: input and cellStateIn types are mismatched");
784
785 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
786 "Reference Lstm: input and scratchBuffer types are mismatched");
787
788 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
789 "Reference Lstm: input and outputStateOut types are mismatched");
790
791 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
792 "Reference Lstm: input and cellStateOut types are mismatched");
793
794 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
795 "Reference Lstm: input and output types are mismatched");
796
797 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +0100798}
799
saoste012df12b32018-11-28 16:57:20 +0000800bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
801 const TensorInfo& input1,
802 const TensorInfo& output,
803 Optional<std::string&> reasonIfUnsupported) const
804{
Sadik Armagan2999a022019-04-09 14:20:12 +0100805 bool supported = true;
806
807 std::array<DataType,3> supportedTypes = {
808 DataType::Float32,
809 DataType::QuantisedAsymm8,
810 DataType::QuantisedSymm16
811 };
812
813 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
814 "Reference maximum: input 0 is not a supported type.");
815
816 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
817 "Reference maximum: input 1 is not a supported type.");
818
819 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
820 "Reference maximum: output is not a supported type.");
821
822 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
823 "Reference maximum: input 0 and Input 1 types are mismatched");
824
825 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
826 "Reference maximum: input and output types are mismatched");
827
828 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
829 "Reference maximum: shapes are not suitable for implicit broadcast.");
830
831 return supported;
saoste012df12b32018-11-28 16:57:20 +0000832}
833
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100834bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
835 const TensorInfo& output,
836 const MeanDescriptor& descriptor,
837 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100838{
narpra011e4c31d2018-09-28 11:07:51 +0100839 ignore_unused(output);
840 ignore_unused(descriptor);
841 return IsSupportedForDataTypeRef(reasonIfUnsupported,
842 input.GetDataType(),
843 &TrueFunc<>,
844 &TrueFunc<>);
narpra0132b90462018-09-13 11:07:48 +0100845}
846
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100847bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +0000848 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100849 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100850 Optional<std::string&> reasonIfUnsupported) const
851{
Jim Flynne242f2d2019-05-22 14:24:13 +0100852 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100853}
854
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000855bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
856 const TensorInfo &output,
857 Optional<std::string &> reasonIfUnsupported) const
858{
859 ignore_unused(output);
kevmay012b4d88e2019-01-24 14:05:09 +0000860 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
861 input.GetDataType(),
862 &TrueFunc<>,
863 &TrueFunc<>,
864 &TrueFunc<>,
865 &FalseFuncI32<>,
866 &TrueFunc<>);
Matteo Martincigh992d6dc2019-01-10 17:34:20 +0000867}
868
Éanna Ó Catháin20e58802018-12-04 10:29:06 +0000869bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
870 const TensorInfo& input1,
871 const TensorInfo& output,
872 Optional<std::string&> reasonIfUnsupported) const
873{
Sadik Armagan2999a022019-04-09 14:20:12 +0100874 bool supported = true;
875
876 std::array<DataType,3> supportedTypes = {
877 DataType::Float32,
878 DataType::QuantisedAsymm8,
879 DataType::QuantisedSymm16
880 };
881
882 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
883 "Reference minimum: input 0 is not a supported type.");
884
885 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
886 "Reference minimum: input 1 is not a supported type.");
887
888 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
889 "Reference minimum: output is not a supported type.");
890
891 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
892 "Reference minimum: input 0 and Input 1 types are mismatched");
893
894 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
895 "Reference minimum: input and output types are mismatched");
896
897 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
898 "Reference minimum: shapes are not suitable for implicit broadcast.");
899
900 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +0000901}
902
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100903bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
904 const TensorInfo& input1,
905 const TensorInfo& output,
906 Optional<std::string&> reasonIfUnsupported) const
907{
Sadik Armagan2999a022019-04-09 14:20:12 +0100908 bool supported = true;
909
910 std::array<DataType,3> supportedTypes = {
911 DataType::Float32,
912 DataType::QuantisedAsymm8,
913 DataType::QuantisedSymm16
914 };
915
916 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
917 "Reference multiplication: input 0 is not a supported type.");
918
919 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
920 "Reference multiplication: input 1 is not a supported type.");
921
922 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
923 "Reference multiplication: output is not a supported type.");
924
925 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
926 "Reference multiplication: input 0 and Input 1 types are mismatched");
927
928 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
929 "Reference multiplication: input and output types are mismatched");
930
931 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
932 "Reference multiplication: shapes are not suitable for implicit broadcast.");
933
934 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100935}
936
937bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
938 const TensorInfo& output,
939 const NormalizationDescriptor& descriptor,
940 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +0100941{
942 ignore_unused(output);
943 ignore_unused(descriptor);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100944 return IsSupportedForDataTypeRef(reasonIfUnsupported,
945 input.GetDataType(),
946 &TrueFunc<>,
947 &FalseFuncU8<>);
948}
949
950bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
951 Optional<std::string&> reasonIfUnsupported) const
952{
kevmay012b4d88e2019-01-24 14:05:09 +0000953 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
954 output.GetDataType(),
955 &TrueFunc<>,
956 &TrueFunc<>,
957 &TrueFunc<>,
958 &FalseFuncI32<>,
959 &TrueFunc<>);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100960}
961
962bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
963 const TensorInfo& output,
964 const PadDescriptor& descriptor,
965 Optional<std::string&> reasonIfUnsupported) const
966{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100967 ignore_unused(output);
968 ignore_unused(descriptor);
jimfly01f6ba7472018-12-04 10:09:52 +0000969 return IsSupportedForDataTypeRef(reasonIfUnsupported,
970 input.GetDataType(),
971 &TrueFunc<>,
972 &TrueFunc<>);
Nina Drozd661dfa72018-10-02 11:14:17 +0100973}
974
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100975bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
976 const TensorInfo& output,
977 const PermuteDescriptor& descriptor,
978 Optional<std::string&> reasonIfUnsupported) const
979{
980 ignore_unused(output);
981 ignore_unused(descriptor);
982 return IsSupportedForDataTypeRef(reasonIfUnsupported,
983 input.GetDataType(),
984 &TrueFunc<>,
985 &TrueFunc<>);
986}
987
988bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
989 const TensorInfo& output,
990 const Pooling2dDescriptor& descriptor,
991 Optional<std::string&> reasonIfUnsupported) const
992{
993 ignore_unused(output);
994 ignore_unused(descriptor);
995 return IsSupportedForDataTypeRef(reasonIfUnsupported,
996 input.GetDataType(),
997 &TrueFunc<>,
998 &TrueFunc<>);
999}
1000
Derek Lamberti5f400d62019-03-25 15:41:58 +00001001bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1002 const TensorInfo& output,
1003 Optional<std::string&> reasonIfUnsupported) const
1004{
1005 bool supported = true;
1006
1007 // Define supported output types.
1008 std::array<DataType,2> supportedInputTypes = {
1009 DataType::Float32,
1010 };
1011
1012 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1013 "Reference quantize: input type not supported.");
1014
1015 // Define supported output types.
1016 std::array<DataType,2> supportedOutputTypes = {
1017 DataType::QuantisedAsymm8,
1018 DataType::QuantisedSymm16
1019 };
1020 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1021 "Reference quantize: output type not supported.");
1022
1023 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1024 "Reference quantize: input and output shapes have different num total elements.");
1025
1026 return supported;
1027}
1028
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001029bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001030 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001031 Optional<std::string&> reasonIfUnsupported) const
1032{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001033 ignore_unused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001034 // Define supported output types.
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001035 std::array<DataType,4> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001036 {
1037 DataType::Float32,
1038 DataType::Float16,
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001039 DataType::QuantisedAsymm8,
1040 DataType::QuantisedSymm16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001041 };
1042 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1043 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001044}
1045
1046bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001047 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001048 Optional<std::string&> reasonIfUnsupported) const
1049{
Sadik Armaganc625f002018-12-17 11:32:16 +00001050 ignore_unused(output);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001051 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1052 input.GetDataType(),
1053 &TrueFunc<>,
1054 &TrueFunc<>);
1055}
1056
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001057bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1058 const TensorInfo& output,
1059 Optional<std::string&> reasonIfUnsupported) const
1060{
1061 ignore_unused(output);
1062 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1063 input.GetDataType(),
1064 &TrueFunc<>,
1065 &FalseFuncU8<>);
1066}
1067
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001068bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1069 const TensorInfo& output,
1070 const SoftmaxDescriptor& descriptor,
1071 Optional<std::string&> reasonIfUnsupported) const
1072{
1073 ignore_unused(output);
nikraj01248683f2019-05-29 16:46:50 +01001074 bool supported = true;
1075 std::array<DataType,3> supportedTypes =
1076 {
1077 DataType::Float32,
1078 DataType::QuantisedAsymm8,
1079 DataType::QuantisedSymm16
1080 };
1081
1082 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1083 "Reference concatenation: output type not supported");
1084
1085 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1086 "Reference concatenation: input type not supported");
1087
1088 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1089 "Reference concatenation: input type not supported");
1090
1091 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001092}
1093
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001094bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1095 const TensorInfo& output,
1096 const SpaceToBatchNdDescriptor& descriptor,
1097 Optional<std::string&> reasonIfUnsupported) const
1098{
1099 ignore_unused(output);
nikraj01120522a2019-05-31 11:33:07 +01001100 bool supported = true;
1101 std::array<DataType,3> supportedTypes =
1102 {
1103 DataType::Float32,
1104 DataType::QuantisedAsymm8,
1105 DataType::QuantisedSymm16
1106 };
1107
1108 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1109 "Reference SpaceToBatchNd: input type not supported");
1110
1111 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1112 "Reference SpaceToBatchNd: output type not supported");
1113
1114 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1115 "Reference SpaceToBatchNd: input and output types are mismatched");
1116
1117 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001118}
1119
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001120bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1121 const ViewsDescriptor& descriptor,
1122 Optional<std::string&> reasonIfUnsupported) const
1123{
1124 ignore_unused(descriptor);
1125 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1126 input.GetDataType(),
1127 &TrueFunc<>,
1128 &TrueFunc<>);
1129}
1130
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001131bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1132 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1133 const ViewsDescriptor& descriptor,
1134 Optional<std::string&> reasonIfUnsupported) const
1135{
1136 ignore_unused(descriptor);
1137 ignore_unused(outputs);
1138 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1139 input.GetDataType(),
1140 &TrueFunc<>,
1141 &TrueFunc<>);
1142}
1143
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001144bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1145 const TensorInfo& output,
1146 const StridedSliceDescriptor& descriptor,
1147 Optional<std::string&> reasonIfUnsupported) const
1148{
1149 ignore_unused(output);
1150 ignore_unused(descriptor);
1151 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1152 input.GetDataType(),
1153 &TrueFunc<>,
1154 &TrueFunc<>);
1155}
1156
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001157bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1158 const TensorInfo& input1,
1159 const TensorInfo& output,
1160 Optional<std::string&> reasonIfUnsupported) const
1161{
Sadik Armagan2999a022019-04-09 14:20:12 +01001162 bool supported = true;
1163
1164 std::array<DataType,3> supportedTypes = {
1165 DataType::Float32,
1166 DataType::QuantisedAsymm8,
1167 DataType::QuantisedSymm16
1168 };
1169
1170 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1171 "Reference subtraction: input 0 is not a supported type.");
1172
1173 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1174 "Reference subtraction: input 1 is not a supported type.");
1175
1176 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1177 "Reference subtraction: output is not a supported type.");
1178
1179 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1180 "Reference subtraction: input 0 and Input 1 types are mismatched");
1181
1182 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1183 "Reference subtraction: input and output types are mismatched");
1184
1185 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1186 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1187
1188 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001189}
1190
arovir011c7c81b2018-10-08 11:34:28 +01001191} // namespace armnn