blob: b9aa126a8c50126c2a779fa89bb95d69f241831c [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <InternalTypes.hpp>
10#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000012#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
David Beck111b5d92018-11-12 14:59:37 +000014#include <backendsCommon/BackendRegistry.hpp>
Francis Murtagh46c09d02019-05-28 08:15:28 +010015#include <backendsCommon/test/WorkloadTestUtils.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010016
telsoa014fcda012018-03-09 14:13:49 +000017#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
Derek Lamberti50db4e82019-03-13 14:16:15 +000019#include <vector>
20#include <algorithm>
21#include <array>
22
telsoa014fcda012018-03-09 14:13:49 +000023using namespace boost;
24
25namespace armnn
26{
27
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010028namespace
29{
30
31template<typename Float32Func, typename Uint8Func, typename ... Params>
32bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
33 DataType dataType,
34 Float32Func floatFuncPtr,
35 Uint8Func uint8FuncPtr,
36 Params&&... params)
37{
38 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
39 dataType,
40 &FalseFunc<Params...>,
41 floatFuncPtr,
42 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000043 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000044 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010045 std::forward<Params>(params)...);
46}
47
48} // anonymous namespace
49
James Conroy4d1ff582019-06-10 17:06:39 +010050namespace
51{
52
53std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
54 unsigned int actual,
55 std::string& layerStr,
56 std::string& tensorName)
57{
58 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
59 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
60
61 return errorMsg;
62}
63
64} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000065
66namespace
67{
68template<typename F>
69bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
70{
71 bool supported = rule();
72 if (!supported && reason)
73 {
74 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
75 }
76 return supported;
77}
78
79struct Rule
80{
81 bool operator()() const
82 {
83 return m_Res;
84 }
85
86 bool m_Res = true;
87};
88
Derek Lamberti2a434a82019-03-20 13:07:57 +000089template<typename T>
90bool AllTypesAreEqualImpl(T t)
Derek Lamberti50db4e82019-03-13 14:16:15 +000091{
92 return true;
93}
94
95template<typename T, typename... Rest>
96bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
97{
98 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
99
Derek Lamberti2a434a82019-03-20 13:07:57 +0000100 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
Derek Lamberti50db4e82019-03-13 14:16:15 +0000101}
102
103struct TypesAreEqual : public Rule
104{
105 template<typename ... Ts>
106 TypesAreEqual(const Ts&... ts)
107 {
108 m_Res = AllTypesAreEqualImpl(ts...);
109 }
110};
111
112struct QuantizationParametersAreEqual : public Rule
113{
114 QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
115 {
116 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
117 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
118 }
119};
120
121struct TypeAnyOf : public Rule
122{
123 template<typename Container>
124 TypeAnyOf(const TensorInfo& info, const Container& c)
125 {
126 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
Francis Murtagh46c09d02019-05-28 08:15:28 +0100127 {
128 return dt == info.GetDataType();
129 });
130 }
131};
132
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100133struct TypeIs : public Rule
134{
135 TypeIs(const TensorInfo& info, DataType dt)
136 {
137 m_Res = dt == info.GetDataType();
138 }
139};
140
Francis Murtagh46c09d02019-05-28 08:15:28 +0100141struct BiasAndWeightsTypesMatch : public Rule
142{
143 BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
144 {
145 m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value();
146 }
147};
148
149struct BiasAndWeightsTypesCompatible : public Rule
150{
151 template<typename Container>
152 BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c)
153 {
154 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
155 {
156 return dt == GetBiasTypeFromWeightsType(info.GetDataType()).value();
157 });
Derek Lamberti50db4e82019-03-13 14:16:15 +0000158 }
159};
160
161struct ShapesAreSameRank : public Rule
162{
163 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
164 {
165 m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
166 }
167};
168
Derek Lamberti5f400d62019-03-25 15:41:58 +0000169struct ShapesAreSameTotalSize : public Rule
170{
171 ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
172 {
173 m_Res = info0.GetNumElements() == info1.GetNumElements();
174 }
175};
176
Derek Lamberti50db4e82019-03-13 14:16:15 +0000177struct ShapesAreBroadcastCompatible : public Rule
178{
179 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
180 {
181 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
182 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
183 return sizeIn;
184 }
185
186 ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
187 {
188 const TensorShape& shape0 = in0.GetShape();
189 const TensorShape& shape1 = in1.GetShape();
190 const TensorShape& outShape = out.GetShape();
191
192 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
193 {
194 unsigned int sizeOut = outShape[i];
195 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
196 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
197
198 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
199 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
200 }
201 }
202};
James Conroy4d1ff582019-06-10 17:06:39 +0100203
204struct TensorNumDimensionsAreCorrect : public Rule
205{
206 TensorNumDimensionsAreCorrect(const TensorInfo& info, unsigned int expectedNumDimensions)
207 {
208 m_Res = info.GetNumDimensions() == expectedNumDimensions;
209 }
210};
211
Derek Lamberti50db4e82019-03-13 14:16:15 +0000212} // namespace
213
214
arovir011c7c81b2018-10-08 11:34:28 +0100215bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
216 const TensorInfo& output,
217 const ActivationDescriptor& descriptor,
218 Optional<std::string&> reasonIfUnsupported) const
219{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000220 bool supported = true;
221
222 // Define supported types.
Teresa Charlin18515e22019-04-24 10:17:46 +0100223 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000224 DataType::Float32,
Teresa Charlin18515e22019-04-24 10:17:46 +0100225 DataType::QuantisedAsymm8,
226 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000227 };
228
229 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
230 "Reference activation: input type not supported.");
231
232 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
233 "Reference activation: output type not supported.");
234
235 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
236 "Reference activation: input and output types mismatched.");
237
238 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
239 "Reference activation: input and output shapes are of different rank.");
240
241
242 struct ActivationFunctionSupported : public Rule
243 {
244 ActivationFunctionSupported(const ActivationDescriptor& desc)
245 {
246 switch(desc.m_Function)
247 {
248 case ActivationFunction::Abs:
249 case ActivationFunction::BoundedReLu:
250 case ActivationFunction::LeakyReLu:
251 case ActivationFunction::Linear:
252 case ActivationFunction::ReLu:
253 case ActivationFunction::Sigmoid:
254 case ActivationFunction::SoftReLu:
255 case ActivationFunction::Sqrt:
256 case ActivationFunction::Square:
257 case ActivationFunction::TanH:
258 {
259 m_Res = true;
260 break;
261 }
262 default:
263 {
264 m_Res = false;
265 break;
266 }
267 }
268 }
269 };
270
271 // Function is supported
272 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
273 "Reference activation: function not supported.");
274
275 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100276}
277
278bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
279 const TensorInfo& input1,
280 const TensorInfo& output,
281 Optional<std::string&> reasonIfUnsupported) const
282{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000283 bool supported = true;
284
Sadik Armagan2999a022019-04-09 14:20:12 +0100285 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000286 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100287 DataType::QuantisedAsymm8,
288 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000289 };
290
291 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
292 "Reference addition: input 0 is not a supported type.");
293
294 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
295 "Reference addition: input 1 is not a supported type.");
296
297 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
298 "Reference addition: output is not a supported type.");
299
300 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
301 "Reference addition: input 0 and Input 1 types are mismatched");
302
303 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
304 "Reference addition: input and output types are mismatched");
305
306 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
307 "Reference addition: shapes are not suitable for implicit broadcast.");
308
309 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100310}
311
312bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
313 const TensorInfo& output,
314 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100315 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100316 const TensorInfo& beta,
317 const TensorInfo& gamma,
318 const BatchNormalizationDescriptor& descriptor,
319 Optional<std::string&> reasonIfUnsupported) const
320{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100321 ignore_unused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100322
Matteo Martincighf5507132019-06-04 10:59:47 +0100323 std::array<DataType, 3> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100324 {
325 DataType::Float32,
Matteo Martincighf5507132019-06-04 10:59:47 +0100326 DataType::QuantisedAsymm8,
327 DataType::QuantisedSymm16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100328 };
329
330 bool supported = true;
331
332 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
333 "Reference batch normalization: input is not a supported type.");
334
335 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
336 "Reference batch normalization: output is not a supported type.");
337
338 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
339 "Reference batch normalization: input and output types are mismatched");
340
341 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
342 "Reference batch normalization: mean is not a supported type.");
343
344 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
345 "Reference batch normalization: variance is not a supported type.");
346
347 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
348 "Reference batch normalization: beta is not a supported type.");
349
350 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
351 "Reference batch normalization: gamma is not a supported type.");
352
353 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100354}
355
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000356bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
357 const TensorInfo& output,
358 const BatchToSpaceNdDescriptor& descriptor,
359 Optional<std::string&> reasonIfUnsupported) const
360{
361 ignore_unused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100362
363 bool supported = true;
364
365 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
366 std::string inputTensorStr = "input";
367 std::string outputTensorStr = "output";
368
369 // Define supported types.
370 std::array<DataType,3> supportedTypes =
371 {
372 DataType::Float32,
373 DataType::QuantisedAsymm8,
374 DataType::QuantisedSymm16
375 };
376
377 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
378 "Reference BatchToSpaceNd: input type not supported.");
379
380 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
381 "Reference BatchToSpaceNd: output type not supported.");
382
383 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
384 "Reference BatchToSpaceNd: input and output types mismatched.");
385
386 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
387 reasonIfUnsupported,
388 CreateIncorrectDimensionsErrorMsg(4,
389 output.GetNumDimensions(),
390 batchToSpaceNdLayerStr,
391 outputTensorStr).data());
392
393 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
394 reasonIfUnsupported,
395 CreateIncorrectDimensionsErrorMsg(4,
396 input.GetNumDimensions(),
397 batchToSpaceNdLayerStr,
398 inputTensorStr).data());
399
400 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000401}
402
Jim Flynn906f9462019-05-10 13:55:21 +0100403bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
404 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100405 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100406 Optional<std::string&> reasonIfUnsupported) const
407{
Jim Flynne242f2d2019-05-22 14:24:13 +0100408 ignore_unused(descriptor);
409
410 bool supported = true;
411 std::array<DataType,3> supportedTypes =
412 {
413 DataType::Float32,
414 DataType::QuantisedAsymm8,
415 DataType::QuantisedSymm16
416 };
417
418 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
419 "Reference concatenation: output type not supported");
420 for (const TensorInfo* input : inputs)
421 {
Matthew Jackson81e601c2019-07-11 12:07:09 +0100422 BOOST_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100423 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
424 "Reference concatenation: input type not supported");
425
426 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
427 "Reference concatenation: input and output types mismatched.");
428 }
429
430 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100431}
432
arovir011c7c81b2018-10-08 11:34:28 +0100433bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
434 Optional<std::string&> reasonIfUnsupported) const
435{
Jim Flynne242f2d2019-05-22 14:24:13 +0100436 std::array<DataType,4> supportedTypes =
437 {
Nina Drozd58ef2c62019-05-16 12:09:18 +0100438 DataType::Float32,
439 DataType::Signed32,
440 DataType::QuantisedAsymm8,
441 DataType::QuantisedSymm16
442 };
443
444 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
445 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100446}
447
448bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
449 const TensorInfo& output,
450 Optional<std::string&> reasonIfUnsupported) const
451{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100452 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
453 input.GetDataType(),
454 &TrueFunc<>,
455 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000456 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000457 &FalseFuncI32<>,
458 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100459 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
460 output.GetDataType(),
461 &FalseOutputFuncF16<>,
462 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000463 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000464 &FalseFuncI32<>,
465 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100466}
467
468bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
469 const TensorInfo& output,
470 Optional<std::string&> reasonIfUnsupported) const
471{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100472 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
473 input.GetDataType(),
474 &FalseInputFuncF16<>,
475 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000476 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000477 &FalseFuncI32<>,
478 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100479 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
480 output.GetDataType(),
481 &TrueFunc<>,
482 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000483 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000484 &FalseFuncI32<>,
485 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100486}
487
488bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
489 const TensorInfo& output,
490 const Convolution2dDescriptor& descriptor,
491 const TensorInfo& weights,
492 const Optional<TensorInfo>& biases,
493 Optional<std::string&> reasonIfUnsupported) const
494{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100495 bool supported = true;
496
497 // Define supported types.
498 std::array<DataType,3> supportedTypes = {
499 DataType::Float32,
500 DataType::QuantisedAsymm8,
501 DataType::QuantisedSymm16
502 };
503
504 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100505 "Reference convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100506
507 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100508 "Reference convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100509
510 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100511 "Reference convolution2d: weights is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100512
513 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100514 "Reference convolution2d: input and output types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100515
516 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100517 "Reference convolution2d: input and weights types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100518
519 if (biases.has_value())
520 {
521 std::array<DataType,3> biasesSupportedTypes = {
522 DataType::Float32,
523 DataType::Signed32
524 };
525 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100526 "Reference convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100527 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100528 ignore_unused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100529
530 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100531}
532
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000533bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
534 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000535 Optional<std::string&> reasonIfUnsupported) const
536{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100537 bool supported = true;
538
539 std::array<DataType,3> supportedTypes =
540 {
541 DataType::Float32,
542 DataType::QuantisedAsymm8,
543 DataType::QuantisedSymm16
544 };
545
546 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
547 "Reference debug: input type not supported");
548
549 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
550 "Reference debug: output type not supported");
551
552 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
553 "Reference debug: input and output types are mismatched");
554
555 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000556}
557
arovir011c7c81b2018-10-08 11:34:28 +0100558bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
559 const TensorInfo& output,
560 const DepthwiseConvolution2dDescriptor& descriptor,
561 const TensorInfo& weights,
562 const Optional<TensorInfo>& biases,
563 Optional<std::string&> reasonIfUnsupported) const
564{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100565 bool supported = true;
566
567 // Define supported types.
568 std::array<DataType,3> supportedTypes =
569 {
570 DataType::Float32,
571 DataType::QuantisedAsymm8,
572 DataType::QuantisedSymm16
573 };
574
575 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
576 "Reference DepthwiseConvolution2d: input is not a supported type.");
577
578 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
579 "Reference DepthwiseConvolution2d: output is not a supported type.");
580
581 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
582 "Reference DepthwiseConvolution2d: weights is not a supported type.");
583
584 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
585 "Reference DepthwiseConvolution2d: input and output types mismatched.");
586
587 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
588 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
589
590 if (biases.has_value())
591 {
592 std::array<DataType,2> biasesSupportedTypes =
593 {
594 DataType::Float32,
595 DataType::Signed32
596 };
597 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
598 "Reference DepthwiseConvolution2d: biases is not a supported type.");
599 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100600 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100601
602 return supported;
603
arovir011c7c81b2018-10-08 11:34:28 +0100604}
605
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000606bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
607 const TensorInfo& output,
608 Optional<std::string&> reasonIfUnsupported) const
609{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100610 bool supported = true;
611
612 std::array<DataType,2> supportedInputTypes = {
613 DataType::QuantisedAsymm8,
614 DataType::QuantisedSymm16
615 };
616
617 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
618 "Reference dequantize: input type not supported.");
619
620 std::array<DataType,2> supportedOutputTypes = {
621 DataType::Float32,
622 };
623
624 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
625 "Reference dequantize: output type not supported.");
626
627 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
628 "Reference dequantize: input and output shapes have different num total elements.");
629
630 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000631}
632
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000633bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
634 const armnn::TensorInfo& input1,
635 const armnn::DetectionPostProcessDescriptor& descriptor,
636 armnn::Optional<std::string&> reasonIfUnsupported) const
637{
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100638 bool supported = true;
639
640 std::vector<DataType> supportedInputTypes =
641 {
642 DataType::Float32,
643 DataType::QuantisedAsymm8,
644 DataType::QuantisedSymm16
645 };
646
647 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
648 "Reference DetectionPostProcess: input 0 is not a supported type.");
649
650 supported &= CheckSupportRule(TypeAnyOf(input1, supportedInputTypes), reasonIfUnsupported,
651 "Reference DetectionPostProcess: input 1 is not a supported type.");
652
653 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000654}
655
Pablo Tellof0bd6832019-04-26 17:58:13 +0100656bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
657 const TensorInfo& output,
658 const DepthwiseConvolution2dDescriptor& descriptor,
659 const TensorInfo& weights,
660 const Optional<TensorInfo>& biases,
661 Optional<std::string&> reasonIfUnsupported) const
662{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100663 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100664}
665
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100666bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100667 const TensorInfo& input1,
668 const TensorInfo& output,
669 Optional<std::string&> reasonIfUnsupported) const
670{
Sadik Armagan2999a022019-04-09 14:20:12 +0100671 bool supported = true;
672
673 std::array<DataType,3> supportedTypes = {
674 DataType::Float32,
675 DataType::QuantisedAsymm8,
676 DataType::QuantisedSymm16
677 };
678
679 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
680 "Reference division: input 0 is not a supported type.");
681
682 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
683 "Reference division: input 1 is not a supported type.");
684
685 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
686 "Reference division: output is not a supported type.");
687
688 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
689 "Reference division: input 0 and Input 1 types are mismatched");
690
691 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
692 "Reference division: input and output types are mismatched");
693
694 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
695 "Reference division: shapes are not suitable for implicit broadcast.");
696
697 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100698}
699
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000700bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
701 const TensorInfo& input1,
702 const TensorInfo& output,
703 Optional<std::string&> reasonIfUnsupported) const
704{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100705 bool supported = true;
706
707 std::array<DataType,3> supportedTypes =
708 {
709 DataType::Float32,
710 DataType::QuantisedAsymm8,
711 DataType::QuantisedSymm16
712 };
713
714 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
715 "Reference equal: input 0 is not a supported type.");
716
717 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
718 "Reference equal: input 1 is not a supported type.");
719
720 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
721 "Reference equal: input 0 and Input 1 types are mismatched");
722
723 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
724 "Reference equal: shapes are not suitable for implicit broadcast.");
725
726 return supported;
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000727}
728
arovir011c7c81b2018-10-08 11:34:28 +0100729bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
730 const FakeQuantizationDescriptor& descriptor,
731 Optional<std::string&> reasonIfUnsupported) const
732{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100733 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100734 bool supported = true;
735
736 std::array<DataType,1> supportedTypes =
737 {
738 DataType::Float32
739 };
740
741 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
742 "Reference fake quantization: input type not supported.");
743
744 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100745}
746
747bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
748 const TensorInfo& output,
749 Optional<std::string&> reasonIfUnsupported) const
750{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100751 ignore_unused(output);
James Conroy83735b12019-05-30 16:36:59 +0100752 bool supported = true;
753
James Conroyb40d7102019-06-04 12:32:09 +0100754 std::array<DataType,2> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100755 {
James Conroyb40d7102019-06-04 12:32:09 +0100756 DataType::Float32,
757 DataType::QuantisedSymm16
James Conroy83735b12019-05-30 16:36:59 +0100758 };
759
760 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
761 "Reference Floor: input type not supported.");
762
763 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
764 "Reference Floor: output type not supported.");
765
766 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100767}
768
769bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
770 const TensorInfo& output,
771 const TensorInfo& weights,
772 const TensorInfo& biases,
773 const FullyConnectedDescriptor& descriptor,
774 Optional<std::string&> reasonIfUnsupported) const
775{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100776 bool supported = true;
777
778 // Define supported types.
779 std::array<DataType,3> supportedTypes =
780 {
781 DataType::Float32,
782 DataType::QuantisedAsymm8,
783 DataType::QuantisedSymm16
784 };
785
786 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
787 "Reference Fully Connected: input type not supported.");
788
789 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
790 "Reference Fully Connected: output type not supported.");
791
792 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
793 "Reference Fully Connected: input and output types mismatched.");
794
795 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
796 "Reference Fully Connected: weights type not supported.");
797
798 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
799 "Reference Fully Connected: input and weight types mismatched.");
800
801 if (descriptor.m_BiasEnabled)
802 {
803 // Defined supported types for bias
804 std::array<DataType, 2>
805 supportedBiasTypes =
806 {
807 DataType::Float32,
808 DataType::Signed32
809 };
810
811 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
812 "Reference Fully Connected: bias type not supported.");
813
814 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
815 "Reference Fully Connected: bias and weight types mismatch.");
816
817 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
818 "Reference Fully Connected: bias type inferred from weights is incompatible.");
819
820 }
821
822 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100823}
824
narpra014951d842019-01-18 16:53:53 +0000825bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
826 const armnn::TensorInfo& input1,
827 const armnn::TensorInfo& output,
828 armnn::Optional<std::string&> reasonIfUnsupported) const
829{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100830 bool supported = true;
831 std::array<DataType,3> supportedTypes =
832 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100833 DataType::Float32,
834 DataType::QuantisedAsymm8,
835 DataType::QuantisedSymm16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100836 };
837
838 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
839 "Reference Gather: input type not supported");
840
841 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
842 "Reference Gather: output type not supported");
843
844 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
845 "Reference Gather: indices (input1) type not supported");
846
847 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
848 "Reference Gather: input and output types not matching");
849
850 return supported;
narpra014951d842019-01-18 16:53:53 +0000851}
852
FrancisMurtagh878f0232018-12-19 10:56:15 +0000853bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
854 const TensorInfo& input1,
855 const TensorInfo& output,
856 Optional<std::string&> reasonIfUnsupported) const
857{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100858 bool supported = true;
859
860 std::array<DataType,3> supportedTypes =
861 {
862 DataType::Float32,
863 DataType::QuantisedAsymm8,
864 DataType::QuantisedSymm16
865 };
866
867 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
868 "Reference greater: input 0 is not a supported type.");
869
870 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
871 "Reference greater: input 1 is not a supported type.");
872
873 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
874 "Reference greater: input 0 and Input 1 types are mismatched");
875
876 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
877 "Reference greater: shapes are not suitable for implicit broadcast.");
878
879 return supported;
FrancisMurtagh878f0232018-12-19 10:56:15 +0000880}
881
arovir011c7c81b2018-10-08 11:34:28 +0100882bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
883 Optional<std::string&> reasonIfUnsupported) const
884{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +0100885 return true;
arovir011c7c81b2018-10-08 11:34:28 +0100886}
887
888bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
889 const TensorInfo& output,
890 const L2NormalizationDescriptor& descriptor,
891 Optional<std::string&> reasonIfUnsupported) const
892{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100893 ignore_unused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100894 // Define supported types
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100895 std::array<DataType, 3> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100896 {
897 DataType::Float32,
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100898 DataType::QuantisedAsymm8,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100899 DataType::QuantisedSymm16
900 };
901
902 bool supported = true;
903
904 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
905 "Reference L2normalization: input type not supported.");
906
907 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
908 "Reference L2normalization: output type not supported.");
909
910 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
911 "Reference L2normalization: input and output types mismatched.");
912
913 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
914 "Reference L2normalization: input and output shapes have different "
915 "num total elements.");
916
917 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100918}
919
920bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
921 const TensorInfo& outputStateIn,
922 const TensorInfo& cellStateIn,
923 const TensorInfo& scratchBuffer,
924 const TensorInfo& outputStateOut,
925 const TensorInfo& cellStateOut,
926 const TensorInfo& output,
927 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100928 const LstmInputParamsInfo& paramsInfo,
929 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +0100930{
telsoa01c577f2c2018-08-31 09:22:23 +0100931 ignore_unused(descriptor);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100932 ignore_unused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100933
934 bool supported = true;
935
936 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +0100937 DataType::Float32,
938 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100939 };
940
Jan Eilersd01a83c2019-07-03 18:20:40 +0100941 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100942 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
943 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100944 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
945 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100946 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
947 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100948 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
949 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100950 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
951 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100952 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
953 "Reference Lstm: input and cellStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100954 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
955 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +0100956 // check layer parameters
957 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToForgetWeights()), reasonIfUnsupported,
958 "Reference Lstm: input and InputToForgetWeights types are mismatched");
959 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToCellWeights()), reasonIfUnsupported,
960 "Reference Lstm: input and InputToCellWeights types are mismatched");
961 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToOutputWeights()), reasonIfUnsupported,
962 "Reference Lstm: input and InputToOutputWeights types are mismatched");
963 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToForgetWeights()), reasonIfUnsupported,
964 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
965 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToCellWeights()), reasonIfUnsupported,
966 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
967 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToOutputWeights()), reasonIfUnsupported,
968 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
969 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ForgetGateBias()), reasonIfUnsupported,
970 "Reference Lstm: input and ForgetGateBias types are mismatched");
971 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellBias()), reasonIfUnsupported,
972 "Reference Lstm: input and CellBias types are mismatched");
973 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_OutputGateBias()), reasonIfUnsupported,
974 "Reference Lstm: input and OutputGateBias types are mismatched");
975 if (!descriptor.m_CifgEnabled)
976 {
977 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToInputWeights()), reasonIfUnsupported,
978 "Reference Lstm: input and InputToInputWeights types are mismatched");
979 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToInputWeights()),
980 reasonIfUnsupported,
981 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
982 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputGateBias()), reasonIfUnsupported,
983 "Reference Lstm: input and InputGateBias types are mismatched");
984 if (descriptor.m_PeepholeEnabled)
985 {
986 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToInputWeights()),
987 reasonIfUnsupported,
988 "Reference Lstm: input and CellToInputWeights types are mismatched");
989 }
990 }
991 if (descriptor.m_PeepholeEnabled)
992 {
993 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToForgetWeights()), reasonIfUnsupported,
994 "Reference Lstm: input and CellToForgetWeights types are mismatched");
995 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToOutputWeights()), reasonIfUnsupported,
996 "Reference Lstm: input and CellToOutputWeights types are mismatched");
997 }
998 if (descriptor.m_ProjectionEnabled)
999 {
1000 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ProjectionWeights()), reasonIfUnsupported,
1001 "Reference Lstm: input and mProjectionWeights types are mismatched");
1002 if (paramsInfo.m_ProjectionBias != nullptr)
1003 {
1004 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ProjectionBias()), reasonIfUnsupported,
1005 "Reference Lstm: input and ProjectionBias types are mismatched");
1006 }
1007 }
1008 if (descriptor.m_LayerNormEnabled)
1009 {
1010 if (!descriptor.m_CifgEnabled)
1011 {
1012 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputLayerNormWeights()),
1013 reasonIfUnsupported,
1014 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1015 }
1016 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ForgetLayerNormWeights()),
1017 reasonIfUnsupported,
1018 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
1019 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellLayerNormWeights()),
1020 reasonIfUnsupported,
1021 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
1022 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_OutputLayerNormWeights()),
1023 reasonIfUnsupported,
1024 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1025 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001026
1027 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001028}
1029
saoste012df12b32018-11-28 16:57:20 +00001030bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1031 const TensorInfo& input1,
1032 const TensorInfo& output,
1033 Optional<std::string&> reasonIfUnsupported) const
1034{
Sadik Armagan2999a022019-04-09 14:20:12 +01001035 bool supported = true;
1036
1037 std::array<DataType,3> supportedTypes = {
1038 DataType::Float32,
1039 DataType::QuantisedAsymm8,
1040 DataType::QuantisedSymm16
1041 };
1042
1043 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1044 "Reference maximum: input 0 is not a supported type.");
1045
1046 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1047 "Reference maximum: input 1 is not a supported type.");
1048
1049 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1050 "Reference maximum: output is not a supported type.");
1051
1052 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1053 "Reference maximum: input 0 and Input 1 types are mismatched");
1054
1055 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1056 "Reference maximum: input and output types are mismatched");
1057
1058 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1059 "Reference maximum: shapes are not suitable for implicit broadcast.");
1060
1061 return supported;
saoste012df12b32018-11-28 16:57:20 +00001062}
1063
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001064bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1065 const TensorInfo& output,
1066 const MeanDescriptor& descriptor,
1067 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001068{
James Conroy4d1ff582019-06-10 17:06:39 +01001069 bool supported = true;
1070 std::string meanLayerStr = "Mean";
1071 std::string outputTensorStr = "output";
1072
James Conroyb80775f2019-06-11 11:25:30 +01001073 std::array<DataType,3> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001074 {
1075 DataType::Float32,
James Conroyb80775f2019-06-11 11:25:30 +01001076 DataType::QuantisedAsymm8,
1077 DataType::QuantisedSymm16
James Conroy4d1ff582019-06-10 17:06:39 +01001078 };
1079
1080 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1081 "Reference Mean: input type not supported.");
1082
1083 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1084 "Reference Mean: input and output types are mismatched");
1085
1086 if (descriptor.m_KeepDims)
1087 {
1088 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1089 reasonIfUnsupported,
1090 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1091 output.GetNumDimensions(),
1092 meanLayerStr, outputTensorStr).data());
1093 }
1094 else if (descriptor.m_Axis.empty())
1095 {
1096 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1097 reasonIfUnsupported,
1098 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1099 meanLayerStr, outputTensorStr).data());
1100 }
1101 else
1102 {
1103 auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
1104
1105 if (outputDim > 0)
1106 {
1107 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1108 reasonIfUnsupported,
1109 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1110 meanLayerStr, outputTensorStr).data());
1111 }
1112 else
1113 {
1114 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1115 reasonIfUnsupported,
1116 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1117 meanLayerStr, outputTensorStr).data());
1118 }
1119 }
1120
1121 return supported;
narpra0132b90462018-09-13 11:07:48 +01001122}
1123
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001124bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +00001125 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +01001126 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001127 Optional<std::string&> reasonIfUnsupported) const
1128{
Jim Flynne242f2d2019-05-22 14:24:13 +01001129 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001130}
1131
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001132bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1133 const TensorInfo &output,
1134 Optional<std::string &> reasonIfUnsupported) const
1135{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001136 bool supported = true;
1137
1138 std::array<DataType,5> supportedTypes =
1139 {
1140 DataType::Float32,
1141 DataType::Float16,
1142 DataType::QuantisedAsymm8,
1143 DataType::QuantisedSymm16,
1144 DataType::Boolean
1145 };
1146
1147 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1148 "Reference MemCopy: input type not supported");
1149
1150 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1151 "Reference MemCopy: output type not supported");
1152
1153 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1154 "Reference MemCopy: input and output types are mismatched");
1155
1156 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001157}
1158
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001159bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1160 const TensorInfo& input1,
1161 const TensorInfo& output,
1162 Optional<std::string&> reasonIfUnsupported) const
1163{
Sadik Armagan2999a022019-04-09 14:20:12 +01001164 bool supported = true;
1165
1166 std::array<DataType,3> supportedTypes = {
1167 DataType::Float32,
1168 DataType::QuantisedAsymm8,
1169 DataType::QuantisedSymm16
1170 };
1171
1172 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1173 "Reference minimum: input 0 is not a supported type.");
1174
1175 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1176 "Reference minimum: input 1 is not a supported type.");
1177
1178 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1179 "Reference minimum: output is not a supported type.");
1180
1181 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1182 "Reference minimum: input 0 and Input 1 types are mismatched");
1183
1184 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1185 "Reference minimum: input and output types are mismatched");
1186
1187 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1188 "Reference minimum: shapes are not suitable for implicit broadcast.");
1189
1190 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001191}
1192
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001193bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1194 const TensorInfo& input1,
1195 const TensorInfo& output,
1196 Optional<std::string&> reasonIfUnsupported) const
1197{
Sadik Armagan2999a022019-04-09 14:20:12 +01001198 bool supported = true;
1199
1200 std::array<DataType,3> supportedTypes = {
1201 DataType::Float32,
1202 DataType::QuantisedAsymm8,
1203 DataType::QuantisedSymm16
1204 };
1205
1206 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1207 "Reference multiplication: input 0 is not a supported type.");
1208
1209 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1210 "Reference multiplication: input 1 is not a supported type.");
1211
1212 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1213 "Reference multiplication: output is not a supported type.");
1214
1215 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1216 "Reference multiplication: input 0 and Input 1 types are mismatched");
1217
1218 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1219 "Reference multiplication: input and output types are mismatched");
1220
1221 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1222 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1223
1224 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001225}
1226
1227bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1228 const TensorInfo& output,
1229 const NormalizationDescriptor& descriptor,
1230 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001231{
Nina Drozd661dfa72018-10-02 11:14:17 +01001232 ignore_unused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001233
1234 // Define supported types
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001235 std::array<DataType, 4> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001236 {
1237 DataType::Float16,
1238 DataType::Float32,
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001239 DataType::QuantisedAsymm8,
1240 DataType::QuantisedSymm16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001241 };
1242
1243 bool supported = true;
1244
1245 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1246 "Reference normalization: input type not supported.");
1247
1248 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1249 "Reference normalization: output type not supported.");
1250
1251 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1252 "Reference normalization: input and output shapes have different "
1253 "num total elements.");
1254
1255 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001256}
1257
1258bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
1259 Optional<std::string&> reasonIfUnsupported) const
1260{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001261 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001262}
1263
1264bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1265 const TensorInfo& output,
1266 const PadDescriptor& descriptor,
1267 Optional<std::string&> reasonIfUnsupported) const
1268{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001269 ignore_unused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001270 bool supported = true;
1271
1272 // Define supported output and inputs types.
1273 std::array<DataType,3> supportedTypes =
1274 {
1275 DataType::Float32,
1276 DataType::QuantisedAsymm8,
1277 DataType::QuantisedSymm16
1278 };
1279
1280 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1281 "Reference pad: input is not a supported type.");
1282
1283 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1284 "Reference pad: output is not a supported type.");
1285
1286 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1287 "Reference pad: input and output types are mismatched.");
1288
1289 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01001290}
1291
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001292bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1293 const TensorInfo& output,
1294 const PermuteDescriptor& descriptor,
1295 Optional<std::string&> reasonIfUnsupported) const
1296{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001297 ignore_unused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001298 bool supported = true;
1299
1300 // Define supported output and inputs types.
1301 std::array<DataType,3> supportedTypes =
1302 {
1303 DataType::Float32,
1304 DataType::QuantisedAsymm8,
1305 DataType::QuantisedSymm16
1306 };
1307
1308 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1309 "Reference permute: input is not a supported type.");
1310
1311 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1312 "Reference permute: output is not a supported type.");
1313
1314 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1315 "Reference permute: input and output types are mismatched.");
1316
1317 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001318}
1319
1320bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1321 const TensorInfo& output,
1322 const Pooling2dDescriptor& descriptor,
1323 Optional<std::string&> reasonIfUnsupported) const
1324{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001325 ignore_unused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001326 bool supported = true;
1327
1328 // Define supported output and inputs types.
Teresa Charlin0434df62019-06-06 13:40:35 +01001329 std::array<DataType,3> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001330 {
1331 DataType::Float32,
Teresa Charlin0434df62019-06-06 13:40:35 +01001332 DataType::QuantisedAsymm8,
1333 DataType::QuantisedSymm16
Teresa Charlina3b20472019-06-06 11:12:32 +01001334 };
1335
1336 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1337 "Reference poolind2d: input is not a supported type.");
1338
1339 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1340 "Reference poolind2d: output is not a supported type.");
1341
1342 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1343 "Reference poolind2d: input and output types are mismatched.");
1344
1345 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001346}
1347
Derek Lamberti5f400d62019-03-25 15:41:58 +00001348bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1349 const TensorInfo& output,
1350 Optional<std::string&> reasonIfUnsupported) const
1351{
1352 bool supported = true;
1353
1354 // Define supported output types.
1355 std::array<DataType,2> supportedInputTypes = {
1356 DataType::Float32,
1357 };
1358
1359 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1360 "Reference quantize: input type not supported.");
1361
1362 // Define supported output types.
1363 std::array<DataType,2> supportedOutputTypes = {
1364 DataType::QuantisedAsymm8,
1365 DataType::QuantisedSymm16
1366 };
1367 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1368 "Reference quantize: output type not supported.");
1369
1370 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1371 "Reference quantize: input and output shapes have different num total elements.");
1372
1373 return supported;
1374}
1375
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001376bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001377 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001378 Optional<std::string&> reasonIfUnsupported) const
1379{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001380 ignore_unused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001381 // Define supported output types.
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001382 std::array<DataType,4> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001383 {
1384 DataType::Float32,
1385 DataType::Float16,
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001386 DataType::QuantisedAsymm8,
1387 DataType::QuantisedSymm16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001388 };
1389 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1390 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001391}
1392
1393bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001394 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001395 Optional<std::string&> reasonIfUnsupported) const
1396{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001397 bool supported = true;
1398 std::array<DataType,3> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001399 {
1400 DataType::Float32,
1401 DataType::QuantisedAsymm8,
1402 DataType::QuantisedSymm16
1403 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001404
1405 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1406 "Reference ResizeBilinear: input type not supported");
1407
1408 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1409 "Reference ResizeBilinear: output type not supported");
1410
1411 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1412 "Reference ResizeBilinear: input and output types not matching");
1413
1414 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001415}
1416
Teresa Charlin970f43b2019-07-01 13:51:07 +01001417bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1418 const TensorInfo& output,
1419 const ResizeDescriptor& descriptor,
1420 Optional<std::string&> reasonIfUnsupported) const
1421{
1422 bool supported = true;
1423 std::array<DataType,3> supportedTypes =
1424 {
1425 DataType::Float32,
1426 DataType::QuantisedAsymm8,
1427 DataType::QuantisedSymm16
1428 };
1429
1430 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1431 "Reference Resize: input type not supported");
1432
1433 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1434 "Reference Resize: output type not supported");
1435
1436 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1437 "Reference Resize: input and output types not matching");
1438
1439 return supported;
1440}
1441
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001442bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1443 const TensorInfo& output,
1444 Optional<std::string&> reasonIfUnsupported) const
1445{
nikraj010421e7f2019-06-14 09:40:34 +01001446 bool supported = true;
nikraj0124d73212019-06-14 14:20:40 +01001447 std::array<DataType,3> supportedTypes =
nikraj010421e7f2019-06-14 09:40:34 +01001448 {
1449 DataType::Float32,
nikraj0124d73212019-06-14 14:20:40 +01001450 DataType::QuantisedAsymm8,
1451 DataType::QuantisedSymm16
nikraj010421e7f2019-06-14 09:40:34 +01001452 };
1453
1454 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1455 "Reference rsqrt: input type not supported");
1456
1457 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1458 "Reference rsqrt: output type not supported");
1459
1460 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1461 "Reference rsqrt: input and output types not matching");
1462
1463 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1464 "Reference Rsqrt: input and output shapes have different number of total elements");
1465
1466 return supported;
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001467}
1468
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001469bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1470 const TensorInfo& output,
1471 const SoftmaxDescriptor& descriptor,
1472 Optional<std::string&> reasonIfUnsupported) const
1473{
1474 ignore_unused(output);
nikraj01248683f2019-05-29 16:46:50 +01001475 bool supported = true;
1476 std::array<DataType,3> supportedTypes =
1477 {
1478 DataType::Float32,
1479 DataType::QuantisedAsymm8,
1480 DataType::QuantisedSymm16
1481 };
1482
1483 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1484 "Reference concatenation: output type not supported");
1485
1486 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1487 "Reference concatenation: input type not supported");
1488
1489 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1490 "Reference concatenation: input type not supported");
1491
1492 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001493}
1494
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001495bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1496 const TensorInfo& output,
1497 const SpaceToBatchNdDescriptor& descriptor,
1498 Optional<std::string&> reasonIfUnsupported) const
1499{
1500 ignore_unused(output);
nikraj01120522a2019-05-31 11:33:07 +01001501 bool supported = true;
1502 std::array<DataType,3> supportedTypes =
1503 {
1504 DataType::Float32,
1505 DataType::QuantisedAsymm8,
1506 DataType::QuantisedSymm16
1507 };
1508
1509 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1510 "Reference SpaceToBatchNd: input type not supported");
1511
1512 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1513 "Reference SpaceToBatchNd: output type not supported");
1514
1515 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1516 "Reference SpaceToBatchNd: input and output types are mismatched");
1517
1518 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001519}
1520
Keith Davisa57eccb2019-06-14 17:33:22 +01001521bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01001522 const TensorInfo& output,
1523 const SpaceToDepthDescriptor& descriptor,
1524 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01001525{
1526
1527 ignore_unused(descriptor);
1528 bool supported = true;
1529
James Conroyd2aa85e2019-07-01 17:12:40 +01001530 std::array<DataType,3> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01001531 {
1532 DataType::Float32,
1533 DataType::QuantisedAsymm8,
James Conroyd2aa85e2019-07-01 17:12:40 +01001534 DataType::QuantisedSymm16
Keith Davisa57eccb2019-06-14 17:33:22 +01001535 };
1536
1537 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1538 "Reference SpaceToDepth: input type not supported");
1539
1540 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1541 "Reference SpaceToDepth: output type not supported");
1542
1543 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1544 "Reference SpaceToDepth: input and output types are mismatched");
1545
1546 return supported;
1547}
1548
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001549bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1550 const ViewsDescriptor& descriptor,
1551 Optional<std::string&> reasonIfUnsupported) const
1552{
1553 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001554 bool supported = true;
1555 std::array<DataType,3> supportedTypes =
1556 {
1557 DataType::Float32,
1558 DataType::QuantisedAsymm8,
1559 DataType::QuantisedSymm16
1560 };
1561
1562 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1563 "Reference splitter: input type not supported");
1564
1565 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001566}
1567
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001568bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1569 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1570 const ViewsDescriptor& descriptor,
1571 Optional<std::string&> reasonIfUnsupported) const
1572{
1573 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001574 bool supported = true;
1575 std::array<DataType,3> supportedTypes =
1576 {
1577 DataType::Float32,
1578 DataType::QuantisedAsymm8,
1579 DataType::QuantisedSymm16
1580 };
1581
1582 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1583 "Reference splitter: output type not supported");
1584 for (const TensorInfo output : outputs)
1585 {
1586 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1587 "Reference splitter: input type not supported");
1588
1589 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1590 "Reference splitter: input and output types mismatched.");
1591 }
1592
1593 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001594}
1595
Matthew Jackson81e601c2019-07-11 12:07:09 +01001596bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
1597 const TensorInfo& output,
1598 const StackDescriptor& descriptor,
1599 Optional<std::string&> reasonIfUnsupported) const
1600{
1601 ignore_unused(descriptor);
1602
1603 bool supported = true;
1604 std::array<DataType,3> supportedTypes =
1605 {
1606 DataType::Float32,
1607 DataType::QuantisedAsymm8,
1608 DataType::QuantisedSymm16
1609 };
1610
1611 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1612 "Reference stack: output type not supported");
1613 for (const TensorInfo* input : inputs)
1614 {
1615 BOOST_ASSERT(input != nullptr);
1616 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
1617 "Reference stack: input type not supported");
1618
1619 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
1620 "Reference stack: input and output types mismatched.");
1621 }
1622
1623 return supported;
1624}
1625
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001626bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1627 const TensorInfo& output,
1628 const StridedSliceDescriptor& descriptor,
1629 Optional<std::string&> reasonIfUnsupported) const
1630{
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001631 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001632 bool supported = true;
1633
1634 std::array<DataType,3> supportedTypes =
1635 {
1636 DataType::Float32,
1637 DataType::QuantisedAsymm8,
1638 DataType::QuantisedSymm16
1639 };
1640
1641 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1642 "Reference StridedSlice: input type not supported");
1643
1644 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1645 "Reference StridedSlice: output type not supported");
1646
1647 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1648 "Reference StridedSlice: input and output types are mismatched");
1649
1650 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001651}
1652
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001653bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1654 const TensorInfo& input1,
1655 const TensorInfo& output,
1656 Optional<std::string&> reasonIfUnsupported) const
1657{
Sadik Armagan2999a022019-04-09 14:20:12 +01001658 bool supported = true;
1659
1660 std::array<DataType,3> supportedTypes = {
1661 DataType::Float32,
1662 DataType::QuantisedAsymm8,
1663 DataType::QuantisedSymm16
1664 };
1665
1666 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1667 "Reference subtraction: input 0 is not a supported type.");
1668
1669 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1670 "Reference subtraction: input 1 is not a supported type.");
1671
1672 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1673 "Reference subtraction: output is not a supported type.");
1674
1675 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1676 "Reference subtraction: input 0 and Input 1 types are mismatched");
1677
1678 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1679 "Reference subtraction: input and output types are mismatched");
1680
1681 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1682 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1683
1684 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001685}
1686
Matteo Martincighab9e5252019-06-13 17:27:46 +01001687bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
1688 const TensorInfo& alpha,
1689 const TensorInfo& output,
1690 Optional<std::string&> reasonIfUnsupported) const
1691{
1692 bool supported = true;
1693
1694 std::array<DataType, 3> supportedTypes
1695 {
1696 DataType::Float32,
1697 DataType::QuantisedAsymm8,
1698 DataType::QuantisedSymm16
1699 };
1700
1701 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1702 "PReLU: input is not a supported type.");
1703
1704 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
1705 "PReLU: alpha is not a supported type.");
1706
1707 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1708 "PReLU: output is not a supported type.");
1709
1710 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
1711 "PReLU: input, alpha and output types are mismatched");
1712
1713 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
1714 "PReLU: shapes are not suitable for implicit broadcast");
1715
1716 return supported;
1717}
1718
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001719bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
1720 const TensorInfo& output,
1721 const TransposeConvolution2dDescriptor& descriptor,
1722 const TensorInfo& weights,
1723 const Optional<TensorInfo>& biases,
1724 Optional<std::string&> reasonIfUnsupported) const
1725{
1726 ignore_unused(descriptor);
1727
1728 bool supported = true;
1729
1730 std::array<DataType,3> supportedTypes =
1731 {
1732 DataType::Float32,
1733 DataType::QuantisedAsymm8,
1734 DataType::QuantisedSymm16
1735 };
1736
1737 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1738 "Reference TransposeConvolution2d: input is not a supported type.");
1739
1740 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1741 "Reference TransposeConvolution2d: output is not a supported type.");
1742
1743 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1744 "Reference TransposeConvolution2d: weights is not a supported type.");
1745
1746 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1747 "Reference TransposeConvolution2d: input and output types mismatched.");
1748
1749 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1750 "Reference TransposeConvolution2d: input and weights types mismatched.");
1751
1752 if (biases.has_value())
1753 {
1754 std::array<DataType,3> biasesSupportedTypes = {
1755 DataType::Float32,
1756 DataType::Signed32
1757 };
1758 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1759 "Reference TransposeConvolution2d: biases is not a supported type.");
1760 }
1761
1762 return supported;
1763}
1764
arovir011c7c81b2018-10-08 11:34:28 +01001765} // namespace armnn