blob: 429993a55f1db26b46ed7c271a0b6ca0befedfc0 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <InternalTypes.hpp>
10#include <LayerSupportCommon.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000012#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
David Beck111b5d92018-11-12 14:59:37 +000014#include <backendsCommon/BackendRegistry.hpp>
Francis Murtagh46c09d02019-05-28 08:15:28 +010015#include <backendsCommon/test/WorkloadTestUtils.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010016
telsoa014fcda012018-03-09 14:13:49 +000017#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
Derek Lamberti50db4e82019-03-13 14:16:15 +000019#include <vector>
20#include <algorithm>
21#include <array>
22
telsoa014fcda012018-03-09 14:13:49 +000023using namespace boost;
24
25namespace armnn
26{
27
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010028namespace
29{
30
31template<typename Float32Func, typename Uint8Func, typename ... Params>
32bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
33 DataType dataType,
34 Float32Func floatFuncPtr,
35 Uint8Func uint8FuncPtr,
36 Params&&... params)
37{
38 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
39 dataType,
40 &FalseFunc<Params...>,
41 floatFuncPtr,
42 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000043 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000044 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010045 std::forward<Params>(params)...);
46}
47
48} // anonymous namespace
49
James Conroy4d1ff582019-06-10 17:06:39 +010050namespace
51{
52
53std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
54 unsigned int actual,
55 std::string& layerStr,
56 std::string& tensorName)
57{
58 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
59 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
60
61 return errorMsg;
62}
63
64} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000065
66namespace
67{
68template<typename F>
69bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
70{
71 bool supported = rule();
72 if (!supported && reason)
73 {
74 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
75 }
76 return supported;
77}
78
79struct Rule
80{
81 bool operator()() const
82 {
83 return m_Res;
84 }
85
86 bool m_Res = true;
87};
88
Derek Lamberti2a434a82019-03-20 13:07:57 +000089template<typename T>
90bool AllTypesAreEqualImpl(T t)
Derek Lamberti50db4e82019-03-13 14:16:15 +000091{
92 return true;
93}
94
95template<typename T, typename... Rest>
96bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
97{
98 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
99
Derek Lamberti2a434a82019-03-20 13:07:57 +0000100 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
Derek Lamberti50db4e82019-03-13 14:16:15 +0000101}
102
103struct TypesAreEqual : public Rule
104{
105 template<typename ... Ts>
106 TypesAreEqual(const Ts&... ts)
107 {
108 m_Res = AllTypesAreEqualImpl(ts...);
109 }
110};
111
112struct QuantizationParametersAreEqual : public Rule
113{
114 QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
115 {
116 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
117 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
118 }
119};
120
121struct TypeAnyOf : public Rule
122{
123 template<typename Container>
124 TypeAnyOf(const TensorInfo& info, const Container& c)
125 {
126 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
Francis Murtagh46c09d02019-05-28 08:15:28 +0100127 {
128 return dt == info.GetDataType();
129 });
130 }
131};
132
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100133struct TypeIs : public Rule
134{
135 TypeIs(const TensorInfo& info, DataType dt)
136 {
137 m_Res = dt == info.GetDataType();
138 }
139};
140
Francis Murtagh46c09d02019-05-28 08:15:28 +0100141struct BiasAndWeightsTypesMatch : public Rule
142{
143 BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
144 {
145 m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value();
146 }
147};
148
149struct BiasAndWeightsTypesCompatible : public Rule
150{
151 template<typename Container>
152 BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c)
153 {
154 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
155 {
156 return dt == GetBiasTypeFromWeightsType(info.GetDataType()).value();
157 });
Derek Lamberti50db4e82019-03-13 14:16:15 +0000158 }
159};
160
161struct ShapesAreSameRank : public Rule
162{
163 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
164 {
165 m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
166 }
167};
168
Derek Lamberti5f400d62019-03-25 15:41:58 +0000169struct ShapesAreSameTotalSize : public Rule
170{
171 ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
172 {
173 m_Res = info0.GetNumElements() == info1.GetNumElements();
174 }
175};
176
Derek Lamberti50db4e82019-03-13 14:16:15 +0000177struct ShapesAreBroadcastCompatible : public Rule
178{
179 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
180 {
181 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
182 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
183 return sizeIn;
184 }
185
186 ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
187 {
188 const TensorShape& shape0 = in0.GetShape();
189 const TensorShape& shape1 = in1.GetShape();
190 const TensorShape& outShape = out.GetShape();
191
192 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
193 {
194 unsigned int sizeOut = outShape[i];
195 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
196 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
197
198 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
199 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
200 }
201 }
202};
James Conroy4d1ff582019-06-10 17:06:39 +0100203
204struct TensorNumDimensionsAreCorrect : public Rule
205{
206 TensorNumDimensionsAreCorrect(const TensorInfo& info, unsigned int expectedNumDimensions)
207 {
208 m_Res = info.GetNumDimensions() == expectedNumDimensions;
209 }
210};
211
Derek Lamberti50db4e82019-03-13 14:16:15 +0000212} // namespace
213
214
arovir011c7c81b2018-10-08 11:34:28 +0100215bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
216 const TensorInfo& output,
217 const ActivationDescriptor& descriptor,
218 Optional<std::string&> reasonIfUnsupported) const
219{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000220 bool supported = true;
221
222 // Define supported types.
Teresa Charlin18515e22019-04-24 10:17:46 +0100223 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000224 DataType::Float32,
Teresa Charlin18515e22019-04-24 10:17:46 +0100225 DataType::QuantisedAsymm8,
226 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000227 };
228
229 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
230 "Reference activation: input type not supported.");
231
232 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
233 "Reference activation: output type not supported.");
234
235 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
236 "Reference activation: input and output types mismatched.");
237
238 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
239 "Reference activation: input and output shapes are of different rank.");
240
241
242 struct ActivationFunctionSupported : public Rule
243 {
244 ActivationFunctionSupported(const ActivationDescriptor& desc)
245 {
246 switch(desc.m_Function)
247 {
248 case ActivationFunction::Abs:
249 case ActivationFunction::BoundedReLu:
250 case ActivationFunction::LeakyReLu:
251 case ActivationFunction::Linear:
252 case ActivationFunction::ReLu:
253 case ActivationFunction::Sigmoid:
254 case ActivationFunction::SoftReLu:
255 case ActivationFunction::Sqrt:
256 case ActivationFunction::Square:
257 case ActivationFunction::TanH:
258 {
259 m_Res = true;
260 break;
261 }
262 default:
263 {
264 m_Res = false;
265 break;
266 }
267 }
268 }
269 };
270
271 // Function is supported
272 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
273 "Reference activation: function not supported.");
274
275 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100276}
277
278bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
279 const TensorInfo& input1,
280 const TensorInfo& output,
281 Optional<std::string&> reasonIfUnsupported) const
282{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000283 bool supported = true;
284
Sadik Armagan2999a022019-04-09 14:20:12 +0100285 std::array<DataType,3> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000286 DataType::Float32,
Sadik Armagan2999a022019-04-09 14:20:12 +0100287 DataType::QuantisedAsymm8,
288 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000289 };
290
291 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
292 "Reference addition: input 0 is not a supported type.");
293
294 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
295 "Reference addition: input 1 is not a supported type.");
296
297 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
298 "Reference addition: output is not a supported type.");
299
300 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
301 "Reference addition: input 0 and Input 1 types are mismatched");
302
303 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
304 "Reference addition: input and output types are mismatched");
305
306 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
307 "Reference addition: shapes are not suitable for implicit broadcast.");
308
309 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100310}
311
312bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
313 const TensorInfo& output,
314 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100315 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100316 const TensorInfo& beta,
317 const TensorInfo& gamma,
318 const BatchNormalizationDescriptor& descriptor,
319 Optional<std::string&> reasonIfUnsupported) const
320{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100321 ignore_unused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100322
Matteo Martincighf5507132019-06-04 10:59:47 +0100323 std::array<DataType, 3> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100324 {
325 DataType::Float32,
Matteo Martincighf5507132019-06-04 10:59:47 +0100326 DataType::QuantisedAsymm8,
327 DataType::QuantisedSymm16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100328 };
329
330 bool supported = true;
331
332 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
333 "Reference batch normalization: input is not a supported type.");
334
335 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
336 "Reference batch normalization: output is not a supported type.");
337
338 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
339 "Reference batch normalization: input and output types are mismatched");
340
341 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
342 "Reference batch normalization: mean is not a supported type.");
343
344 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
345 "Reference batch normalization: variance is not a supported type.");
346
347 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
348 "Reference batch normalization: beta is not a supported type.");
349
350 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
351 "Reference batch normalization: gamma is not a supported type.");
352
353 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100354}
355
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000356bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
357 const TensorInfo& output,
358 const BatchToSpaceNdDescriptor& descriptor,
359 Optional<std::string&> reasonIfUnsupported) const
360{
361 ignore_unused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100362
363 bool supported = true;
364
365 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
366 std::string inputTensorStr = "input";
367 std::string outputTensorStr = "output";
368
369 // Define supported types.
370 std::array<DataType,3> supportedTypes =
371 {
372 DataType::Float32,
373 DataType::QuantisedAsymm8,
374 DataType::QuantisedSymm16
375 };
376
377 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
378 "Reference BatchToSpaceNd: input type not supported.");
379
380 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
381 "Reference BatchToSpaceNd: output type not supported.");
382
383 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
384 "Reference BatchToSpaceNd: input and output types mismatched.");
385
386 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
387 reasonIfUnsupported,
388 CreateIncorrectDimensionsErrorMsg(4,
389 output.GetNumDimensions(),
390 batchToSpaceNdLayerStr,
391 outputTensorStr).data());
392
393 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
394 reasonIfUnsupported,
395 CreateIncorrectDimensionsErrorMsg(4,
396 input.GetNumDimensions(),
397 batchToSpaceNdLayerStr,
398 inputTensorStr).data());
399
400 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000401}
402
Jim Flynn906f9462019-05-10 13:55:21 +0100403bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
404 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100405 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100406 Optional<std::string&> reasonIfUnsupported) const
407{
Jim Flynne242f2d2019-05-22 14:24:13 +0100408 ignore_unused(descriptor);
409
410 bool supported = true;
411 std::array<DataType,3> supportedTypes =
412 {
413 DataType::Float32,
414 DataType::QuantisedAsymm8,
415 DataType::QuantisedSymm16
416 };
417
418 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
419 "Reference concatenation: output type not supported");
420 for (const TensorInfo* input : inputs)
421 {
422 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
423 "Reference concatenation: input type not supported");
424
425 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
426 "Reference concatenation: input and output types mismatched.");
427 }
428
429 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100430}
431
arovir011c7c81b2018-10-08 11:34:28 +0100432bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
433 Optional<std::string&> reasonIfUnsupported) const
434{
Jim Flynne242f2d2019-05-22 14:24:13 +0100435 std::array<DataType,4> supportedTypes =
436 {
Nina Drozd58ef2c62019-05-16 12:09:18 +0100437 DataType::Float32,
438 DataType::Signed32,
439 DataType::QuantisedAsymm8,
440 DataType::QuantisedSymm16
441 };
442
443 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
444 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100445}
446
447bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
448 const TensorInfo& output,
449 Optional<std::string&> reasonIfUnsupported) const
450{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100451 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
452 input.GetDataType(),
453 &TrueFunc<>,
454 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000455 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000456 &FalseFuncI32<>,
457 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100458 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
459 output.GetDataType(),
460 &FalseOutputFuncF16<>,
461 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000462 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000463 &FalseFuncI32<>,
464 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100465}
466
467bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
468 const TensorInfo& output,
469 Optional<std::string&> reasonIfUnsupported) const
470{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100471 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
472 input.GetDataType(),
473 &FalseInputFuncF16<>,
474 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000475 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000476 &FalseFuncI32<>,
477 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100478 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
479 output.GetDataType(),
480 &TrueFunc<>,
481 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000482 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000483 &FalseFuncI32<>,
484 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100485}
486
487bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
488 const TensorInfo& output,
489 const Convolution2dDescriptor& descriptor,
490 const TensorInfo& weights,
491 const Optional<TensorInfo>& biases,
492 Optional<std::string&> reasonIfUnsupported) const
493{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100494 bool supported = true;
495
496 // Define supported types.
497 std::array<DataType,3> supportedTypes = {
498 DataType::Float32,
499 DataType::QuantisedAsymm8,
500 DataType::QuantisedSymm16
501 };
502
503 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100504 "Reference convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100505
506 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100507 "Reference convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100508
509 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100510 "Reference convolution2d: weights is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100511
512 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100513 "Reference convolution2d: input and output types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100514
515 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100516 "Reference convolution2d: input and weights types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100517
518 if (biases.has_value())
519 {
520 std::array<DataType,3> biasesSupportedTypes = {
521 DataType::Float32,
522 DataType::Signed32
523 };
524 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100525 "Reference convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100526 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100527 ignore_unused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100528
529 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100530}
531
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000532bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
533 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000534 Optional<std::string&> reasonIfUnsupported) const
535{
536 ignore_unused(output);
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000537 return IsSupportedForDataTypeRef(reasonIfUnsupported,
538 input.GetDataType(),
539 &TrueFunc<>,
540 &TrueFunc<>);
541}
542
arovir011c7c81b2018-10-08 11:34:28 +0100543bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
544 const TensorInfo& output,
545 const DepthwiseConvolution2dDescriptor& descriptor,
546 const TensorInfo& weights,
547 const Optional<TensorInfo>& biases,
548 Optional<std::string&> reasonIfUnsupported) const
549{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100550 ignore_unused(output);
551 ignore_unused(descriptor);
552 ignore_unused(weights);
553 ignore_unused(biases);
554 return IsSupportedForDataTypeRef(reasonIfUnsupported,
555 input.GetDataType(),
556 &TrueFunc<>,
557 &TrueFunc<>);
arovir011c7c81b2018-10-08 11:34:28 +0100558}
559
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000560bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
561 const TensorInfo& output,
562 Optional<std::string&> reasonIfUnsupported) const
563{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100564 bool supported = true;
565
566 std::array<DataType,2> supportedInputTypes = {
567 DataType::QuantisedAsymm8,
568 DataType::QuantisedSymm16
569 };
570
571 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
572 "Reference dequantize: input type not supported.");
573
574 std::array<DataType,2> supportedOutputTypes = {
575 DataType::Float32,
576 };
577
578 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
579 "Reference dequantize: output type not supported.");
580
581 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
582 "Reference dequantize: input and output shapes have different num total elements.");
583
584 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000585}
586
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000587bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
588 const armnn::TensorInfo& input1,
589 const armnn::DetectionPostProcessDescriptor& descriptor,
590 armnn::Optional<std::string&> reasonIfUnsupported) const
591{
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100592 bool supported = true;
593
594 std::vector<DataType> supportedInputTypes =
595 {
596 DataType::Float32,
597 DataType::QuantisedAsymm8,
598 DataType::QuantisedSymm16
599 };
600
601 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
602 "Reference DetectionPostProcess: input 0 is not a supported type.");
603
604 supported &= CheckSupportRule(TypeAnyOf(input1, supportedInputTypes), reasonIfUnsupported,
605 "Reference DetectionPostProcess: input 1 is not a supported type.");
606
607 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000608}
609
Pablo Tellof0bd6832019-04-26 17:58:13 +0100610bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
611 const TensorInfo& output,
612 const DepthwiseConvolution2dDescriptor& descriptor,
613 const TensorInfo& weights,
614 const Optional<TensorInfo>& biases,
615 Optional<std::string&> reasonIfUnsupported) const
616{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100617 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100618}
619
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100620bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100621 const TensorInfo& input1,
622 const TensorInfo& output,
623 Optional<std::string&> reasonIfUnsupported) const
624{
Sadik Armagan2999a022019-04-09 14:20:12 +0100625 bool supported = true;
626
627 std::array<DataType,3> supportedTypes = {
628 DataType::Float32,
629 DataType::QuantisedAsymm8,
630 DataType::QuantisedSymm16
631 };
632
633 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
634 "Reference division: input 0 is not a supported type.");
635
636 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
637 "Reference division: input 1 is not a supported type.");
638
639 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
640 "Reference division: output is not a supported type.");
641
642 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
643 "Reference division: input 0 and Input 1 types are mismatched");
644
645 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
646 "Reference division: input and output types are mismatched");
647
648 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
649 "Reference division: shapes are not suitable for implicit broadcast.");
650
651 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100652}
653
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000654bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
655 const TensorInfo& input1,
656 const TensorInfo& output,
657 Optional<std::string&> reasonIfUnsupported) const
658{
659 ignore_unused(input0);
660 ignore_unused(input1);
661 ignore_unused(output);
662 ignore_unused(reasonIfUnsupported);
663 return IsSupportedForDataTypeRef(reasonIfUnsupported,
664 input0.GetDataType(),
665 &TrueFunc<>,
666 &TrueFunc<>);
667}
668
arovir011c7c81b2018-10-08 11:34:28 +0100669bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
670 const FakeQuantizationDescriptor& descriptor,
671 Optional<std::string&> reasonIfUnsupported) const
672{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100673 ignore_unused(descriptor);
674 return IsSupportedForDataTypeRef(reasonIfUnsupported,
675 input.GetDataType(),
676 &TrueFunc<>,
677 &FalseFuncU8<>);
arovir011c7c81b2018-10-08 11:34:28 +0100678}
679
680bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
681 const TensorInfo& output,
682 Optional<std::string&> reasonIfUnsupported) const
683{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100684 ignore_unused(output);
James Conroy83735b12019-05-30 16:36:59 +0100685 bool supported = true;
686
James Conroyb40d7102019-06-04 12:32:09 +0100687 std::array<DataType,2> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100688 {
James Conroyb40d7102019-06-04 12:32:09 +0100689 DataType::Float32,
690 DataType::QuantisedSymm16
James Conroy83735b12019-05-30 16:36:59 +0100691 };
692
693 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
694 "Reference Floor: input type not supported.");
695
696 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
697 "Reference Floor: output type not supported.");
698
699 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100700}
701
702bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
703 const TensorInfo& output,
704 const TensorInfo& weights,
705 const TensorInfo& biases,
706 const FullyConnectedDescriptor& descriptor,
707 Optional<std::string&> reasonIfUnsupported) const
708{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100709 bool supported = true;
710
711 // Define supported types.
712 std::array<DataType,3> supportedTypes =
713 {
714 DataType::Float32,
715 DataType::QuantisedAsymm8,
716 DataType::QuantisedSymm16
717 };
718
719 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
720 "Reference Fully Connected: input type not supported.");
721
722 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
723 "Reference Fully Connected: output type not supported.");
724
725 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
726 "Reference Fully Connected: input and output types mismatched.");
727
728 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
729 "Reference Fully Connected: weights type not supported.");
730
731 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
732 "Reference Fully Connected: input and weight types mismatched.");
733
734 if (descriptor.m_BiasEnabled)
735 {
736 // Defined supported types for bias
737 std::array<DataType, 2>
738 supportedBiasTypes =
739 {
740 DataType::Float32,
741 DataType::Signed32
742 };
743
744 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
745 "Reference Fully Connected: bias type not supported.");
746
747 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
748 "Reference Fully Connected: bias and weight types mismatch.");
749
750 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
751 "Reference Fully Connected: bias type inferred from weights is incompatible.");
752
753 }
754
755 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100756}
757
narpra014951d842019-01-18 16:53:53 +0000758bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
759 const armnn::TensorInfo& input1,
760 const armnn::TensorInfo& output,
761 armnn::Optional<std::string&> reasonIfUnsupported) const
762{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100763 bool supported = true;
764 std::array<DataType,3> supportedTypes =
765 {
766 DataType::Float32,
767 DataType::QuantisedAsymm8,
768 DataType::QuantisedSymm16
769 };
770
771 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
772 "Reference Gather: input type not supported");
773
774 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
775 "Reference Gather: output type not supported");
776
777 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
778 "Reference Gather: indices (input1) type not supported");
779
780 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
781 "Reference Gather: input and output types not matching");
782
783 return supported;
narpra014951d842019-01-18 16:53:53 +0000784}
785
FrancisMurtagh878f0232018-12-19 10:56:15 +0000786bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
787 const TensorInfo& input1,
788 const TensorInfo& output,
789 Optional<std::string&> reasonIfUnsupported) const
790{
791 ignore_unused(input0);
792 ignore_unused(input1);
793 ignore_unused(output);
794 ignore_unused(reasonIfUnsupported);
795 return IsSupportedForDataTypeRef(reasonIfUnsupported,
796 input0.GetDataType(),
797 &TrueFunc<>,
798 &TrueFunc<>);
799}
800
arovir011c7c81b2018-10-08 11:34:28 +0100801bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
802 Optional<std::string&> reasonIfUnsupported) const
803{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +0100804 return true;
arovir011c7c81b2018-10-08 11:34:28 +0100805}
806
807bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
808 const TensorInfo& output,
809 const L2NormalizationDescriptor& descriptor,
810 Optional<std::string&> reasonIfUnsupported) const
811{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100812 ignore_unused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100813 // Define supported types
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100814 std::array<DataType, 3> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100815 {
816 DataType::Float32,
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100817 DataType::QuantisedAsymm8,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100818 DataType::QuantisedSymm16
819 };
820
821 bool supported = true;
822
823 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
824 "Reference L2normalization: input type not supported.");
825
826 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
827 "Reference L2normalization: output type not supported.");
828
829 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
830 "Reference L2normalization: input and output types mismatched.");
831
832 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
833 "Reference L2normalization: input and output shapes have different "
834 "num total elements.");
835
836 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100837}
838
839bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
840 const TensorInfo& outputStateIn,
841 const TensorInfo& cellStateIn,
842 const TensorInfo& scratchBuffer,
843 const TensorInfo& outputStateOut,
844 const TensorInfo& cellStateOut,
845 const TensorInfo& output,
846 const LstmDescriptor& descriptor,
847 const TensorInfo& inputToForgetWeights,
848 const TensorInfo& inputToCellWeights,
849 const TensorInfo& inputToOutputWeights,
850 const TensorInfo& recurrentToForgetWeights,
851 const TensorInfo& recurrentToCellWeights,
852 const TensorInfo& recurrentToOutputWeights,
853 const TensorInfo& forgetGateBias,
854 const TensorInfo& cellBias,
855 const TensorInfo& outputGateBias,
856 const TensorInfo* inputToInputWeights,
857 const TensorInfo* recurrentToInputWeights,
858 const TensorInfo* cellToInputWeights,
859 const TensorInfo* inputGateBias,
860 const TensorInfo* projectionWeights,
861 const TensorInfo* projectionBias,
862 const TensorInfo* cellToForgetWeights,
863 const TensorInfo* cellToOutputWeights,
864 Optional<std::string&> reasonIfUnsupported) const
865{
telsoa01c577f2c2018-08-31 09:22:23 +0100866 ignore_unused(descriptor);
867 ignore_unused(inputToForgetWeights);
868 ignore_unused(inputToCellWeights);
869 ignore_unused(inputToOutputWeights);
870 ignore_unused(recurrentToForgetWeights);
871 ignore_unused(recurrentToCellWeights);
872 ignore_unused(recurrentToOutputWeights);
873 ignore_unused(forgetGateBias);
874 ignore_unused(cellBias);
875 ignore_unused(outputGateBias);
876 ignore_unused(inputToInputWeights);
877 ignore_unused(recurrentToInputWeights);
878 ignore_unused(cellToInputWeights);
879 ignore_unused(inputGateBias);
880 ignore_unused(projectionWeights);
881 ignore_unused(projectionBias);
882 ignore_unused(cellToForgetWeights);
883 ignore_unused(cellToOutputWeights);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100884
885 bool supported = true;
886
887 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +0100888 DataType::Float32,
889 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100890 };
891
892 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
893 "Reference Lstm: input is not a supported type.");
894
895 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
896 "Reference Lstm: input and outputStateIn types are mismatched");
897
898 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
899 "Reference Lstm: input and cellStateIn types are mismatched");
900
901 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
902 "Reference Lstm: input and scratchBuffer types are mismatched");
903
904 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
905 "Reference Lstm: input and outputStateOut types are mismatched");
906
907 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
908 "Reference Lstm: input and cellStateOut types are mismatched");
909
910 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
911 "Reference Lstm: input and output types are mismatched");
912
913 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +0100914}
915
saoste012df12b32018-11-28 16:57:20 +0000916bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
917 const TensorInfo& input1,
918 const TensorInfo& output,
919 Optional<std::string&> reasonIfUnsupported) const
920{
Sadik Armagan2999a022019-04-09 14:20:12 +0100921 bool supported = true;
922
923 std::array<DataType,3> supportedTypes = {
924 DataType::Float32,
925 DataType::QuantisedAsymm8,
926 DataType::QuantisedSymm16
927 };
928
929 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
930 "Reference maximum: input 0 is not a supported type.");
931
932 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
933 "Reference maximum: input 1 is not a supported type.");
934
935 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
936 "Reference maximum: output is not a supported type.");
937
938 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
939 "Reference maximum: input 0 and Input 1 types are mismatched");
940
941 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
942 "Reference maximum: input and output types are mismatched");
943
944 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
945 "Reference maximum: shapes are not suitable for implicit broadcast.");
946
947 return supported;
saoste012df12b32018-11-28 16:57:20 +0000948}
949
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100950bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
951 const TensorInfo& output,
952 const MeanDescriptor& descriptor,
953 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100954{
James Conroy4d1ff582019-06-10 17:06:39 +0100955 bool supported = true;
956 std::string meanLayerStr = "Mean";
957 std::string outputTensorStr = "output";
958
James Conroyb80775f2019-06-11 11:25:30 +0100959 std::array<DataType,3> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +0100960 {
961 DataType::Float32,
James Conroyb80775f2019-06-11 11:25:30 +0100962 DataType::QuantisedAsymm8,
963 DataType::QuantisedSymm16
James Conroy4d1ff582019-06-10 17:06:39 +0100964 };
965
966 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
967 "Reference Mean: input type not supported.");
968
969 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
970 "Reference Mean: input and output types are mismatched");
971
972 if (descriptor.m_KeepDims)
973 {
974 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
975 reasonIfUnsupported,
976 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
977 output.GetNumDimensions(),
978 meanLayerStr, outputTensorStr).data());
979 }
980 else if (descriptor.m_Axis.empty())
981 {
982 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
983 reasonIfUnsupported,
984 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
985 meanLayerStr, outputTensorStr).data());
986 }
987 else
988 {
989 auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
990
991 if (outputDim > 0)
992 {
993 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
994 reasonIfUnsupported,
995 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
996 meanLayerStr, outputTensorStr).data());
997 }
998 else
999 {
1000 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1001 reasonIfUnsupported,
1002 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1003 meanLayerStr, outputTensorStr).data());
1004 }
1005 }
1006
1007 return supported;
narpra0132b90462018-09-13 11:07:48 +01001008}
1009
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001010bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +00001011 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +01001012 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001013 Optional<std::string&> reasonIfUnsupported) const
1014{
Jim Flynne242f2d2019-05-22 14:24:13 +01001015 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001016}
1017
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001018bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1019 const TensorInfo &output,
1020 Optional<std::string &> reasonIfUnsupported) const
1021{
1022 ignore_unused(output);
kevmay012b4d88e2019-01-24 14:05:09 +00001023 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
1024 input.GetDataType(),
1025 &TrueFunc<>,
1026 &TrueFunc<>,
1027 &TrueFunc<>,
1028 &FalseFuncI32<>,
1029 &TrueFunc<>);
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001030}
1031
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001032bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1033 const TensorInfo& input1,
1034 const TensorInfo& output,
1035 Optional<std::string&> reasonIfUnsupported) const
1036{
Sadik Armagan2999a022019-04-09 14:20:12 +01001037 bool supported = true;
1038
1039 std::array<DataType,3> supportedTypes = {
1040 DataType::Float32,
1041 DataType::QuantisedAsymm8,
1042 DataType::QuantisedSymm16
1043 };
1044
1045 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1046 "Reference minimum: input 0 is not a supported type.");
1047
1048 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1049 "Reference minimum: input 1 is not a supported type.");
1050
1051 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1052 "Reference minimum: output is not a supported type.");
1053
1054 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1055 "Reference minimum: input 0 and Input 1 types are mismatched");
1056
1057 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1058 "Reference minimum: input and output types are mismatched");
1059
1060 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1061 "Reference minimum: shapes are not suitable for implicit broadcast.");
1062
1063 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001064}
1065
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001066bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1067 const TensorInfo& input1,
1068 const TensorInfo& output,
1069 Optional<std::string&> reasonIfUnsupported) const
1070{
Sadik Armagan2999a022019-04-09 14:20:12 +01001071 bool supported = true;
1072
1073 std::array<DataType,3> supportedTypes = {
1074 DataType::Float32,
1075 DataType::QuantisedAsymm8,
1076 DataType::QuantisedSymm16
1077 };
1078
1079 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1080 "Reference multiplication: input 0 is not a supported type.");
1081
1082 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1083 "Reference multiplication: input 1 is not a supported type.");
1084
1085 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1086 "Reference multiplication: output is not a supported type.");
1087
1088 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1089 "Reference multiplication: input 0 and Input 1 types are mismatched");
1090
1091 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1092 "Reference multiplication: input and output types are mismatched");
1093
1094 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1095 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1096
1097 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001098}
1099
1100bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1101 const TensorInfo& output,
1102 const NormalizationDescriptor& descriptor,
1103 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001104{
Nina Drozd661dfa72018-10-02 11:14:17 +01001105 ignore_unused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001106
1107 // Define supported types
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001108 std::array<DataType, 4> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001109 {
1110 DataType::Float16,
1111 DataType::Float32,
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001112 DataType::QuantisedAsymm8,
1113 DataType::QuantisedSymm16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001114 };
1115
1116 bool supported = true;
1117
1118 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1119 "Reference normalization: input type not supported.");
1120
1121 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1122 "Reference normalization: output type not supported.");
1123
1124 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1125 "Reference normalization: input and output shapes have different "
1126 "num total elements.");
1127
1128 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001129}
1130
1131bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
1132 Optional<std::string&> reasonIfUnsupported) const
1133{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001134 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001135}
1136
1137bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1138 const TensorInfo& output,
1139 const PadDescriptor& descriptor,
1140 Optional<std::string&> reasonIfUnsupported) const
1141{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001142 ignore_unused(output);
1143 ignore_unused(descriptor);
jimfly01f6ba7472018-12-04 10:09:52 +00001144 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1145 input.GetDataType(),
1146 &TrueFunc<>,
1147 &TrueFunc<>);
Nina Drozd661dfa72018-10-02 11:14:17 +01001148}
1149
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001150bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1151 const TensorInfo& output,
1152 const PermuteDescriptor& descriptor,
1153 Optional<std::string&> reasonIfUnsupported) const
1154{
1155 ignore_unused(output);
1156 ignore_unused(descriptor);
1157 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1158 input.GetDataType(),
1159 &TrueFunc<>,
1160 &TrueFunc<>);
1161}
1162
1163bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1164 const TensorInfo& output,
1165 const Pooling2dDescriptor& descriptor,
1166 Optional<std::string&> reasonIfUnsupported) const
1167{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001168 ignore_unused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001169 bool supported = true;
1170
1171 // Define supported output and inputs types.
Teresa Charlin0434df62019-06-06 13:40:35 +01001172 std::array<DataType,3> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001173 {
1174 DataType::Float32,
Teresa Charlin0434df62019-06-06 13:40:35 +01001175 DataType::QuantisedAsymm8,
1176 DataType::QuantisedSymm16
Teresa Charlina3b20472019-06-06 11:12:32 +01001177 };
1178
1179 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1180 "Reference poolind2d: input is not a supported type.");
1181
1182 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1183 "Reference poolind2d: output is not a supported type.");
1184
1185 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1186 "Reference poolind2d: input and output types are mismatched.");
1187
1188 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001189}
1190
Derek Lamberti5f400d62019-03-25 15:41:58 +00001191bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1192 const TensorInfo& output,
1193 Optional<std::string&> reasonIfUnsupported) const
1194{
1195 bool supported = true;
1196
1197 // Define supported output types.
1198 std::array<DataType,2> supportedInputTypes = {
1199 DataType::Float32,
1200 };
1201
1202 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1203 "Reference quantize: input type not supported.");
1204
1205 // Define supported output types.
1206 std::array<DataType,2> supportedOutputTypes = {
1207 DataType::QuantisedAsymm8,
1208 DataType::QuantisedSymm16
1209 };
1210 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1211 "Reference quantize: output type not supported.");
1212
1213 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1214 "Reference quantize: input and output shapes have different num total elements.");
1215
1216 return supported;
1217}
1218
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001219bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001220 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001221 Optional<std::string&> reasonIfUnsupported) const
1222{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001223 ignore_unused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001224 // Define supported output types.
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001225 std::array<DataType,4> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001226 {
1227 DataType::Float32,
1228 DataType::Float16,
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001229 DataType::QuantisedAsymm8,
1230 DataType::QuantisedSymm16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001231 };
1232 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1233 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001234}
1235
1236bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001237 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001238 Optional<std::string&> reasonIfUnsupported) const
1239{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001240 bool supported = true;
1241 std::array<DataType,3> supportedTypes =
1242 {
1243 DataType::Float32,
1244 DataType::QuantisedAsymm8,
1245 DataType::QuantisedSymm16
1246 };
1247
1248 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1249 "Reference ResizeBilinear: input type not supported");
1250
1251 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1252 "Reference ResizeBilinear: output type not supported");
1253
1254 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1255 "Reference ResizeBilinear: input and output types not matching");
1256
1257 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001258}
1259
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001260bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1261 const TensorInfo& output,
1262 Optional<std::string&> reasonIfUnsupported) const
1263{
nikraj010421e7f2019-06-14 09:40:34 +01001264 bool supported = true;
nikraj0124d73212019-06-14 14:20:40 +01001265 std::array<DataType,3> supportedTypes =
nikraj010421e7f2019-06-14 09:40:34 +01001266 {
1267 DataType::Float32,
nikraj0124d73212019-06-14 14:20:40 +01001268 DataType::QuantisedAsymm8,
1269 DataType::QuantisedSymm16
nikraj010421e7f2019-06-14 09:40:34 +01001270 };
1271
1272 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1273 "Reference rsqrt: input type not supported");
1274
1275 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1276 "Reference rsqrt: output type not supported");
1277
1278 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1279 "Reference rsqrt: input and output types not matching");
1280
1281 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1282 "Reference Rsqrt: input and output shapes have different number of total elements");
1283
1284 return supported;
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001285}
1286
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001287bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1288 const TensorInfo& output,
1289 const SoftmaxDescriptor& descriptor,
1290 Optional<std::string&> reasonIfUnsupported) const
1291{
1292 ignore_unused(output);
nikraj01248683f2019-05-29 16:46:50 +01001293 bool supported = true;
1294 std::array<DataType,3> supportedTypes =
1295 {
1296 DataType::Float32,
1297 DataType::QuantisedAsymm8,
1298 DataType::QuantisedSymm16
1299 };
1300
1301 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1302 "Reference concatenation: output type not supported");
1303
1304 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1305 "Reference concatenation: input type not supported");
1306
1307 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1308 "Reference concatenation: input type not supported");
1309
1310 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001311}
1312
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001313bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1314 const TensorInfo& output,
1315 const SpaceToBatchNdDescriptor& descriptor,
1316 Optional<std::string&> reasonIfUnsupported) const
1317{
1318 ignore_unused(output);
nikraj01120522a2019-05-31 11:33:07 +01001319 bool supported = true;
1320 std::array<DataType,3> supportedTypes =
1321 {
1322 DataType::Float32,
1323 DataType::QuantisedAsymm8,
1324 DataType::QuantisedSymm16
1325 };
1326
1327 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1328 "Reference SpaceToBatchNd: input type not supported");
1329
1330 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1331 "Reference SpaceToBatchNd: output type not supported");
1332
1333 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1334 "Reference SpaceToBatchNd: input and output types are mismatched");
1335
1336 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001337}
1338
Keith Davisa57eccb2019-06-14 17:33:22 +01001339bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01001340 const TensorInfo& output,
1341 const SpaceToDepthDescriptor& descriptor,
1342 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01001343{
1344
1345 ignore_unused(descriptor);
1346 bool supported = true;
1347
1348 std::array<DataType,2> supportedTypes =
1349 {
1350 DataType::Float32,
1351 DataType::QuantisedAsymm8,
1352 };
1353
1354 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1355 "Reference SpaceToDepth: input type not supported");
1356
1357 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1358 "Reference SpaceToDepth: output type not supported");
1359
1360 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1361 "Reference SpaceToDepth: input and output types are mismatched");
1362
1363 return supported;
1364}
1365
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001366bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1367 const ViewsDescriptor& descriptor,
1368 Optional<std::string&> reasonIfUnsupported) const
1369{
1370 ignore_unused(descriptor);
1371 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1372 input.GetDataType(),
1373 &TrueFunc<>,
1374 &TrueFunc<>);
1375}
1376
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001377bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1378 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1379 const ViewsDescriptor& descriptor,
1380 Optional<std::string&> reasonIfUnsupported) const
1381{
1382 ignore_unused(descriptor);
1383 ignore_unused(outputs);
1384 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1385 input.GetDataType(),
1386 &TrueFunc<>,
1387 &TrueFunc<>);
1388}
1389
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001390bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1391 const TensorInfo& output,
1392 const StridedSliceDescriptor& descriptor,
1393 Optional<std::string&> reasonIfUnsupported) const
1394{
1395 ignore_unused(output);
1396 ignore_unused(descriptor);
1397 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1398 input.GetDataType(),
1399 &TrueFunc<>,
1400 &TrueFunc<>);
1401}
1402
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001403bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1404 const TensorInfo& input1,
1405 const TensorInfo& output,
1406 Optional<std::string&> reasonIfUnsupported) const
1407{
Sadik Armagan2999a022019-04-09 14:20:12 +01001408 bool supported = true;
1409
1410 std::array<DataType,3> supportedTypes = {
1411 DataType::Float32,
1412 DataType::QuantisedAsymm8,
1413 DataType::QuantisedSymm16
1414 };
1415
1416 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1417 "Reference subtraction: input 0 is not a supported type.");
1418
1419 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1420 "Reference subtraction: input 1 is not a supported type.");
1421
1422 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1423 "Reference subtraction: output is not a supported type.");
1424
1425 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1426 "Reference subtraction: input 0 and Input 1 types are mismatched");
1427
1428 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1429 "Reference subtraction: input and output types are mismatched");
1430
1431 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1432 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1433
1434 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001435}
1436
Matteo Martincighab9e5252019-06-13 17:27:46 +01001437bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
1438 const TensorInfo& alpha,
1439 const TensorInfo& output,
1440 Optional<std::string&> reasonIfUnsupported) const
1441{
1442 bool supported = true;
1443
1444 std::array<DataType, 3> supportedTypes
1445 {
1446 DataType::Float32,
1447 DataType::QuantisedAsymm8,
1448 DataType::QuantisedSymm16
1449 };
1450
1451 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1452 "PReLU: input is not a supported type.");
1453
1454 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
1455 "PReLU: alpha is not a supported type.");
1456
1457 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1458 "PReLU: output is not a supported type.");
1459
1460 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
1461 "PReLU: input, alpha and output types are mismatched");
1462
1463 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
1464 "PReLU: shapes are not suitable for implicit broadcast");
1465
1466 return supported;
1467}
1468
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001469bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
1470 const TensorInfo& output,
1471 const TransposeConvolution2dDescriptor& descriptor,
1472 const TensorInfo& weights,
1473 const Optional<TensorInfo>& biases,
1474 Optional<std::string&> reasonIfUnsupported) const
1475{
1476 ignore_unused(descriptor);
1477
1478 bool supported = true;
1479
1480 std::array<DataType,3> supportedTypes =
1481 {
1482 DataType::Float32,
1483 DataType::QuantisedAsymm8,
1484 DataType::QuantisedSymm16
1485 };
1486
1487 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1488 "Reference TransposeConvolution2d: input is not a supported type.");
1489
1490 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1491 "Reference TransposeConvolution2d: output is not a supported type.");
1492
1493 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1494 "Reference TransposeConvolution2d: weights is not a supported type.");
1495
1496 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1497 "Reference TransposeConvolution2d: input and output types mismatched.");
1498
1499 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1500 "Reference TransposeConvolution2d: input and weights types mismatched.");
1501
1502 if (biases.has_value())
1503 {
1504 std::array<DataType,3> biasesSupportedTypes = {
1505 DataType::Float32,
1506 DataType::Signed32
1507 };
1508 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1509 "Reference TransposeConvolution2d: biases is not a supported type.");
1510 }
1511
1512 return supported;
1513}
1514
arovir011c7c81b2018-10-08 11:34:28 +01001515} // namespace armnn