blob: d42404d25b71353b14f3291dc8bbe336d2ca2dc7 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01009#include <DataLayoutIndexed.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000010#include <InternalTypes.hpp>
11#include <LayerSupportCommon.hpp>
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +010012
telsoa014fcda012018-03-09 14:13:49 +000013#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000014#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000015
David Beck111b5d92018-11-12 14:59:37 +000016#include <backendsCommon/BackendRegistry.hpp>
Francis Murtagh46c09d02019-05-28 08:15:28 +010017#include <backendsCommon/test/WorkloadTestUtils.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010018
telsoa014fcda012018-03-09 14:13:49 +000019#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000020
Derek Lamberti50db4e82019-03-13 14:16:15 +000021#include <vector>
22#include <algorithm>
23#include <array>
24
telsoa014fcda012018-03-09 14:13:49 +000025using namespace boost;
26
27namespace armnn
28{
29
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010030namespace
31{
32
33template<typename Float32Func, typename Uint8Func, typename ... Params>
34bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
35 DataType dataType,
36 Float32Func floatFuncPtr,
37 Uint8Func uint8FuncPtr,
38 Params&&... params)
39{
40 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
41 dataType,
42 &FalseFunc<Params...>,
43 floatFuncPtr,
44 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000045 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000046 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010047 std::forward<Params>(params)...);
48}
49
50} // anonymous namespace
51
James Conroy4d1ff582019-06-10 17:06:39 +010052namespace
53{
54
55std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
56 unsigned int actual,
57 std::string& layerStr,
58 std::string& tensorName)
59{
60 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
61 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
62
63 return errorMsg;
64}
65
66} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000067
68namespace
69{
70template<typename F>
71bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
72{
73 bool supported = rule();
74 if (!supported && reason)
75 {
76 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
77 }
78 return supported;
79}
80
81struct Rule
82{
83 bool operator()() const
84 {
85 return m_Res;
86 }
87
88 bool m_Res = true;
89};
90
Derek Lamberti2a434a82019-03-20 13:07:57 +000091template<typename T>
92bool AllTypesAreEqualImpl(T t)
Derek Lamberti50db4e82019-03-13 14:16:15 +000093{
94 return true;
95}
96
97template<typename T, typename... Rest>
98bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
99{
100 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
101
Derek Lamberti2a434a82019-03-20 13:07:57 +0000102 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
Derek Lamberti50db4e82019-03-13 14:16:15 +0000103}
104
105struct TypesAreEqual : public Rule
106{
107 template<typename ... Ts>
108 TypesAreEqual(const Ts&... ts)
109 {
110 m_Res = AllTypesAreEqualImpl(ts...);
111 }
112};
113
114struct QuantizationParametersAreEqual : public Rule
115{
116 QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
117 {
118 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
119 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
120 }
121};
122
123struct TypeAnyOf : public Rule
124{
125 template<typename Container>
126 TypeAnyOf(const TensorInfo& info, const Container& c)
127 {
128 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
Francis Murtagh46c09d02019-05-28 08:15:28 +0100129 {
130 return dt == info.GetDataType();
131 });
132 }
133};
134
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100135struct TypeIs : public Rule
136{
137 TypeIs(const TensorInfo& info, DataType dt)
138 {
139 m_Res = dt == info.GetDataType();
140 }
141};
142
Francis Murtagh46c09d02019-05-28 08:15:28 +0100143struct BiasAndWeightsTypesMatch : public Rule
144{
145 BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
146 {
147 m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value();
148 }
149};
150
151struct BiasAndWeightsTypesCompatible : public Rule
152{
153 template<typename Container>
154 BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c)
155 {
156 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
157 {
158 return dt == GetBiasTypeFromWeightsType(info.GetDataType()).value();
159 });
Derek Lamberti50db4e82019-03-13 14:16:15 +0000160 }
161};
162
163struct ShapesAreSameRank : public Rule
164{
165 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
166 {
167 m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
168 }
169};
170
Derek Lamberti5f400d62019-03-25 15:41:58 +0000171struct ShapesAreSameTotalSize : public Rule
172{
173 ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
174 {
175 m_Res = info0.GetNumElements() == info1.GetNumElements();
176 }
177};
178
Derek Lamberti50db4e82019-03-13 14:16:15 +0000179struct ShapesAreBroadcastCompatible : public Rule
180{
181 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
182 {
183 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
184 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
185 return sizeIn;
186 }
187
188 ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
189 {
190 const TensorShape& shape0 = in0.GetShape();
191 const TensorShape& shape1 = in1.GetShape();
192 const TensorShape& outShape = out.GetShape();
193
194 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
195 {
196 unsigned int sizeOut = outShape[i];
197 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
198 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
199
200 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
201 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
202 }
203 }
204};
James Conroy4d1ff582019-06-10 17:06:39 +0100205
206struct TensorNumDimensionsAreCorrect : public Rule
207{
208 TensorNumDimensionsAreCorrect(const TensorInfo& info, unsigned int expectedNumDimensions)
209 {
210 m_Res = info.GetNumDimensions() == expectedNumDimensions;
211 }
212};
213
Derek Lamberti50db4e82019-03-13 14:16:15 +0000214} // namespace
215
216
arovir011c7c81b2018-10-08 11:34:28 +0100217bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
218 const TensorInfo& output,
219 const ActivationDescriptor& descriptor,
220 Optional<std::string&> reasonIfUnsupported) const
221{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000222 bool supported = true;
223
224 // Define supported types.
Teresa Charlin18515e22019-04-24 10:17:46 +0100225 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000226 DataType::Float32,
Teresa Charlin18515e22019-04-24 10:17:46 +0100227 DataType::QuantisedAsymm8,
228 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000229 };
230
231 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
232 "Reference activation: input type not supported.");
233
234 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
235 "Reference activation: output type not supported.");
236
237 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
238 "Reference activation: input and output types mismatched.");
239
240 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
241 "Reference activation: input and output shapes are of different rank.");
242
243
244 struct ActivationFunctionSupported : public Rule
245 {
246 ActivationFunctionSupported(const ActivationDescriptor& desc)
247 {
248 switch(desc.m_Function)
249 {
250 case ActivationFunction::Abs:
251 case ActivationFunction::BoundedReLu:
252 case ActivationFunction::LeakyReLu:
253 case ActivationFunction::Linear:
254 case ActivationFunction::ReLu:
255 case ActivationFunction::Sigmoid:
256 case ActivationFunction::SoftReLu:
257 case ActivationFunction::Sqrt:
258 case ActivationFunction::Square:
259 case ActivationFunction::TanH:
260 {
261 m_Res = true;
262 break;
263 }
264 default:
265 {
266 m_Res = false;
267 break;
268 }
269 }
270 }
271 };
272
273 // Function is supported
274 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
275 "Reference activation: function not supported.");
276
277 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100278}
279
280bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
281 const TensorInfo& input1,
282 const TensorInfo& output,
283 Optional<std::string&> reasonIfUnsupported) const
284{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000285 bool supported = true;
286
Sadik Armagan2999a022019-04-09 14:20:12 +0100287 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000288 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100289 DataType::QuantisedAsymm8,
290 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000291 };
292
293 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
294 "Reference addition: input 0 is not a supported type.");
295
296 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
297 "Reference addition: input 1 is not a supported type.");
298
299 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
300 "Reference addition: output is not a supported type.");
301
302 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
303 "Reference addition: input 0 and Input 1 types are mismatched");
304
305 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
306 "Reference addition: input and output types are mismatched");
307
308 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
309 "Reference addition: shapes are not suitable for implicit broadcast.");
310
311 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100312}
313
314bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
315 const TensorInfo& output,
316 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100317 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100318 const TensorInfo& beta,
319 const TensorInfo& gamma,
320 const BatchNormalizationDescriptor& descriptor,
321 Optional<std::string&> reasonIfUnsupported) const
322{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100323 ignore_unused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100324
Matteo Martincighf5507132019-06-04 10:59:47 +0100325 std::array<DataType, 3> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100326 {
327 DataType::Float32,
Matteo Martincighf5507132019-06-04 10:59:47 +0100328 DataType::QuantisedAsymm8,
329 DataType::QuantisedSymm16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100330 };
331
332 bool supported = true;
333
334 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
335 "Reference batch normalization: input is not a supported type.");
336
337 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
338 "Reference batch normalization: output is not a supported type.");
339
340 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
341 "Reference batch normalization: input and output types are mismatched");
342
343 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
344 "Reference batch normalization: mean is not a supported type.");
345
346 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
347 "Reference batch normalization: variance is not a supported type.");
348
349 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
350 "Reference batch normalization: beta is not a supported type.");
351
352 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
353 "Reference batch normalization: gamma is not a supported type.");
354
355 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100356}
357
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000358bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
359 const TensorInfo& output,
360 const BatchToSpaceNdDescriptor& descriptor,
361 Optional<std::string&> reasonIfUnsupported) const
362{
363 ignore_unused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100364
365 bool supported = true;
366
367 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
368 std::string inputTensorStr = "input";
369 std::string outputTensorStr = "output";
370
371 // Define supported types.
372 std::array<DataType,3> supportedTypes =
373 {
374 DataType::Float32,
375 DataType::QuantisedAsymm8,
376 DataType::QuantisedSymm16
377 };
378
379 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
380 "Reference BatchToSpaceNd: input type not supported.");
381
382 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
383 "Reference BatchToSpaceNd: output type not supported.");
384
385 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
386 "Reference BatchToSpaceNd: input and output types mismatched.");
387
388 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
389 reasonIfUnsupported,
390 CreateIncorrectDimensionsErrorMsg(4,
391 output.GetNumDimensions(),
392 batchToSpaceNdLayerStr,
393 outputTensorStr).data());
394
395 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
396 reasonIfUnsupported,
397 CreateIncorrectDimensionsErrorMsg(4,
398 input.GetNumDimensions(),
399 batchToSpaceNdLayerStr,
400 inputTensorStr).data());
401
402 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000403}
404
Jim Flynn906f9462019-05-10 13:55:21 +0100405bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
406 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100407 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100408 Optional<std::string&> reasonIfUnsupported) const
409{
Jim Flynne242f2d2019-05-22 14:24:13 +0100410 ignore_unused(descriptor);
411
412 bool supported = true;
413 std::array<DataType,3> supportedTypes =
414 {
415 DataType::Float32,
416 DataType::QuantisedAsymm8,
417 DataType::QuantisedSymm16
418 };
419
420 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
421 "Reference concatenation: output type not supported");
422 for (const TensorInfo* input : inputs)
423 {
Matthew Jackson81e601c2019-07-11 12:07:09 +0100424 BOOST_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100425 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
426 "Reference concatenation: input type not supported");
427
428 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
429 "Reference concatenation: input and output types mismatched.");
430 }
431
432 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100433}
434
arovir011c7c81b2018-10-08 11:34:28 +0100435bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
436 Optional<std::string&> reasonIfUnsupported) const
437{
Jim Flynne242f2d2019-05-22 14:24:13 +0100438 std::array<DataType,4> supportedTypes =
439 {
Nina Drozd58ef2c62019-05-16 12:09:18 +0100440 DataType::Float32,
441 DataType::Signed32,
442 DataType::QuantisedAsymm8,
443 DataType::QuantisedSymm16
444 };
445
446 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
447 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100448}
449
450bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
451 const TensorInfo& output,
452 Optional<std::string&> reasonIfUnsupported) const
453{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100454 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
455 input.GetDataType(),
456 &TrueFunc<>,
457 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000458 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000459 &FalseFuncI32<>,
460 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100461 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
462 output.GetDataType(),
463 &FalseOutputFuncF16<>,
464 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000465 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000466 &FalseFuncI32<>,
467 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100468}
469
470bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
471 const TensorInfo& output,
472 Optional<std::string&> reasonIfUnsupported) const
473{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100474 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
475 input.GetDataType(),
476 &FalseInputFuncF16<>,
477 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000478 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000479 &FalseFuncI32<>,
480 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100481 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
482 output.GetDataType(),
483 &TrueFunc<>,
484 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000485 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000486 &FalseFuncI32<>,
487 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100488}
489
490bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
491 const TensorInfo& output,
492 const Convolution2dDescriptor& descriptor,
493 const TensorInfo& weights,
494 const Optional<TensorInfo>& biases,
495 Optional<std::string&> reasonIfUnsupported) const
496{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100497 bool supported = true;
498
499 // Define supported types.
500 std::array<DataType,3> supportedTypes = {
501 DataType::Float32,
502 DataType::QuantisedAsymm8,
503 DataType::QuantisedSymm16
504 };
505
506 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100507 "Reference convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100508
509 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100510 "Reference convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100511
512 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100513 "Reference convolution2d: weights is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100514
515 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100516 "Reference convolution2d: input and output types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100517
518 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100519 "Reference convolution2d: input and weights types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100520
521 if (biases.has_value())
522 {
523 std::array<DataType,3> biasesSupportedTypes = {
524 DataType::Float32,
525 DataType::Signed32
526 };
527 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100528 "Reference convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100529 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100530 ignore_unused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100531
532 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100533}
534
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000535bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
536 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000537 Optional<std::string&> reasonIfUnsupported) const
538{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100539 bool supported = true;
540
541 std::array<DataType,3> supportedTypes =
542 {
543 DataType::Float32,
544 DataType::QuantisedAsymm8,
545 DataType::QuantisedSymm16
546 };
547
548 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
549 "Reference debug: input type not supported");
550
551 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
552 "Reference debug: output type not supported");
553
554 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
555 "Reference debug: input and output types are mismatched");
556
557 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000558}
559
arovir011c7c81b2018-10-08 11:34:28 +0100560bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
561 const TensorInfo& output,
562 const DepthwiseConvolution2dDescriptor& descriptor,
563 const TensorInfo& weights,
564 const Optional<TensorInfo>& biases,
565 Optional<std::string&> reasonIfUnsupported) const
566{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100567 bool supported = true;
568
569 // Define supported types.
570 std::array<DataType,3> supportedTypes =
571 {
572 DataType::Float32,
573 DataType::QuantisedAsymm8,
574 DataType::QuantisedSymm16
575 };
576
577 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
578 "Reference DepthwiseConvolution2d: input is not a supported type.");
579
580 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
581 "Reference DepthwiseConvolution2d: output is not a supported type.");
582
583 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
584 "Reference DepthwiseConvolution2d: weights is not a supported type.");
585
586 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
587 "Reference DepthwiseConvolution2d: input and output types mismatched.");
588
589 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
590 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
591
592 if (biases.has_value())
593 {
594 std::array<DataType,2> biasesSupportedTypes =
595 {
596 DataType::Float32,
597 DataType::Signed32
598 };
599 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
600 "Reference DepthwiseConvolution2d: biases is not a supported type.");
601 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100602 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100603
604 return supported;
605
arovir011c7c81b2018-10-08 11:34:28 +0100606}
607
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000608bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
609 const TensorInfo& output,
610 Optional<std::string&> reasonIfUnsupported) const
611{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100612 bool supported = true;
613
614 std::array<DataType,2> supportedInputTypes = {
615 DataType::QuantisedAsymm8,
616 DataType::QuantisedSymm16
617 };
618
619 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
620 "Reference dequantize: input type not supported.");
621
622 std::array<DataType,2> supportedOutputTypes = {
623 DataType::Float32,
624 };
625
626 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
627 "Reference dequantize: output type not supported.");
628
629 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
630 "Reference dequantize: input and output shapes have different num total elements.");
631
632 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000633}
634
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000635bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
636 const armnn::TensorInfo& input1,
637 const armnn::DetectionPostProcessDescriptor& descriptor,
638 armnn::Optional<std::string&> reasonIfUnsupported) const
639{
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100640 bool supported = true;
641
642 std::vector<DataType> supportedInputTypes =
643 {
644 DataType::Float32,
645 DataType::QuantisedAsymm8,
646 DataType::QuantisedSymm16
647 };
648
649 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
650 "Reference DetectionPostProcess: input 0 is not a supported type.");
651
652 supported &= CheckSupportRule(TypeAnyOf(input1, supportedInputTypes), reasonIfUnsupported,
653 "Reference DetectionPostProcess: input 1 is not a supported type.");
654
655 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000656}
657
Pablo Tellof0bd6832019-04-26 17:58:13 +0100658bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
659 const TensorInfo& output,
660 const DepthwiseConvolution2dDescriptor& descriptor,
661 const TensorInfo& weights,
662 const Optional<TensorInfo>& biases,
663 Optional<std::string&> reasonIfUnsupported) const
664{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100665 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100666}
667
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100668bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100669 const TensorInfo& input1,
670 const TensorInfo& output,
671 Optional<std::string&> reasonIfUnsupported) const
672{
Sadik Armagan2999a022019-04-09 14:20:12 +0100673 bool supported = true;
674
675 std::array<DataType,3> supportedTypes = {
676 DataType::Float32,
677 DataType::QuantisedAsymm8,
678 DataType::QuantisedSymm16
679 };
680
681 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
682 "Reference division: input 0 is not a supported type.");
683
684 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
685 "Reference division: input 1 is not a supported type.");
686
687 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
688 "Reference division: output is not a supported type.");
689
690 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
691 "Reference division: input 0 and Input 1 types are mismatched");
692
693 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
694 "Reference division: input and output types are mismatched");
695
696 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
697 "Reference division: shapes are not suitable for implicit broadcast.");
698
699 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100700}
701
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000702bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
703 const TensorInfo& input1,
704 const TensorInfo& output,
705 Optional<std::string&> reasonIfUnsupported) const
706{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100707 bool supported = true;
708
709 std::array<DataType,3> supportedTypes =
710 {
711 DataType::Float32,
712 DataType::QuantisedAsymm8,
713 DataType::QuantisedSymm16
714 };
715
716 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
717 "Reference equal: input 0 is not a supported type.");
718
719 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
720 "Reference equal: input 1 is not a supported type.");
721
722 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
723 "Reference equal: input 0 and Input 1 types are mismatched");
724
725 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
726 "Reference equal: shapes are not suitable for implicit broadcast.");
727
728 return supported;
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000729}
730
arovir011c7c81b2018-10-08 11:34:28 +0100731bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
732 const FakeQuantizationDescriptor& descriptor,
733 Optional<std::string&> reasonIfUnsupported) const
734{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100735 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100736 bool supported = true;
737
738 std::array<DataType,1> supportedTypes =
739 {
740 DataType::Float32
741 };
742
743 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
744 "Reference fake quantization: input type not supported.");
745
746 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100747}
748
749bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
750 const TensorInfo& output,
751 Optional<std::string&> reasonIfUnsupported) const
752{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100753 ignore_unused(output);
James Conroy83735b12019-05-30 16:36:59 +0100754 bool supported = true;
755
James Conroyb40d7102019-06-04 12:32:09 +0100756 std::array<DataType,2> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100757 {
James Conroyb40d7102019-06-04 12:32:09 +0100758 DataType::Float32,
759 DataType::QuantisedSymm16
James Conroy83735b12019-05-30 16:36:59 +0100760 };
761
762 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
763 "Reference Floor: input type not supported.");
764
765 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
766 "Reference Floor: output type not supported.");
767
768 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100769}
770
771bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
772 const TensorInfo& output,
773 const TensorInfo& weights,
774 const TensorInfo& biases,
775 const FullyConnectedDescriptor& descriptor,
776 Optional<std::string&> reasonIfUnsupported) const
777{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100778 bool supported = true;
779
780 // Define supported types.
781 std::array<DataType,3> supportedTypes =
782 {
783 DataType::Float32,
784 DataType::QuantisedAsymm8,
785 DataType::QuantisedSymm16
786 };
787
788 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
789 "Reference Fully Connected: input type not supported.");
790
791 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
792 "Reference Fully Connected: output type not supported.");
793
794 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
795 "Reference Fully Connected: input and output types mismatched.");
796
797 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
798 "Reference Fully Connected: weights type not supported.");
799
800 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
801 "Reference Fully Connected: input and weight types mismatched.");
802
803 if (descriptor.m_BiasEnabled)
804 {
805 // Defined supported types for bias
806 std::array<DataType, 2>
807 supportedBiasTypes =
808 {
809 DataType::Float32,
810 DataType::Signed32
811 };
812
813 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
814 "Reference Fully Connected: bias type not supported.");
815
816 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
817 "Reference Fully Connected: bias and weight types mismatch.");
818
819 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
820 "Reference Fully Connected: bias type inferred from weights is incompatible.");
821
822 }
823
824 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100825}
826
narpra014951d842019-01-18 16:53:53 +0000827bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
828 const armnn::TensorInfo& input1,
829 const armnn::TensorInfo& output,
830 armnn::Optional<std::string&> reasonIfUnsupported) const
831{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100832 bool supported = true;
833 std::array<DataType,3> supportedTypes =
834 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100835 DataType::Float32,
836 DataType::QuantisedAsymm8,
837 DataType::QuantisedSymm16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100838 };
839
840 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
841 "Reference Gather: input type not supported");
842
843 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
844 "Reference Gather: output type not supported");
845
846 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
847 "Reference Gather: indices (input1) type not supported");
848
849 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
850 "Reference Gather: input and output types not matching");
851
852 return supported;
narpra014951d842019-01-18 16:53:53 +0000853}
854
FrancisMurtagh878f0232018-12-19 10:56:15 +0000855bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
856 const TensorInfo& input1,
857 const TensorInfo& output,
858 Optional<std::string&> reasonIfUnsupported) const
859{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100860 bool supported = true;
861
862 std::array<DataType,3> supportedTypes =
863 {
864 DataType::Float32,
865 DataType::QuantisedAsymm8,
866 DataType::QuantisedSymm16
867 };
868
869 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
870 "Reference greater: input 0 is not a supported type.");
871
872 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
873 "Reference greater: input 1 is not a supported type.");
874
875 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
876 "Reference greater: input 0 and Input 1 types are mismatched");
877
878 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
879 "Reference greater: shapes are not suitable for implicit broadcast.");
880
881 return supported;
FrancisMurtagh878f0232018-12-19 10:56:15 +0000882}
883
arovir011c7c81b2018-10-08 11:34:28 +0100884bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
885 Optional<std::string&> reasonIfUnsupported) const
886{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +0100887 return true;
arovir011c7c81b2018-10-08 11:34:28 +0100888}
889
890bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
891 const TensorInfo& output,
892 const L2NormalizationDescriptor& descriptor,
893 Optional<std::string&> reasonIfUnsupported) const
894{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100895 ignore_unused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100896 // Define supported types
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100897 std::array<DataType, 3> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100898 {
899 DataType::Float32,
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100900 DataType::QuantisedAsymm8,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100901 DataType::QuantisedSymm16
902 };
903
904 bool supported = true;
905
906 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
907 "Reference L2normalization: input type not supported.");
908
909 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
910 "Reference L2normalization: output type not supported.");
911
912 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
913 "Reference L2normalization: input and output types mismatched.");
914
915 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
916 "Reference L2normalization: input and output shapes have different "
917 "num total elements.");
918
919 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100920}
921
922bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
923 const TensorInfo& outputStateIn,
924 const TensorInfo& cellStateIn,
925 const TensorInfo& scratchBuffer,
926 const TensorInfo& outputStateOut,
927 const TensorInfo& cellStateOut,
928 const TensorInfo& output,
929 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100930 const LstmInputParamsInfo& paramsInfo,
931 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +0100932{
telsoa01c577f2c2018-08-31 09:22:23 +0100933 ignore_unused(descriptor);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100934 ignore_unused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100935
936 bool supported = true;
937
938 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +0100939 DataType::Float32,
940 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100941 };
942
Jan Eilersd01a83c2019-07-03 18:20:40 +0100943 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100944 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
945 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100946 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
947 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100948 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
949 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100950 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
951 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100952 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
953 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100954 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
955 "Reference Lstm: input and cellStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100956 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
957 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +0100958 // check layer parameters
959 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToForgetWeights()), reasonIfUnsupported,
960 "Reference Lstm: input and InputToForgetWeights types are mismatched");
961 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToCellWeights()), reasonIfUnsupported,
962 "Reference Lstm: input and InputToCellWeights types are mismatched");
963 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToOutputWeights()), reasonIfUnsupported,
964 "Reference Lstm: input and InputToOutputWeights types are mismatched");
965 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToForgetWeights()), reasonIfUnsupported,
966 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
967 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToCellWeights()), reasonIfUnsupported,
968 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
969 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToOutputWeights()), reasonIfUnsupported,
970 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
971 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ForgetGateBias()), reasonIfUnsupported,
972 "Reference Lstm: input and ForgetGateBias types are mismatched");
973 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellBias()), reasonIfUnsupported,
974 "Reference Lstm: input and CellBias types are mismatched");
975 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_OutputGateBias()), reasonIfUnsupported,
976 "Reference Lstm: input and OutputGateBias types are mismatched");
977 if (!descriptor.m_CifgEnabled)
978 {
979 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToInputWeights()), reasonIfUnsupported,
980 "Reference Lstm: input and InputToInputWeights types are mismatched");
981 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToInputWeights()),
982 reasonIfUnsupported,
983 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
984 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputGateBias()), reasonIfUnsupported,
985 "Reference Lstm: input and InputGateBias types are mismatched");
986 if (descriptor.m_PeepholeEnabled)
987 {
988 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToInputWeights()),
989 reasonIfUnsupported,
990 "Reference Lstm: input and CellToInputWeights types are mismatched");
991 }
992 }
993 if (descriptor.m_PeepholeEnabled)
994 {
995 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToForgetWeights()), reasonIfUnsupported,
996 "Reference Lstm: input and CellToForgetWeights types are mismatched");
997 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToOutputWeights()), reasonIfUnsupported,
998 "Reference Lstm: input and CellToOutputWeights types are mismatched");
999 }
1000 if (descriptor.m_ProjectionEnabled)
1001 {
1002 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ProjectionWeights()), reasonIfUnsupported,
1003 "Reference Lstm: input and mProjectionWeights types are mismatched");
1004 if (paramsInfo.m_ProjectionBias != nullptr)
1005 {
1006 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ProjectionBias()), reasonIfUnsupported,
1007 "Reference Lstm: input and ProjectionBias types are mismatched");
1008 }
1009 }
1010 if (descriptor.m_LayerNormEnabled)
1011 {
1012 if (!descriptor.m_CifgEnabled)
1013 {
1014 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputLayerNormWeights()),
1015 reasonIfUnsupported,
1016 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1017 }
1018 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ForgetLayerNormWeights()),
1019 reasonIfUnsupported,
1020 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
1021 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellLayerNormWeights()),
1022 reasonIfUnsupported,
1023 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
1024 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_OutputLayerNormWeights()),
1025 reasonIfUnsupported,
1026 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1027 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001028
1029 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001030}
1031
saoste012df12b32018-11-28 16:57:20 +00001032bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1033 const TensorInfo& input1,
1034 const TensorInfo& output,
1035 Optional<std::string&> reasonIfUnsupported) const
1036{
Sadik Armagan2999a022019-04-09 14:20:12 +01001037 bool supported = true;
1038
Mike Kelly1da02362019-08-01 08:43:57 +01001039 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001040 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01001041 DataType::Signed32,
Sadik Armagan2999a022019-04-09 14:20:12 +01001042 DataType::QuantisedAsymm8,
1043 DataType::QuantisedSymm16
1044 };
1045
1046 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1047 "Reference maximum: input 0 is not a supported type.");
1048
1049 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1050 "Reference maximum: input 1 is not a supported type.");
1051
1052 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1053 "Reference maximum: output is not a supported type.");
1054
1055 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1056 "Reference maximum: input 0 and Input 1 types are mismatched");
1057
1058 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1059 "Reference maximum: input and output types are mismatched");
1060
1061 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1062 "Reference maximum: shapes are not suitable for implicit broadcast.");
1063
1064 return supported;
saoste012df12b32018-11-28 16:57:20 +00001065}
1066
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001067bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1068 const TensorInfo& output,
1069 const MeanDescriptor& descriptor,
1070 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001071{
James Conroy4d1ff582019-06-10 17:06:39 +01001072 bool supported = true;
1073 std::string meanLayerStr = "Mean";
1074 std::string outputTensorStr = "output";
1075
James Conroyb80775f2019-06-11 11:25:30 +01001076 std::array<DataType,3> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001077 {
1078 DataType::Float32,
James Conroyb80775f2019-06-11 11:25:30 +01001079 DataType::QuantisedAsymm8,
1080 DataType::QuantisedSymm16
James Conroy4d1ff582019-06-10 17:06:39 +01001081 };
1082
1083 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1084 "Reference Mean: input type not supported.");
1085
1086 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1087 "Reference Mean: input and output types are mismatched");
1088
1089 if (descriptor.m_KeepDims)
1090 {
1091 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1092 reasonIfUnsupported,
1093 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1094 output.GetNumDimensions(),
1095 meanLayerStr, outputTensorStr).data());
1096 }
1097 else if (descriptor.m_Axis.empty())
1098 {
1099 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1100 reasonIfUnsupported,
1101 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1102 meanLayerStr, outputTensorStr).data());
1103 }
1104 else
1105 {
1106 auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
1107
1108 if (outputDim > 0)
1109 {
1110 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1111 reasonIfUnsupported,
1112 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1113 meanLayerStr, outputTensorStr).data());
1114 }
1115 else
1116 {
1117 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1118 reasonIfUnsupported,
1119 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1120 meanLayerStr, outputTensorStr).data());
1121 }
1122 }
1123
1124 return supported;
narpra0132b90462018-09-13 11:07:48 +01001125}
1126
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001127bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +00001128 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +01001129 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001130 Optional<std::string&> reasonIfUnsupported) const
1131{
Jim Flynne242f2d2019-05-22 14:24:13 +01001132 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001133}
1134
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001135bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1136 const TensorInfo &output,
1137 Optional<std::string &> reasonIfUnsupported) const
1138{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001139 bool supported = true;
1140
1141 std::array<DataType,5> supportedTypes =
1142 {
1143 DataType::Float32,
1144 DataType::Float16,
1145 DataType::QuantisedAsymm8,
1146 DataType::QuantisedSymm16,
1147 DataType::Boolean
1148 };
1149
1150 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1151 "Reference MemCopy: input type not supported");
1152
1153 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1154 "Reference MemCopy: output type not supported");
1155
1156 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1157 "Reference MemCopy: input and output types are mismatched");
1158
1159 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001160}
1161
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001162bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1163 const TensorInfo& input1,
1164 const TensorInfo& output,
1165 Optional<std::string&> reasonIfUnsupported) const
1166{
Sadik Armagan2999a022019-04-09 14:20:12 +01001167 bool supported = true;
1168
Mike Kelly1da02362019-08-01 08:43:57 +01001169 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001170 DataType::Float32,
Mike Kelly1da02362019-08-01 08:43:57 +01001171 DataType::Signed32,
Sadik Armagan2999a022019-04-09 14:20:12 +01001172 DataType::QuantisedAsymm8,
1173 DataType::QuantisedSymm16
1174 };
1175
1176 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1177 "Reference minimum: input 0 is not a supported type.");
1178
1179 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1180 "Reference minimum: input 1 is not a supported type.");
1181
1182 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1183 "Reference minimum: output is not a supported type.");
1184
1185 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1186 "Reference minimum: input 0 and Input 1 types are mismatched");
1187
1188 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1189 "Reference minimum: input and output types are mismatched");
1190
1191 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1192 "Reference minimum: shapes are not suitable for implicit broadcast.");
1193
1194 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001195}
1196
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001197bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1198 const TensorInfo& input1,
1199 const TensorInfo& output,
1200 Optional<std::string&> reasonIfUnsupported) const
1201{
Sadik Armagan2999a022019-04-09 14:20:12 +01001202 bool supported = true;
1203
1204 std::array<DataType,3> supportedTypes = {
1205 DataType::Float32,
1206 DataType::QuantisedAsymm8,
1207 DataType::QuantisedSymm16
1208 };
1209
1210 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1211 "Reference multiplication: input 0 is not a supported type.");
1212
1213 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1214 "Reference multiplication: input 1 is not a supported type.");
1215
1216 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1217 "Reference multiplication: output is not a supported type.");
1218
1219 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1220 "Reference multiplication: input 0 and Input 1 types are mismatched");
1221
1222 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1223 "Reference multiplication: input and output types are mismatched");
1224
1225 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1226 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1227
1228 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001229}
1230
1231bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1232 const TensorInfo& output,
1233 const NormalizationDescriptor& descriptor,
1234 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001235{
Nina Drozd661dfa72018-10-02 11:14:17 +01001236 ignore_unused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001237
1238 // Define supported types
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001239 std::array<DataType, 4> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001240 {
1241 DataType::Float16,
1242 DataType::Float32,
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001243 DataType::QuantisedAsymm8,
1244 DataType::QuantisedSymm16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001245 };
1246
1247 bool supported = true;
1248
1249 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1250 "Reference normalization: input type not supported.");
1251
1252 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1253 "Reference normalization: output type not supported.");
1254
1255 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1256 "Reference normalization: input and output shapes have different "
1257 "num total elements.");
1258
1259 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001260}
1261
1262bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
1263 Optional<std::string&> reasonIfUnsupported) const
1264{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001265 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001266}
1267
1268bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1269 const TensorInfo& output,
1270 const PadDescriptor& descriptor,
1271 Optional<std::string&> reasonIfUnsupported) const
1272{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001273 ignore_unused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001274 bool supported = true;
1275
1276 // Define supported output and inputs types.
1277 std::array<DataType,3> supportedTypes =
1278 {
1279 DataType::Float32,
1280 DataType::QuantisedAsymm8,
1281 DataType::QuantisedSymm16
1282 };
1283
1284 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1285 "Reference pad: input is not a supported type.");
1286
1287 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1288 "Reference pad: output is not a supported type.");
1289
1290 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1291 "Reference pad: input and output types are mismatched.");
1292
1293 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01001294}
1295
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001296bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1297 const TensorInfo& output,
1298 const PermuteDescriptor& descriptor,
1299 Optional<std::string&> reasonIfUnsupported) const
1300{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001301 ignore_unused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001302 bool supported = true;
1303
1304 // Define supported output and inputs types.
1305 std::array<DataType,3> supportedTypes =
1306 {
1307 DataType::Float32,
1308 DataType::QuantisedAsymm8,
1309 DataType::QuantisedSymm16
1310 };
1311
1312 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1313 "Reference permute: input is not a supported type.");
1314
1315 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1316 "Reference permute: output is not a supported type.");
1317
1318 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1319 "Reference permute: input and output types are mismatched.");
1320
1321 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001322}
1323
1324bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1325 const TensorInfo& output,
1326 const Pooling2dDescriptor& descriptor,
1327 Optional<std::string&> reasonIfUnsupported) const
1328{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001329 ignore_unused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001330 bool supported = true;
1331
1332 // Define supported output and inputs types.
Teresa Charlin0434df62019-06-06 13:40:35 +01001333 std::array<DataType,3> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001334 {
1335 DataType::Float32,
Teresa Charlin0434df62019-06-06 13:40:35 +01001336 DataType::QuantisedAsymm8,
1337 DataType::QuantisedSymm16
Teresa Charlina3b20472019-06-06 11:12:32 +01001338 };
1339
1340 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1341 "Reference poolind2d: input is not a supported type.");
1342
1343 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1344 "Reference poolind2d: output is not a supported type.");
1345
1346 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1347 "Reference poolind2d: input and output types are mismatched.");
1348
1349 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001350}
1351
Derek Lamberti5f400d62019-03-25 15:41:58 +00001352bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1353 const TensorInfo& output,
1354 Optional<std::string&> reasonIfUnsupported) const
1355{
1356 bool supported = true;
1357
1358 // Define supported output types.
Sadik Armagan1a816302019-07-29 17:16:40 +01001359 std::array<DataType,1> supportedInputTypes = {
Derek Lamberti5f400d62019-03-25 15:41:58 +00001360 DataType::Float32,
1361 };
1362
1363 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1364 "Reference quantize: input type not supported.");
1365
1366 // Define supported output types.
1367 std::array<DataType,2> supportedOutputTypes = {
1368 DataType::QuantisedAsymm8,
1369 DataType::QuantisedSymm16
1370 };
1371 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1372 "Reference quantize: output type not supported.");
1373
1374 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1375 "Reference quantize: input and output shapes have different num total elements.");
1376
1377 return supported;
1378}
1379
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001380bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001381 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001382 Optional<std::string&> reasonIfUnsupported) const
1383{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001384 ignore_unused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001385 // Define supported output types.
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001386 std::array<DataType,4> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001387 {
1388 DataType::Float32,
1389 DataType::Float16,
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001390 DataType::QuantisedAsymm8,
1391 DataType::QuantisedSymm16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001392 };
1393 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1394 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001395}
1396
1397bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001398 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001399 Optional<std::string&> reasonIfUnsupported) const
1400{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001401 bool supported = true;
1402 std::array<DataType,3> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001403 {
1404 DataType::Float32,
1405 DataType::QuantisedAsymm8,
1406 DataType::QuantisedSymm16
1407 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001408
1409 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1410 "Reference ResizeBilinear: input type not supported");
1411
1412 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1413 "Reference ResizeBilinear: output type not supported");
1414
1415 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1416 "Reference ResizeBilinear: input and output types not matching");
1417
1418 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001419}
1420
Teresa Charlin970f43b2019-07-01 13:51:07 +01001421bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1422 const TensorInfo& output,
1423 const ResizeDescriptor& descriptor,
1424 Optional<std::string&> reasonIfUnsupported) const
1425{
1426 bool supported = true;
1427 std::array<DataType,3> supportedTypes =
1428 {
1429 DataType::Float32,
1430 DataType::QuantisedAsymm8,
1431 DataType::QuantisedSymm16
1432 };
1433
1434 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1435 "Reference Resize: input type not supported");
1436
1437 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1438 "Reference Resize: output type not supported");
1439
1440 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1441 "Reference Resize: input and output types not matching");
1442
1443 return supported;
1444}
1445
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001446bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1447 const TensorInfo& output,
1448 Optional<std::string&> reasonIfUnsupported) const
1449{
nikraj010421e7f2019-06-14 09:40:34 +01001450 bool supported = true;
nikraj0124d73212019-06-14 14:20:40 +01001451 std::array<DataType,3> supportedTypes =
nikraj010421e7f2019-06-14 09:40:34 +01001452 {
1453 DataType::Float32,
nikraj0124d73212019-06-14 14:20:40 +01001454 DataType::QuantisedAsymm8,
1455 DataType::QuantisedSymm16
nikraj010421e7f2019-06-14 09:40:34 +01001456 };
1457
1458 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1459 "Reference rsqrt: input type not supported");
1460
1461 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1462 "Reference rsqrt: output type not supported");
1463
1464 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1465 "Reference rsqrt: input and output types not matching");
1466
1467 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1468 "Reference Rsqrt: input and output shapes have different number of total elements");
1469
1470 return supported;
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001471}
1472
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001473bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1474 const TensorInfo& output,
1475 const SoftmaxDescriptor& descriptor,
1476 Optional<std::string&> reasonIfUnsupported) const
1477{
1478 ignore_unused(output);
nikraj01248683f2019-05-29 16:46:50 +01001479 bool supported = true;
1480 std::array<DataType,3> supportedTypes =
1481 {
1482 DataType::Float32,
1483 DataType::QuantisedAsymm8,
1484 DataType::QuantisedSymm16
1485 };
1486
1487 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1488 "Reference concatenation: output type not supported");
1489
1490 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1491 "Reference concatenation: input type not supported");
1492
1493 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1494 "Reference concatenation: input type not supported");
1495
1496 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001497}
1498
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001499bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1500 const TensorInfo& output,
1501 const SpaceToBatchNdDescriptor& descriptor,
1502 Optional<std::string&> reasonIfUnsupported) const
1503{
1504 ignore_unused(output);
nikraj01120522a2019-05-31 11:33:07 +01001505 bool supported = true;
1506 std::array<DataType,3> supportedTypes =
1507 {
1508 DataType::Float32,
1509 DataType::QuantisedAsymm8,
1510 DataType::QuantisedSymm16
1511 };
1512
1513 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1514 "Reference SpaceToBatchNd: input type not supported");
1515
1516 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1517 "Reference SpaceToBatchNd: output type not supported");
1518
1519 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1520 "Reference SpaceToBatchNd: input and output types are mismatched");
1521
1522 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001523}
1524
Keith Davisa57eccb2019-06-14 17:33:22 +01001525bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01001526 const TensorInfo& output,
1527 const SpaceToDepthDescriptor& descriptor,
1528 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01001529{
1530
1531 ignore_unused(descriptor);
1532 bool supported = true;
1533
James Conroyd2aa85e2019-07-01 17:12:40 +01001534 std::array<DataType,3> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01001535 {
1536 DataType::Float32,
1537 DataType::QuantisedAsymm8,
James Conroyd2aa85e2019-07-01 17:12:40 +01001538 DataType::QuantisedSymm16
Keith Davisa57eccb2019-06-14 17:33:22 +01001539 };
1540
1541 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1542 "Reference SpaceToDepth: input type not supported");
1543
1544 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1545 "Reference SpaceToDepth: output type not supported");
1546
1547 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1548 "Reference SpaceToDepth: input and output types are mismatched");
1549
1550 return supported;
1551}
1552
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001553bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1554 const ViewsDescriptor& descriptor,
1555 Optional<std::string&> reasonIfUnsupported) const
1556{
1557 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001558 bool supported = true;
1559 std::array<DataType,3> supportedTypes =
1560 {
1561 DataType::Float32,
1562 DataType::QuantisedAsymm8,
1563 DataType::QuantisedSymm16
1564 };
1565
1566 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1567 "Reference splitter: input type not supported");
1568
1569 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001570}
1571
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001572bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1573 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1574 const ViewsDescriptor& descriptor,
1575 Optional<std::string&> reasonIfUnsupported) const
1576{
1577 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001578 bool supported = true;
1579 std::array<DataType,3> supportedTypes =
1580 {
1581 DataType::Float32,
1582 DataType::QuantisedAsymm8,
1583 DataType::QuantisedSymm16
1584 };
1585
1586 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1587 "Reference splitter: output type not supported");
1588 for (const TensorInfo output : outputs)
1589 {
1590 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1591 "Reference splitter: input type not supported");
1592
1593 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1594 "Reference splitter: input and output types mismatched.");
1595 }
1596
1597 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001598}
1599
Matthew Jackson81e601c2019-07-11 12:07:09 +01001600bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
1601 const TensorInfo& output,
1602 const StackDescriptor& descriptor,
1603 Optional<std::string&> reasonIfUnsupported) const
1604{
1605 ignore_unused(descriptor);
1606
1607 bool supported = true;
1608 std::array<DataType,3> supportedTypes =
1609 {
1610 DataType::Float32,
1611 DataType::QuantisedAsymm8,
1612 DataType::QuantisedSymm16
1613 };
1614
1615 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1616 "Reference stack: output type not supported");
1617 for (const TensorInfo* input : inputs)
1618 {
1619 BOOST_ASSERT(input != nullptr);
1620 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
1621 "Reference stack: input type not supported");
1622
1623 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
1624 "Reference stack: input and output types mismatched.");
1625 }
1626
1627 return supported;
1628}
1629
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001630bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1631 const TensorInfo& output,
1632 const StridedSliceDescriptor& descriptor,
1633 Optional<std::string&> reasonIfUnsupported) const
1634{
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001635 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001636 bool supported = true;
1637
1638 std::array<DataType,3> supportedTypes =
1639 {
1640 DataType::Float32,
1641 DataType::QuantisedAsymm8,
1642 DataType::QuantisedSymm16
1643 };
1644
1645 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1646 "Reference StridedSlice: input type not supported");
1647
1648 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1649 "Reference StridedSlice: output type not supported");
1650
1651 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1652 "Reference StridedSlice: input and output types are mismatched");
1653
1654 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001655}
1656
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001657bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1658 const TensorInfo& input1,
1659 const TensorInfo& output,
1660 Optional<std::string&> reasonIfUnsupported) const
1661{
Sadik Armagan2999a022019-04-09 14:20:12 +01001662 bool supported = true;
1663
1664 std::array<DataType,3> supportedTypes = {
1665 DataType::Float32,
1666 DataType::QuantisedAsymm8,
1667 DataType::QuantisedSymm16
1668 };
1669
1670 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1671 "Reference subtraction: input 0 is not a supported type.");
1672
1673 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1674 "Reference subtraction: input 1 is not a supported type.");
1675
1676 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1677 "Reference subtraction: output is not a supported type.");
1678
1679 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1680 "Reference subtraction: input 0 and Input 1 types are mismatched");
1681
1682 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1683 "Reference subtraction: input and output types are mismatched");
1684
1685 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1686 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1687
1688 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001689}
1690
Matteo Martincighab9e5252019-06-13 17:27:46 +01001691bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
1692 const TensorInfo& alpha,
1693 const TensorInfo& output,
1694 Optional<std::string&> reasonIfUnsupported) const
1695{
1696 bool supported = true;
1697
1698 std::array<DataType, 3> supportedTypes
1699 {
1700 DataType::Float32,
1701 DataType::QuantisedAsymm8,
1702 DataType::QuantisedSymm16
1703 };
1704
1705 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1706 "PReLU: input is not a supported type.");
1707
1708 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
1709 "PReLU: alpha is not a supported type.");
1710
1711 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1712 "PReLU: output is not a supported type.");
1713
1714 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
1715 "PReLU: input, alpha and output types are mismatched");
1716
1717 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
1718 "PReLU: shapes are not suitable for implicit broadcast");
1719
1720 return supported;
1721}
1722
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001723bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
1724 const TensorInfo& output,
1725 const TransposeConvolution2dDescriptor& descriptor,
1726 const TensorInfo& weights,
1727 const Optional<TensorInfo>& biases,
1728 Optional<std::string&> reasonIfUnsupported) const
1729{
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001730 bool supported = true;
1731
1732 std::array<DataType,3> supportedTypes =
1733 {
1734 DataType::Float32,
1735 DataType::QuantisedAsymm8,
1736 DataType::QuantisedSymm16
1737 };
1738
1739 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1740 "Reference TransposeConvolution2d: input is not a supported type.");
1741
1742 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1743 "Reference TransposeConvolution2d: output is not a supported type.");
1744
1745 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1746 "Reference TransposeConvolution2d: weights is not a supported type.");
1747
1748 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1749 "Reference TransposeConvolution2d: input and output types mismatched.");
1750
1751 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1752 "Reference TransposeConvolution2d: input and weights types mismatched.");
1753
1754 if (biases.has_value())
1755 {
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01001756 std::array<DataType,3> biasesSupportedTypes =
1757 {
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001758 DataType::Float32,
1759 DataType::Signed32
1760 };
1761 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1762 "Reference TransposeConvolution2d: biases is not a supported type.");
1763 }
1764
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01001765 // NOTE: Temporary restriction; should be removed as soon as support for channel
1766 // multiplier different from 1 (input channels != output channels) has been added
1767 struct ChannelsAreEqual : public Rule
1768 {
1769 ChannelsAreEqual(const TensorInfo& input,
1770 const TensorInfo& output,
1771 const TransposeConvolution2dDescriptor& descriptor)
1772 {
1773 armnnUtils::DataLayoutIndexed dataLayoutIndexed(descriptor.m_DataLayout);
1774 const unsigned int channelsIndex = dataLayoutIndexed.GetChannelsIndex();
1775
1776 m_Res = (input.GetShape()[channelsIndex] == output.GetShape()[channelsIndex]);
1777 }
1778 };
1779
1780 supported &= CheckSupportRule(ChannelsAreEqual(input, output, descriptor), reasonIfUnsupported,
1781 "Reference TransposeConvolution2d: inputChannels != outputChannels");
1782
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001783 return supported;
1784}
1785
arovir011c7c81b2018-10-08 11:34:28 +01001786} // namespace armnn