blob: 26070a532801daa63c970a44ba8e6391f420d61a [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <InternalTypes.hpp>
10#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000012#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
David Beck111b5d92018-11-12 14:59:37 +000014#include <backendsCommon/BackendRegistry.hpp>
Francis Murtagh46c09d02019-05-28 08:15:28 +010015#include <backendsCommon/test/WorkloadTestUtils.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010016
telsoa014fcda012018-03-09 14:13:49 +000017#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
Derek Lamberti50db4e82019-03-13 14:16:15 +000019#include <vector>
20#include <algorithm>
21#include <array>
22
telsoa014fcda012018-03-09 14:13:49 +000023using namespace boost;
24
25namespace armnn
26{
27
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010028namespace
29{
30
31template<typename Float32Func, typename Uint8Func, typename ... Params>
32bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
33 DataType dataType,
34 Float32Func floatFuncPtr,
35 Uint8Func uint8FuncPtr,
36 Params&&... params)
37{
38 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
39 dataType,
40 &FalseFunc<Params...>,
41 floatFuncPtr,
42 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000043 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000044 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010045 std::forward<Params>(params)...);
46}
47
48} // anonymous namespace
49
James Conroy4d1ff582019-06-10 17:06:39 +010050namespace
51{
52
53std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
54 unsigned int actual,
55 std::string& layerStr,
56 std::string& tensorName)
57{
58 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
59 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
60
61 return errorMsg;
62}
63
64} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000065
66namespace
67{
68template<typename F>
69bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
70{
71 bool supported = rule();
72 if (!supported && reason)
73 {
74 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
75 }
76 return supported;
77}
78
79struct Rule
80{
81 bool operator()() const
82 {
83 return m_Res;
84 }
85
86 bool m_Res = true;
87};
88
Derek Lamberti2a434a82019-03-20 13:07:57 +000089template<typename T>
90bool AllTypesAreEqualImpl(T t)
Derek Lamberti50db4e82019-03-13 14:16:15 +000091{
92 return true;
93}
94
95template<typename T, typename... Rest>
96bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
97{
98 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
99
Derek Lamberti2a434a82019-03-20 13:07:57 +0000100 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
Derek Lamberti50db4e82019-03-13 14:16:15 +0000101}
102
103struct TypesAreEqual : public Rule
104{
105 template<typename ... Ts>
106 TypesAreEqual(const Ts&... ts)
107 {
108 m_Res = AllTypesAreEqualImpl(ts...);
109 }
110};
111
112struct QuantizationParametersAreEqual : public Rule
113{
114 QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
115 {
116 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
117 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
118 }
119};
120
121struct TypeAnyOf : public Rule
122{
123 template<typename Container>
124 TypeAnyOf(const TensorInfo& info, const Container& c)
125 {
126 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
Francis Murtagh46c09d02019-05-28 08:15:28 +0100127 {
128 return dt == info.GetDataType();
129 });
130 }
131};
132
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100133struct TypeIs : public Rule
134{
135 TypeIs(const TensorInfo& info, DataType dt)
136 {
137 m_Res = dt == info.GetDataType();
138 }
139};
140
Francis Murtagh46c09d02019-05-28 08:15:28 +0100141struct BiasAndWeightsTypesMatch : public Rule
142{
143 BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
144 {
145 m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value();
146 }
147};
148
149struct BiasAndWeightsTypesCompatible : public Rule
150{
151 template<typename Container>
152 BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c)
153 {
154 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
155 {
156 return dt == GetBiasTypeFromWeightsType(info.GetDataType()).value();
157 });
Derek Lamberti50db4e82019-03-13 14:16:15 +0000158 }
159};
160
161struct ShapesAreSameRank : public Rule
162{
163 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
164 {
165 m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
166 }
167};
168
Derek Lamberti5f400d62019-03-25 15:41:58 +0000169struct ShapesAreSameTotalSize : public Rule
170{
171 ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
172 {
173 m_Res = info0.GetNumElements() == info1.GetNumElements();
174 }
175};
176
Derek Lamberti50db4e82019-03-13 14:16:15 +0000177struct ShapesAreBroadcastCompatible : public Rule
178{
179 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
180 {
181 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
182 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
183 return sizeIn;
184 }
185
186 ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
187 {
188 const TensorShape& shape0 = in0.GetShape();
189 const TensorShape& shape1 = in1.GetShape();
190 const TensorShape& outShape = out.GetShape();
191
192 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
193 {
194 unsigned int sizeOut = outShape[i];
195 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
196 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
197
198 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
199 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
200 }
201 }
202};
James Conroy4d1ff582019-06-10 17:06:39 +0100203
204struct TensorNumDimensionsAreCorrect : public Rule
205{
206 TensorNumDimensionsAreCorrect(const TensorInfo& info, unsigned int expectedNumDimensions)
207 {
208 m_Res = info.GetNumDimensions() == expectedNumDimensions;
209 }
210};
211
Derek Lamberti50db4e82019-03-13 14:16:15 +0000212} // namespace
213
214
arovir011c7c81b2018-10-08 11:34:28 +0100215bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
216 const TensorInfo& output,
217 const ActivationDescriptor& descriptor,
218 Optional<std::string&> reasonIfUnsupported) const
219{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000220 bool supported = true;
221
222 // Define supported types.
Teresa Charlin18515e22019-04-24 10:17:46 +0100223 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000224 DataType::Float32,
Teresa Charlin18515e22019-04-24 10:17:46 +0100225 DataType::QuantisedAsymm8,
226 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000227 };
228
229 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
230 "Reference activation: input type not supported.");
231
232 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
233 "Reference activation: output type not supported.");
234
235 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
236 "Reference activation: input and output types mismatched.");
237
238 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
239 "Reference activation: input and output shapes are of different rank.");
240
241
242 struct ActivationFunctionSupported : public Rule
243 {
244 ActivationFunctionSupported(const ActivationDescriptor& desc)
245 {
246 switch(desc.m_Function)
247 {
248 case ActivationFunction::Abs:
249 case ActivationFunction::BoundedReLu:
250 case ActivationFunction::LeakyReLu:
251 case ActivationFunction::Linear:
252 case ActivationFunction::ReLu:
253 case ActivationFunction::Sigmoid:
254 case ActivationFunction::SoftReLu:
255 case ActivationFunction::Sqrt:
256 case ActivationFunction::Square:
257 case ActivationFunction::TanH:
258 {
259 m_Res = true;
260 break;
261 }
262 default:
263 {
264 m_Res = false;
265 break;
266 }
267 }
268 }
269 };
270
271 // Function is supported
272 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
273 "Reference activation: function not supported.");
274
275 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100276}
277
278bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
279 const TensorInfo& input1,
280 const TensorInfo& output,
281 Optional<std::string&> reasonIfUnsupported) const
282{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000283 bool supported = true;
284
Sadik Armagan2999a022019-04-09 14:20:12 +0100285 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000286 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100287 DataType::QuantisedAsymm8,
288 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000289 };
290
291 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
292 "Reference addition: input 0 is not a supported type.");
293
294 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
295 "Reference addition: input 1 is not a supported type.");
296
297 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
298 "Reference addition: output is not a supported type.");
299
300 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
301 "Reference addition: input 0 and Input 1 types are mismatched");
302
303 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
304 "Reference addition: input and output types are mismatched");
305
306 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
307 "Reference addition: shapes are not suitable for implicit broadcast.");
308
309 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100310}
311
312bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
313 const TensorInfo& output,
314 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100315 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100316 const TensorInfo& beta,
317 const TensorInfo& gamma,
318 const BatchNormalizationDescriptor& descriptor,
319 Optional<std::string&> reasonIfUnsupported) const
320{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100321 ignore_unused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100322
Matteo Martincighf5507132019-06-04 10:59:47 +0100323 std::array<DataType, 3> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100324 {
325 DataType::Float32,
Matteo Martincighf5507132019-06-04 10:59:47 +0100326 DataType::QuantisedAsymm8,
327 DataType::QuantisedSymm16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100328 };
329
330 bool supported = true;
331
332 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
333 "Reference batch normalization: input is not a supported type.");
334
335 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
336 "Reference batch normalization: output is not a supported type.");
337
338 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
339 "Reference batch normalization: input and output types are mismatched");
340
341 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
342 "Reference batch normalization: mean is not a supported type.");
343
344 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
345 "Reference batch normalization: variance is not a supported type.");
346
347 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
348 "Reference batch normalization: beta is not a supported type.");
349
350 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
351 "Reference batch normalization: gamma is not a supported type.");
352
353 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100354}
355
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000356bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
357 const TensorInfo& output,
358 const BatchToSpaceNdDescriptor& descriptor,
359 Optional<std::string&> reasonIfUnsupported) const
360{
361 ignore_unused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100362
363 bool supported = true;
364
365 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
366 std::string inputTensorStr = "input";
367 std::string outputTensorStr = "output";
368
369 // Define supported types.
370 std::array<DataType,3> supportedTypes =
371 {
372 DataType::Float32,
373 DataType::QuantisedAsymm8,
374 DataType::QuantisedSymm16
375 };
376
377 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
378 "Reference BatchToSpaceNd: input type not supported.");
379
380 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
381 "Reference BatchToSpaceNd: output type not supported.");
382
383 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
384 "Reference BatchToSpaceNd: input and output types mismatched.");
385
386 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
387 reasonIfUnsupported,
388 CreateIncorrectDimensionsErrorMsg(4,
389 output.GetNumDimensions(),
390 batchToSpaceNdLayerStr,
391 outputTensorStr).data());
392
393 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
394 reasonIfUnsupported,
395 CreateIncorrectDimensionsErrorMsg(4,
396 input.GetNumDimensions(),
397 batchToSpaceNdLayerStr,
398 inputTensorStr).data());
399
400 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000401}
402
Jim Flynn906f9462019-05-10 13:55:21 +0100403bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
404 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100405 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100406 Optional<std::string&> reasonIfUnsupported) const
407{
Jim Flynne242f2d2019-05-22 14:24:13 +0100408 ignore_unused(descriptor);
409
410 bool supported = true;
411 std::array<DataType,3> supportedTypes =
412 {
413 DataType::Float32,
414 DataType::QuantisedAsymm8,
415 DataType::QuantisedSymm16
416 };
417
418 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
419 "Reference concatenation: output type not supported");
420 for (const TensorInfo* input : inputs)
421 {
422 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
423 "Reference concatenation: input type not supported");
424
425 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
426 "Reference concatenation: input and output types mismatched.");
427 }
428
429 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100430}
431
arovir011c7c81b2018-10-08 11:34:28 +0100432bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
433 Optional<std::string&> reasonIfUnsupported) const
434{
Jim Flynne242f2d2019-05-22 14:24:13 +0100435 std::array<DataType,4> supportedTypes =
436 {
Nina Drozd58ef2c62019-05-16 12:09:18 +0100437 DataType::Float32,
438 DataType::Signed32,
439 DataType::QuantisedAsymm8,
440 DataType::QuantisedSymm16
441 };
442
443 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
444 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100445}
446
447bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
448 const TensorInfo& output,
449 Optional<std::string&> reasonIfUnsupported) const
450{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100451 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
452 input.GetDataType(),
453 &TrueFunc<>,
454 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000455 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000456 &FalseFuncI32<>,
457 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100458 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
459 output.GetDataType(),
460 &FalseOutputFuncF16<>,
461 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000462 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000463 &FalseFuncI32<>,
464 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100465}
466
467bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
468 const TensorInfo& output,
469 Optional<std::string&> reasonIfUnsupported) const
470{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100471 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
472 input.GetDataType(),
473 &FalseInputFuncF16<>,
474 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000475 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000476 &FalseFuncI32<>,
477 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100478 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
479 output.GetDataType(),
480 &TrueFunc<>,
481 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000482 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000483 &FalseFuncI32<>,
484 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100485}
486
487bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
488 const TensorInfo& output,
489 const Convolution2dDescriptor& descriptor,
490 const TensorInfo& weights,
491 const Optional<TensorInfo>& biases,
492 Optional<std::string&> reasonIfUnsupported) const
493{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100494 bool supported = true;
495
496 // Define supported types.
497 std::array<DataType,3> supportedTypes = {
498 DataType::Float32,
499 DataType::QuantisedAsymm8,
500 DataType::QuantisedSymm16
501 };
502
503 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100504 "Reference convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100505
506 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100507 "Reference convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100508
509 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100510 "Reference convolution2d: weights is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100511
512 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100513 "Reference convolution2d: input and output types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100514
515 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100516 "Reference convolution2d: input and weights types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100517
518 if (biases.has_value())
519 {
520 std::array<DataType,3> biasesSupportedTypes = {
521 DataType::Float32,
522 DataType::Signed32
523 };
524 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100525 "Reference convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100526 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100527 ignore_unused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100528
529 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100530}
531
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000532bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
533 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000534 Optional<std::string&> reasonIfUnsupported) const
535{
536 ignore_unused(output);
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000537 return IsSupportedForDataTypeRef(reasonIfUnsupported,
538 input.GetDataType(),
539 &TrueFunc<>,
540 &TrueFunc<>);
541}
542
arovir011c7c81b2018-10-08 11:34:28 +0100543bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
544 const TensorInfo& output,
545 const DepthwiseConvolution2dDescriptor& descriptor,
546 const TensorInfo& weights,
547 const Optional<TensorInfo>& biases,
548 Optional<std::string&> reasonIfUnsupported) const
549{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100550 bool supported = true;
551
552 // Define supported types.
553 std::array<DataType,3> supportedTypes =
554 {
555 DataType::Float32,
556 DataType::QuantisedAsymm8,
557 DataType::QuantisedSymm16
558 };
559
560 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
561 "Reference DepthwiseConvolution2d: input is not a supported type.");
562
563 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
564 "Reference DepthwiseConvolution2d: output is not a supported type.");
565
566 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
567 "Reference DepthwiseConvolution2d: weights is not a supported type.");
568
569 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
570 "Reference DepthwiseConvolution2d: input and output types mismatched.");
571
572 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
573 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
574
575 if (biases.has_value())
576 {
577 std::array<DataType,2> biasesSupportedTypes =
578 {
579 DataType::Float32,
580 DataType::Signed32
581 };
582 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
583 "Reference DepthwiseConvolution2d: biases is not a supported type.");
584 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100585 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100586
587 return supported;
588
arovir011c7c81b2018-10-08 11:34:28 +0100589}
590
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000591bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
592 const TensorInfo& output,
593 Optional<std::string&> reasonIfUnsupported) const
594{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100595 bool supported = true;
596
597 std::array<DataType,2> supportedInputTypes = {
598 DataType::QuantisedAsymm8,
599 DataType::QuantisedSymm16
600 };
601
602 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
603 "Reference dequantize: input type not supported.");
604
605 std::array<DataType,2> supportedOutputTypes = {
606 DataType::Float32,
607 };
608
609 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
610 "Reference dequantize: output type not supported.");
611
612 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
613 "Reference dequantize: input and output shapes have different num total elements.");
614
615 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000616}
617
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000618bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
619 const armnn::TensorInfo& input1,
620 const armnn::DetectionPostProcessDescriptor& descriptor,
621 armnn::Optional<std::string&> reasonIfUnsupported) const
622{
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100623 bool supported = true;
624
625 std::vector<DataType> supportedInputTypes =
626 {
627 DataType::Float32,
628 DataType::QuantisedAsymm8,
629 DataType::QuantisedSymm16
630 };
631
632 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
633 "Reference DetectionPostProcess: input 0 is not a supported type.");
634
635 supported &= CheckSupportRule(TypeAnyOf(input1, supportedInputTypes), reasonIfUnsupported,
636 "Reference DetectionPostProcess: input 1 is not a supported type.");
637
638 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000639}
640
Pablo Tellof0bd6832019-04-26 17:58:13 +0100641bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
642 const TensorInfo& output,
643 const DepthwiseConvolution2dDescriptor& descriptor,
644 const TensorInfo& weights,
645 const Optional<TensorInfo>& biases,
646 Optional<std::string&> reasonIfUnsupported) const
647{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100648 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100649}
650
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100651bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100652 const TensorInfo& input1,
653 const TensorInfo& output,
654 Optional<std::string&> reasonIfUnsupported) const
655{
Sadik Armagan2999a022019-04-09 14:20:12 +0100656 bool supported = true;
657
658 std::array<DataType,3> supportedTypes = {
659 DataType::Float32,
660 DataType::QuantisedAsymm8,
661 DataType::QuantisedSymm16
662 };
663
664 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
665 "Reference division: input 0 is not a supported type.");
666
667 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
668 "Reference division: input 1 is not a supported type.");
669
670 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
671 "Reference division: output is not a supported type.");
672
673 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
674 "Reference division: input 0 and Input 1 types are mismatched");
675
676 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
677 "Reference division: input and output types are mismatched");
678
679 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
680 "Reference division: shapes are not suitable for implicit broadcast.");
681
682 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100683}
684
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000685bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
686 const TensorInfo& input1,
687 const TensorInfo& output,
688 Optional<std::string&> reasonIfUnsupported) const
689{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100690 bool supported = true;
691
692 std::array<DataType,3> supportedTypes =
693 {
694 DataType::Float32,
695 DataType::QuantisedAsymm8,
696 DataType::QuantisedSymm16
697 };
698
699 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
700 "Reference equal: input 0 is not a supported type.");
701
702 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
703 "Reference equal: input 1 is not a supported type.");
704
705 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
706 "Reference equal: input 0 and Input 1 types are mismatched");
707
708 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
709 "Reference equal: shapes are not suitable for implicit broadcast.");
710
711 return supported;
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000712}
713
arovir011c7c81b2018-10-08 11:34:28 +0100714bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
715 const FakeQuantizationDescriptor& descriptor,
716 Optional<std::string&> reasonIfUnsupported) const
717{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100718 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100719 bool supported = true;
720
721 std::array<DataType,1> supportedTypes =
722 {
723 DataType::Float32
724 };
725
726 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
727 "Reference fake quantization: input type not supported.");
728
729 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100730}
731
732bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
733 const TensorInfo& output,
734 Optional<std::string&> reasonIfUnsupported) const
735{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100736 ignore_unused(output);
James Conroy83735b12019-05-30 16:36:59 +0100737 bool supported = true;
738
James Conroyb40d7102019-06-04 12:32:09 +0100739 std::array<DataType,2> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100740 {
James Conroyb40d7102019-06-04 12:32:09 +0100741 DataType::Float32,
742 DataType::QuantisedSymm16
James Conroy83735b12019-05-30 16:36:59 +0100743 };
744
745 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
746 "Reference Floor: input type not supported.");
747
748 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
749 "Reference Floor: output type not supported.");
750
751 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100752}
753
754bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
755 const TensorInfo& output,
756 const TensorInfo& weights,
757 const TensorInfo& biases,
758 const FullyConnectedDescriptor& descriptor,
759 Optional<std::string&> reasonIfUnsupported) const
760{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100761 bool supported = true;
762
763 // Define supported types.
764 std::array<DataType,3> supportedTypes =
765 {
766 DataType::Float32,
767 DataType::QuantisedAsymm8,
768 DataType::QuantisedSymm16
769 };
770
771 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
772 "Reference Fully Connected: input type not supported.");
773
774 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
775 "Reference Fully Connected: output type not supported.");
776
777 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
778 "Reference Fully Connected: input and output types mismatched.");
779
780 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
781 "Reference Fully Connected: weights type not supported.");
782
783 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
784 "Reference Fully Connected: input and weight types mismatched.");
785
786 if (descriptor.m_BiasEnabled)
787 {
788 // Defined supported types for bias
789 std::array<DataType, 2>
790 supportedBiasTypes =
791 {
792 DataType::Float32,
793 DataType::Signed32
794 };
795
796 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
797 "Reference Fully Connected: bias type not supported.");
798
799 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
800 "Reference Fully Connected: bias and weight types mismatch.");
801
802 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
803 "Reference Fully Connected: bias type inferred from weights is incompatible.");
804
805 }
806
807 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100808}
809
narpra014951d842019-01-18 16:53:53 +0000810bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
811 const armnn::TensorInfo& input1,
812 const armnn::TensorInfo& output,
813 armnn::Optional<std::string&> reasonIfUnsupported) const
814{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100815 bool supported = true;
816 std::array<DataType,3> supportedTypes =
817 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100818 DataType::Float32,
819 DataType::QuantisedAsymm8,
820 DataType::QuantisedSymm16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100821 };
822
823 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
824 "Reference Gather: input type not supported");
825
826 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
827 "Reference Gather: output type not supported");
828
829 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
830 "Reference Gather: indices (input1) type not supported");
831
832 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
833 "Reference Gather: input and output types not matching");
834
835 return supported;
narpra014951d842019-01-18 16:53:53 +0000836}
837
FrancisMurtagh878f0232018-12-19 10:56:15 +0000838bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
839 const TensorInfo& input1,
840 const TensorInfo& output,
841 Optional<std::string&> reasonIfUnsupported) const
842{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100843 bool supported = true;
844
845 std::array<DataType,3> supportedTypes =
846 {
847 DataType::Float32,
848 DataType::QuantisedAsymm8,
849 DataType::QuantisedSymm16
850 };
851
852 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
853 "Reference greater: input 0 is not a supported type.");
854
855 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
856 "Reference greater: input 1 is not a supported type.");
857
858 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
859 "Reference greater: input 0 and Input 1 types are mismatched");
860
861 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
862 "Reference greater: shapes are not suitable for implicit broadcast.");
863
864 return supported;
FrancisMurtagh878f0232018-12-19 10:56:15 +0000865}
866
arovir011c7c81b2018-10-08 11:34:28 +0100867bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
868 Optional<std::string&> reasonIfUnsupported) const
869{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +0100870 return true;
arovir011c7c81b2018-10-08 11:34:28 +0100871}
872
873bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
874 const TensorInfo& output,
875 const L2NormalizationDescriptor& descriptor,
876 Optional<std::string&> reasonIfUnsupported) const
877{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100878 ignore_unused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100879 // Define supported types
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100880 std::array<DataType, 3> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100881 {
882 DataType::Float32,
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100883 DataType::QuantisedAsymm8,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100884 DataType::QuantisedSymm16
885 };
886
887 bool supported = true;
888
889 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
890 "Reference L2normalization: input type not supported.");
891
892 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
893 "Reference L2normalization: output type not supported.");
894
895 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
896 "Reference L2normalization: input and output types mismatched.");
897
898 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
899 "Reference L2normalization: input and output shapes have different "
900 "num total elements.");
901
902 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100903}
904
905bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
906 const TensorInfo& outputStateIn,
907 const TensorInfo& cellStateIn,
908 const TensorInfo& scratchBuffer,
909 const TensorInfo& outputStateOut,
910 const TensorInfo& cellStateOut,
911 const TensorInfo& output,
912 const LstmDescriptor& descriptor,
913 const TensorInfo& inputToForgetWeights,
914 const TensorInfo& inputToCellWeights,
915 const TensorInfo& inputToOutputWeights,
916 const TensorInfo& recurrentToForgetWeights,
917 const TensorInfo& recurrentToCellWeights,
918 const TensorInfo& recurrentToOutputWeights,
919 const TensorInfo& forgetGateBias,
920 const TensorInfo& cellBias,
921 const TensorInfo& outputGateBias,
922 const TensorInfo* inputToInputWeights,
923 const TensorInfo* recurrentToInputWeights,
924 const TensorInfo* cellToInputWeights,
925 const TensorInfo* inputGateBias,
926 const TensorInfo* projectionWeights,
927 const TensorInfo* projectionBias,
928 const TensorInfo* cellToForgetWeights,
929 const TensorInfo* cellToOutputWeights,
Jan Eilers38e05bd2019-06-26 13:10:09 +0100930 Optional<std::string&> reasonIfUnsupported,
931 const TensorInfo* inputLayerNormWeights,
932 const TensorInfo* forgetLayerNormWeights,
933 const TensorInfo* cellLayerNormWeights,
934 const TensorInfo* outputLayerNormWeights) const
arovir011c7c81b2018-10-08 11:34:28 +0100935{
telsoa01c577f2c2018-08-31 09:22:23 +0100936 ignore_unused(descriptor);
937 ignore_unused(inputToForgetWeights);
938 ignore_unused(inputToCellWeights);
939 ignore_unused(inputToOutputWeights);
940 ignore_unused(recurrentToForgetWeights);
941 ignore_unused(recurrentToCellWeights);
942 ignore_unused(recurrentToOutputWeights);
943 ignore_unused(forgetGateBias);
944 ignore_unused(cellBias);
945 ignore_unused(outputGateBias);
946 ignore_unused(inputToInputWeights);
947 ignore_unused(recurrentToInputWeights);
948 ignore_unused(cellToInputWeights);
949 ignore_unused(inputGateBias);
950 ignore_unused(projectionWeights);
951 ignore_unused(projectionBias);
952 ignore_unused(cellToForgetWeights);
953 ignore_unused(cellToOutputWeights);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100954 ignore_unused(inputLayerNormWeights);
955 ignore_unused(forgetLayerNormWeights);
956 ignore_unused(cellLayerNormWeights);
957 ignore_unused(outputLayerNormWeights);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100958
959 bool supported = true;
960
961 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +0100962 DataType::Float32,
963 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100964 };
965
966 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
967 "Reference Lstm: input is not a supported type.");
968
969 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
970 "Reference Lstm: input and outputStateIn types are mismatched");
971
972 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
973 "Reference Lstm: input and cellStateIn types are mismatched");
974
975 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
976 "Reference Lstm: input and scratchBuffer types are mismatched");
977
978 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
979 "Reference Lstm: input and outputStateOut types are mismatched");
980
981 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
982 "Reference Lstm: input and cellStateOut types are mismatched");
983
984 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
985 "Reference Lstm: input and output types are mismatched");
986
987 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +0100988}
989
saoste012df12b32018-11-28 16:57:20 +0000990bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
991 const TensorInfo& input1,
992 const TensorInfo& output,
993 Optional<std::string&> reasonIfUnsupported) const
994{
Sadik Armagan2999a022019-04-09 14:20:12 +0100995 bool supported = true;
996
997 std::array<DataType,3> supportedTypes = {
998 DataType::Float32,
999 DataType::QuantisedAsymm8,
1000 DataType::QuantisedSymm16
1001 };
1002
1003 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1004 "Reference maximum: input 0 is not a supported type.");
1005
1006 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1007 "Reference maximum: input 1 is not a supported type.");
1008
1009 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1010 "Reference maximum: output is not a supported type.");
1011
1012 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1013 "Reference maximum: input 0 and Input 1 types are mismatched");
1014
1015 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1016 "Reference maximum: input and output types are mismatched");
1017
1018 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1019 "Reference maximum: shapes are not suitable for implicit broadcast.");
1020
1021 return supported;
saoste012df12b32018-11-28 16:57:20 +00001022}
1023
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001024bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1025 const TensorInfo& output,
1026 const MeanDescriptor& descriptor,
1027 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001028{
James Conroy4d1ff582019-06-10 17:06:39 +01001029 bool supported = true;
1030 std::string meanLayerStr = "Mean";
1031 std::string outputTensorStr = "output";
1032
James Conroyb80775f2019-06-11 11:25:30 +01001033 std::array<DataType,3> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001034 {
1035 DataType::Float32,
James Conroyb80775f2019-06-11 11:25:30 +01001036 DataType::QuantisedAsymm8,
1037 DataType::QuantisedSymm16
James Conroy4d1ff582019-06-10 17:06:39 +01001038 };
1039
1040 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1041 "Reference Mean: input type not supported.");
1042
1043 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1044 "Reference Mean: input and output types are mismatched");
1045
1046 if (descriptor.m_KeepDims)
1047 {
1048 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1049 reasonIfUnsupported,
1050 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1051 output.GetNumDimensions(),
1052 meanLayerStr, outputTensorStr).data());
1053 }
1054 else if (descriptor.m_Axis.empty())
1055 {
1056 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1057 reasonIfUnsupported,
1058 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1059 meanLayerStr, outputTensorStr).data());
1060 }
1061 else
1062 {
1063 auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
1064
1065 if (outputDim > 0)
1066 {
1067 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1068 reasonIfUnsupported,
1069 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1070 meanLayerStr, outputTensorStr).data());
1071 }
1072 else
1073 {
1074 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1075 reasonIfUnsupported,
1076 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1077 meanLayerStr, outputTensorStr).data());
1078 }
1079 }
1080
1081 return supported;
narpra0132b90462018-09-13 11:07:48 +01001082}
1083
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001084bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +00001085 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +01001086 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001087 Optional<std::string&> reasonIfUnsupported) const
1088{
Jim Flynne242f2d2019-05-22 14:24:13 +01001089 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001090}
1091
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001092bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1093 const TensorInfo &output,
1094 Optional<std::string &> reasonIfUnsupported) const
1095{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001096 bool supported = true;
1097
1098 std::array<DataType,5> supportedTypes =
1099 {
1100 DataType::Float32,
1101 DataType::Float16,
1102 DataType::QuantisedAsymm8,
1103 DataType::QuantisedSymm16,
1104 DataType::Boolean
1105 };
1106
1107 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1108 "Reference MemCopy: input type not supported");
1109
1110 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1111 "Reference MemCopy: output type not supported");
1112
1113 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1114 "Reference MemCopy: input and output types are mismatched");
1115
1116 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001117}
1118
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001119bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1120 const TensorInfo& input1,
1121 const TensorInfo& output,
1122 Optional<std::string&> reasonIfUnsupported) const
1123{
Sadik Armagan2999a022019-04-09 14:20:12 +01001124 bool supported = true;
1125
1126 std::array<DataType,3> supportedTypes = {
1127 DataType::Float32,
1128 DataType::QuantisedAsymm8,
1129 DataType::QuantisedSymm16
1130 };
1131
1132 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1133 "Reference minimum: input 0 is not a supported type.");
1134
1135 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1136 "Reference minimum: input 1 is not a supported type.");
1137
1138 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1139 "Reference minimum: output is not a supported type.");
1140
1141 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1142 "Reference minimum: input 0 and Input 1 types are mismatched");
1143
1144 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1145 "Reference minimum: input and output types are mismatched");
1146
1147 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1148 "Reference minimum: shapes are not suitable for implicit broadcast.");
1149
1150 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001151}
1152
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001153bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1154 const TensorInfo& input1,
1155 const TensorInfo& output,
1156 Optional<std::string&> reasonIfUnsupported) const
1157{
Sadik Armagan2999a022019-04-09 14:20:12 +01001158 bool supported = true;
1159
1160 std::array<DataType,3> supportedTypes = {
1161 DataType::Float32,
1162 DataType::QuantisedAsymm8,
1163 DataType::QuantisedSymm16
1164 };
1165
1166 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1167 "Reference multiplication: input 0 is not a supported type.");
1168
1169 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1170 "Reference multiplication: input 1 is not a supported type.");
1171
1172 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1173 "Reference multiplication: output is not a supported type.");
1174
1175 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1176 "Reference multiplication: input 0 and Input 1 types are mismatched");
1177
1178 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1179 "Reference multiplication: input and output types are mismatched");
1180
1181 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1182 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1183
1184 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001185}
1186
1187bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1188 const TensorInfo& output,
1189 const NormalizationDescriptor& descriptor,
1190 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001191{
Nina Drozd661dfa72018-10-02 11:14:17 +01001192 ignore_unused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001193
1194 // Define supported types
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001195 std::array<DataType, 4> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001196 {
1197 DataType::Float16,
1198 DataType::Float32,
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001199 DataType::QuantisedAsymm8,
1200 DataType::QuantisedSymm16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001201 };
1202
1203 bool supported = true;
1204
1205 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1206 "Reference normalization: input type not supported.");
1207
1208 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1209 "Reference normalization: output type not supported.");
1210
1211 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1212 "Reference normalization: input and output shapes have different "
1213 "num total elements.");
1214
1215 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001216}
1217
1218bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
1219 Optional<std::string&> reasonIfUnsupported) const
1220{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001221 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001222}
1223
1224bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1225 const TensorInfo& output,
1226 const PadDescriptor& descriptor,
1227 Optional<std::string&> reasonIfUnsupported) const
1228{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001229 ignore_unused(output);
1230 ignore_unused(descriptor);
jimfly01f6ba7472018-12-04 10:09:52 +00001231 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1232 input.GetDataType(),
1233 &TrueFunc<>,
1234 &TrueFunc<>);
Nina Drozd661dfa72018-10-02 11:14:17 +01001235}
1236
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001237bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1238 const TensorInfo& output,
1239 const PermuteDescriptor& descriptor,
1240 Optional<std::string&> reasonIfUnsupported) const
1241{
1242 ignore_unused(output);
1243 ignore_unused(descriptor);
1244 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1245 input.GetDataType(),
1246 &TrueFunc<>,
1247 &TrueFunc<>);
1248}
1249
1250bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1251 const TensorInfo& output,
1252 const Pooling2dDescriptor& descriptor,
1253 Optional<std::string&> reasonIfUnsupported) const
1254{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001255 ignore_unused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001256 bool supported = true;
1257
1258 // Define supported output and inputs types.
Teresa Charlin0434df62019-06-06 13:40:35 +01001259 std::array<DataType,3> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001260 {
1261 DataType::Float32,
Teresa Charlin0434df62019-06-06 13:40:35 +01001262 DataType::QuantisedAsymm8,
1263 DataType::QuantisedSymm16
Teresa Charlina3b20472019-06-06 11:12:32 +01001264 };
1265
1266 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1267 "Reference poolind2d: input is not a supported type.");
1268
1269 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1270 "Reference poolind2d: output is not a supported type.");
1271
1272 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1273 "Reference poolind2d: input and output types are mismatched.");
1274
1275 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001276}
1277
Derek Lamberti5f400d62019-03-25 15:41:58 +00001278bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1279 const TensorInfo& output,
1280 Optional<std::string&> reasonIfUnsupported) const
1281{
1282 bool supported = true;
1283
1284 // Define supported output types.
1285 std::array<DataType,2> supportedInputTypes = {
1286 DataType::Float32,
1287 };
1288
1289 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1290 "Reference quantize: input type not supported.");
1291
1292 // Define supported output types.
1293 std::array<DataType,2> supportedOutputTypes = {
1294 DataType::QuantisedAsymm8,
1295 DataType::QuantisedSymm16
1296 };
1297 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1298 "Reference quantize: output type not supported.");
1299
1300 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1301 "Reference quantize: input and output shapes have different num total elements.");
1302
1303 return supported;
1304}
1305
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001306bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001307 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001308 Optional<std::string&> reasonIfUnsupported) const
1309{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001310 ignore_unused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001311 // Define supported output types.
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001312 std::array<DataType,4> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001313 {
1314 DataType::Float32,
1315 DataType::Float16,
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001316 DataType::QuantisedAsymm8,
1317 DataType::QuantisedSymm16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001318 };
1319 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1320 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001321}
1322
1323bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001324 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001325 Optional<std::string&> reasonIfUnsupported) const
1326{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001327 bool supported = true;
1328 std::array<DataType,3> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001329 {
1330 DataType::Float32,
1331 DataType::QuantisedAsymm8,
1332 DataType::QuantisedSymm16
1333 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001334
1335 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1336 "Reference ResizeBilinear: input type not supported");
1337
1338 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1339 "Reference ResizeBilinear: output type not supported");
1340
1341 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1342 "Reference ResizeBilinear: input and output types not matching");
1343
1344 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001345}
1346
Teresa Charlin970f43b2019-07-01 13:51:07 +01001347bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1348 const TensorInfo& output,
1349 const ResizeDescriptor& descriptor,
1350 Optional<std::string&> reasonIfUnsupported) const
1351{
1352 bool supported = true;
1353 std::array<DataType,3> supportedTypes =
1354 {
1355 DataType::Float32,
1356 DataType::QuantisedAsymm8,
1357 DataType::QuantisedSymm16
1358 };
1359
1360 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1361 "Reference Resize: input type not supported");
1362
1363 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1364 "Reference Resize: output type not supported");
1365
1366 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1367 "Reference Resize: input and output types not matching");
1368
1369 return supported;
1370}
1371
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001372bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1373 const TensorInfo& output,
1374 Optional<std::string&> reasonIfUnsupported) const
1375{
nikraj010421e7f2019-06-14 09:40:34 +01001376 bool supported = true;
nikraj0124d73212019-06-14 14:20:40 +01001377 std::array<DataType,3> supportedTypes =
nikraj010421e7f2019-06-14 09:40:34 +01001378 {
1379 DataType::Float32,
nikraj0124d73212019-06-14 14:20:40 +01001380 DataType::QuantisedAsymm8,
1381 DataType::QuantisedSymm16
nikraj010421e7f2019-06-14 09:40:34 +01001382 };
1383
1384 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1385 "Reference rsqrt: input type not supported");
1386
1387 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1388 "Reference rsqrt: output type not supported");
1389
1390 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1391 "Reference rsqrt: input and output types not matching");
1392
1393 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1394 "Reference Rsqrt: input and output shapes have different number of total elements");
1395
1396 return supported;
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001397}
1398
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001399bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1400 const TensorInfo& output,
1401 const SoftmaxDescriptor& descriptor,
1402 Optional<std::string&> reasonIfUnsupported) const
1403{
1404 ignore_unused(output);
nikraj01248683f2019-05-29 16:46:50 +01001405 bool supported = true;
1406 std::array<DataType,3> supportedTypes =
1407 {
1408 DataType::Float32,
1409 DataType::QuantisedAsymm8,
1410 DataType::QuantisedSymm16
1411 };
1412
1413 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1414 "Reference concatenation: output type not supported");
1415
1416 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1417 "Reference concatenation: input type not supported");
1418
1419 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1420 "Reference concatenation: input type not supported");
1421
1422 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001423}
1424
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001425bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1426 const TensorInfo& output,
1427 const SpaceToBatchNdDescriptor& descriptor,
1428 Optional<std::string&> reasonIfUnsupported) const
1429{
1430 ignore_unused(output);
nikraj01120522a2019-05-31 11:33:07 +01001431 bool supported = true;
1432 std::array<DataType,3> supportedTypes =
1433 {
1434 DataType::Float32,
1435 DataType::QuantisedAsymm8,
1436 DataType::QuantisedSymm16
1437 };
1438
1439 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1440 "Reference SpaceToBatchNd: input type not supported");
1441
1442 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1443 "Reference SpaceToBatchNd: output type not supported");
1444
1445 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1446 "Reference SpaceToBatchNd: input and output types are mismatched");
1447
1448 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001449}
1450
Keith Davisa57eccb2019-06-14 17:33:22 +01001451bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01001452 const TensorInfo& output,
1453 const SpaceToDepthDescriptor& descriptor,
1454 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01001455{
1456
1457 ignore_unused(descriptor);
1458 bool supported = true;
1459
1460 std::array<DataType,2> supportedTypes =
1461 {
1462 DataType::Float32,
1463 DataType::QuantisedAsymm8,
1464 };
1465
1466 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1467 "Reference SpaceToDepth: input type not supported");
1468
1469 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1470 "Reference SpaceToDepth: output type not supported");
1471
1472 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1473 "Reference SpaceToDepth: input and output types are mismatched");
1474
1475 return supported;
1476}
1477
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001478bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1479 const ViewsDescriptor& descriptor,
1480 Optional<std::string&> reasonIfUnsupported) const
1481{
1482 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001483 bool supported = true;
1484 std::array<DataType,3> supportedTypes =
1485 {
1486 DataType::Float32,
1487 DataType::QuantisedAsymm8,
1488 DataType::QuantisedSymm16
1489 };
1490
1491 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1492 "Reference splitter: input type not supported");
1493
1494 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001495}
1496
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001497bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1498 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1499 const ViewsDescriptor& descriptor,
1500 Optional<std::string&> reasonIfUnsupported) const
1501{
1502 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001503 bool supported = true;
1504 std::array<DataType,3> supportedTypes =
1505 {
1506 DataType::Float32,
1507 DataType::QuantisedAsymm8,
1508 DataType::QuantisedSymm16
1509 };
1510
1511 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1512 "Reference splitter: output type not supported");
1513 for (const TensorInfo output : outputs)
1514 {
1515 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1516 "Reference splitter: input type not supported");
1517
1518 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1519 "Reference splitter: input and output types mismatched.");
1520 }
1521
1522 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001523}
1524
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001525bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1526 const TensorInfo& output,
1527 const StridedSliceDescriptor& descriptor,
1528 Optional<std::string&> reasonIfUnsupported) const
1529{
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001530 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001531 bool supported = true;
1532
1533 std::array<DataType,3> supportedTypes =
1534 {
1535 DataType::Float32,
1536 DataType::QuantisedAsymm8,
1537 DataType::QuantisedSymm16
1538 };
1539
1540 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1541 "Reference StridedSlice: input type not supported");
1542
1543 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1544 "Reference StridedSlice: output type not supported");
1545
1546 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1547 "Reference StridedSlice: input and output types are mismatched");
1548
1549 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001550}
1551
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001552bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1553 const TensorInfo& input1,
1554 const TensorInfo& output,
1555 Optional<std::string&> reasonIfUnsupported) const
1556{
Sadik Armagan2999a022019-04-09 14:20:12 +01001557 bool supported = true;
1558
1559 std::array<DataType,3> supportedTypes = {
1560 DataType::Float32,
1561 DataType::QuantisedAsymm8,
1562 DataType::QuantisedSymm16
1563 };
1564
1565 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1566 "Reference subtraction: input 0 is not a supported type.");
1567
1568 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1569 "Reference subtraction: input 1 is not a supported type.");
1570
1571 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1572 "Reference subtraction: output is not a supported type.");
1573
1574 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1575 "Reference subtraction: input 0 and Input 1 types are mismatched");
1576
1577 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1578 "Reference subtraction: input and output types are mismatched");
1579
1580 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1581 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1582
1583 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001584}
1585
Matteo Martincighab9e5252019-06-13 17:27:46 +01001586bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
1587 const TensorInfo& alpha,
1588 const TensorInfo& output,
1589 Optional<std::string&> reasonIfUnsupported) const
1590{
1591 bool supported = true;
1592
1593 std::array<DataType, 3> supportedTypes
1594 {
1595 DataType::Float32,
1596 DataType::QuantisedAsymm8,
1597 DataType::QuantisedSymm16
1598 };
1599
1600 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1601 "PReLU: input is not a supported type.");
1602
1603 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
1604 "PReLU: alpha is not a supported type.");
1605
1606 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1607 "PReLU: output is not a supported type.");
1608
1609 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
1610 "PReLU: input, alpha and output types are mismatched");
1611
1612 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
1613 "PReLU: shapes are not suitable for implicit broadcast");
1614
1615 return supported;
1616}
1617
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001618bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
1619 const TensorInfo& output,
1620 const TransposeConvolution2dDescriptor& descriptor,
1621 const TensorInfo& weights,
1622 const Optional<TensorInfo>& biases,
1623 Optional<std::string&> reasonIfUnsupported) const
1624{
1625 ignore_unused(descriptor);
1626
1627 bool supported = true;
1628
1629 std::array<DataType,3> supportedTypes =
1630 {
1631 DataType::Float32,
1632 DataType::QuantisedAsymm8,
1633 DataType::QuantisedSymm16
1634 };
1635
1636 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1637 "Reference TransposeConvolution2d: input is not a supported type.");
1638
1639 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1640 "Reference TransposeConvolution2d: output is not a supported type.");
1641
1642 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1643 "Reference TransposeConvolution2d: weights is not a supported type.");
1644
1645 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1646 "Reference TransposeConvolution2d: input and output types mismatched.");
1647
1648 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1649 "Reference TransposeConvolution2d: input and weights types mismatched.");
1650
1651 if (biases.has_value())
1652 {
1653 std::array<DataType,3> biasesSupportedTypes = {
1654 DataType::Float32,
1655 DataType::Signed32
1656 };
1657 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1658 "Reference TransposeConvolution2d: biases is not a supported type.");
1659 }
1660
1661 return supported;
1662}
1663
arovir011c7c81b2018-10-08 11:34:28 +01001664} // namespace armnn