blob: e98af7097b18b5c4aab54d9560eed562b55384c0 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01009#include <DataLayoutIndexed.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000010#include <InternalTypes.hpp>
11#include <LayerSupportCommon.hpp>
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +010012
telsoa014fcda012018-03-09 14:13:49 +000013#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000014#include <armnn/Descriptors.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000015#include <armnn/BackendRegistry.hpp>
telsoa014fcda012018-03-09 14:13:49 +000016
Derek Lambertif674aa02019-08-01 15:56:25 +010017#include <backendsCommon/LayerSupportRules.hpp>
Francis Murtagh46c09d02019-05-28 08:15:28 +010018#include <backendsCommon/test/WorkloadTestUtils.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010019
telsoa014fcda012018-03-09 14:13:49 +000020#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021
Derek Lamberti50db4e82019-03-13 14:16:15 +000022#include <vector>
23#include <algorithm>
24#include <array>
25
telsoa014fcda012018-03-09 14:13:49 +000026using namespace boost;
27
28namespace armnn
29{
30
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010031namespace
32{
33
34template<typename Float32Func, typename Uint8Func, typename ... Params>
35bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
36 DataType dataType,
37 Float32Func floatFuncPtr,
38 Uint8Func uint8FuncPtr,
39 Params&&... params)
40{
41 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
42 dataType,
43 &FalseFunc<Params...>,
44 floatFuncPtr,
45 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000046 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000047 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010048 std::forward<Params>(params)...);
49}
50
51} // anonymous namespace
52
James Conroy4d1ff582019-06-10 17:06:39 +010053namespace
54{
55
56std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
57 unsigned int actual,
58 std::string& layerStr,
59 std::string& tensorName)
60{
61 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
62 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
63
64 return errorMsg;
65}
66
67} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000068
Sadik Armagan9199e582019-09-05 17:35:31 +010069bool RefLayerSupport::IsAbsSupported(const TensorInfo& input, const TensorInfo& output,
70 Optional<std::string&> reasonIfUnsupported) const
71{
72 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +010073 std::array<DataType,4> supportedTypes =
Sadik Armagan9199e582019-09-05 17:35:31 +010074 {
75 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +010076 DataType::Float16,
Sadik Armagan9199e582019-09-05 17:35:31 +010077 DataType::QuantisedAsymm8,
78 DataType::QuantisedSymm16
79 };
80
81 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
82 "Reference abs: input type not supported");
83
84 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
85 "Reference abs: output type not supported");
86
87 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
88 "Reference abs: input and output types not matching");
89
90 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
91 "Reference abs: input and output shapes have different number of total elements");
92
93 return supported;
94}
95
arovir011c7c81b2018-10-08 11:34:28 +010096bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
97 const TensorInfo& output,
98 const ActivationDescriptor& descriptor,
99 Optional<std::string&> reasonIfUnsupported) const
100{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000101 bool supported = true;
102
103 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100104 std::array<DataType,4> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000105 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100106 DataType::Float16,
Teresa Charlin18515e22019-04-24 10:17:46 +0100107 DataType::QuantisedAsymm8,
108 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000109 };
110
111 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
112 "Reference activation: input type not supported.");
113
114 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
115 "Reference activation: output type not supported.");
116
117 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
118 "Reference activation: input and output types mismatched.");
119
120 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
121 "Reference activation: input and output shapes are of different rank.");
122
123
124 struct ActivationFunctionSupported : public Rule
125 {
126 ActivationFunctionSupported(const ActivationDescriptor& desc)
127 {
128 switch(desc.m_Function)
129 {
130 case ActivationFunction::Abs:
131 case ActivationFunction::BoundedReLu:
132 case ActivationFunction::LeakyReLu:
133 case ActivationFunction::Linear:
134 case ActivationFunction::ReLu:
135 case ActivationFunction::Sigmoid:
136 case ActivationFunction::SoftReLu:
137 case ActivationFunction::Sqrt:
138 case ActivationFunction::Square:
139 case ActivationFunction::TanH:
140 {
141 m_Res = true;
142 break;
143 }
144 default:
145 {
146 m_Res = false;
147 break;
148 }
149 }
150 }
151 };
152
153 // Function is supported
154 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
155 "Reference activation: function not supported.");
156
157 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100158}
159
160bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
161 const TensorInfo& input1,
162 const TensorInfo& output,
163 Optional<std::string&> reasonIfUnsupported) const
164{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000165 bool supported = true;
166
Matthew Jackson252df3a2019-09-11 09:19:18 +0100167 std::array<DataType,4> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000168 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100169 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +0100170 DataType::QuantisedAsymm8,
171 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000172 };
173
174 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
175 "Reference addition: input 0 is not a supported type.");
176
177 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
178 "Reference addition: input 1 is not a supported type.");
179
180 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
181 "Reference addition: output is not a supported type.");
182
183 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
184 "Reference addition: input 0 and Input 1 types are mismatched");
185
186 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
187 "Reference addition: input and output types are mismatched");
188
189 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
190 "Reference addition: shapes are not suitable for implicit broadcast.");
191
192 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100193}
194
Nikhil Raj68c2c902019-09-19 11:21:11 +0100195bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
196 const armnn::ArgMinMaxDescriptor &descriptor,
197 armnn::Optional<std::string &> reasonIfUnsupported) const
198{
199 ignore_unused(descriptor);
200
201 std::array<DataType, 3> supportedTypes =
202 {
203 DataType::Float32,
204 DataType::QuantisedAsymm8,
205 DataType::QuantisedSymm16
206 };
207
208 bool supported = true;
209
210 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
211 "Reference ArgMinMax: input is not a supported type.");
212 supported &= CheckSupportRule(TypeIs(output, DataType::Signed32), reasonIfUnsupported,
213 "Reference ArgMinMax: output type not supported");
214
215 return supported;
216}
217
arovir011c7c81b2018-10-08 11:34:28 +0100218bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
219 const TensorInfo& output,
220 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100221 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100222 const TensorInfo& beta,
223 const TensorInfo& gamma,
224 const BatchNormalizationDescriptor& descriptor,
225 Optional<std::string&> reasonIfUnsupported) const
226{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100227 ignore_unused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100228
Matthew Jackson9bff1442019-09-12 09:08:23 +0100229 std::array<DataType, 4> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100230 {
231 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100232 DataType::Float16,
Matteo Martincighf5507132019-06-04 10:59:47 +0100233 DataType::QuantisedAsymm8,
234 DataType::QuantisedSymm16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100235 };
236
237 bool supported = true;
238
239 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
240 "Reference batch normalization: input is not a supported type.");
241
242 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
243 "Reference batch normalization: output is not a supported type.");
244
245 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
246 "Reference batch normalization: input and output types are mismatched");
247
248 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
249 "Reference batch normalization: mean is not a supported type.");
250
251 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
252 "Reference batch normalization: variance is not a supported type.");
253
254 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
255 "Reference batch normalization: beta is not a supported type.");
256
257 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
258 "Reference batch normalization: gamma is not a supported type.");
259
260 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100261}
262
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000263bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
264 const TensorInfo& output,
265 const BatchToSpaceNdDescriptor& descriptor,
266 Optional<std::string&> reasonIfUnsupported) const
267{
268 ignore_unused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100269
270 bool supported = true;
271
272 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
273 std::string inputTensorStr = "input";
274 std::string outputTensorStr = "output";
275
276 // Define supported types.
Matthew Jackson9bff1442019-09-12 09:08:23 +0100277 std::array<DataType,4> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100278 {
279 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100280 DataType::Float16,
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100281 DataType::QuantisedAsymm8,
282 DataType::QuantisedSymm16
283 };
284
285 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
286 "Reference BatchToSpaceNd: input type not supported.");
287
288 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
289 "Reference BatchToSpaceNd: output type not supported.");
290
291 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
292 "Reference BatchToSpaceNd: input and output types mismatched.");
293
294 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
295 reasonIfUnsupported,
296 CreateIncorrectDimensionsErrorMsg(4,
297 output.GetNumDimensions(),
298 batchToSpaceNdLayerStr,
299 outputTensorStr).data());
300
301 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
302 reasonIfUnsupported,
303 CreateIncorrectDimensionsErrorMsg(4,
304 input.GetNumDimensions(),
305 batchToSpaceNdLayerStr,
306 inputTensorStr).data());
307
308 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000309}
310
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100311bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
312 const TensorInfo& input1,
313 const TensorInfo& output,
314 const ComparisonDescriptor& descriptor,
315 Optional<std::string&> reasonIfUnsupported) const
316{
317 boost::ignore_unused(descriptor);
318
319 std::array<DataType, 4> supportedInputTypes =
320 {
321 DataType::Float32,
322 DataType::Float16,
323 DataType::QuantisedAsymm8,
324 DataType::QuantisedSymm16
325 };
326
327 bool supported = true;
328 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
329 "Reference comparison: input 0 is not a supported type");
330
331 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
332 "Reference comparison: input 0 and Input 1 types are mismatched");
333
334 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
335 "Reference comparison: output is not of type Boolean");
336
337 return supported;
338}
339
Jim Flynn906f9462019-05-10 13:55:21 +0100340bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
341 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100342 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100343 Optional<std::string&> reasonIfUnsupported) const
344{
Jim Flynne242f2d2019-05-22 14:24:13 +0100345 ignore_unused(descriptor);
346
347 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +0100348 std::array<DataType,4> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100349 {
350 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100351 DataType::Float16,
Jim Flynne242f2d2019-05-22 14:24:13 +0100352 DataType::QuantisedAsymm8,
353 DataType::QuantisedSymm16
354 };
355
356 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
357 "Reference concatenation: output type not supported");
358 for (const TensorInfo* input : inputs)
359 {
Matthew Jackson81e601c2019-07-11 12:07:09 +0100360 BOOST_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100361 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
362 "Reference concatenation: input type not supported");
363
364 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
365 "Reference concatenation: input and output types mismatched.");
366 }
367
368 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100369}
370
arovir011c7c81b2018-10-08 11:34:28 +0100371bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
372 Optional<std::string&> reasonIfUnsupported) const
373{
Jim Flynne242f2d2019-05-22 14:24:13 +0100374 std::array<DataType,4> supportedTypes =
375 {
Nina Drozd58ef2c62019-05-16 12:09:18 +0100376 DataType::Float32,
377 DataType::Signed32,
378 DataType::QuantisedAsymm8,
379 DataType::QuantisedSymm16
380 };
381
382 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
383 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100384}
385
386bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
387 const TensorInfo& output,
388 Optional<std::string&> reasonIfUnsupported) const
389{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100390 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
391 input.GetDataType(),
392 &TrueFunc<>,
393 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000394 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000395 &FalseFuncI32<>,
396 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100397 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
398 output.GetDataType(),
399 &FalseOutputFuncF16<>,
400 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000401 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000402 &FalseFuncI32<>,
403 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100404}
405
406bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
407 const TensorInfo& output,
408 Optional<std::string&> reasonIfUnsupported) const
409{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100410 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
411 input.GetDataType(),
412 &FalseInputFuncF16<>,
413 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000414 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000415 &FalseFuncI32<>,
416 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100417 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
418 output.GetDataType(),
419 &TrueFunc<>,
420 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000421 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000422 &FalseFuncI32<>,
423 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100424}
425
426bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
427 const TensorInfo& output,
428 const Convolution2dDescriptor& descriptor,
429 const TensorInfo& weights,
430 const Optional<TensorInfo>& biases,
431 Optional<std::string&> reasonIfUnsupported) const
432{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100433 bool supported = true;
434
435 // Define supported types.
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000436 std::array<DataType,4> supportedTypes =
437 {
438 DataType::Float32,
439 DataType::Float16,
440 DataType::QuantisedAsymm8,
441 DataType::QuantisedSymm16
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100442 };
443
444 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100445 "Reference convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100446
447 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100448 "Reference convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100449
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100450 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100451 "Reference convolution2d: input and output types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100452
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000453 const DataType inputType = input.GetDataType();
454 if (inputType == DataType::QuantisedAsymm8)
455 {
456 std::array<DataType, 2> supportedWeightTypes =
457 {
458 DataType::QuantisedAsymm8,
459 DataType::QuantizedSymm8PerAxis
460 };
461
462 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
463 "Reference convolution2d: weights type not supported for quantized input.");
464 }
465 else
466 {
467 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
468 "Reference convolution2d: weights is not a supported type.");
469
470 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
471 "Reference convolution2d: input and weights types mismatched.");
472 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100473
474 if (biases.has_value())
475 {
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000476 std::array<DataType,3> biasesSupportedTypes =
477 {
478 DataType::Float32,
479 DataType::Float16,
480 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100481 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000482
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100483 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100484 "Reference convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100485 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100486 ignore_unused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100487
488 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100489}
490
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000491bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
492 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000493 Optional<std::string&> reasonIfUnsupported) const
494{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100495 bool supported = true;
496
497 std::array<DataType,3> supportedTypes =
498 {
499 DataType::Float32,
500 DataType::QuantisedAsymm8,
501 DataType::QuantisedSymm16
502 };
503
504 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
505 "Reference debug: input type not supported");
506
507 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
508 "Reference debug: output type not supported");
509
510 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
511 "Reference debug: input and output types are mismatched");
512
513 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000514}
515
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100516bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
517 const TensorInfo& output,
518 const DepthToSpaceDescriptor& descriptor,
519 Optional<std::string&> reasonIfUnsupported) const
520{
521 ignore_unused(descriptor);
522 bool supported = true;
523
524 std::array<DataType,4> supportedTypes =
525 {
526 DataType::Float32,
527 DataType::Float16,
528 DataType::QuantisedAsymm8,
529 DataType::QuantisedSymm16
530 };
531
532 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
533 "Reference DepthToSpace: input type not supported");
534
535 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
536 "Reference DepthToSpace: output type not supported");
537
538 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
539 "Reference DepthToSpace: input and output types are mismatched");
540
541 return supported;
542}
543
arovir011c7c81b2018-10-08 11:34:28 +0100544bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
545 const TensorInfo& output,
546 const DepthwiseConvolution2dDescriptor& descriptor,
547 const TensorInfo& weights,
548 const Optional<TensorInfo>& biases,
549 Optional<std::string&> reasonIfUnsupported) const
550{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100551 bool supported = true;
552
553 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100554 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100555 {
556 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100557 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100558 DataType::QuantisedAsymm8,
559 DataType::QuantisedSymm16
560 };
561
562 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
563 "Reference DepthwiseConvolution2d: input is not a supported type.");
564
565 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
566 "Reference DepthwiseConvolution2d: output is not a supported type.");
567
568 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
569 "Reference DepthwiseConvolution2d: weights is not a supported type.");
570
571 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
572 "Reference DepthwiseConvolution2d: input and output types mismatched.");
573
574 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
575 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
576
577 if (biases.has_value())
578 {
Matthew Jackson252df3a2019-09-11 09:19:18 +0100579 std::array<DataType,3> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100580 {
581 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100582 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100583 DataType::Signed32
584 };
585 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
586 "Reference DepthwiseConvolution2d: biases is not a supported type.");
587 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100588 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100589
590 return supported;
591
arovir011c7c81b2018-10-08 11:34:28 +0100592}
593
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000594bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
595 const TensorInfo& output,
596 Optional<std::string&> reasonIfUnsupported) const
597{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100598 bool supported = true;
599
600 std::array<DataType,2> supportedInputTypes = {
601 DataType::QuantisedAsymm8,
602 DataType::QuantisedSymm16
603 };
604
605 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
606 "Reference dequantize: input type not supported.");
607
Jan Eilersf7107932019-11-01 11:09:36 +0000608 std::array<DataType,2> supportedOutputTypes = {
609 DataType::Float32,
610 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100611 };
612
613 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
614 "Reference dequantize: output type not supported.");
615
616 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
617 "Reference dequantize: input and output shapes have different num total elements.");
618
619 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000620}
621
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000622bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
623 const armnn::TensorInfo& input1,
624 const armnn::DetectionPostProcessDescriptor& descriptor,
625 armnn::Optional<std::string&> reasonIfUnsupported) const
626{
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100627 bool supported = true;
628
Mike Kelly4992c342019-08-14 11:33:11 +0100629 std::array<DataType,3> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100630 {
631 DataType::Float32,
632 DataType::QuantisedAsymm8,
633 DataType::QuantisedSymm16
634 };
635
636 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
637 "Reference DetectionPostProcess: input 0 is not a supported type.");
638
639 supported &= CheckSupportRule(TypeAnyOf(input1, supportedInputTypes), reasonIfUnsupported,
640 "Reference DetectionPostProcess: input 1 is not a supported type.");
641
642 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000643}
644
Pablo Tellof0bd6832019-04-26 17:58:13 +0100645bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
646 const TensorInfo& output,
647 const DepthwiseConvolution2dDescriptor& descriptor,
648 const TensorInfo& weights,
649 const Optional<TensorInfo>& biases,
650 Optional<std::string&> reasonIfUnsupported) const
651{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100652 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100653}
654
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100655bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100656 const TensorInfo& input1,
657 const TensorInfo& output,
658 Optional<std::string&> reasonIfUnsupported) const
659{
Sadik Armagan2999a022019-04-09 14:20:12 +0100660 bool supported = true;
661
Matthew Jackson9bff1442019-09-12 09:08:23 +0100662 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +0100663 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100664 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +0100665 DataType::QuantisedAsymm8,
666 DataType::QuantisedSymm16
667 };
668
669 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
670 "Reference division: input 0 is not a supported type.");
671
672 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
673 "Reference division: input 1 is not a supported type.");
674
675 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
676 "Reference division: output is not a supported type.");
677
678 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
679 "Reference division: input 0 and Input 1 types are mismatched");
680
681 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
682 "Reference division: input and output types are mismatched");
683
684 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
685 "Reference division: shapes are not suitable for implicit broadcast.");
686
687 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100688}
689
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000690bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
691 const TensorInfo& input1,
692 const TensorInfo& output,
693 Optional<std::string&> reasonIfUnsupported) const
694{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100695 return IsComparisonSupported(input0,
696 input1,
697 output,
698 ComparisonDescriptor(ComparisonOperation::Equal),
699 reasonIfUnsupported);
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000700}
701
arovir011c7c81b2018-10-08 11:34:28 +0100702bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
703 const FakeQuantizationDescriptor& descriptor,
704 Optional<std::string&> reasonIfUnsupported) const
705{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100706 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100707 bool supported = true;
708
709 std::array<DataType,1> supportedTypes =
710 {
711 DataType::Float32
712 };
713
714 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
715 "Reference fake quantization: input type not supported.");
716
717 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100718}
719
720bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
721 const TensorInfo& output,
722 Optional<std::string&> reasonIfUnsupported) const
723{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100724 ignore_unused(output);
James Conroy83735b12019-05-30 16:36:59 +0100725 bool supported = true;
726
Matthew Jackson9bff1442019-09-12 09:08:23 +0100727 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100728 {
James Conroyb40d7102019-06-04 12:32:09 +0100729 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100730 DataType::Float16,
James Conroyb40d7102019-06-04 12:32:09 +0100731 DataType::QuantisedSymm16
James Conroy83735b12019-05-30 16:36:59 +0100732 };
733
734 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
735 "Reference Floor: input type not supported.");
736
737 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
738 "Reference Floor: output type not supported.");
739
740 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100741}
742
743bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
744 const TensorInfo& output,
745 const TensorInfo& weights,
746 const TensorInfo& biases,
747 const FullyConnectedDescriptor& descriptor,
748 Optional<std::string&> reasonIfUnsupported) const
749{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100750 bool supported = true;
751
752 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100753 std::array<DataType,4> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +0100754 {
755 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100756 DataType::Float16,
Francis Murtagh46c09d02019-05-28 08:15:28 +0100757 DataType::QuantisedAsymm8,
758 DataType::QuantisedSymm16
759 };
760
761 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
762 "Reference Fully Connected: input type not supported.");
763
764 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
765 "Reference Fully Connected: output type not supported.");
766
767 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
768 "Reference Fully Connected: input and output types mismatched.");
769
770 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
771 "Reference Fully Connected: weights type not supported.");
772
773 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
774 "Reference Fully Connected: input and weight types mismatched.");
775
776 if (descriptor.m_BiasEnabled)
777 {
778 // Defined supported types for bias
Matthew Jackson252df3a2019-09-11 09:19:18 +0100779 std::array<DataType, 3>
Francis Murtagh46c09d02019-05-28 08:15:28 +0100780 supportedBiasTypes =
781 {
782 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100783 DataType::Float16,
Francis Murtagh46c09d02019-05-28 08:15:28 +0100784 DataType::Signed32
785 };
786
787 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
788 "Reference Fully Connected: bias type not supported.");
789
790 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
791 "Reference Fully Connected: bias and weight types mismatch.");
792
793 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
794 "Reference Fully Connected: bias type inferred from weights is incompatible.");
795
796 }
797
798 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100799}
800
narpra014951d842019-01-18 16:53:53 +0000801bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
802 const armnn::TensorInfo& input1,
803 const armnn::TensorInfo& output,
804 armnn::Optional<std::string&> reasonIfUnsupported) const
805{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100806 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +0100807 std::array<DataType,4> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100808 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100809 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100810 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100811 DataType::QuantisedAsymm8,
812 DataType::QuantisedSymm16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100813 };
814
815 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
816 "Reference Gather: input type not supported");
817
818 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
819 "Reference Gather: output type not supported");
820
821 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
822 "Reference Gather: indices (input1) type not supported");
823
824 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
825 "Reference Gather: input and output types not matching");
826
827 return supported;
narpra014951d842019-01-18 16:53:53 +0000828}
829
FrancisMurtagh878f0232018-12-19 10:56:15 +0000830bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
831 const TensorInfo& input1,
832 const TensorInfo& output,
833 Optional<std::string&> reasonIfUnsupported) const
834{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100835 return IsComparisonSupported(input0,
836 input1,
837 output,
838 ComparisonDescriptor(ComparisonOperation::Greater),
839 reasonIfUnsupported);
FrancisMurtagh878f0232018-12-19 10:56:15 +0000840}
841
arovir011c7c81b2018-10-08 11:34:28 +0100842bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
843 Optional<std::string&> reasonIfUnsupported) const
844{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +0100845 return true;
arovir011c7c81b2018-10-08 11:34:28 +0100846}
847
Kevin May09ca49c2019-10-09 12:37:34 +0100848bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
849 const TensorInfo& output,
850 const InstanceNormalizationDescriptor& descriptor,
851 Optional<std::string&> reasonIfUnsupported) const
852{
853 ignore_unused(descriptor);
854 // Define supported types
855 std::array<DataType, 4> supportedTypes =
856 {
857 DataType::Float32,
858 DataType::Float16
859 };
860
861 bool supported = true;
862
863 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
864 "Reference Instance Normalization: input type not supported.");
865
866 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
867 "Reference Instance Normalization: output type not supported.");
868
869 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
870 "Reference Instance Normalization: input and output types mismatched.");
871
872 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
873 "Reference Instance Normalization: input and output shapes have different "
874 "num total elements.");
875
876 return supported;
877}
878
arovir011c7c81b2018-10-08 11:34:28 +0100879bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
880 const TensorInfo& output,
881 const L2NormalizationDescriptor& descriptor,
882 Optional<std::string&> reasonIfUnsupported) const
883{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100884 ignore_unused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100885 // Define supported types
Matthew Jackson252df3a2019-09-11 09:19:18 +0100886 std::array<DataType, 4> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100887 {
888 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100889 DataType::Float16,
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100890 DataType::QuantisedAsymm8,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100891 DataType::QuantisedSymm16
892 };
893
894 bool supported = true;
895
896 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
897 "Reference L2normalization: input type not supported.");
898
899 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
900 "Reference L2normalization: output type not supported.");
901
902 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
903 "Reference L2normalization: input and output types mismatched.");
904
905 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
906 "Reference L2normalization: input and output shapes have different "
907 "num total elements.");
908
909 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100910}
911
Aron Virginas-Tare662a942019-10-14 15:12:00 +0100912bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
913 const TensorInfo& output,
914 const LogSoftmaxDescriptor& descriptor,
915 Optional<std::string&> reasonIfUnsupported) const
916{
917 ignore_unused(descriptor);
918
919 std::array<DataType, 2> supportedTypes =
920 {
921 DataType::Float32,
922 DataType::Float16
923 };
924
925 bool supported = true;
926 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
927 "Reference LogSoftmax: input type not supported");
928
929 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
930 "Reference LogSoftmax: output type not supported");
931
932 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
933 "Reference LogSoftmax: input and output types do not match");
934
935 return supported;
936}
937
arovir011c7c81b2018-10-08 11:34:28 +0100938bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
939 const TensorInfo& outputStateIn,
940 const TensorInfo& cellStateIn,
941 const TensorInfo& scratchBuffer,
942 const TensorInfo& outputStateOut,
943 const TensorInfo& cellStateOut,
944 const TensorInfo& output,
945 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100946 const LstmInputParamsInfo& paramsInfo,
947 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +0100948{
telsoa01c577f2c2018-08-31 09:22:23 +0100949 ignore_unused(descriptor);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100950 ignore_unused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100951
952 bool supported = true;
953
954 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +0100955 DataType::Float32,
956 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100957 };
958
Jan Eilersd01a83c2019-07-03 18:20:40 +0100959 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100960 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
961 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100962 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
963 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100964 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
965 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100966 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
967 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100968 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
969 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100970 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
971 "Reference Lstm: input and cellStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100972 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
973 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +0100974 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +0100975 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100976 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100977 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100978 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100979 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100980 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100981 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100982 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100983 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100984 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100985 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100986 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100987 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100988 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100989 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100990 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100991 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100992 "Reference Lstm: input and OutputGateBias types are mismatched");
993 if (!descriptor.m_CifgEnabled)
994 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100995 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100996 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100997 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +0100998 reasonIfUnsupported,
999 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001000 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001001 "Reference Lstm: input and InputGateBias types are mismatched");
1002 if (descriptor.m_PeepholeEnabled)
1003 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001004 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001005 reasonIfUnsupported,
1006 "Reference Lstm: input and CellToInputWeights types are mismatched");
1007 }
1008 }
1009 if (descriptor.m_PeepholeEnabled)
1010 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001011 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001012 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001013 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001014 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1015 }
1016 if (descriptor.m_ProjectionEnabled)
1017 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001018 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001019 "Reference Lstm: input and mProjectionWeights types are mismatched");
1020 if (paramsInfo.m_ProjectionBias != nullptr)
1021 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001022 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001023 "Reference Lstm: input and ProjectionBias types are mismatched");
1024 }
1025 }
1026 if (descriptor.m_LayerNormEnabled)
1027 {
1028 if (!descriptor.m_CifgEnabled)
1029 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001030 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001031 reasonIfUnsupported,
1032 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1033 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001034 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001035 reasonIfUnsupported,
1036 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001037 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001038 reasonIfUnsupported,
1039 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001040 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001041 reasonIfUnsupported,
1042 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1043 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001044
1045 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001046}
1047
saoste012df12b32018-11-28 16:57:20 +00001048bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1049 const TensorInfo& input1,
1050 const TensorInfo& output,
1051 Optional<std::string&> reasonIfUnsupported) const
1052{
Sadik Armagan2999a022019-04-09 14:20:12 +01001053 bool supported = true;
1054
Matthew Jackson9bff1442019-09-12 09:08:23 +01001055 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001056 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001057 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001058 DataType::QuantisedAsymm8,
1059 DataType::QuantisedSymm16
1060 };
1061
1062 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1063 "Reference maximum: input 0 is not a supported type.");
1064
1065 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1066 "Reference maximum: input 1 is not a supported type.");
1067
1068 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1069 "Reference maximum: output is not a supported type.");
1070
1071 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1072 "Reference maximum: input 0 and Input 1 types are mismatched");
1073
1074 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1075 "Reference maximum: input and output types are mismatched");
1076
1077 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1078 "Reference maximum: shapes are not suitable for implicit broadcast.");
1079
1080 return supported;
saoste012df12b32018-11-28 16:57:20 +00001081}
1082
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001083bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1084 const TensorInfo& output,
1085 const MeanDescriptor& descriptor,
1086 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001087{
James Conroy4d1ff582019-06-10 17:06:39 +01001088 bool supported = true;
1089 std::string meanLayerStr = "Mean";
1090 std::string outputTensorStr = "output";
1091
Matthew Jackson252df3a2019-09-11 09:19:18 +01001092 std::array<DataType,4> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001093 {
1094 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001095 DataType::Float16,
James Conroyb80775f2019-06-11 11:25:30 +01001096 DataType::QuantisedAsymm8,
1097 DataType::QuantisedSymm16
James Conroy4d1ff582019-06-10 17:06:39 +01001098 };
1099
1100 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1101 "Reference Mean: input type not supported.");
1102
1103 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1104 "Reference Mean: input and output types are mismatched");
1105
1106 if (descriptor.m_KeepDims)
1107 {
1108 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1109 reasonIfUnsupported,
1110 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1111 output.GetNumDimensions(),
1112 meanLayerStr, outputTensorStr).data());
1113 }
1114 else if (descriptor.m_Axis.empty())
1115 {
1116 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1117 reasonIfUnsupported,
1118 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1119 meanLayerStr, outputTensorStr).data());
1120 }
1121 else
1122 {
1123 auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
1124
1125 if (outputDim > 0)
1126 {
1127 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1128 reasonIfUnsupported,
1129 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1130 meanLayerStr, outputTensorStr).data());
1131 }
1132 else
1133 {
1134 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1135 reasonIfUnsupported,
1136 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1137 meanLayerStr, outputTensorStr).data());
1138 }
1139 }
1140
1141 return supported;
narpra0132b90462018-09-13 11:07:48 +01001142}
1143
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001144bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +00001145 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +01001146 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001147 Optional<std::string&> reasonIfUnsupported) const
1148{
Jim Flynne242f2d2019-05-22 14:24:13 +01001149 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001150}
1151
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001152bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1153 const TensorInfo &output,
1154 Optional<std::string &> reasonIfUnsupported) const
1155{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001156 bool supported = true;
1157
1158 std::array<DataType,5> supportedTypes =
1159 {
1160 DataType::Float32,
1161 DataType::Float16,
1162 DataType::QuantisedAsymm8,
1163 DataType::QuantisedSymm16,
1164 DataType::Boolean
1165 };
1166
1167 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1168 "Reference MemCopy: input type not supported");
1169
1170 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1171 "Reference MemCopy: output type not supported");
1172
1173 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1174 "Reference MemCopy: input and output types are mismatched");
1175
1176 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001177}
1178
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001179bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1180 const TensorInfo& input1,
1181 const TensorInfo& output,
1182 Optional<std::string&> reasonIfUnsupported) const
1183{
Sadik Armagan2999a022019-04-09 14:20:12 +01001184 bool supported = true;
1185
Matthew Jackson9bff1442019-09-12 09:08:23 +01001186 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001187 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001188 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001189 DataType::QuantisedAsymm8,
1190 DataType::QuantisedSymm16
1191 };
1192
1193 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1194 "Reference minimum: input 0 is not a supported type.");
1195
1196 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1197 "Reference minimum: input 1 is not a supported type.");
1198
1199 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1200 "Reference minimum: output is not a supported type.");
1201
1202 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1203 "Reference minimum: input 0 and Input 1 types are mismatched");
1204
1205 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1206 "Reference minimum: input and output types are mismatched");
1207
1208 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1209 "Reference minimum: shapes are not suitable for implicit broadcast.");
1210
1211 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001212}
1213
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001214bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1215 const TensorInfo& input1,
1216 const TensorInfo& output,
1217 Optional<std::string&> reasonIfUnsupported) const
1218{
Sadik Armagan2999a022019-04-09 14:20:12 +01001219 bool supported = true;
1220
Matthew Jackson252df3a2019-09-11 09:19:18 +01001221 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001222 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001223 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001224 DataType::QuantisedAsymm8,
1225 DataType::QuantisedSymm16
1226 };
1227
1228 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1229 "Reference multiplication: input 0 is not a supported type.");
1230
1231 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1232 "Reference multiplication: input 1 is not a supported type.");
1233
1234 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1235 "Reference multiplication: output is not a supported type.");
1236
1237 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1238 "Reference multiplication: input 0 and Input 1 types are mismatched");
1239
1240 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1241 "Reference multiplication: input and output types are mismatched");
1242
1243 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1244 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1245
1246 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001247}
1248
1249bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1250 const TensorInfo& output,
1251 const NormalizationDescriptor& descriptor,
1252 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001253{
Nina Drozd661dfa72018-10-02 11:14:17 +01001254 ignore_unused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001255
1256 // Define supported types
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001257 std::array<DataType, 4> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001258 {
1259 DataType::Float16,
1260 DataType::Float32,
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001261 DataType::QuantisedAsymm8,
1262 DataType::QuantisedSymm16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001263 };
1264
1265 bool supported = true;
1266
1267 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1268 "Reference normalization: input type not supported.");
1269
1270 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1271 "Reference normalization: output type not supported.");
1272
1273 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1274 "Reference normalization: input and output shapes have different "
1275 "num total elements.");
1276
1277 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001278}
1279
1280bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
1281 Optional<std::string&> reasonIfUnsupported) const
1282{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001283 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001284}
1285
1286bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1287 const TensorInfo& output,
1288 const PadDescriptor& descriptor,
1289 Optional<std::string&> reasonIfUnsupported) const
1290{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001291 ignore_unused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001292 bool supported = true;
1293
1294 // Define supported output and inputs types.
Matthew Jackson252df3a2019-09-11 09:19:18 +01001295 std::array<DataType,4> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001296 {
1297 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001298 DataType::Float16,
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001299 DataType::QuantisedAsymm8,
1300 DataType::QuantisedSymm16
1301 };
1302
1303 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1304 "Reference pad: input is not a supported type.");
1305
1306 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1307 "Reference pad: output is not a supported type.");
1308
1309 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1310 "Reference pad: input and output types are mismatched.");
1311
1312 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01001313}
1314
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001315bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1316 const TensorInfo& output,
1317 const PermuteDescriptor& descriptor,
1318 Optional<std::string&> reasonIfUnsupported) const
1319{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001320 ignore_unused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001321 bool supported = true;
1322
1323 // Define supported output and inputs types.
1324 std::array<DataType,3> supportedTypes =
1325 {
1326 DataType::Float32,
1327 DataType::QuantisedAsymm8,
1328 DataType::QuantisedSymm16
1329 };
1330
1331 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1332 "Reference permute: input is not a supported type.");
1333
1334 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1335 "Reference permute: output is not a supported type.");
1336
1337 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1338 "Reference permute: input and output types are mismatched.");
1339
1340 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001341}
1342
1343bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1344 const TensorInfo& output,
1345 const Pooling2dDescriptor& descriptor,
1346 Optional<std::string&> reasonIfUnsupported) const
1347{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001348 ignore_unused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001349 bool supported = true;
1350
1351 // Define supported output and inputs types.
Matthew Jackson252df3a2019-09-11 09:19:18 +01001352 std::array<DataType,4> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001353 {
1354 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001355 DataType::Float16,
Teresa Charlin0434df62019-06-06 13:40:35 +01001356 DataType::QuantisedAsymm8,
1357 DataType::QuantisedSymm16
Teresa Charlina3b20472019-06-06 11:12:32 +01001358 };
1359
1360 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1361 "Reference poolind2d: input is not a supported type.");
1362
1363 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1364 "Reference poolind2d: output is not a supported type.");
1365
1366 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1367 "Reference poolind2d: input and output types are mismatched.");
1368
1369 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001370}
1371
Derek Lamberti5f400d62019-03-25 15:41:58 +00001372bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1373 const TensorInfo& output,
1374 Optional<std::string&> reasonIfUnsupported) const
1375{
1376 bool supported = true;
1377
1378 // Define supported output types.
Sadik Armagan1a816302019-07-29 17:16:40 +01001379 std::array<DataType,1> supportedInputTypes = {
Derek Lamberti5f400d62019-03-25 15:41:58 +00001380 DataType::Float32,
1381 };
1382
1383 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1384 "Reference quantize: input type not supported.");
1385
1386 // Define supported output types.
1387 std::array<DataType,2> supportedOutputTypes = {
1388 DataType::QuantisedAsymm8,
1389 DataType::QuantisedSymm16
1390 };
1391 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1392 "Reference quantize: output type not supported.");
1393
1394 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1395 "Reference quantize: input and output shapes have different num total elements.");
1396
1397 return supported;
1398}
1399
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001400bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001401 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001402 Optional<std::string&> reasonIfUnsupported) const
1403{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001404 ignore_unused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001405 // Define supported output types.
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001406 std::array<DataType,5> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001407 {
1408 DataType::Float32,
1409 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001410 DataType::Signed32,
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001411 DataType::QuantisedAsymm8,
1412 DataType::QuantisedSymm16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001413 };
1414 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1415 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001416}
1417
1418bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001419 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001420 Optional<std::string&> reasonIfUnsupported) const
1421{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001422 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001423 std::array<DataType,4> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001424 {
1425 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001426 DataType::Float16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001427 DataType::QuantisedAsymm8,
1428 DataType::QuantisedSymm16
1429 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001430
1431 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1432 "Reference ResizeBilinear: input type not supported");
1433
1434 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1435 "Reference ResizeBilinear: output type not supported");
1436
1437 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1438 "Reference ResizeBilinear: input and output types not matching");
1439
1440 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001441}
1442
Teresa Charlin970f43b2019-07-01 13:51:07 +01001443bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1444 const TensorInfo& output,
1445 const ResizeDescriptor& descriptor,
1446 Optional<std::string&> reasonIfUnsupported) const
1447{
1448 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001449 std::array<DataType,4> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001450 {
1451 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001452 DataType::Float16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001453 DataType::QuantisedAsymm8,
1454 DataType::QuantisedSymm16
1455 };
1456
1457 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1458 "Reference Resize: input type not supported");
1459
1460 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1461 "Reference Resize: output type not supported");
1462
1463 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1464 "Reference Resize: input and output types not matching");
1465
1466 return supported;
1467}
1468
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001469bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1470 const TensorInfo& output,
1471 Optional<std::string&> reasonIfUnsupported) const
1472{
nikraj010421e7f2019-06-14 09:40:34 +01001473 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001474 std::array<DataType,4> supportedTypes =
nikraj010421e7f2019-06-14 09:40:34 +01001475 {
1476 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001477 DataType::Float16,
nikraj0124d73212019-06-14 14:20:40 +01001478 DataType::QuantisedAsymm8,
1479 DataType::QuantisedSymm16
nikraj010421e7f2019-06-14 09:40:34 +01001480 };
1481
1482 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1483 "Reference rsqrt: input type not supported");
1484
1485 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1486 "Reference rsqrt: output type not supported");
1487
1488 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1489 "Reference rsqrt: input and output types not matching");
1490
1491 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1492 "Reference Rsqrt: input and output shapes have different number of total elements");
1493
1494 return supported;
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001495}
1496
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001497bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
1498 const TensorInfo& output,
1499 const SliceDescriptor& descriptor,
1500 Optional<std::string&> reasonIfUnsupported) const
1501{
1502 ignore_unused(descriptor);
1503 bool supported = true;
1504
1505 std::array<DataType, 3> supportedTypes =
1506 {
1507 DataType::Float32,
1508 DataType::QuantisedAsymm8,
1509 DataType::QuantisedSymm16
1510 };
1511
1512 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1513 "Reference Slice: input type not supported");
1514
1515 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1516 "Reference Slice: output type not supported");
1517
1518 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1519 "Reference Slice: input and output types are mismatched");
1520
1521 return supported;
1522}
1523
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001524bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1525 const TensorInfo& output,
1526 const SoftmaxDescriptor& descriptor,
1527 Optional<std::string&> reasonIfUnsupported) const
1528{
1529 ignore_unused(output);
nikraj01248683f2019-05-29 16:46:50 +01001530 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001531 std::array<DataType,4> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01001532 {
1533 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001534 DataType::Float16,
nikraj01248683f2019-05-29 16:46:50 +01001535 DataType::QuantisedAsymm8,
1536 DataType::QuantisedSymm16
1537 };
1538
1539 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001540 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001541
1542 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001543 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001544
1545 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001546 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001547
1548 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001549}
1550
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001551bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1552 const TensorInfo& output,
1553 const SpaceToBatchNdDescriptor& descriptor,
1554 Optional<std::string&> reasonIfUnsupported) const
1555{
1556 ignore_unused(output);
nikraj01120522a2019-05-31 11:33:07 +01001557 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001558 std::array<DataType,4> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01001559 {
1560 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001561 DataType::Float16,
nikraj01120522a2019-05-31 11:33:07 +01001562 DataType::QuantisedAsymm8,
1563 DataType::QuantisedSymm16
1564 };
1565
1566 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1567 "Reference SpaceToBatchNd: input type not supported");
1568
1569 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1570 "Reference SpaceToBatchNd: output type not supported");
1571
1572 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1573 "Reference SpaceToBatchNd: input and output types are mismatched");
1574
1575 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001576}
1577
Keith Davisa57eccb2019-06-14 17:33:22 +01001578bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01001579 const TensorInfo& output,
1580 const SpaceToDepthDescriptor& descriptor,
1581 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01001582{
1583
1584 ignore_unused(descriptor);
1585 bool supported = true;
1586
Matthew Jackson9bff1442019-09-12 09:08:23 +01001587 std::array<DataType,4> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01001588 {
1589 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001590 DataType::Float16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001591 DataType::QuantisedAsymm8,
James Conroyd2aa85e2019-07-01 17:12:40 +01001592 DataType::QuantisedSymm16
Keith Davisa57eccb2019-06-14 17:33:22 +01001593 };
1594
1595 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1596 "Reference SpaceToDepth: input type not supported");
1597
1598 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1599 "Reference SpaceToDepth: output type not supported");
1600
1601 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1602 "Reference SpaceToDepth: input and output types are mismatched");
1603
1604 return supported;
1605}
1606
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001607bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1608 const ViewsDescriptor& descriptor,
1609 Optional<std::string&> reasonIfUnsupported) const
1610{
1611 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001612 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001613 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001614 {
1615 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001616 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001617 DataType::QuantisedAsymm8,
1618 DataType::QuantisedSymm16
1619 };
1620
1621 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1622 "Reference splitter: input type not supported");
1623
1624 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001625}
1626
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001627bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1628 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1629 const ViewsDescriptor& descriptor,
1630 Optional<std::string&> reasonIfUnsupported) const
1631{
1632 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001633 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001634 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001635 {
1636 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001637 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001638 DataType::QuantisedAsymm8,
1639 DataType::QuantisedSymm16
1640 };
1641
1642 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1643 "Reference splitter: output type not supported");
1644 for (const TensorInfo output : outputs)
1645 {
1646 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1647 "Reference splitter: input type not supported");
1648
1649 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1650 "Reference splitter: input and output types mismatched.");
1651 }
1652
1653 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001654}
1655
Matthew Jackson81e601c2019-07-11 12:07:09 +01001656bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
1657 const TensorInfo& output,
1658 const StackDescriptor& descriptor,
1659 Optional<std::string&> reasonIfUnsupported) const
1660{
1661 ignore_unused(descriptor);
1662
1663 bool supported = true;
Matthew Jacksone69c3992019-09-09 14:31:21 +01001664 std::array<DataType,4> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01001665 {
1666 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01001667 DataType::Float16,
Matthew Jackson81e601c2019-07-11 12:07:09 +01001668 DataType::QuantisedAsymm8,
1669 DataType::QuantisedSymm16
1670 };
1671
1672 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1673 "Reference stack: output type not supported");
1674 for (const TensorInfo* input : inputs)
1675 {
1676 BOOST_ASSERT(input != nullptr);
1677 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
1678 "Reference stack: input type not supported");
1679
1680 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
1681 "Reference stack: input and output types mismatched.");
1682 }
1683
1684 return supported;
1685}
1686
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001687bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1688 const TensorInfo& output,
1689 const StridedSliceDescriptor& descriptor,
1690 Optional<std::string&> reasonIfUnsupported) const
1691{
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001692 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001693 bool supported = true;
1694
1695 std::array<DataType,3> supportedTypes =
1696 {
1697 DataType::Float32,
1698 DataType::QuantisedAsymm8,
1699 DataType::QuantisedSymm16
1700 };
1701
1702 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1703 "Reference StridedSlice: input type not supported");
1704
1705 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1706 "Reference StridedSlice: output type not supported");
1707
1708 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1709 "Reference StridedSlice: input and output types are mismatched");
1710
1711 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001712}
1713
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001714bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1715 const TensorInfo& input1,
1716 const TensorInfo& output,
1717 Optional<std::string&> reasonIfUnsupported) const
1718{
Sadik Armagan2999a022019-04-09 14:20:12 +01001719 bool supported = true;
1720
Matthew Jackson9bff1442019-09-12 09:08:23 +01001721 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001722 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001723 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001724 DataType::QuantisedAsymm8,
1725 DataType::QuantisedSymm16
1726 };
1727
1728 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1729 "Reference subtraction: input 0 is not a supported type.");
1730
1731 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1732 "Reference subtraction: input 1 is not a supported type.");
1733
1734 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1735 "Reference subtraction: output is not a supported type.");
1736
1737 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1738 "Reference subtraction: input 0 and Input 1 types are mismatched");
1739
1740 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1741 "Reference subtraction: input and output types are mismatched");
1742
1743 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1744 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1745
1746 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001747}
1748
Matteo Martincighab9e5252019-06-13 17:27:46 +01001749bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
1750 const TensorInfo& alpha,
1751 const TensorInfo& output,
1752 Optional<std::string&> reasonIfUnsupported) const
1753{
1754 bool supported = true;
1755
Matthew Jackson9bff1442019-09-12 09:08:23 +01001756 std::array<DataType, 4> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01001757 {
1758 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001759 DataType::Float16,
Matteo Martincighab9e5252019-06-13 17:27:46 +01001760 DataType::QuantisedAsymm8,
1761 DataType::QuantisedSymm16
1762 };
1763
1764 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1765 "PReLU: input is not a supported type.");
1766
1767 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
1768 "PReLU: alpha is not a supported type.");
1769
1770 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1771 "PReLU: output is not a supported type.");
1772
1773 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
1774 "PReLU: input, alpha and output types are mismatched");
1775
1776 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
1777 "PReLU: shapes are not suitable for implicit broadcast");
1778
1779 return supported;
1780}
1781
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001782bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
1783 const TensorInfo& output,
1784 const TransposeConvolution2dDescriptor& descriptor,
1785 const TensorInfo& weights,
1786 const Optional<TensorInfo>& biases,
1787 Optional<std::string&> reasonIfUnsupported) const
1788{
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001789 bool supported = true;
1790
Matthew Jackson252df3a2019-09-11 09:19:18 +01001791 std::array<DataType,4> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001792 {
1793 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001794 DataType::Float16,
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001795 DataType::QuantisedAsymm8,
1796 DataType::QuantisedSymm16
1797 };
1798
1799 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1800 "Reference TransposeConvolution2d: input is not a supported type.");
1801
1802 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1803 "Reference TransposeConvolution2d: output is not a supported type.");
1804
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001805 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1806 "Reference TransposeConvolution2d: input and output types mismatched.");
1807
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001808
1809 const DataType inputType = input.GetDataType();
1810 if (inputType == DataType::QuantisedAsymm8)
1811 {
1812 std::array<DataType, 2> supportedWeightTypes =
1813 {
1814 DataType::QuantisedAsymm8,
1815 DataType::QuantizedSymm8PerAxis
1816 };
1817
1818 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1819 "Reference TransposeConvolution2d: weights type not supported for "
1820 "quantized input.");
1821 }
1822 else
1823 {
1824 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1825 "Reference TransposeConvolution2d: weights is not a supported type.");
1826
1827 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1828 "Reference TransposeConvolution2d: input and weights types mismatched.");
1829 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001830
1831 if (biases.has_value())
1832 {
Matthew Jackson252df3a2019-09-11 09:19:18 +01001833 std::array<DataType,3> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01001834 {
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001835 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001836 DataType::Float16,
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001837 DataType::Signed32
1838 };
1839 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1840 "Reference TransposeConvolution2d: biases is not a supported type.");
1841 }
1842
1843 return supported;
1844}
1845
arovir011c7c81b2018-10-08 11:34:28 +01001846} // namespace armnn