blob: 59c14c44900cdb82016c68334677acc4c6cca337 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <InternalTypes.hpp>
10#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000012#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
David Beck111b5d92018-11-12 14:59:37 +000014#include <backendsCommon/BackendRegistry.hpp>
Francis Murtagh46c09d02019-05-28 08:15:28 +010015#include <backendsCommon/test/WorkloadTestUtils.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010016
telsoa014fcda012018-03-09 14:13:49 +000017#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
Derek Lamberti50db4e82019-03-13 14:16:15 +000019#include <vector>
20#include <algorithm>
21#include <array>
22
telsoa014fcda012018-03-09 14:13:49 +000023using namespace boost;
24
25namespace armnn
26{
27
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010028namespace
29{
30
31template<typename Float32Func, typename Uint8Func, typename ... Params>
32bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
33 DataType dataType,
34 Float32Func floatFuncPtr,
35 Uint8Func uint8FuncPtr,
36 Params&&... params)
37{
38 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
39 dataType,
40 &FalseFunc<Params...>,
41 floatFuncPtr,
42 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000043 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000044 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010045 std::forward<Params>(params)...);
46}
47
48} // anonymous namespace
49
James Conroy4d1ff582019-06-10 17:06:39 +010050namespace
51{
52
53std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
54 unsigned int actual,
55 std::string& layerStr,
56 std::string& tensorName)
57{
58 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
59 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
60
61 return errorMsg;
62}
63
64} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000065
66namespace
67{
68template<typename F>
69bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
70{
71 bool supported = rule();
72 if (!supported && reason)
73 {
74 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
75 }
76 return supported;
77}
78
79struct Rule
80{
81 bool operator()() const
82 {
83 return m_Res;
84 }
85
86 bool m_Res = true;
87};
88
Derek Lamberti2a434a82019-03-20 13:07:57 +000089template<typename T>
90bool AllTypesAreEqualImpl(T t)
Derek Lamberti50db4e82019-03-13 14:16:15 +000091{
92 return true;
93}
94
95template<typename T, typename... Rest>
96bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
97{
98 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
99
Derek Lamberti2a434a82019-03-20 13:07:57 +0000100 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
Derek Lamberti50db4e82019-03-13 14:16:15 +0000101}
102
103struct TypesAreEqual : public Rule
104{
105 template<typename ... Ts>
106 TypesAreEqual(const Ts&... ts)
107 {
108 m_Res = AllTypesAreEqualImpl(ts...);
109 }
110};
111
112struct QuantizationParametersAreEqual : public Rule
113{
114 QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
115 {
116 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
117 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
118 }
119};
120
121struct TypeAnyOf : public Rule
122{
123 template<typename Container>
124 TypeAnyOf(const TensorInfo& info, const Container& c)
125 {
126 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
Francis Murtagh46c09d02019-05-28 08:15:28 +0100127 {
128 return dt == info.GetDataType();
129 });
130 }
131};
132
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100133struct TypeIs : public Rule
134{
135 TypeIs(const TensorInfo& info, DataType dt)
136 {
137 m_Res = dt == info.GetDataType();
138 }
139};
140
Francis Murtagh46c09d02019-05-28 08:15:28 +0100141struct BiasAndWeightsTypesMatch : public Rule
142{
143 BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
144 {
145 m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value();
146 }
147};
148
149struct BiasAndWeightsTypesCompatible : public Rule
150{
151 template<typename Container>
152 BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c)
153 {
154 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
155 {
156 return dt == GetBiasTypeFromWeightsType(info.GetDataType()).value();
157 });
Derek Lamberti50db4e82019-03-13 14:16:15 +0000158 }
159};
160
161struct ShapesAreSameRank : public Rule
162{
163 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
164 {
165 m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
166 }
167};
168
Derek Lamberti5f400d62019-03-25 15:41:58 +0000169struct ShapesAreSameTotalSize : public Rule
170{
171 ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
172 {
173 m_Res = info0.GetNumElements() == info1.GetNumElements();
174 }
175};
176
Derek Lamberti50db4e82019-03-13 14:16:15 +0000177struct ShapesAreBroadcastCompatible : public Rule
178{
179 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
180 {
181 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
182 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
183 return sizeIn;
184 }
185
186 ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
187 {
188 const TensorShape& shape0 = in0.GetShape();
189 const TensorShape& shape1 = in1.GetShape();
190 const TensorShape& outShape = out.GetShape();
191
192 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
193 {
194 unsigned int sizeOut = outShape[i];
195 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
196 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
197
198 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
199 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
200 }
201 }
202};
James Conroy4d1ff582019-06-10 17:06:39 +0100203
204struct TensorNumDimensionsAreCorrect : public Rule
205{
206 TensorNumDimensionsAreCorrect(const TensorInfo& info, unsigned int expectedNumDimensions)
207 {
208 m_Res = info.GetNumDimensions() == expectedNumDimensions;
209 }
210};
211
Derek Lamberti50db4e82019-03-13 14:16:15 +0000212} // namespace
213
214
arovir011c7c81b2018-10-08 11:34:28 +0100215bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
216 const TensorInfo& output,
217 const ActivationDescriptor& descriptor,
218 Optional<std::string&> reasonIfUnsupported) const
219{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000220 bool supported = true;
221
222 // Define supported types.
Teresa Charlin18515e22019-04-24 10:17:46 +0100223 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000224 DataType::Float32,
Teresa Charlin18515e22019-04-24 10:17:46 +0100225 DataType::QuantisedAsymm8,
226 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000227 };
228
229 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
230 "Reference activation: input type not supported.");
231
232 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
233 "Reference activation: output type not supported.");
234
235 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
236 "Reference activation: input and output types mismatched.");
237
238 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
239 "Reference activation: input and output shapes are of different rank.");
240
241
242 struct ActivationFunctionSupported : public Rule
243 {
244 ActivationFunctionSupported(const ActivationDescriptor& desc)
245 {
246 switch(desc.m_Function)
247 {
248 case ActivationFunction::Abs:
249 case ActivationFunction::BoundedReLu:
250 case ActivationFunction::LeakyReLu:
251 case ActivationFunction::Linear:
252 case ActivationFunction::ReLu:
253 case ActivationFunction::Sigmoid:
254 case ActivationFunction::SoftReLu:
255 case ActivationFunction::Sqrt:
256 case ActivationFunction::Square:
257 case ActivationFunction::TanH:
258 {
259 m_Res = true;
260 break;
261 }
262 default:
263 {
264 m_Res = false;
265 break;
266 }
267 }
268 }
269 };
270
271 // Function is supported
272 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
273 "Reference activation: function not supported.");
274
275 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100276}
277
278bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
279 const TensorInfo& input1,
280 const TensorInfo& output,
281 Optional<std::string&> reasonIfUnsupported) const
282{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000283 bool supported = true;
284
Sadik Armagan2999a022019-04-09 14:20:12 +0100285 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000286 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100287 DataType::QuantisedAsymm8,
288 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000289 };
290
291 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
292 "Reference addition: input 0 is not a supported type.");
293
294 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
295 "Reference addition: input 1 is not a supported type.");
296
297 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
298 "Reference addition: output is not a supported type.");
299
300 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
301 "Reference addition: input 0 and Input 1 types are mismatched");
302
303 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
304 "Reference addition: input and output types are mismatched");
305
306 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
307 "Reference addition: shapes are not suitable for implicit broadcast.");
308
309 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100310}
311
312bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
313 const TensorInfo& output,
314 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100315 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100316 const TensorInfo& beta,
317 const TensorInfo& gamma,
318 const BatchNormalizationDescriptor& descriptor,
319 Optional<std::string&> reasonIfUnsupported) const
320{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100321 ignore_unused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100322
Matteo Martincighf5507132019-06-04 10:59:47 +0100323 std::array<DataType, 3> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100324 {
325 DataType::Float32,
Matteo Martincighf5507132019-06-04 10:59:47 +0100326 DataType::QuantisedAsymm8,
327 DataType::QuantisedSymm16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100328 };
329
330 bool supported = true;
331
332 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
333 "Reference batch normalization: input is not a supported type.");
334
335 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
336 "Reference batch normalization: output is not a supported type.");
337
338 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
339 "Reference batch normalization: input and output types are mismatched");
340
341 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
342 "Reference batch normalization: mean is not a supported type.");
343
344 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
345 "Reference batch normalization: variance is not a supported type.");
346
347 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
348 "Reference batch normalization: beta is not a supported type.");
349
350 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
351 "Reference batch normalization: gamma is not a supported type.");
352
353 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100354}
355
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000356bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
357 const TensorInfo& output,
358 const BatchToSpaceNdDescriptor& descriptor,
359 Optional<std::string&> reasonIfUnsupported) const
360{
361 ignore_unused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100362
363 bool supported = true;
364
365 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
366 std::string inputTensorStr = "input";
367 std::string outputTensorStr = "output";
368
369 // Define supported types.
370 std::array<DataType,3> supportedTypes =
371 {
372 DataType::Float32,
373 DataType::QuantisedAsymm8,
374 DataType::QuantisedSymm16
375 };
376
377 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
378 "Reference BatchToSpaceNd: input type not supported.");
379
380 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
381 "Reference BatchToSpaceNd: output type not supported.");
382
383 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
384 "Reference BatchToSpaceNd: input and output types mismatched.");
385
386 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
387 reasonIfUnsupported,
388 CreateIncorrectDimensionsErrorMsg(4,
389 output.GetNumDimensions(),
390 batchToSpaceNdLayerStr,
391 outputTensorStr).data());
392
393 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
394 reasonIfUnsupported,
395 CreateIncorrectDimensionsErrorMsg(4,
396 input.GetNumDimensions(),
397 batchToSpaceNdLayerStr,
398 inputTensorStr).data());
399
400 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000401}
402
Jim Flynn906f9462019-05-10 13:55:21 +0100403bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
404 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100405 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100406 Optional<std::string&> reasonIfUnsupported) const
407{
Jim Flynne242f2d2019-05-22 14:24:13 +0100408 ignore_unused(descriptor);
409
410 bool supported = true;
411 std::array<DataType,3> supportedTypes =
412 {
413 DataType::Float32,
414 DataType::QuantisedAsymm8,
415 DataType::QuantisedSymm16
416 };
417
418 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
419 "Reference concatenation: output type not supported");
420 for (const TensorInfo* input : inputs)
421 {
422 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
423 "Reference concatenation: input type not supported");
424
425 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
426 "Reference concatenation: input and output types mismatched.");
427 }
428
429 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100430}
431
arovir011c7c81b2018-10-08 11:34:28 +0100432bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
433 Optional<std::string&> reasonIfUnsupported) const
434{
Jim Flynne242f2d2019-05-22 14:24:13 +0100435 std::array<DataType,4> supportedTypes =
436 {
Nina Drozd58ef2c62019-05-16 12:09:18 +0100437 DataType::Float32,
438 DataType::Signed32,
439 DataType::QuantisedAsymm8,
440 DataType::QuantisedSymm16
441 };
442
443 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
444 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100445}
446
447bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
448 const TensorInfo& output,
449 Optional<std::string&> reasonIfUnsupported) const
450{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100451 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
452 input.GetDataType(),
453 &TrueFunc<>,
454 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000455 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000456 &FalseFuncI32<>,
457 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100458 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
459 output.GetDataType(),
460 &FalseOutputFuncF16<>,
461 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000462 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000463 &FalseFuncI32<>,
464 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100465}
466
467bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
468 const TensorInfo& output,
469 Optional<std::string&> reasonIfUnsupported) const
470{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100471 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
472 input.GetDataType(),
473 &FalseInputFuncF16<>,
474 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000475 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000476 &FalseFuncI32<>,
477 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100478 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
479 output.GetDataType(),
480 &TrueFunc<>,
481 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000482 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000483 &FalseFuncI32<>,
484 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100485}
486
487bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
488 const TensorInfo& output,
489 const Convolution2dDescriptor& descriptor,
490 const TensorInfo& weights,
491 const Optional<TensorInfo>& biases,
492 Optional<std::string&> reasonIfUnsupported) const
493{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100494 bool supported = true;
495
496 // Define supported types.
497 std::array<DataType,3> supportedTypes = {
498 DataType::Float32,
499 DataType::QuantisedAsymm8,
500 DataType::QuantisedSymm16
501 };
502
503 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100504 "Reference convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100505
506 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100507 "Reference convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100508
509 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100510 "Reference convolution2d: weights is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100511
512 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100513 "Reference convolution2d: input and output types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100514
515 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100516 "Reference convolution2d: input and weights types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100517
518 if (biases.has_value())
519 {
520 std::array<DataType,3> biasesSupportedTypes = {
521 DataType::Float32,
522 DataType::Signed32
523 };
524 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100525 "Reference convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100526 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100527 ignore_unused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100528
529 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100530}
531
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000532bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
533 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000534 Optional<std::string&> reasonIfUnsupported) const
535{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100536 bool supported = true;
537
538 std::array<DataType,3> supportedTypes =
539 {
540 DataType::Float32,
541 DataType::QuantisedAsymm8,
542 DataType::QuantisedSymm16
543 };
544
545 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
546 "Reference debug: input type not supported");
547
548 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
549 "Reference debug: output type not supported");
550
551 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
552 "Reference debug: input and output types are mismatched");
553
554 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000555}
556
arovir011c7c81b2018-10-08 11:34:28 +0100557bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
558 const TensorInfo& output,
559 const DepthwiseConvolution2dDescriptor& descriptor,
560 const TensorInfo& weights,
561 const Optional<TensorInfo>& biases,
562 Optional<std::string&> reasonIfUnsupported) const
563{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100564 bool supported = true;
565
566 // Define supported types.
567 std::array<DataType,3> supportedTypes =
568 {
569 DataType::Float32,
570 DataType::QuantisedAsymm8,
571 DataType::QuantisedSymm16
572 };
573
574 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
575 "Reference DepthwiseConvolution2d: input is not a supported type.");
576
577 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
578 "Reference DepthwiseConvolution2d: output is not a supported type.");
579
580 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
581 "Reference DepthwiseConvolution2d: weights is not a supported type.");
582
583 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
584 "Reference DepthwiseConvolution2d: input and output types mismatched.");
585
586 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
587 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
588
589 if (biases.has_value())
590 {
591 std::array<DataType,2> biasesSupportedTypes =
592 {
593 DataType::Float32,
594 DataType::Signed32
595 };
596 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
597 "Reference DepthwiseConvolution2d: biases is not a supported type.");
598 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100599 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100600
601 return supported;
602
arovir011c7c81b2018-10-08 11:34:28 +0100603}
604
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000605bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
606 const TensorInfo& output,
607 Optional<std::string&> reasonIfUnsupported) const
608{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100609 bool supported = true;
610
611 std::array<DataType,2> supportedInputTypes = {
612 DataType::QuantisedAsymm8,
613 DataType::QuantisedSymm16
614 };
615
616 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
617 "Reference dequantize: input type not supported.");
618
619 std::array<DataType,2> supportedOutputTypes = {
620 DataType::Float32,
621 };
622
623 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
624 "Reference dequantize: output type not supported.");
625
626 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
627 "Reference dequantize: input and output shapes have different num total elements.");
628
629 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000630}
631
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000632bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
633 const armnn::TensorInfo& input1,
634 const armnn::DetectionPostProcessDescriptor& descriptor,
635 armnn::Optional<std::string&> reasonIfUnsupported) const
636{
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100637 bool supported = true;
638
639 std::vector<DataType> supportedInputTypes =
640 {
641 DataType::Float32,
642 DataType::QuantisedAsymm8,
643 DataType::QuantisedSymm16
644 };
645
646 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
647 "Reference DetectionPostProcess: input 0 is not a supported type.");
648
649 supported &= CheckSupportRule(TypeAnyOf(input1, supportedInputTypes), reasonIfUnsupported,
650 "Reference DetectionPostProcess: input 1 is not a supported type.");
651
652 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000653}
654
Pablo Tellof0bd6832019-04-26 17:58:13 +0100655bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
656 const TensorInfo& output,
657 const DepthwiseConvolution2dDescriptor& descriptor,
658 const TensorInfo& weights,
659 const Optional<TensorInfo>& biases,
660 Optional<std::string&> reasonIfUnsupported) const
661{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100662 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100663}
664
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100665bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100666 const TensorInfo& input1,
667 const TensorInfo& output,
668 Optional<std::string&> reasonIfUnsupported) const
669{
Sadik Armagan2999a022019-04-09 14:20:12 +0100670 bool supported = true;
671
672 std::array<DataType,3> supportedTypes = {
673 DataType::Float32,
674 DataType::QuantisedAsymm8,
675 DataType::QuantisedSymm16
676 };
677
678 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
679 "Reference division: input 0 is not a supported type.");
680
681 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
682 "Reference division: input 1 is not a supported type.");
683
684 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
685 "Reference division: output is not a supported type.");
686
687 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
688 "Reference division: input 0 and Input 1 types are mismatched");
689
690 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
691 "Reference division: input and output types are mismatched");
692
693 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
694 "Reference division: shapes are not suitable for implicit broadcast.");
695
696 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100697}
698
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000699bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
700 const TensorInfo& input1,
701 const TensorInfo& output,
702 Optional<std::string&> reasonIfUnsupported) const
703{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100704 bool supported = true;
705
706 std::array<DataType,3> supportedTypes =
707 {
708 DataType::Float32,
709 DataType::QuantisedAsymm8,
710 DataType::QuantisedSymm16
711 };
712
713 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
714 "Reference equal: input 0 is not a supported type.");
715
716 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
717 "Reference equal: input 1 is not a supported type.");
718
719 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
720 "Reference equal: input 0 and Input 1 types are mismatched");
721
722 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
723 "Reference equal: shapes are not suitable for implicit broadcast.");
724
725 return supported;
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000726}
727
arovir011c7c81b2018-10-08 11:34:28 +0100728bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
729 const FakeQuantizationDescriptor& descriptor,
730 Optional<std::string&> reasonIfUnsupported) const
731{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100732 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100733 bool supported = true;
734
735 std::array<DataType,1> supportedTypes =
736 {
737 DataType::Float32
738 };
739
740 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
741 "Reference fake quantization: input type not supported.");
742
743 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100744}
745
746bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
747 const TensorInfo& output,
748 Optional<std::string&> reasonIfUnsupported) const
749{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100750 ignore_unused(output);
James Conroy83735b12019-05-30 16:36:59 +0100751 bool supported = true;
752
James Conroyb40d7102019-06-04 12:32:09 +0100753 std::array<DataType,2> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100754 {
James Conroyb40d7102019-06-04 12:32:09 +0100755 DataType::Float32,
756 DataType::QuantisedSymm16
James Conroy83735b12019-05-30 16:36:59 +0100757 };
758
759 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
760 "Reference Floor: input type not supported.");
761
762 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
763 "Reference Floor: output type not supported.");
764
765 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100766}
767
768bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
769 const TensorInfo& output,
770 const TensorInfo& weights,
771 const TensorInfo& biases,
772 const FullyConnectedDescriptor& descriptor,
773 Optional<std::string&> reasonIfUnsupported) const
774{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100775 bool supported = true;
776
777 // Define supported types.
778 std::array<DataType,3> supportedTypes =
779 {
780 DataType::Float32,
781 DataType::QuantisedAsymm8,
782 DataType::QuantisedSymm16
783 };
784
785 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
786 "Reference Fully Connected: input type not supported.");
787
788 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
789 "Reference Fully Connected: output type not supported.");
790
791 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
792 "Reference Fully Connected: input and output types mismatched.");
793
794 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
795 "Reference Fully Connected: weights type not supported.");
796
797 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
798 "Reference Fully Connected: input and weight types mismatched.");
799
800 if (descriptor.m_BiasEnabled)
801 {
802 // Defined supported types for bias
803 std::array<DataType, 2>
804 supportedBiasTypes =
805 {
806 DataType::Float32,
807 DataType::Signed32
808 };
809
810 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
811 "Reference Fully Connected: bias type not supported.");
812
813 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
814 "Reference Fully Connected: bias and weight types mismatch.");
815
816 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
817 "Reference Fully Connected: bias type inferred from weights is incompatible.");
818
819 }
820
821 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100822}
823
narpra014951d842019-01-18 16:53:53 +0000824bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
825 const armnn::TensorInfo& input1,
826 const armnn::TensorInfo& output,
827 armnn::Optional<std::string&> reasonIfUnsupported) const
828{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100829 bool supported = true;
830 std::array<DataType,3> supportedTypes =
831 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100832 DataType::Float32,
833 DataType::QuantisedAsymm8,
834 DataType::QuantisedSymm16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100835 };
836
837 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
838 "Reference Gather: input type not supported");
839
840 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
841 "Reference Gather: output type not supported");
842
843 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
844 "Reference Gather: indices (input1) type not supported");
845
846 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
847 "Reference Gather: input and output types not matching");
848
849 return supported;
narpra014951d842019-01-18 16:53:53 +0000850}
851
FrancisMurtagh878f0232018-12-19 10:56:15 +0000852bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
853 const TensorInfo& input1,
854 const TensorInfo& output,
855 Optional<std::string&> reasonIfUnsupported) const
856{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100857 bool supported = true;
858
859 std::array<DataType,3> supportedTypes =
860 {
861 DataType::Float32,
862 DataType::QuantisedAsymm8,
863 DataType::QuantisedSymm16
864 };
865
866 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
867 "Reference greater: input 0 is not a supported type.");
868
869 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
870 "Reference greater: input 1 is not a supported type.");
871
872 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
873 "Reference greater: input 0 and Input 1 types are mismatched");
874
875 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
876 "Reference greater: shapes are not suitable for implicit broadcast.");
877
878 return supported;
FrancisMurtagh878f0232018-12-19 10:56:15 +0000879}
880
arovir011c7c81b2018-10-08 11:34:28 +0100881bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
882 Optional<std::string&> reasonIfUnsupported) const
883{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +0100884 return true;
arovir011c7c81b2018-10-08 11:34:28 +0100885}
886
887bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
888 const TensorInfo& output,
889 const L2NormalizationDescriptor& descriptor,
890 Optional<std::string&> reasonIfUnsupported) const
891{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100892 ignore_unused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100893 // Define supported types
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100894 std::array<DataType, 3> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100895 {
896 DataType::Float32,
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100897 DataType::QuantisedAsymm8,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100898 DataType::QuantisedSymm16
899 };
900
901 bool supported = true;
902
903 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
904 "Reference L2normalization: input type not supported.");
905
906 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
907 "Reference L2normalization: output type not supported.");
908
909 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
910 "Reference L2normalization: input and output types mismatched.");
911
912 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
913 "Reference L2normalization: input and output shapes have different "
914 "num total elements.");
915
916 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100917}
918
919bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
920 const TensorInfo& outputStateIn,
921 const TensorInfo& cellStateIn,
922 const TensorInfo& scratchBuffer,
923 const TensorInfo& outputStateOut,
924 const TensorInfo& cellStateOut,
925 const TensorInfo& output,
926 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100927 const LstmInputParamsInfo& paramsInfo,
928 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +0100929{
telsoa01c577f2c2018-08-31 09:22:23 +0100930 ignore_unused(descriptor);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100931 ignore_unused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100932
933 bool supported = true;
934
935 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +0100936 DataType::Float32,
937 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100938 };
939
Jan Eilersd01a83c2019-07-03 18:20:40 +0100940 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100941 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
942 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100943 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
944 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100945 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
946 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100947 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
948 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100949 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
950 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100951 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
952 "Reference Lstm: input and cellStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100953 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
954 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +0100955 // check layer parameters
956 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToForgetWeights()), reasonIfUnsupported,
957 "Reference Lstm: input and InputToForgetWeights types are mismatched");
958 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToCellWeights()), reasonIfUnsupported,
959 "Reference Lstm: input and InputToCellWeights types are mismatched");
960 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToOutputWeights()), reasonIfUnsupported,
961 "Reference Lstm: input and InputToOutputWeights types are mismatched");
962 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToForgetWeights()), reasonIfUnsupported,
963 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
964 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToCellWeights()), reasonIfUnsupported,
965 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
966 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToOutputWeights()), reasonIfUnsupported,
967 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
968 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ForgetGateBias()), reasonIfUnsupported,
969 "Reference Lstm: input and ForgetGateBias types are mismatched");
970 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellBias()), reasonIfUnsupported,
971 "Reference Lstm: input and CellBias types are mismatched");
972 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_OutputGateBias()), reasonIfUnsupported,
973 "Reference Lstm: input and OutputGateBias types are mismatched");
974 if (!descriptor.m_CifgEnabled)
975 {
976 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToInputWeights()), reasonIfUnsupported,
977 "Reference Lstm: input and InputToInputWeights types are mismatched");
978 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToInputWeights()),
979 reasonIfUnsupported,
980 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
981 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputGateBias()), reasonIfUnsupported,
982 "Reference Lstm: input and InputGateBias types are mismatched");
983 if (descriptor.m_PeepholeEnabled)
984 {
985 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToInputWeights()),
986 reasonIfUnsupported,
987 "Reference Lstm: input and CellToInputWeights types are mismatched");
988 }
989 }
990 if (descriptor.m_PeepholeEnabled)
991 {
992 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToForgetWeights()), reasonIfUnsupported,
993 "Reference Lstm: input and CellToForgetWeights types are mismatched");
994 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToOutputWeights()), reasonIfUnsupported,
995 "Reference Lstm: input and CellToOutputWeights types are mismatched");
996 }
997 if (descriptor.m_ProjectionEnabled)
998 {
999 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ProjectionWeights()), reasonIfUnsupported,
1000 "Reference Lstm: input and mProjectionWeights types are mismatched");
1001 if (paramsInfo.m_ProjectionBias != nullptr)
1002 {
1003 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ProjectionBias()), reasonIfUnsupported,
1004 "Reference Lstm: input and ProjectionBias types are mismatched");
1005 }
1006 }
1007 if (descriptor.m_LayerNormEnabled)
1008 {
1009 if (!descriptor.m_CifgEnabled)
1010 {
1011 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputLayerNormWeights()),
1012 reasonIfUnsupported,
1013 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1014 }
1015 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ForgetLayerNormWeights()),
1016 reasonIfUnsupported,
1017 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
1018 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellLayerNormWeights()),
1019 reasonIfUnsupported,
1020 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
1021 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_OutputLayerNormWeights()),
1022 reasonIfUnsupported,
1023 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1024 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001025
1026 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001027}
1028
saoste012df12b32018-11-28 16:57:20 +00001029bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1030 const TensorInfo& input1,
1031 const TensorInfo& output,
1032 Optional<std::string&> reasonIfUnsupported) const
1033{
Sadik Armagan2999a022019-04-09 14:20:12 +01001034 bool supported = true;
1035
1036 std::array<DataType,3> supportedTypes = {
1037 DataType::Float32,
1038 DataType::QuantisedAsymm8,
1039 DataType::QuantisedSymm16
1040 };
1041
1042 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1043 "Reference maximum: input 0 is not a supported type.");
1044
1045 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1046 "Reference maximum: input 1 is not a supported type.");
1047
1048 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1049 "Reference maximum: output is not a supported type.");
1050
1051 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1052 "Reference maximum: input 0 and Input 1 types are mismatched");
1053
1054 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1055 "Reference maximum: input and output types are mismatched");
1056
1057 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1058 "Reference maximum: shapes are not suitable for implicit broadcast.");
1059
1060 return supported;
saoste012df12b32018-11-28 16:57:20 +00001061}
1062
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001063bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1064 const TensorInfo& output,
1065 const MeanDescriptor& descriptor,
1066 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001067{
James Conroy4d1ff582019-06-10 17:06:39 +01001068 bool supported = true;
1069 std::string meanLayerStr = "Mean";
1070 std::string outputTensorStr = "output";
1071
James Conroyb80775f2019-06-11 11:25:30 +01001072 std::array<DataType,3> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001073 {
1074 DataType::Float32,
James Conroyb80775f2019-06-11 11:25:30 +01001075 DataType::QuantisedAsymm8,
1076 DataType::QuantisedSymm16
James Conroy4d1ff582019-06-10 17:06:39 +01001077 };
1078
1079 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1080 "Reference Mean: input type not supported.");
1081
1082 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1083 "Reference Mean: input and output types are mismatched");
1084
1085 if (descriptor.m_KeepDims)
1086 {
1087 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1088 reasonIfUnsupported,
1089 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1090 output.GetNumDimensions(),
1091 meanLayerStr, outputTensorStr).data());
1092 }
1093 else if (descriptor.m_Axis.empty())
1094 {
1095 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1096 reasonIfUnsupported,
1097 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1098 meanLayerStr, outputTensorStr).data());
1099 }
1100 else
1101 {
1102 auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
1103
1104 if (outputDim > 0)
1105 {
1106 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1107 reasonIfUnsupported,
1108 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1109 meanLayerStr, outputTensorStr).data());
1110 }
1111 else
1112 {
1113 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1114 reasonIfUnsupported,
1115 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1116 meanLayerStr, outputTensorStr).data());
1117 }
1118 }
1119
1120 return supported;
narpra0132b90462018-09-13 11:07:48 +01001121}
1122
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001123bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +00001124 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +01001125 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001126 Optional<std::string&> reasonIfUnsupported) const
1127{
Jim Flynne242f2d2019-05-22 14:24:13 +01001128 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001129}
1130
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001131bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1132 const TensorInfo &output,
1133 Optional<std::string &> reasonIfUnsupported) const
1134{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001135 bool supported = true;
1136
1137 std::array<DataType,5> supportedTypes =
1138 {
1139 DataType::Float32,
1140 DataType::Float16,
1141 DataType::QuantisedAsymm8,
1142 DataType::QuantisedSymm16,
1143 DataType::Boolean
1144 };
1145
1146 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1147 "Reference MemCopy: input type not supported");
1148
1149 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1150 "Reference MemCopy: output type not supported");
1151
1152 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1153 "Reference MemCopy: input and output types are mismatched");
1154
1155 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001156}
1157
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001158bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1159 const TensorInfo& input1,
1160 const TensorInfo& output,
1161 Optional<std::string&> reasonIfUnsupported) const
1162{
Sadik Armagan2999a022019-04-09 14:20:12 +01001163 bool supported = true;
1164
1165 std::array<DataType,3> supportedTypes = {
1166 DataType::Float32,
1167 DataType::QuantisedAsymm8,
1168 DataType::QuantisedSymm16
1169 };
1170
1171 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1172 "Reference minimum: input 0 is not a supported type.");
1173
1174 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1175 "Reference minimum: input 1 is not a supported type.");
1176
1177 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1178 "Reference minimum: output is not a supported type.");
1179
1180 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1181 "Reference minimum: input 0 and Input 1 types are mismatched");
1182
1183 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1184 "Reference minimum: input and output types are mismatched");
1185
1186 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1187 "Reference minimum: shapes are not suitable for implicit broadcast.");
1188
1189 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001190}
1191
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001192bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1193 const TensorInfo& input1,
1194 const TensorInfo& output,
1195 Optional<std::string&> reasonIfUnsupported) const
1196{
Sadik Armagan2999a022019-04-09 14:20:12 +01001197 bool supported = true;
1198
1199 std::array<DataType,3> supportedTypes = {
1200 DataType::Float32,
1201 DataType::QuantisedAsymm8,
1202 DataType::QuantisedSymm16
1203 };
1204
1205 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1206 "Reference multiplication: input 0 is not a supported type.");
1207
1208 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1209 "Reference multiplication: input 1 is not a supported type.");
1210
1211 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1212 "Reference multiplication: output is not a supported type.");
1213
1214 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1215 "Reference multiplication: input 0 and Input 1 types are mismatched");
1216
1217 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1218 "Reference multiplication: input and output types are mismatched");
1219
1220 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1221 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1222
1223 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001224}
1225
1226bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1227 const TensorInfo& output,
1228 const NormalizationDescriptor& descriptor,
1229 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001230{
Nina Drozd661dfa72018-10-02 11:14:17 +01001231 ignore_unused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001232
1233 // Define supported types
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001234 std::array<DataType, 4> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001235 {
1236 DataType::Float16,
1237 DataType::Float32,
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001238 DataType::QuantisedAsymm8,
1239 DataType::QuantisedSymm16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001240 };
1241
1242 bool supported = true;
1243
1244 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1245 "Reference normalization: input type not supported.");
1246
1247 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1248 "Reference normalization: output type not supported.");
1249
1250 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1251 "Reference normalization: input and output shapes have different "
1252 "num total elements.");
1253
1254 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001255}
1256
1257bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
1258 Optional<std::string&> reasonIfUnsupported) const
1259{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001260 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001261}
1262
1263bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1264 const TensorInfo& output,
1265 const PadDescriptor& descriptor,
1266 Optional<std::string&> reasonIfUnsupported) const
1267{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001268 ignore_unused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001269 bool supported = true;
1270
1271 // Define supported output and inputs types.
1272 std::array<DataType,3> supportedTypes =
1273 {
1274 DataType::Float32,
1275 DataType::QuantisedAsymm8,
1276 DataType::QuantisedSymm16
1277 };
1278
1279 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1280 "Reference pad: input is not a supported type.");
1281
1282 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1283 "Reference pad: output is not a supported type.");
1284
1285 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1286 "Reference pad: input and output types are mismatched.");
1287
1288 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01001289}
1290
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001291bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1292 const TensorInfo& output,
1293 const PermuteDescriptor& descriptor,
1294 Optional<std::string&> reasonIfUnsupported) const
1295{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001296 ignore_unused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001297 bool supported = true;
1298
1299 // Define supported output and inputs types.
1300 std::array<DataType,3> supportedTypes =
1301 {
1302 DataType::Float32,
1303 DataType::QuantisedAsymm8,
1304 DataType::QuantisedSymm16
1305 };
1306
1307 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1308 "Reference permute: input is not a supported type.");
1309
1310 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1311 "Reference permute: output is not a supported type.");
1312
1313 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1314 "Reference permute: input and output types are mismatched.");
1315
1316 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001317}
1318
1319bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1320 const TensorInfo& output,
1321 const Pooling2dDescriptor& descriptor,
1322 Optional<std::string&> reasonIfUnsupported) const
1323{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001324 ignore_unused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001325 bool supported = true;
1326
1327 // Define supported output and inputs types.
Teresa Charlin0434df62019-06-06 13:40:35 +01001328 std::array<DataType,3> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001329 {
1330 DataType::Float32,
Teresa Charlin0434df62019-06-06 13:40:35 +01001331 DataType::QuantisedAsymm8,
1332 DataType::QuantisedSymm16
Teresa Charlina3b20472019-06-06 11:12:32 +01001333 };
1334
1335 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1336 "Reference poolind2d: input is not a supported type.");
1337
1338 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1339 "Reference poolind2d: output is not a supported type.");
1340
1341 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1342 "Reference poolind2d: input and output types are mismatched.");
1343
1344 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001345}
1346
Derek Lamberti5f400d62019-03-25 15:41:58 +00001347bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1348 const TensorInfo& output,
1349 Optional<std::string&> reasonIfUnsupported) const
1350{
1351 bool supported = true;
1352
1353 // Define supported output types.
1354 std::array<DataType,2> supportedInputTypes = {
1355 DataType::Float32,
1356 };
1357
1358 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1359 "Reference quantize: input type not supported.");
1360
1361 // Define supported output types.
1362 std::array<DataType,2> supportedOutputTypes = {
1363 DataType::QuantisedAsymm8,
1364 DataType::QuantisedSymm16
1365 };
1366 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1367 "Reference quantize: output type not supported.");
1368
1369 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1370 "Reference quantize: input and output shapes have different num total elements.");
1371
1372 return supported;
1373}
1374
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001375bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001376 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001377 Optional<std::string&> reasonIfUnsupported) const
1378{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001379 ignore_unused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001380 // Define supported output types.
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001381 std::array<DataType,4> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001382 {
1383 DataType::Float32,
1384 DataType::Float16,
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001385 DataType::QuantisedAsymm8,
1386 DataType::QuantisedSymm16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001387 };
1388 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1389 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001390}
1391
1392bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001393 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001394 Optional<std::string&> reasonIfUnsupported) const
1395{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001396 bool supported = true;
1397 std::array<DataType,3> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001398 {
1399 DataType::Float32,
1400 DataType::QuantisedAsymm8,
1401 DataType::QuantisedSymm16
1402 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001403
1404 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1405 "Reference ResizeBilinear: input type not supported");
1406
1407 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1408 "Reference ResizeBilinear: output type not supported");
1409
1410 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1411 "Reference ResizeBilinear: input and output types not matching");
1412
1413 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001414}
1415
Teresa Charlin970f43b2019-07-01 13:51:07 +01001416bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1417 const TensorInfo& output,
1418 const ResizeDescriptor& descriptor,
1419 Optional<std::string&> reasonIfUnsupported) const
1420{
1421 bool supported = true;
1422 std::array<DataType,3> supportedTypes =
1423 {
1424 DataType::Float32,
1425 DataType::QuantisedAsymm8,
1426 DataType::QuantisedSymm16
1427 };
1428
1429 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1430 "Reference Resize: input type not supported");
1431
1432 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1433 "Reference Resize: output type not supported");
1434
1435 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1436 "Reference Resize: input and output types not matching");
1437
1438 return supported;
1439}
1440
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001441bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1442 const TensorInfo& output,
1443 Optional<std::string&> reasonIfUnsupported) const
1444{
nikraj010421e7f2019-06-14 09:40:34 +01001445 bool supported = true;
nikraj0124d73212019-06-14 14:20:40 +01001446 std::array<DataType,3> supportedTypes =
nikraj010421e7f2019-06-14 09:40:34 +01001447 {
1448 DataType::Float32,
nikraj0124d73212019-06-14 14:20:40 +01001449 DataType::QuantisedAsymm8,
1450 DataType::QuantisedSymm16
nikraj010421e7f2019-06-14 09:40:34 +01001451 };
1452
1453 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1454 "Reference rsqrt: input type not supported");
1455
1456 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1457 "Reference rsqrt: output type not supported");
1458
1459 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1460 "Reference rsqrt: input and output types not matching");
1461
1462 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1463 "Reference Rsqrt: input and output shapes have different number of total elements");
1464
1465 return supported;
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001466}
1467
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001468bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1469 const TensorInfo& output,
1470 const SoftmaxDescriptor& descriptor,
1471 Optional<std::string&> reasonIfUnsupported) const
1472{
1473 ignore_unused(output);
nikraj01248683f2019-05-29 16:46:50 +01001474 bool supported = true;
1475 std::array<DataType,3> supportedTypes =
1476 {
1477 DataType::Float32,
1478 DataType::QuantisedAsymm8,
1479 DataType::QuantisedSymm16
1480 };
1481
1482 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1483 "Reference concatenation: output type not supported");
1484
1485 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1486 "Reference concatenation: input type not supported");
1487
1488 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1489 "Reference concatenation: input type not supported");
1490
1491 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001492}
1493
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001494bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1495 const TensorInfo& output,
1496 const SpaceToBatchNdDescriptor& descriptor,
1497 Optional<std::string&> reasonIfUnsupported) const
1498{
1499 ignore_unused(output);
nikraj01120522a2019-05-31 11:33:07 +01001500 bool supported = true;
1501 std::array<DataType,3> supportedTypes =
1502 {
1503 DataType::Float32,
1504 DataType::QuantisedAsymm8,
1505 DataType::QuantisedSymm16
1506 };
1507
1508 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1509 "Reference SpaceToBatchNd: input type not supported");
1510
1511 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1512 "Reference SpaceToBatchNd: output type not supported");
1513
1514 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1515 "Reference SpaceToBatchNd: input and output types are mismatched");
1516
1517 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001518}
1519
Keith Davisa57eccb2019-06-14 17:33:22 +01001520bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01001521 const TensorInfo& output,
1522 const SpaceToDepthDescriptor& descriptor,
1523 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01001524{
1525
1526 ignore_unused(descriptor);
1527 bool supported = true;
1528
James Conroyd2aa85e2019-07-01 17:12:40 +01001529 std::array<DataType,3> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01001530 {
1531 DataType::Float32,
1532 DataType::QuantisedAsymm8,
James Conroyd2aa85e2019-07-01 17:12:40 +01001533 DataType::QuantisedSymm16
Keith Davisa57eccb2019-06-14 17:33:22 +01001534 };
1535
1536 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1537 "Reference SpaceToDepth: input type not supported");
1538
1539 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1540 "Reference SpaceToDepth: output type not supported");
1541
1542 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1543 "Reference SpaceToDepth: input and output types are mismatched");
1544
1545 return supported;
1546}
1547
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001548bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1549 const ViewsDescriptor& descriptor,
1550 Optional<std::string&> reasonIfUnsupported) const
1551{
1552 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001553 bool supported = true;
1554 std::array<DataType,3> supportedTypes =
1555 {
1556 DataType::Float32,
1557 DataType::QuantisedAsymm8,
1558 DataType::QuantisedSymm16
1559 };
1560
1561 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1562 "Reference splitter: input type not supported");
1563
1564 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001565}
1566
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001567bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1568 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1569 const ViewsDescriptor& descriptor,
1570 Optional<std::string&> reasonIfUnsupported) const
1571{
1572 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001573 bool supported = true;
1574 std::array<DataType,3> supportedTypes =
1575 {
1576 DataType::Float32,
1577 DataType::QuantisedAsymm8,
1578 DataType::QuantisedSymm16
1579 };
1580
1581 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1582 "Reference splitter: output type not supported");
1583 for (const TensorInfo output : outputs)
1584 {
1585 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1586 "Reference splitter: input type not supported");
1587
1588 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1589 "Reference splitter: input and output types mismatched.");
1590 }
1591
1592 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001593}
1594
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001595bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1596 const TensorInfo& output,
1597 const StridedSliceDescriptor& descriptor,
1598 Optional<std::string&> reasonIfUnsupported) const
1599{
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001600 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001601 bool supported = true;
1602
1603 std::array<DataType,3> supportedTypes =
1604 {
1605 DataType::Float32,
1606 DataType::QuantisedAsymm8,
1607 DataType::QuantisedSymm16
1608 };
1609
1610 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1611 "Reference StridedSlice: input type not supported");
1612
1613 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1614 "Reference StridedSlice: output type not supported");
1615
1616 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1617 "Reference StridedSlice: input and output types are mismatched");
1618
1619 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001620}
1621
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001622bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1623 const TensorInfo& input1,
1624 const TensorInfo& output,
1625 Optional<std::string&> reasonIfUnsupported) const
1626{
Sadik Armagan2999a022019-04-09 14:20:12 +01001627 bool supported = true;
1628
1629 std::array<DataType,3> supportedTypes = {
1630 DataType::Float32,
1631 DataType::QuantisedAsymm8,
1632 DataType::QuantisedSymm16
1633 };
1634
1635 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1636 "Reference subtraction: input 0 is not a supported type.");
1637
1638 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1639 "Reference subtraction: input 1 is not a supported type.");
1640
1641 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1642 "Reference subtraction: output is not a supported type.");
1643
1644 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1645 "Reference subtraction: input 0 and Input 1 types are mismatched");
1646
1647 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1648 "Reference subtraction: input and output types are mismatched");
1649
1650 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1651 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1652
1653 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001654}
1655
Matteo Martincighab9e5252019-06-13 17:27:46 +01001656bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
1657 const TensorInfo& alpha,
1658 const TensorInfo& output,
1659 Optional<std::string&> reasonIfUnsupported) const
1660{
1661 bool supported = true;
1662
1663 std::array<DataType, 3> supportedTypes
1664 {
1665 DataType::Float32,
1666 DataType::QuantisedAsymm8,
1667 DataType::QuantisedSymm16
1668 };
1669
1670 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1671 "PReLU: input is not a supported type.");
1672
1673 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
1674 "PReLU: alpha is not a supported type.");
1675
1676 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1677 "PReLU: output is not a supported type.");
1678
1679 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
1680 "PReLU: input, alpha and output types are mismatched");
1681
1682 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
1683 "PReLU: shapes are not suitable for implicit broadcast");
1684
1685 return supported;
1686}
1687
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001688bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
1689 const TensorInfo& output,
1690 const TransposeConvolution2dDescriptor& descriptor,
1691 const TensorInfo& weights,
1692 const Optional<TensorInfo>& biases,
1693 Optional<std::string&> reasonIfUnsupported) const
1694{
1695 ignore_unused(descriptor);
1696
1697 bool supported = true;
1698
1699 std::array<DataType,3> supportedTypes =
1700 {
1701 DataType::Float32,
1702 DataType::QuantisedAsymm8,
1703 DataType::QuantisedSymm16
1704 };
1705
1706 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1707 "Reference TransposeConvolution2d: input is not a supported type.");
1708
1709 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1710 "Reference TransposeConvolution2d: output is not a supported type.");
1711
1712 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1713 "Reference TransposeConvolution2d: weights is not a supported type.");
1714
1715 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1716 "Reference TransposeConvolution2d: input and output types mismatched.");
1717
1718 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1719 "Reference TransposeConvolution2d: input and weights types mismatched.");
1720
1721 if (biases.has_value())
1722 {
1723 std::array<DataType,3> biasesSupportedTypes = {
1724 DataType::Float32,
1725 DataType::Signed32
1726 };
1727 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1728 "Reference TransposeConvolution2d: biases is not a supported type.");
1729 }
1730
1731 return supported;
1732}
1733
arovir011c7c81b2018-10-08 11:34:28 +01001734} // namespace armnn