blob: 716e8d949268fad2ac91d805bdc3df88ad146ebf [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01009#include <DataLayoutIndexed.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000010#include <InternalTypes.hpp>
11#include <LayerSupportCommon.hpp>
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +010012
telsoa014fcda012018-03-09 14:13:49 +000013#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000014#include <armnn/Descriptors.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000015#include <armnn/BackendRegistry.hpp>
telsoa014fcda012018-03-09 14:13:49 +000016
Derek Lambertif674aa02019-08-01 15:56:25 +010017#include <backendsCommon/LayerSupportRules.hpp>
Francis Murtagh46c09d02019-05-28 08:15:28 +010018#include <backendsCommon/test/WorkloadTestUtils.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010019
telsoa014fcda012018-03-09 14:13:49 +000020#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021
Derek Lamberti50db4e82019-03-13 14:16:15 +000022#include <vector>
23#include <algorithm>
24#include <array>
25
telsoa014fcda012018-03-09 14:13:49 +000026using namespace boost;
27
28namespace armnn
29{
30
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010031namespace
32{
33
34template<typename Float32Func, typename Uint8Func, typename ... Params>
35bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
36 DataType dataType,
37 Float32Func floatFuncPtr,
38 Uint8Func uint8FuncPtr,
39 Params&&... params)
40{
41 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
42 dataType,
43 &FalseFunc<Params...>,
44 floatFuncPtr,
45 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000046 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000047 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010048 std::forward<Params>(params)...);
49}
50
51} // anonymous namespace
52
James Conroy4d1ff582019-06-10 17:06:39 +010053namespace
54{
55
56std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
57 unsigned int actual,
58 std::string& layerStr,
59 std::string& tensorName)
60{
61 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
62 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
63
64 return errorMsg;
65}
66
67} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000068
Sadik Armagan9199e582019-09-05 17:35:31 +010069bool RefLayerSupport::IsAbsSupported(const TensorInfo& input, const TensorInfo& output,
70 Optional<std::string&> reasonIfUnsupported) const
71{
72 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +010073 std::array<DataType,4> supportedTypes =
Sadik Armagan9199e582019-09-05 17:35:31 +010074 {
75 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +010076 DataType::Float16,
Sadik Armagan9199e582019-09-05 17:35:31 +010077 DataType::QuantisedAsymm8,
78 DataType::QuantisedSymm16
79 };
80
81 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
82 "Reference abs: input type not supported");
83
84 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
85 "Reference abs: output type not supported");
86
87 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
88 "Reference abs: input and output types not matching");
89
90 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
91 "Reference abs: input and output shapes have different number of total elements");
92
93 return supported;
94}
95
arovir011c7c81b2018-10-08 11:34:28 +010096bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
97 const TensorInfo& output,
98 const ActivationDescriptor& descriptor,
99 Optional<std::string&> reasonIfUnsupported) const
100{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000101 bool supported = true;
102
103 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100104 std::array<DataType,4> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000105 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100106 DataType::Float16,
Teresa Charlin18515e22019-04-24 10:17:46 +0100107 DataType::QuantisedAsymm8,
108 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000109 };
110
111 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
112 "Reference activation: input type not supported.");
113
114 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
115 "Reference activation: output type not supported.");
116
117 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
118 "Reference activation: input and output types mismatched.");
119
120 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
121 "Reference activation: input and output shapes are of different rank.");
122
123
124 struct ActivationFunctionSupported : public Rule
125 {
126 ActivationFunctionSupported(const ActivationDescriptor& desc)
127 {
128 switch(desc.m_Function)
129 {
130 case ActivationFunction::Abs:
131 case ActivationFunction::BoundedReLu:
132 case ActivationFunction::LeakyReLu:
133 case ActivationFunction::Linear:
134 case ActivationFunction::ReLu:
135 case ActivationFunction::Sigmoid:
136 case ActivationFunction::SoftReLu:
137 case ActivationFunction::Sqrt:
138 case ActivationFunction::Square:
139 case ActivationFunction::TanH:
140 {
141 m_Res = true;
142 break;
143 }
144 default:
145 {
146 m_Res = false;
147 break;
148 }
149 }
150 }
151 };
152
153 // Function is supported
154 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
155 "Reference activation: function not supported.");
156
157 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100158}
159
160bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
161 const TensorInfo& input1,
162 const TensorInfo& output,
163 Optional<std::string&> reasonIfUnsupported) const
164{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000165 bool supported = true;
166
Matthew Jackson252df3a2019-09-11 09:19:18 +0100167 std::array<DataType,4> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000168 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100169 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +0100170 DataType::QuantisedAsymm8,
171 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000172 };
173
174 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
175 "Reference addition: input 0 is not a supported type.");
176
177 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
178 "Reference addition: input 1 is not a supported type.");
179
180 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
181 "Reference addition: output is not a supported type.");
182
183 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
184 "Reference addition: input 0 and Input 1 types are mismatched");
185
186 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
187 "Reference addition: input and output types are mismatched");
188
189 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
190 "Reference addition: shapes are not suitable for implicit broadcast.");
191
192 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100193}
194
Nikhil Raj68c2c902019-09-19 11:21:11 +0100195bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
196 const armnn::ArgMinMaxDescriptor &descriptor,
197 armnn::Optional<std::string &> reasonIfUnsupported) const
198{
199 ignore_unused(descriptor);
200
201 std::array<DataType, 3> supportedTypes =
202 {
203 DataType::Float32,
204 DataType::QuantisedAsymm8,
205 DataType::QuantisedSymm16
206 };
207
208 bool supported = true;
209
210 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
211 "Reference ArgMinMax: input is not a supported type.");
212 supported &= CheckSupportRule(TypeIs(output, DataType::Signed32), reasonIfUnsupported,
213 "Reference ArgMinMax: output type not supported");
214
215 return supported;
216}
217
arovir011c7c81b2018-10-08 11:34:28 +0100218bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
219 const TensorInfo& output,
220 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100221 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100222 const TensorInfo& beta,
223 const TensorInfo& gamma,
224 const BatchNormalizationDescriptor& descriptor,
225 Optional<std::string&> reasonIfUnsupported) const
226{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100227 ignore_unused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100228
Matthew Jackson9bff1442019-09-12 09:08:23 +0100229 std::array<DataType, 4> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100230 {
231 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100232 DataType::Float16,
Matteo Martincighf5507132019-06-04 10:59:47 +0100233 DataType::QuantisedAsymm8,
234 DataType::QuantisedSymm16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100235 };
236
237 bool supported = true;
238
239 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
240 "Reference batch normalization: input is not a supported type.");
241
242 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
243 "Reference batch normalization: output is not a supported type.");
244
245 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
246 "Reference batch normalization: input and output types are mismatched");
247
248 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
249 "Reference batch normalization: mean is not a supported type.");
250
251 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
252 "Reference batch normalization: variance is not a supported type.");
253
254 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
255 "Reference batch normalization: beta is not a supported type.");
256
257 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
258 "Reference batch normalization: gamma is not a supported type.");
259
260 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100261}
262
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000263bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
264 const TensorInfo& output,
265 const BatchToSpaceNdDescriptor& descriptor,
266 Optional<std::string&> reasonIfUnsupported) const
267{
268 ignore_unused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100269
270 bool supported = true;
271
272 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
273 std::string inputTensorStr = "input";
274 std::string outputTensorStr = "output";
275
276 // Define supported types.
Matthew Jackson9bff1442019-09-12 09:08:23 +0100277 std::array<DataType,4> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100278 {
279 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100280 DataType::Float16,
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100281 DataType::QuantisedAsymm8,
282 DataType::QuantisedSymm16
283 };
284
285 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
286 "Reference BatchToSpaceNd: input type not supported.");
287
288 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
289 "Reference BatchToSpaceNd: output type not supported.");
290
291 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
292 "Reference BatchToSpaceNd: input and output types mismatched.");
293
294 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
295 reasonIfUnsupported,
296 CreateIncorrectDimensionsErrorMsg(4,
297 output.GetNumDimensions(),
298 batchToSpaceNdLayerStr,
299 outputTensorStr).data());
300
301 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
302 reasonIfUnsupported,
303 CreateIncorrectDimensionsErrorMsg(4,
304 input.GetNumDimensions(),
305 batchToSpaceNdLayerStr,
306 inputTensorStr).data());
307
308 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000309}
310
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100311bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
312 const TensorInfo& input1,
313 const TensorInfo& output,
314 const ComparisonDescriptor& descriptor,
315 Optional<std::string&> reasonIfUnsupported) const
316{
317 boost::ignore_unused(descriptor);
318
319 std::array<DataType, 4> supportedInputTypes =
320 {
321 DataType::Float32,
322 DataType::Float16,
323 DataType::QuantisedAsymm8,
324 DataType::QuantisedSymm16
325 };
326
327 bool supported = true;
328 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
329 "Reference comparison: input 0 is not a supported type");
330
331 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
332 "Reference comparison: input 0 and Input 1 types are mismatched");
333
334 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
335 "Reference comparison: output is not of type Boolean");
336
337 return supported;
338}
339
Jim Flynn906f9462019-05-10 13:55:21 +0100340bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
341 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100342 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100343 Optional<std::string&> reasonIfUnsupported) const
344{
Jim Flynne242f2d2019-05-22 14:24:13 +0100345 ignore_unused(descriptor);
346
347 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +0100348 std::array<DataType,4> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100349 {
350 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100351 DataType::Float16,
Jim Flynne242f2d2019-05-22 14:24:13 +0100352 DataType::QuantisedAsymm8,
353 DataType::QuantisedSymm16
354 };
355
356 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
357 "Reference concatenation: output type not supported");
358 for (const TensorInfo* input : inputs)
359 {
Matthew Jackson81e601c2019-07-11 12:07:09 +0100360 BOOST_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100361 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
362 "Reference concatenation: input type not supported");
363
364 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
365 "Reference concatenation: input and output types mismatched.");
366 }
367
368 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100369}
370
arovir011c7c81b2018-10-08 11:34:28 +0100371bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
372 Optional<std::string&> reasonIfUnsupported) const
373{
Jim Flynne242f2d2019-05-22 14:24:13 +0100374 std::array<DataType,4> supportedTypes =
375 {
Nina Drozd58ef2c62019-05-16 12:09:18 +0100376 DataType::Float32,
377 DataType::Signed32,
378 DataType::QuantisedAsymm8,
379 DataType::QuantisedSymm16
380 };
381
382 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
383 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100384}
385
386bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
387 const TensorInfo& output,
388 Optional<std::string&> reasonIfUnsupported) const
389{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100390 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
391 input.GetDataType(),
392 &TrueFunc<>,
393 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000394 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000395 &FalseFuncI32<>,
396 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100397 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
398 output.GetDataType(),
399 &FalseOutputFuncF16<>,
400 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000401 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000402 &FalseFuncI32<>,
403 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100404}
405
406bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
407 const TensorInfo& output,
408 Optional<std::string&> reasonIfUnsupported) const
409{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100410 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
411 input.GetDataType(),
412 &FalseInputFuncF16<>,
413 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000414 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000415 &FalseFuncI32<>,
416 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100417 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
418 output.GetDataType(),
419 &TrueFunc<>,
420 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000421 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000422 &FalseFuncI32<>,
423 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100424}
425
426bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
427 const TensorInfo& output,
428 const Convolution2dDescriptor& descriptor,
429 const TensorInfo& weights,
430 const Optional<TensorInfo>& biases,
431 Optional<std::string&> reasonIfUnsupported) const
432{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100433 bool supported = true;
434
435 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100436 std::array<DataType,4> supportedTypes = {
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100437 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100438 DataType::Float16,
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100439 DataType::QuantisedAsymm8,
440 DataType::QuantisedSymm16
441 };
442
443 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100444 "Reference convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100445
446 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100447 "Reference convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100448
449 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100450 "Reference convolution2d: weights is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100451
452 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100453 "Reference convolution2d: input and output types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100454
455 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100456 "Reference convolution2d: input and weights types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100457
458 if (biases.has_value())
459 {
Matthew Jackson252df3a2019-09-11 09:19:18 +0100460 std::array<DataType,3> biasesSupportedTypes = {
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100461 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100462 DataType::Float16,
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100463 DataType::Signed32
464 };
465 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100466 "Reference convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100467 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100468 ignore_unused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100469
470 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100471}
472
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000473bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
474 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000475 Optional<std::string&> reasonIfUnsupported) const
476{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100477 bool supported = true;
478
479 std::array<DataType,3> supportedTypes =
480 {
481 DataType::Float32,
482 DataType::QuantisedAsymm8,
483 DataType::QuantisedSymm16
484 };
485
486 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
487 "Reference debug: input type not supported");
488
489 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
490 "Reference debug: output type not supported");
491
492 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
493 "Reference debug: input and output types are mismatched");
494
495 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000496}
497
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100498bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
499 const TensorInfo& output,
500 const DepthToSpaceDescriptor& descriptor,
501 Optional<std::string&> reasonIfUnsupported) const
502{
503 ignore_unused(descriptor);
504 bool supported = true;
505
506 std::array<DataType,4> supportedTypes =
507 {
508 DataType::Float32,
509 DataType::Float16,
510 DataType::QuantisedAsymm8,
511 DataType::QuantisedSymm16
512 };
513
514 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
515 "Reference DepthToSpace: input type not supported");
516
517 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
518 "Reference DepthToSpace: output type not supported");
519
520 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
521 "Reference DepthToSpace: input and output types are mismatched");
522
523 return supported;
524}
525
arovir011c7c81b2018-10-08 11:34:28 +0100526bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
527 const TensorInfo& output,
528 const DepthwiseConvolution2dDescriptor& descriptor,
529 const TensorInfo& weights,
530 const Optional<TensorInfo>& biases,
531 Optional<std::string&> reasonIfUnsupported) const
532{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100533 bool supported = true;
534
535 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100536 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100537 {
538 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100539 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100540 DataType::QuantisedAsymm8,
541 DataType::QuantisedSymm16
542 };
543
544 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
545 "Reference DepthwiseConvolution2d: input is not a supported type.");
546
547 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
548 "Reference DepthwiseConvolution2d: output is not a supported type.");
549
550 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
551 "Reference DepthwiseConvolution2d: weights is not a supported type.");
552
553 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
554 "Reference DepthwiseConvolution2d: input and output types mismatched.");
555
556 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
557 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
558
559 if (biases.has_value())
560 {
Matthew Jackson252df3a2019-09-11 09:19:18 +0100561 std::array<DataType,3> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100562 {
563 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100564 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100565 DataType::Signed32
566 };
567 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
568 "Reference DepthwiseConvolution2d: biases is not a supported type.");
569 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100570 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100571
572 return supported;
573
arovir011c7c81b2018-10-08 11:34:28 +0100574}
575
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000576bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
577 const TensorInfo& output,
578 Optional<std::string&> reasonIfUnsupported) const
579{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100580 bool supported = true;
581
582 std::array<DataType,2> supportedInputTypes = {
583 DataType::QuantisedAsymm8,
584 DataType::QuantisedSymm16
585 };
586
587 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
588 "Reference dequantize: input type not supported.");
589
Jan Eilersf7107932019-11-01 11:09:36 +0000590 std::array<DataType,2> supportedOutputTypes = {
591 DataType::Float32,
592 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100593 };
594
595 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
596 "Reference dequantize: output type not supported.");
597
598 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
599 "Reference dequantize: input and output shapes have different num total elements.");
600
601 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000602}
603
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000604bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
605 const armnn::TensorInfo& input1,
606 const armnn::DetectionPostProcessDescriptor& descriptor,
607 armnn::Optional<std::string&> reasonIfUnsupported) const
608{
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100609 bool supported = true;
610
Mike Kelly4992c342019-08-14 11:33:11 +0100611 std::array<DataType,3> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100612 {
613 DataType::Float32,
614 DataType::QuantisedAsymm8,
615 DataType::QuantisedSymm16
616 };
617
618 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
619 "Reference DetectionPostProcess: input 0 is not a supported type.");
620
621 supported &= CheckSupportRule(TypeAnyOf(input1, supportedInputTypes), reasonIfUnsupported,
622 "Reference DetectionPostProcess: input 1 is not a supported type.");
623
624 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000625}
626
Pablo Tellof0bd6832019-04-26 17:58:13 +0100627bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
628 const TensorInfo& output,
629 const DepthwiseConvolution2dDescriptor& descriptor,
630 const TensorInfo& weights,
631 const Optional<TensorInfo>& biases,
632 Optional<std::string&> reasonIfUnsupported) const
633{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100634 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100635}
636
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100637bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100638 const TensorInfo& input1,
639 const TensorInfo& output,
640 Optional<std::string&> reasonIfUnsupported) const
641{
Sadik Armagan2999a022019-04-09 14:20:12 +0100642 bool supported = true;
643
Matthew Jackson9bff1442019-09-12 09:08:23 +0100644 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +0100645 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100646 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +0100647 DataType::QuantisedAsymm8,
648 DataType::QuantisedSymm16
649 };
650
651 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
652 "Reference division: input 0 is not a supported type.");
653
654 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
655 "Reference division: input 1 is not a supported type.");
656
657 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
658 "Reference division: output is not a supported type.");
659
660 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
661 "Reference division: input 0 and Input 1 types are mismatched");
662
663 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
664 "Reference division: input and output types are mismatched");
665
666 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
667 "Reference division: shapes are not suitable for implicit broadcast.");
668
669 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100670}
671
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000672bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
673 const TensorInfo& input1,
674 const TensorInfo& output,
675 Optional<std::string&> reasonIfUnsupported) const
676{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100677 return IsComparisonSupported(input0,
678 input1,
679 output,
680 ComparisonDescriptor(ComparisonOperation::Equal),
681 reasonIfUnsupported);
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000682}
683
arovir011c7c81b2018-10-08 11:34:28 +0100684bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
685 const FakeQuantizationDescriptor& descriptor,
686 Optional<std::string&> reasonIfUnsupported) const
687{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100688 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100689 bool supported = true;
690
691 std::array<DataType,1> supportedTypes =
692 {
693 DataType::Float32
694 };
695
696 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
697 "Reference fake quantization: input type not supported.");
698
699 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100700}
701
702bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
703 const TensorInfo& output,
704 Optional<std::string&> reasonIfUnsupported) const
705{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100706 ignore_unused(output);
James Conroy83735b12019-05-30 16:36:59 +0100707 bool supported = true;
708
Matthew Jackson9bff1442019-09-12 09:08:23 +0100709 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100710 {
James Conroyb40d7102019-06-04 12:32:09 +0100711 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100712 DataType::Float16,
James Conroyb40d7102019-06-04 12:32:09 +0100713 DataType::QuantisedSymm16
James Conroy83735b12019-05-30 16:36:59 +0100714 };
715
716 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
717 "Reference Floor: input type not supported.");
718
719 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
720 "Reference Floor: output type not supported.");
721
722 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100723}
724
725bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
726 const TensorInfo& output,
727 const TensorInfo& weights,
728 const TensorInfo& biases,
729 const FullyConnectedDescriptor& descriptor,
730 Optional<std::string&> reasonIfUnsupported) const
731{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100732 bool supported = true;
733
734 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100735 std::array<DataType,4> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +0100736 {
737 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100738 DataType::Float16,
Francis Murtagh46c09d02019-05-28 08:15:28 +0100739 DataType::QuantisedAsymm8,
740 DataType::QuantisedSymm16
741 };
742
743 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
744 "Reference Fully Connected: input type not supported.");
745
746 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
747 "Reference Fully Connected: output type not supported.");
748
749 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
750 "Reference Fully Connected: input and output types mismatched.");
751
752 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
753 "Reference Fully Connected: weights type not supported.");
754
755 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
756 "Reference Fully Connected: input and weight types mismatched.");
757
758 if (descriptor.m_BiasEnabled)
759 {
760 // Defined supported types for bias
Matthew Jackson252df3a2019-09-11 09:19:18 +0100761 std::array<DataType, 3>
Francis Murtagh46c09d02019-05-28 08:15:28 +0100762 supportedBiasTypes =
763 {
764 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100765 DataType::Float16,
Francis Murtagh46c09d02019-05-28 08:15:28 +0100766 DataType::Signed32
767 };
768
769 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
770 "Reference Fully Connected: bias type not supported.");
771
772 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
773 "Reference Fully Connected: bias and weight types mismatch.");
774
775 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
776 "Reference Fully Connected: bias type inferred from weights is incompatible.");
777
778 }
779
780 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100781}
782
narpra014951d842019-01-18 16:53:53 +0000783bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
784 const armnn::TensorInfo& input1,
785 const armnn::TensorInfo& output,
786 armnn::Optional<std::string&> reasonIfUnsupported) const
787{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100788 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +0100789 std::array<DataType,4> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100790 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100791 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100792 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100793 DataType::QuantisedAsymm8,
794 DataType::QuantisedSymm16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100795 };
796
797 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
798 "Reference Gather: input type not supported");
799
800 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
801 "Reference Gather: output type not supported");
802
803 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
804 "Reference Gather: indices (input1) type not supported");
805
806 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
807 "Reference Gather: input and output types not matching");
808
809 return supported;
narpra014951d842019-01-18 16:53:53 +0000810}
811
FrancisMurtagh878f0232018-12-19 10:56:15 +0000812bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
813 const TensorInfo& input1,
814 const TensorInfo& output,
815 Optional<std::string&> reasonIfUnsupported) const
816{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100817 return IsComparisonSupported(input0,
818 input1,
819 output,
820 ComparisonDescriptor(ComparisonOperation::Greater),
821 reasonIfUnsupported);
FrancisMurtagh878f0232018-12-19 10:56:15 +0000822}
823
arovir011c7c81b2018-10-08 11:34:28 +0100824bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
825 Optional<std::string&> reasonIfUnsupported) const
826{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +0100827 return true;
arovir011c7c81b2018-10-08 11:34:28 +0100828}
829
Kevin May09ca49c2019-10-09 12:37:34 +0100830bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
831 const TensorInfo& output,
832 const InstanceNormalizationDescriptor& descriptor,
833 Optional<std::string&> reasonIfUnsupported) const
834{
835 ignore_unused(descriptor);
836 // Define supported types
837 std::array<DataType, 4> supportedTypes =
838 {
839 DataType::Float32,
840 DataType::Float16
841 };
842
843 bool supported = true;
844
845 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
846 "Reference Instance Normalization: input type not supported.");
847
848 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
849 "Reference Instance Normalization: output type not supported.");
850
851 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
852 "Reference Instance Normalization: input and output types mismatched.");
853
854 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
855 "Reference Instance Normalization: input and output shapes have different "
856 "num total elements.");
857
858 return supported;
859}
860
arovir011c7c81b2018-10-08 11:34:28 +0100861bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
862 const TensorInfo& output,
863 const L2NormalizationDescriptor& descriptor,
864 Optional<std::string&> reasonIfUnsupported) const
865{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100866 ignore_unused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100867 // Define supported types
Matthew Jackson252df3a2019-09-11 09:19:18 +0100868 std::array<DataType, 4> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100869 {
870 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100871 DataType::Float16,
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100872 DataType::QuantisedAsymm8,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100873 DataType::QuantisedSymm16
874 };
875
876 bool supported = true;
877
878 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
879 "Reference L2normalization: input type not supported.");
880
881 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
882 "Reference L2normalization: output type not supported.");
883
884 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
885 "Reference L2normalization: input and output types mismatched.");
886
887 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
888 "Reference L2normalization: input and output shapes have different "
889 "num total elements.");
890
891 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100892}
893
Aron Virginas-Tare662a942019-10-14 15:12:00 +0100894bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
895 const TensorInfo& output,
896 const LogSoftmaxDescriptor& descriptor,
897 Optional<std::string&> reasonIfUnsupported) const
898{
899 ignore_unused(descriptor);
900
901 std::array<DataType, 2> supportedTypes =
902 {
903 DataType::Float32,
904 DataType::Float16
905 };
906
907 bool supported = true;
908 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
909 "Reference LogSoftmax: input type not supported");
910
911 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
912 "Reference LogSoftmax: output type not supported");
913
914 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
915 "Reference LogSoftmax: input and output types do not match");
916
917 return supported;
918}
919
arovir011c7c81b2018-10-08 11:34:28 +0100920bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
921 const TensorInfo& outputStateIn,
922 const TensorInfo& cellStateIn,
923 const TensorInfo& scratchBuffer,
924 const TensorInfo& outputStateOut,
925 const TensorInfo& cellStateOut,
926 const TensorInfo& output,
927 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100928 const LstmInputParamsInfo& paramsInfo,
929 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +0100930{
telsoa01c577f2c2018-08-31 09:22:23 +0100931 ignore_unused(descriptor);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100932 ignore_unused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100933
934 bool supported = true;
935
936 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +0100937 DataType::Float32,
938 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100939 };
940
Jan Eilersd01a83c2019-07-03 18:20:40 +0100941 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100942 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
943 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100944 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
945 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100946 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
947 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100948 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
949 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100950 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
951 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100952 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
953 "Reference Lstm: input and cellStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100954 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
955 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +0100956 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +0100957 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100958 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100959 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100960 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100961 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100962 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100963 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100964 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100965 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100966 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100967 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100968 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100969 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100970 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100971 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100972 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100973 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100974 "Reference Lstm: input and OutputGateBias types are mismatched");
975 if (!descriptor.m_CifgEnabled)
976 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100977 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100978 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100979 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +0100980 reasonIfUnsupported,
981 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100982 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100983 "Reference Lstm: input and InputGateBias types are mismatched");
984 if (descriptor.m_PeepholeEnabled)
985 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100986 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +0100987 reasonIfUnsupported,
988 "Reference Lstm: input and CellToInputWeights types are mismatched");
989 }
990 }
991 if (descriptor.m_PeepholeEnabled)
992 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100993 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100994 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100995 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100996 "Reference Lstm: input and CellToOutputWeights types are mismatched");
997 }
998 if (descriptor.m_ProjectionEnabled)
999 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001000 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001001 "Reference Lstm: input and mProjectionWeights types are mismatched");
1002 if (paramsInfo.m_ProjectionBias != nullptr)
1003 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001004 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001005 "Reference Lstm: input and ProjectionBias types are mismatched");
1006 }
1007 }
1008 if (descriptor.m_LayerNormEnabled)
1009 {
1010 if (!descriptor.m_CifgEnabled)
1011 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001012 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001013 reasonIfUnsupported,
1014 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1015 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001016 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001017 reasonIfUnsupported,
1018 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001019 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001020 reasonIfUnsupported,
1021 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001022 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001023 reasonIfUnsupported,
1024 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1025 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001026
1027 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001028}
1029
saoste012df12b32018-11-28 16:57:20 +00001030bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1031 const TensorInfo& input1,
1032 const TensorInfo& output,
1033 Optional<std::string&> reasonIfUnsupported) const
1034{
Sadik Armagan2999a022019-04-09 14:20:12 +01001035 bool supported = true;
1036
Matthew Jackson9bff1442019-09-12 09:08:23 +01001037 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001038 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001039 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001040 DataType::QuantisedAsymm8,
1041 DataType::QuantisedSymm16
1042 };
1043
1044 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1045 "Reference maximum: input 0 is not a supported type.");
1046
1047 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1048 "Reference maximum: input 1 is not a supported type.");
1049
1050 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1051 "Reference maximum: output is not a supported type.");
1052
1053 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1054 "Reference maximum: input 0 and Input 1 types are mismatched");
1055
1056 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1057 "Reference maximum: input and output types are mismatched");
1058
1059 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1060 "Reference maximum: shapes are not suitable for implicit broadcast.");
1061
1062 return supported;
saoste012df12b32018-11-28 16:57:20 +00001063}
1064
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001065bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1066 const TensorInfo& output,
1067 const MeanDescriptor& descriptor,
1068 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001069{
James Conroy4d1ff582019-06-10 17:06:39 +01001070 bool supported = true;
1071 std::string meanLayerStr = "Mean";
1072 std::string outputTensorStr = "output";
1073
Matthew Jackson252df3a2019-09-11 09:19:18 +01001074 std::array<DataType,4> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001075 {
1076 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001077 DataType::Float16,
James Conroyb80775f2019-06-11 11:25:30 +01001078 DataType::QuantisedAsymm8,
1079 DataType::QuantisedSymm16
James Conroy4d1ff582019-06-10 17:06:39 +01001080 };
1081
1082 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1083 "Reference Mean: input type not supported.");
1084
1085 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1086 "Reference Mean: input and output types are mismatched");
1087
1088 if (descriptor.m_KeepDims)
1089 {
1090 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1091 reasonIfUnsupported,
1092 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1093 output.GetNumDimensions(),
1094 meanLayerStr, outputTensorStr).data());
1095 }
1096 else if (descriptor.m_Axis.empty())
1097 {
1098 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1099 reasonIfUnsupported,
1100 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1101 meanLayerStr, outputTensorStr).data());
1102 }
1103 else
1104 {
1105 auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
1106
1107 if (outputDim > 0)
1108 {
1109 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1110 reasonIfUnsupported,
1111 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1112 meanLayerStr, outputTensorStr).data());
1113 }
1114 else
1115 {
1116 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1117 reasonIfUnsupported,
1118 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1119 meanLayerStr, outputTensorStr).data());
1120 }
1121 }
1122
1123 return supported;
narpra0132b90462018-09-13 11:07:48 +01001124}
1125
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001126bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +00001127 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +01001128 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001129 Optional<std::string&> reasonIfUnsupported) const
1130{
Jim Flynne242f2d2019-05-22 14:24:13 +01001131 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001132}
1133
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001134bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1135 const TensorInfo &output,
1136 Optional<std::string &> reasonIfUnsupported) const
1137{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001138 bool supported = true;
1139
1140 std::array<DataType,5> supportedTypes =
1141 {
1142 DataType::Float32,
1143 DataType::Float16,
1144 DataType::QuantisedAsymm8,
1145 DataType::QuantisedSymm16,
1146 DataType::Boolean
1147 };
1148
1149 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1150 "Reference MemCopy: input type not supported");
1151
1152 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1153 "Reference MemCopy: output type not supported");
1154
1155 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1156 "Reference MemCopy: input and output types are mismatched");
1157
1158 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001159}
1160
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001161bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1162 const TensorInfo& input1,
1163 const TensorInfo& output,
1164 Optional<std::string&> reasonIfUnsupported) const
1165{
Sadik Armagan2999a022019-04-09 14:20:12 +01001166 bool supported = true;
1167
Matthew Jackson9bff1442019-09-12 09:08:23 +01001168 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001169 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001170 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001171 DataType::QuantisedAsymm8,
1172 DataType::QuantisedSymm16
1173 };
1174
1175 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1176 "Reference minimum: input 0 is not a supported type.");
1177
1178 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1179 "Reference minimum: input 1 is not a supported type.");
1180
1181 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1182 "Reference minimum: output is not a supported type.");
1183
1184 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1185 "Reference minimum: input 0 and Input 1 types are mismatched");
1186
1187 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1188 "Reference minimum: input and output types are mismatched");
1189
1190 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1191 "Reference minimum: shapes are not suitable for implicit broadcast.");
1192
1193 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001194}
1195
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001196bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1197 const TensorInfo& input1,
1198 const TensorInfo& output,
1199 Optional<std::string&> reasonIfUnsupported) const
1200{
Sadik Armagan2999a022019-04-09 14:20:12 +01001201 bool supported = true;
1202
Matthew Jackson252df3a2019-09-11 09:19:18 +01001203 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001204 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001205 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001206 DataType::QuantisedAsymm8,
1207 DataType::QuantisedSymm16
1208 };
1209
1210 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1211 "Reference multiplication: input 0 is not a supported type.");
1212
1213 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1214 "Reference multiplication: input 1 is not a supported type.");
1215
1216 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1217 "Reference multiplication: output is not a supported type.");
1218
1219 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1220 "Reference multiplication: input 0 and Input 1 types are mismatched");
1221
1222 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1223 "Reference multiplication: input and output types are mismatched");
1224
1225 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1226 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1227
1228 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001229}
1230
1231bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1232 const TensorInfo& output,
1233 const NormalizationDescriptor& descriptor,
1234 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001235{
Nina Drozd661dfa72018-10-02 11:14:17 +01001236 ignore_unused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001237
1238 // Define supported types
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001239 std::array<DataType, 4> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001240 {
1241 DataType::Float16,
1242 DataType::Float32,
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001243 DataType::QuantisedAsymm8,
1244 DataType::QuantisedSymm16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001245 };
1246
1247 bool supported = true;
1248
1249 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1250 "Reference normalization: input type not supported.");
1251
1252 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1253 "Reference normalization: output type not supported.");
1254
1255 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1256 "Reference normalization: input and output shapes have different "
1257 "num total elements.");
1258
1259 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001260}
1261
1262bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
1263 Optional<std::string&> reasonIfUnsupported) const
1264{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001265 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001266}
1267
1268bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1269 const TensorInfo& output,
1270 const PadDescriptor& descriptor,
1271 Optional<std::string&> reasonIfUnsupported) const
1272{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001273 ignore_unused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001274 bool supported = true;
1275
1276 // Define supported output and inputs types.
Matthew Jackson252df3a2019-09-11 09:19:18 +01001277 std::array<DataType,4> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001278 {
1279 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001280 DataType::Float16,
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001281 DataType::QuantisedAsymm8,
1282 DataType::QuantisedSymm16
1283 };
1284
1285 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1286 "Reference pad: input is not a supported type.");
1287
1288 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1289 "Reference pad: output is not a supported type.");
1290
1291 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1292 "Reference pad: input and output types are mismatched.");
1293
1294 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01001295}
1296
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001297bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1298 const TensorInfo& output,
1299 const PermuteDescriptor& descriptor,
1300 Optional<std::string&> reasonIfUnsupported) const
1301{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001302 ignore_unused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001303 bool supported = true;
1304
1305 // Define supported output and inputs types.
1306 std::array<DataType,3> supportedTypes =
1307 {
1308 DataType::Float32,
1309 DataType::QuantisedAsymm8,
1310 DataType::QuantisedSymm16
1311 };
1312
1313 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1314 "Reference permute: input is not a supported type.");
1315
1316 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1317 "Reference permute: output is not a supported type.");
1318
1319 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1320 "Reference permute: input and output types are mismatched.");
1321
1322 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001323}
1324
1325bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1326 const TensorInfo& output,
1327 const Pooling2dDescriptor& descriptor,
1328 Optional<std::string&> reasonIfUnsupported) const
1329{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001330 ignore_unused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001331 bool supported = true;
1332
1333 // Define supported output and inputs types.
Matthew Jackson252df3a2019-09-11 09:19:18 +01001334 std::array<DataType,4> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001335 {
1336 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001337 DataType::Float16,
Teresa Charlin0434df62019-06-06 13:40:35 +01001338 DataType::QuantisedAsymm8,
1339 DataType::QuantisedSymm16
Teresa Charlina3b20472019-06-06 11:12:32 +01001340 };
1341
1342 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1343 "Reference poolind2d: input is not a supported type.");
1344
1345 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1346 "Reference poolind2d: output is not a supported type.");
1347
1348 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1349 "Reference poolind2d: input and output types are mismatched.");
1350
1351 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001352}
1353
Derek Lamberti5f400d62019-03-25 15:41:58 +00001354bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1355 const TensorInfo& output,
1356 Optional<std::string&> reasonIfUnsupported) const
1357{
1358 bool supported = true;
1359
1360 // Define supported output types.
Sadik Armagan1a816302019-07-29 17:16:40 +01001361 std::array<DataType,1> supportedInputTypes = {
Derek Lamberti5f400d62019-03-25 15:41:58 +00001362 DataType::Float32,
1363 };
1364
1365 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1366 "Reference quantize: input type not supported.");
1367
1368 // Define supported output types.
1369 std::array<DataType,2> supportedOutputTypes = {
1370 DataType::QuantisedAsymm8,
1371 DataType::QuantisedSymm16
1372 };
1373 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1374 "Reference quantize: output type not supported.");
1375
1376 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1377 "Reference quantize: input and output shapes have different num total elements.");
1378
1379 return supported;
1380}
1381
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001382bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001383 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001384 Optional<std::string&> reasonIfUnsupported) const
1385{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001386 ignore_unused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001387 // Define supported output types.
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001388 std::array<DataType,5> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001389 {
1390 DataType::Float32,
1391 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001392 DataType::Signed32,
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001393 DataType::QuantisedAsymm8,
1394 DataType::QuantisedSymm16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001395 };
1396 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1397 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001398}
1399
1400bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001401 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001402 Optional<std::string&> reasonIfUnsupported) const
1403{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001404 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001405 std::array<DataType,4> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001406 {
1407 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001408 DataType::Float16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001409 DataType::QuantisedAsymm8,
1410 DataType::QuantisedSymm16
1411 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001412
1413 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1414 "Reference ResizeBilinear: input type not supported");
1415
1416 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1417 "Reference ResizeBilinear: output type not supported");
1418
1419 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1420 "Reference ResizeBilinear: input and output types not matching");
1421
1422 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001423}
1424
Teresa Charlin970f43b2019-07-01 13:51:07 +01001425bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1426 const TensorInfo& output,
1427 const ResizeDescriptor& descriptor,
1428 Optional<std::string&> reasonIfUnsupported) const
1429{
1430 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001431 std::array<DataType,4> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001432 {
1433 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001434 DataType::Float16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001435 DataType::QuantisedAsymm8,
1436 DataType::QuantisedSymm16
1437 };
1438
1439 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1440 "Reference Resize: input type not supported");
1441
1442 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1443 "Reference Resize: output type not supported");
1444
1445 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1446 "Reference Resize: input and output types not matching");
1447
1448 return supported;
1449}
1450
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001451bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1452 const TensorInfo& output,
1453 Optional<std::string&> reasonIfUnsupported) const
1454{
nikraj010421e7f2019-06-14 09:40:34 +01001455 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001456 std::array<DataType,4> supportedTypes =
nikraj010421e7f2019-06-14 09:40:34 +01001457 {
1458 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001459 DataType::Float16,
nikraj0124d73212019-06-14 14:20:40 +01001460 DataType::QuantisedAsymm8,
1461 DataType::QuantisedSymm16
nikraj010421e7f2019-06-14 09:40:34 +01001462 };
1463
1464 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1465 "Reference rsqrt: input type not supported");
1466
1467 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1468 "Reference rsqrt: output type not supported");
1469
1470 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1471 "Reference rsqrt: input and output types not matching");
1472
1473 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1474 "Reference Rsqrt: input and output shapes have different number of total elements");
1475
1476 return supported;
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001477}
1478
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001479bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
1480 const TensorInfo& output,
1481 const SliceDescriptor& descriptor,
1482 Optional<std::string&> reasonIfUnsupported) const
1483{
1484 ignore_unused(descriptor);
1485 bool supported = true;
1486
1487 std::array<DataType, 3> supportedTypes =
1488 {
1489 DataType::Float32,
1490 DataType::QuantisedAsymm8,
1491 DataType::QuantisedSymm16
1492 };
1493
1494 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1495 "Reference Slice: input type not supported");
1496
1497 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1498 "Reference Slice: output type not supported");
1499
1500 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1501 "Reference Slice: input and output types are mismatched");
1502
1503 return supported;
1504}
1505
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001506bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1507 const TensorInfo& output,
1508 const SoftmaxDescriptor& descriptor,
1509 Optional<std::string&> reasonIfUnsupported) const
1510{
1511 ignore_unused(output);
nikraj01248683f2019-05-29 16:46:50 +01001512 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001513 std::array<DataType,4> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01001514 {
1515 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001516 DataType::Float16,
nikraj01248683f2019-05-29 16:46:50 +01001517 DataType::QuantisedAsymm8,
1518 DataType::QuantisedSymm16
1519 };
1520
1521 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001522 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001523
1524 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001525 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001526
1527 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001528 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001529
1530 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001531}
1532
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001533bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1534 const TensorInfo& output,
1535 const SpaceToBatchNdDescriptor& descriptor,
1536 Optional<std::string&> reasonIfUnsupported) const
1537{
1538 ignore_unused(output);
nikraj01120522a2019-05-31 11:33:07 +01001539 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001540 std::array<DataType,4> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01001541 {
1542 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001543 DataType::Float16,
nikraj01120522a2019-05-31 11:33:07 +01001544 DataType::QuantisedAsymm8,
1545 DataType::QuantisedSymm16
1546 };
1547
1548 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1549 "Reference SpaceToBatchNd: input type not supported");
1550
1551 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1552 "Reference SpaceToBatchNd: output type not supported");
1553
1554 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1555 "Reference SpaceToBatchNd: input and output types are mismatched");
1556
1557 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001558}
1559
Keith Davisa57eccb2019-06-14 17:33:22 +01001560bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01001561 const TensorInfo& output,
1562 const SpaceToDepthDescriptor& descriptor,
1563 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01001564{
1565
1566 ignore_unused(descriptor);
1567 bool supported = true;
1568
Matthew Jackson9bff1442019-09-12 09:08:23 +01001569 std::array<DataType,4> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01001570 {
1571 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001572 DataType::Float16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001573 DataType::QuantisedAsymm8,
James Conroyd2aa85e2019-07-01 17:12:40 +01001574 DataType::QuantisedSymm16
Keith Davisa57eccb2019-06-14 17:33:22 +01001575 };
1576
1577 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1578 "Reference SpaceToDepth: input type not supported");
1579
1580 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1581 "Reference SpaceToDepth: output type not supported");
1582
1583 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1584 "Reference SpaceToDepth: input and output types are mismatched");
1585
1586 return supported;
1587}
1588
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001589bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1590 const ViewsDescriptor& descriptor,
1591 Optional<std::string&> reasonIfUnsupported) const
1592{
1593 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001594 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001595 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001596 {
1597 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001598 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001599 DataType::QuantisedAsymm8,
1600 DataType::QuantisedSymm16
1601 };
1602
1603 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1604 "Reference splitter: input type not supported");
1605
1606 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001607}
1608
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001609bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1610 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1611 const ViewsDescriptor& descriptor,
1612 Optional<std::string&> reasonIfUnsupported) const
1613{
1614 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001615 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001616 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001617 {
1618 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001619 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001620 DataType::QuantisedAsymm8,
1621 DataType::QuantisedSymm16
1622 };
1623
1624 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1625 "Reference splitter: output type not supported");
1626 for (const TensorInfo output : outputs)
1627 {
1628 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1629 "Reference splitter: input type not supported");
1630
1631 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1632 "Reference splitter: input and output types mismatched.");
1633 }
1634
1635 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001636}
1637
Matthew Jackson81e601c2019-07-11 12:07:09 +01001638bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
1639 const TensorInfo& output,
1640 const StackDescriptor& descriptor,
1641 Optional<std::string&> reasonIfUnsupported) const
1642{
1643 ignore_unused(descriptor);
1644
1645 bool supported = true;
Matthew Jacksone69c3992019-09-09 14:31:21 +01001646 std::array<DataType,4> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01001647 {
1648 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01001649 DataType::Float16,
Matthew Jackson81e601c2019-07-11 12:07:09 +01001650 DataType::QuantisedAsymm8,
1651 DataType::QuantisedSymm16
1652 };
1653
1654 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1655 "Reference stack: output type not supported");
1656 for (const TensorInfo* input : inputs)
1657 {
1658 BOOST_ASSERT(input != nullptr);
1659 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
1660 "Reference stack: input type not supported");
1661
1662 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
1663 "Reference stack: input and output types mismatched.");
1664 }
1665
1666 return supported;
1667}
1668
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001669bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1670 const TensorInfo& output,
1671 const StridedSliceDescriptor& descriptor,
1672 Optional<std::string&> reasonIfUnsupported) const
1673{
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001674 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001675 bool supported = true;
1676
1677 std::array<DataType,3> supportedTypes =
1678 {
1679 DataType::Float32,
1680 DataType::QuantisedAsymm8,
1681 DataType::QuantisedSymm16
1682 };
1683
1684 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1685 "Reference StridedSlice: input type not supported");
1686
1687 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1688 "Reference StridedSlice: output type not supported");
1689
1690 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1691 "Reference StridedSlice: input and output types are mismatched");
1692
1693 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001694}
1695
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001696bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1697 const TensorInfo& input1,
1698 const TensorInfo& output,
1699 Optional<std::string&> reasonIfUnsupported) const
1700{
Sadik Armagan2999a022019-04-09 14:20:12 +01001701 bool supported = true;
1702
Matthew Jackson9bff1442019-09-12 09:08:23 +01001703 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001704 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001705 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001706 DataType::QuantisedAsymm8,
1707 DataType::QuantisedSymm16
1708 };
1709
1710 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1711 "Reference subtraction: input 0 is not a supported type.");
1712
1713 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1714 "Reference subtraction: input 1 is not a supported type.");
1715
1716 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1717 "Reference subtraction: output is not a supported type.");
1718
1719 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1720 "Reference subtraction: input 0 and Input 1 types are mismatched");
1721
1722 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1723 "Reference subtraction: input and output types are mismatched");
1724
1725 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1726 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1727
1728 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001729}
1730
Matteo Martincighab9e5252019-06-13 17:27:46 +01001731bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
1732 const TensorInfo& alpha,
1733 const TensorInfo& output,
1734 Optional<std::string&> reasonIfUnsupported) const
1735{
1736 bool supported = true;
1737
Matthew Jackson9bff1442019-09-12 09:08:23 +01001738 std::array<DataType, 4> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01001739 {
1740 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001741 DataType::Float16,
Matteo Martincighab9e5252019-06-13 17:27:46 +01001742 DataType::QuantisedAsymm8,
1743 DataType::QuantisedSymm16
1744 };
1745
1746 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1747 "PReLU: input is not a supported type.");
1748
1749 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
1750 "PReLU: alpha is not a supported type.");
1751
1752 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1753 "PReLU: output is not a supported type.");
1754
1755 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
1756 "PReLU: input, alpha and output types are mismatched");
1757
1758 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
1759 "PReLU: shapes are not suitable for implicit broadcast");
1760
1761 return supported;
1762}
1763
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001764bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
1765 const TensorInfo& output,
1766 const TransposeConvolution2dDescriptor& descriptor,
1767 const TensorInfo& weights,
1768 const Optional<TensorInfo>& biases,
1769 Optional<std::string&> reasonIfUnsupported) const
1770{
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001771 bool supported = true;
1772
Matthew Jackson252df3a2019-09-11 09:19:18 +01001773 std::array<DataType,4> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001774 {
1775 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001776 DataType::Float16,
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001777 DataType::QuantisedAsymm8,
1778 DataType::QuantisedSymm16
1779 };
1780
1781 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1782 "Reference TransposeConvolution2d: input is not a supported type.");
1783
1784 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1785 "Reference TransposeConvolution2d: output is not a supported type.");
1786
1787 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1788 "Reference TransposeConvolution2d: weights is not a supported type.");
1789
1790 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1791 "Reference TransposeConvolution2d: input and output types mismatched.");
1792
1793 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1794 "Reference TransposeConvolution2d: input and weights types mismatched.");
1795
1796 if (biases.has_value())
1797 {
Matthew Jackson252df3a2019-09-11 09:19:18 +01001798 std::array<DataType,3> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01001799 {
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001800 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001801 DataType::Float16,
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001802 DataType::Signed32
1803 };
1804 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1805 "Reference TransposeConvolution2d: biases is not a supported type.");
1806 }
1807
1808 return supported;
1809}
1810
arovir011c7c81b2018-10-08 11:34:28 +01001811} // namespace armnn