blob: ebcd1f633e44bb08c87cab2ec1d6756959eea5d9 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000010#include <armnn/Descriptors.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000011#include <armnn/BackendRegistry.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
Matteo Martincighe011d202019-11-28 11:35:47 +000013#include <armnnUtils/DataLayoutIndexed.hpp>
14
15#include <InternalTypes.hpp>
16#include <LayerSupportCommon.hpp>
17
Derek Lambertif674aa02019-08-01 15:56:25 +010018#include <backendsCommon/LayerSupportRules.hpp>
Matteo Martincighe011d202019-11-28 11:35:47 +000019
Matteo Martincighe011d202019-11-28 11:35:47 +000020#include <boost/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000022
Derek Lamberti50db4e82019-03-13 14:16:15 +000023#include <vector>
24#include <algorithm>
25#include <array>
26
telsoa014fcda012018-03-09 14:13:49 +000027using namespace boost;
28
29namespace armnn
30{
31
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010032namespace
33{
34
35template<typename Float32Func, typename Uint8Func, typename ... Params>
36bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
37 DataType dataType,
38 Float32Func floatFuncPtr,
39 Uint8Func uint8FuncPtr,
40 Params&&... params)
41{
42 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
43 dataType,
44 &FalseFunc<Params...>,
45 floatFuncPtr,
46 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000047 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000048 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010049 std::forward<Params>(params)...);
50}
51
52} // anonymous namespace
53
James Conroy4d1ff582019-06-10 17:06:39 +010054namespace
55{
56
57std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
58 unsigned int actual,
59 std::string& layerStr,
60 std::string& tensorName)
61{
62 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
63 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
64
65 return errorMsg;
66}
67
68} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000069
Sadik Armagan9199e582019-09-05 17:35:31 +010070bool RefLayerSupport::IsAbsSupported(const TensorInfo& input, const TensorInfo& output,
71 Optional<std::string&> reasonIfUnsupported) const
72{
73 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +010074 std::array<DataType,4> supportedTypes =
Sadik Armagan9199e582019-09-05 17:35:31 +010075 {
76 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +010077 DataType::Float16,
Sadik Armagan9199e582019-09-05 17:35:31 +010078 DataType::QuantisedAsymm8,
79 DataType::QuantisedSymm16
80 };
81
82 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
83 "Reference abs: input type not supported");
84
85 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
86 "Reference abs: output type not supported");
87
88 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
89 "Reference abs: input and output types not matching");
90
91 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
92 "Reference abs: input and output shapes have different number of total elements");
93
94 return supported;
95}
96
arovir011c7c81b2018-10-08 11:34:28 +010097bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
98 const TensorInfo& output,
99 const ActivationDescriptor& descriptor,
100 Optional<std::string&> reasonIfUnsupported) const
101{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000102 bool supported = true;
103
104 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100105 std::array<DataType,4> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000106 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100107 DataType::Float16,
Teresa Charlin18515e22019-04-24 10:17:46 +0100108 DataType::QuantisedAsymm8,
109 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000110 };
111
112 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
113 "Reference activation: input type not supported.");
114
115 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
116 "Reference activation: output type not supported.");
117
118 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
119 "Reference activation: input and output types mismatched.");
120
121 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
122 "Reference activation: input and output shapes are of different rank.");
123
124
125 struct ActivationFunctionSupported : public Rule
126 {
127 ActivationFunctionSupported(const ActivationDescriptor& desc)
128 {
129 switch(desc.m_Function)
130 {
131 case ActivationFunction::Abs:
132 case ActivationFunction::BoundedReLu:
133 case ActivationFunction::LeakyReLu:
134 case ActivationFunction::Linear:
135 case ActivationFunction::ReLu:
136 case ActivationFunction::Sigmoid:
137 case ActivationFunction::SoftReLu:
138 case ActivationFunction::Sqrt:
139 case ActivationFunction::Square:
140 case ActivationFunction::TanH:
141 {
142 m_Res = true;
143 break;
144 }
145 default:
146 {
147 m_Res = false;
148 break;
149 }
150 }
151 }
152 };
153
154 // Function is supported
155 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
156 "Reference activation: function not supported.");
157
158 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100159}
160
161bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
162 const TensorInfo& input1,
163 const TensorInfo& output,
164 Optional<std::string&> reasonIfUnsupported) const
165{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000166 bool supported = true;
167
Matthew Jackson252df3a2019-09-11 09:19:18 +0100168 std::array<DataType,4> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000169 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100170 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +0100171 DataType::QuantisedAsymm8,
172 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000173 };
174
175 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
176 "Reference addition: input 0 is not a supported type.");
177
178 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
179 "Reference addition: input 1 is not a supported type.");
180
181 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
182 "Reference addition: output is not a supported type.");
183
184 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
185 "Reference addition: input 0 and Input 1 types are mismatched");
186
187 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
188 "Reference addition: input and output types are mismatched");
189
190 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
191 "Reference addition: shapes are not suitable for implicit broadcast.");
192
193 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100194}
195
Nikhil Raj68c2c902019-09-19 11:21:11 +0100196bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
197 const armnn::ArgMinMaxDescriptor &descriptor,
198 armnn::Optional<std::string &> reasonIfUnsupported) const
199{
200 ignore_unused(descriptor);
201
Francis Murtagh1939df52019-11-13 15:21:09 +0000202 std::array<DataType, 4> supportedTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100203 {
204 DataType::Float32,
205 DataType::QuantisedAsymm8,
Francis Murtagh1939df52019-11-13 15:21:09 +0000206 DataType::QuantisedSymm16,
207 DataType::Signed32
Nikhil Raj68c2c902019-09-19 11:21:11 +0100208 };
209
210 bool supported = true;
211
212 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
213 "Reference ArgMinMax: input is not a supported type.");
214 supported &= CheckSupportRule(TypeIs(output, DataType::Signed32), reasonIfUnsupported,
215 "Reference ArgMinMax: output type not supported");
216
217 return supported;
218}
219
arovir011c7c81b2018-10-08 11:34:28 +0100220bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
221 const TensorInfo& output,
222 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100223 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100224 const TensorInfo& beta,
225 const TensorInfo& gamma,
226 const BatchNormalizationDescriptor& descriptor,
227 Optional<std::string&> reasonIfUnsupported) const
228{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100229 ignore_unused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100230
Matthew Jackson9bff1442019-09-12 09:08:23 +0100231 std::array<DataType, 4> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100232 {
233 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100234 DataType::Float16,
Matteo Martincighf5507132019-06-04 10:59:47 +0100235 DataType::QuantisedAsymm8,
236 DataType::QuantisedSymm16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100237 };
238
239 bool supported = true;
240
241 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
242 "Reference batch normalization: input is not a supported type.");
243
244 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
245 "Reference batch normalization: output is not a supported type.");
246
247 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
248 "Reference batch normalization: input and output types are mismatched");
249
250 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
251 "Reference batch normalization: mean is not a supported type.");
252
253 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
254 "Reference batch normalization: variance is not a supported type.");
255
256 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
257 "Reference batch normalization: beta is not a supported type.");
258
259 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
260 "Reference batch normalization: gamma is not a supported type.");
261
262 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100263}
264
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000265bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
266 const TensorInfo& output,
267 const BatchToSpaceNdDescriptor& descriptor,
268 Optional<std::string&> reasonIfUnsupported) const
269{
270 ignore_unused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100271
272 bool supported = true;
273
274 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
275 std::string inputTensorStr = "input";
276 std::string outputTensorStr = "output";
277
278 // Define supported types.
Matthew Jackson9bff1442019-09-12 09:08:23 +0100279 std::array<DataType,4> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100280 {
281 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100282 DataType::Float16,
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100283 DataType::QuantisedAsymm8,
284 DataType::QuantisedSymm16
285 };
286
287 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
288 "Reference BatchToSpaceNd: input type not supported.");
289
290 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
291 "Reference BatchToSpaceNd: output type not supported.");
292
293 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
294 "Reference BatchToSpaceNd: input and output types mismatched.");
295
296 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
297 reasonIfUnsupported,
298 CreateIncorrectDimensionsErrorMsg(4,
299 output.GetNumDimensions(),
300 batchToSpaceNdLayerStr,
301 outputTensorStr).data());
302
303 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
304 reasonIfUnsupported,
305 CreateIncorrectDimensionsErrorMsg(4,
306 input.GetNumDimensions(),
307 batchToSpaceNdLayerStr,
308 inputTensorStr).data());
309
310 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000311}
312
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100313bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
314 const TensorInfo& input1,
315 const TensorInfo& output,
316 const ComparisonDescriptor& descriptor,
317 Optional<std::string&> reasonIfUnsupported) const
318{
319 boost::ignore_unused(descriptor);
320
321 std::array<DataType, 4> supportedInputTypes =
322 {
323 DataType::Float32,
324 DataType::Float16,
325 DataType::QuantisedAsymm8,
326 DataType::QuantisedSymm16
327 };
328
329 bool supported = true;
330 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
331 "Reference comparison: input 0 is not a supported type");
332
333 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
334 "Reference comparison: input 0 and Input 1 types are mismatched");
335
336 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
337 "Reference comparison: output is not of type Boolean");
338
339 return supported;
340}
341
Jim Flynn906f9462019-05-10 13:55:21 +0100342bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
343 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100344 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100345 Optional<std::string&> reasonIfUnsupported) const
346{
Jim Flynne242f2d2019-05-22 14:24:13 +0100347 ignore_unused(descriptor);
348
349 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +0100350 std::array<DataType,4> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100351 {
352 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100353 DataType::Float16,
Jim Flynne242f2d2019-05-22 14:24:13 +0100354 DataType::QuantisedAsymm8,
355 DataType::QuantisedSymm16
356 };
357
358 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
359 "Reference concatenation: output type not supported");
360 for (const TensorInfo* input : inputs)
361 {
Matthew Jackson81e601c2019-07-11 12:07:09 +0100362 BOOST_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100363 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
364 "Reference concatenation: input type not supported");
365
366 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
367 "Reference concatenation: input and output types mismatched.");
368 }
369
370 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100371}
372
arovir011c7c81b2018-10-08 11:34:28 +0100373bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
374 Optional<std::string&> reasonIfUnsupported) const
375{
Jim Flynne242f2d2019-05-22 14:24:13 +0100376 std::array<DataType,4> supportedTypes =
377 {
Nina Drozd58ef2c62019-05-16 12:09:18 +0100378 DataType::Float32,
379 DataType::Signed32,
380 DataType::QuantisedAsymm8,
381 DataType::QuantisedSymm16
382 };
383
384 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
385 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100386}
387
388bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
389 const TensorInfo& output,
390 Optional<std::string&> reasonIfUnsupported) const
391{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100392 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
393 input.GetDataType(),
394 &TrueFunc<>,
395 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000396 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000397 &FalseFuncI32<>,
398 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100399 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
400 output.GetDataType(),
401 &FalseOutputFuncF16<>,
402 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000403 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000404 &FalseFuncI32<>,
405 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100406}
407
408bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
409 const TensorInfo& output,
410 Optional<std::string&> reasonIfUnsupported) const
411{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100412 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
413 input.GetDataType(),
414 &FalseInputFuncF16<>,
415 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000416 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000417 &FalseFuncI32<>,
418 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100419 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
420 output.GetDataType(),
421 &TrueFunc<>,
422 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000423 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000424 &FalseFuncI32<>,
425 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100426}
427
428bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
429 const TensorInfo& output,
430 const Convolution2dDescriptor& descriptor,
431 const TensorInfo& weights,
432 const Optional<TensorInfo>& biases,
433 Optional<std::string&> reasonIfUnsupported) const
434{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100435 bool supported = true;
436
437 // Define supported types.
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000438 std::array<DataType,4> supportedTypes =
439 {
440 DataType::Float32,
441 DataType::Float16,
442 DataType::QuantisedAsymm8,
443 DataType::QuantisedSymm16
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100444 };
445
446 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100447 "Reference convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100448
449 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100450 "Reference convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100451
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100452 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100453 "Reference convolution2d: input and output types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100454
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000455 const DataType inputType = input.GetDataType();
456 if (inputType == DataType::QuantisedAsymm8)
457 {
458 std::array<DataType, 2> supportedWeightTypes =
459 {
460 DataType::QuantisedAsymm8,
461 DataType::QuantizedSymm8PerAxis
462 };
463
464 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
465 "Reference convolution2d: weights type not supported for quantized input.");
466 }
467 else
468 {
469 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
470 "Reference convolution2d: weights is not a supported type.");
471
472 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
473 "Reference convolution2d: input and weights types mismatched.");
474 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100475
476 if (biases.has_value())
477 {
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000478 std::array<DataType,3> biasesSupportedTypes =
479 {
480 DataType::Float32,
481 DataType::Float16,
482 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100483 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000484
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100485 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100486 "Reference convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100487 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100488 ignore_unused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100489
490 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100491}
492
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000493bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
494 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000495 Optional<std::string&> reasonIfUnsupported) const
496{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100497 bool supported = true;
498
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +0000499 std::array<DataType, 4> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100500 {
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +0000501 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100502 DataType::Float32,
503 DataType::QuantisedAsymm8,
504 DataType::QuantisedSymm16
505 };
506
507 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
508 "Reference debug: input type not supported");
509
510 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
511 "Reference debug: output type not supported");
512
513 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
514 "Reference debug: input and output types are mismatched");
515
516 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000517}
518
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100519bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
520 const TensorInfo& output,
521 const DepthToSpaceDescriptor& descriptor,
522 Optional<std::string&> reasonIfUnsupported) const
523{
524 ignore_unused(descriptor);
525 bool supported = true;
526
527 std::array<DataType,4> supportedTypes =
528 {
529 DataType::Float32,
530 DataType::Float16,
531 DataType::QuantisedAsymm8,
532 DataType::QuantisedSymm16
533 };
534
535 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
536 "Reference DepthToSpace: input type not supported");
537
538 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
539 "Reference DepthToSpace: output type not supported");
540
541 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
542 "Reference DepthToSpace: input and output types are mismatched");
543
544 return supported;
545}
546
arovir011c7c81b2018-10-08 11:34:28 +0100547bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
548 const TensorInfo& output,
549 const DepthwiseConvolution2dDescriptor& descriptor,
550 const TensorInfo& weights,
551 const Optional<TensorInfo>& biases,
552 Optional<std::string&> reasonIfUnsupported) const
553{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100554 bool supported = true;
555
556 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100557 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100558 {
559 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100560 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100561 DataType::QuantisedAsymm8,
562 DataType::QuantisedSymm16
563 };
564
565 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
566 "Reference DepthwiseConvolution2d: input is not a supported type.");
567
568 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
569 "Reference DepthwiseConvolution2d: output is not a supported type.");
570
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100571 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
572 "Reference DepthwiseConvolution2d: input and output types mismatched.");
573
Teresa Charlind8df0262019-11-11 12:28:15 +0000574 const DataType inputType = input.GetDataType();
575 if (inputType == DataType::QuantisedAsymm8)
576 {
577 std::array<DataType, 2> supportedWeightTypes =
578 {
579 DataType::QuantisedAsymm8,
580 DataType::QuantizedSymm8PerAxis
581 };
582
583 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
584 "Reference convolution2d: weights type not supported for quantized input.");
585 }
586 else
587 {
588 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
589 "Reference DepthwiseConvolution2d: weights is not a supported type.");
590
591 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
592 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
593 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100594
595 if (biases.has_value())
596 {
Matthew Jackson252df3a2019-09-11 09:19:18 +0100597 std::array<DataType,3> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100598 {
599 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100600 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100601 DataType::Signed32
602 };
603 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
604 "Reference DepthwiseConvolution2d: biases is not a supported type.");
605 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100606 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100607
608 return supported;
609
arovir011c7c81b2018-10-08 11:34:28 +0100610}
611
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000612bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
613 const TensorInfo& output,
614 Optional<std::string&> reasonIfUnsupported) const
615{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100616 bool supported = true;
617
Finn Williamsfd271062019-12-04 14:27:27 +0000618 std::array<DataType,3> supportedInputTypes = {
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100619 DataType::QuantisedAsymm8,
Finn Williamsfd271062019-12-04 14:27:27 +0000620 DataType::QSymmS8,
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100621 DataType::QuantisedSymm16
622 };
623
624 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
625 "Reference dequantize: input type not supported.");
626
Jan Eilersf7107932019-11-01 11:09:36 +0000627 std::array<DataType,2> supportedOutputTypes = {
628 DataType::Float32,
629 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100630 };
631
632 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
633 "Reference dequantize: output type not supported.");
634
635 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
636 "Reference dequantize: input and output shapes have different num total elements.");
637
638 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000639}
640
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000641bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
642 const TensorInfo& scores,
643 const TensorInfo& anchors,
644 const TensorInfo& detectionBoxes,
645 const TensorInfo& detectionClasses,
646 const TensorInfo& detectionScores,
647 const TensorInfo& numDetections,
648 const DetectionPostProcessDescriptor& descriptor,
649 Optional<std::string&> reasonIfUnsupported) const
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000650{
Derek Lamberti901ea112019-12-10 22:07:09 +0000651 boost::ignore_unused(anchors, detectionBoxes, detectionClasses, detectionScores, numDetections, descriptor);
652
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100653 bool supported = true;
654
Mike Kelly4992c342019-08-14 11:33:11 +0100655 std::array<DataType,3> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100656 {
657 DataType::Float32,
658 DataType::QuantisedAsymm8,
659 DataType::QuantisedSymm16
660 };
661
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000662 supported &= CheckSupportRule(TypeAnyOf(boxEncodings, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100663 "Reference DetectionPostProcess: input 0 is not a supported type.");
664
Derek Lamberti6a5e5e82019-12-05 14:41:20 +0000665 supported &= CheckSupportRule(TypeAnyOf(scores, supportedInputTypes), reasonIfUnsupported,
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100666 "Reference DetectionPostProcess: input 1 is not a supported type.");
667
668 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000669}
670
Pablo Tellof0bd6832019-04-26 17:58:13 +0100671bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
672 const TensorInfo& output,
673 const DepthwiseConvolution2dDescriptor& descriptor,
674 const TensorInfo& weights,
675 const Optional<TensorInfo>& biases,
676 Optional<std::string&> reasonIfUnsupported) const
677{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100678 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100679}
680
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100681bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100682 const TensorInfo& input1,
683 const TensorInfo& output,
684 Optional<std::string&> reasonIfUnsupported) const
685{
Sadik Armagan2999a022019-04-09 14:20:12 +0100686 bool supported = true;
687
Matthew Jackson9bff1442019-09-12 09:08:23 +0100688 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +0100689 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100690 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +0100691 DataType::QuantisedAsymm8,
692 DataType::QuantisedSymm16
693 };
694
695 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
696 "Reference division: input 0 is not a supported type.");
697
698 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
699 "Reference division: input 1 is not a supported type.");
700
701 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
702 "Reference division: output is not a supported type.");
703
704 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
705 "Reference division: input 0 and Input 1 types are mismatched");
706
707 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
708 "Reference division: input and output types are mismatched");
709
710 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
711 "Reference division: shapes are not suitable for implicit broadcast.");
712
713 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100714}
715
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000716bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
717 const TensorInfo& input1,
718 const TensorInfo& output,
719 Optional<std::string&> reasonIfUnsupported) const
720{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100721 return IsComparisonSupported(input0,
722 input1,
723 output,
724 ComparisonDescriptor(ComparisonOperation::Equal),
725 reasonIfUnsupported);
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000726}
727
arovir011c7c81b2018-10-08 11:34:28 +0100728bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
729 const FakeQuantizationDescriptor& descriptor,
730 Optional<std::string&> reasonIfUnsupported) const
731{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100732 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100733 bool supported = true;
734
735 std::array<DataType,1> supportedTypes =
736 {
737 DataType::Float32
738 };
739
740 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
741 "Reference fake quantization: input type not supported.");
742
743 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100744}
745
746bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
747 const TensorInfo& output,
748 Optional<std::string&> reasonIfUnsupported) const
749{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100750 ignore_unused(output);
James Conroy83735b12019-05-30 16:36:59 +0100751 bool supported = true;
752
Matthew Jackson9bff1442019-09-12 09:08:23 +0100753 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100754 {
James Conroyb40d7102019-06-04 12:32:09 +0100755 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100756 DataType::Float16,
James Conroyb40d7102019-06-04 12:32:09 +0100757 DataType::QuantisedSymm16
James Conroy83735b12019-05-30 16:36:59 +0100758 };
759
760 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
761 "Reference Floor: input type not supported.");
762
763 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
764 "Reference Floor: output type not supported.");
765
766 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100767}
768
769bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
770 const TensorInfo& output,
771 const TensorInfo& weights,
772 const TensorInfo& biases,
773 const FullyConnectedDescriptor& descriptor,
774 Optional<std::string&> reasonIfUnsupported) const
775{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100776 bool supported = true;
777
778 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100779 std::array<DataType,4> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +0100780 {
781 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100782 DataType::Float16,
Francis Murtagh46c09d02019-05-28 08:15:28 +0100783 DataType::QuantisedAsymm8,
784 DataType::QuantisedSymm16
785 };
786
787 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
788 "Reference Fully Connected: input type not supported.");
789
790 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
791 "Reference Fully Connected: output type not supported.");
792
793 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
794 "Reference Fully Connected: input and output types mismatched.");
795
796 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
797 "Reference Fully Connected: weights type not supported.");
798
799 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
800 "Reference Fully Connected: input and weight types mismatched.");
801
802 if (descriptor.m_BiasEnabled)
803 {
804 // Defined supported types for bias
Matthew Jackson252df3a2019-09-11 09:19:18 +0100805 std::array<DataType, 3>
Francis Murtagh46c09d02019-05-28 08:15:28 +0100806 supportedBiasTypes =
807 {
808 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100809 DataType::Float16,
Francis Murtagh46c09d02019-05-28 08:15:28 +0100810 DataType::Signed32
811 };
812
813 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
814 "Reference Fully Connected: bias type not supported.");
815
816 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
817 "Reference Fully Connected: bias and weight types mismatch.");
818
819 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
820 "Reference Fully Connected: bias type inferred from weights is incompatible.");
821
822 }
823
824 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100825}
826
narpra014951d842019-01-18 16:53:53 +0000827bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
828 const armnn::TensorInfo& input1,
829 const armnn::TensorInfo& output,
830 armnn::Optional<std::string&> reasonIfUnsupported) const
831{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100832 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +0100833 std::array<DataType,4> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100834 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100835 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100836 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100837 DataType::QuantisedAsymm8,
838 DataType::QuantisedSymm16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100839 };
840
841 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
842 "Reference Gather: input type not supported");
843
844 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
845 "Reference Gather: output type not supported");
846
847 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
848 "Reference Gather: indices (input1) type not supported");
849
850 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
851 "Reference Gather: input and output types not matching");
852
853 return supported;
narpra014951d842019-01-18 16:53:53 +0000854}
855
FrancisMurtagh878f0232018-12-19 10:56:15 +0000856bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
857 const TensorInfo& input1,
858 const TensorInfo& output,
859 Optional<std::string&> reasonIfUnsupported) const
860{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100861 return IsComparisonSupported(input0,
862 input1,
863 output,
864 ComparisonDescriptor(ComparisonOperation::Greater),
865 reasonIfUnsupported);
FrancisMurtagh878f0232018-12-19 10:56:15 +0000866}
867
Derek Lamberti901ea112019-12-10 22:07:09 +0000868bool RefLayerSupport::IsInputSupported(const TensorInfo& /*input*/,
869 Optional<std::string&> /*reasonIfUnsupported*/) const
arovir011c7c81b2018-10-08 11:34:28 +0100870{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +0100871 return true;
arovir011c7c81b2018-10-08 11:34:28 +0100872}
873
Kevin May09ca49c2019-10-09 12:37:34 +0100874bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
875 const TensorInfo& output,
876 const InstanceNormalizationDescriptor& descriptor,
877 Optional<std::string&> reasonIfUnsupported) const
878{
879 ignore_unused(descriptor);
880 // Define supported types
881 std::array<DataType, 4> supportedTypes =
882 {
883 DataType::Float32,
884 DataType::Float16
885 };
886
887 bool supported = true;
888
889 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
890 "Reference Instance Normalization: input type not supported.");
891
892 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
893 "Reference Instance Normalization: output type not supported.");
894
895 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
896 "Reference Instance Normalization: input and output types mismatched.");
897
898 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
899 "Reference Instance Normalization: input and output shapes have different "
900 "num total elements.");
901
902 return supported;
903}
904
arovir011c7c81b2018-10-08 11:34:28 +0100905bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
906 const TensorInfo& output,
907 const L2NormalizationDescriptor& descriptor,
908 Optional<std::string&> reasonIfUnsupported) const
909{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100910 ignore_unused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100911 // Define supported types
Matthew Jackson252df3a2019-09-11 09:19:18 +0100912 std::array<DataType, 4> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100913 {
914 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100915 DataType::Float16,
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100916 DataType::QuantisedAsymm8,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100917 DataType::QuantisedSymm16
918 };
919
920 bool supported = true;
921
922 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
923 "Reference L2normalization: input type not supported.");
924
925 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
926 "Reference L2normalization: output type not supported.");
927
928 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
929 "Reference L2normalization: input and output types mismatched.");
930
931 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
932 "Reference L2normalization: input and output shapes have different "
933 "num total elements.");
934
935 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100936}
937
Aron Virginas-Tare662a942019-10-14 15:12:00 +0100938bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
939 const TensorInfo& output,
940 const LogSoftmaxDescriptor& descriptor,
941 Optional<std::string&> reasonIfUnsupported) const
942{
943 ignore_unused(descriptor);
944
945 std::array<DataType, 2> supportedTypes =
946 {
947 DataType::Float32,
948 DataType::Float16
949 };
950
951 bool supported = true;
952 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
953 "Reference LogSoftmax: input type not supported");
954
955 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
956 "Reference LogSoftmax: output type not supported");
957
958 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
959 "Reference LogSoftmax: input and output types do not match");
960
961 return supported;
962}
963
arovir011c7c81b2018-10-08 11:34:28 +0100964bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
965 const TensorInfo& outputStateIn,
966 const TensorInfo& cellStateIn,
967 const TensorInfo& scratchBuffer,
968 const TensorInfo& outputStateOut,
969 const TensorInfo& cellStateOut,
970 const TensorInfo& output,
971 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100972 const LstmInputParamsInfo& paramsInfo,
973 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +0100974{
telsoa01c577f2c2018-08-31 09:22:23 +0100975 ignore_unused(descriptor);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100976 ignore_unused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100977
978 bool supported = true;
979
980 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +0100981 DataType::Float32,
982 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100983 };
984
Jan Eilersd01a83c2019-07-03 18:20:40 +0100985 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100986 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
987 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100988 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
989 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100990 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
991 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100992 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
993 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100994 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
995 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100996 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
997 "Reference Lstm: input and cellStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100998 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
999 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +01001000 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +01001001 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001002 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001003 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001004 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001005 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001006 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001007 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001008 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001009 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001010 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001011 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001012 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001013 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001014 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001015 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001016 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001017 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001018 "Reference Lstm: input and OutputGateBias types are mismatched");
1019 if (!descriptor.m_CifgEnabled)
1020 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001021 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001022 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001023 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001024 reasonIfUnsupported,
1025 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001026 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001027 "Reference Lstm: input and InputGateBias types are mismatched");
1028 if (descriptor.m_PeepholeEnabled)
1029 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001030 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001031 reasonIfUnsupported,
1032 "Reference Lstm: input and CellToInputWeights types are mismatched");
1033 }
1034 }
1035 if (descriptor.m_PeepholeEnabled)
1036 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001037 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001038 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001039 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001040 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1041 }
1042 if (descriptor.m_ProjectionEnabled)
1043 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001044 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001045 "Reference Lstm: input and mProjectionWeights types are mismatched");
1046 if (paramsInfo.m_ProjectionBias != nullptr)
1047 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001048 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001049 "Reference Lstm: input and ProjectionBias types are mismatched");
1050 }
1051 }
1052 if (descriptor.m_LayerNormEnabled)
1053 {
1054 if (!descriptor.m_CifgEnabled)
1055 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001056 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001057 reasonIfUnsupported,
1058 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1059 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001060 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001061 reasonIfUnsupported,
1062 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001063 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001064 reasonIfUnsupported,
1065 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001066 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001067 reasonIfUnsupported,
1068 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1069 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001070
1071 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001072}
1073
saoste012df12b32018-11-28 16:57:20 +00001074bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1075 const TensorInfo& input1,
1076 const TensorInfo& output,
1077 Optional<std::string&> reasonIfUnsupported) const
1078{
Sadik Armagan2999a022019-04-09 14:20:12 +01001079 bool supported = true;
1080
Matthew Jackson9bff1442019-09-12 09:08:23 +01001081 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001082 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001083 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001084 DataType::QuantisedAsymm8,
1085 DataType::QuantisedSymm16
1086 };
1087
1088 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1089 "Reference maximum: input 0 is not a supported type.");
1090
1091 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1092 "Reference maximum: input 1 is not a supported type.");
1093
1094 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1095 "Reference maximum: output is not a supported type.");
1096
1097 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1098 "Reference maximum: input 0 and Input 1 types are mismatched");
1099
1100 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1101 "Reference maximum: input and output types are mismatched");
1102
1103 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1104 "Reference maximum: shapes are not suitable for implicit broadcast.");
1105
1106 return supported;
saoste012df12b32018-11-28 16:57:20 +00001107}
1108
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001109bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1110 const TensorInfo& output,
1111 const MeanDescriptor& descriptor,
1112 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001113{
James Conroy4d1ff582019-06-10 17:06:39 +01001114 bool supported = true;
1115 std::string meanLayerStr = "Mean";
1116 std::string outputTensorStr = "output";
1117
Matthew Jackson252df3a2019-09-11 09:19:18 +01001118 std::array<DataType,4> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001119 {
1120 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001121 DataType::Float16,
James Conroyb80775f2019-06-11 11:25:30 +01001122 DataType::QuantisedAsymm8,
1123 DataType::QuantisedSymm16
James Conroy4d1ff582019-06-10 17:06:39 +01001124 };
1125
1126 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1127 "Reference Mean: input type not supported.");
1128
1129 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1130 "Reference Mean: input and output types are mismatched");
1131
1132 if (descriptor.m_KeepDims)
1133 {
1134 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1135 reasonIfUnsupported,
1136 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1137 output.GetNumDimensions(),
1138 meanLayerStr, outputTensorStr).data());
1139 }
1140 else if (descriptor.m_Axis.empty())
1141 {
1142 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1143 reasonIfUnsupported,
1144 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1145 meanLayerStr, outputTensorStr).data());
1146 }
1147 else
1148 {
1149 auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
1150
1151 if (outputDim > 0)
1152 {
1153 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1154 reasonIfUnsupported,
1155 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1156 meanLayerStr, outputTensorStr).data());
1157 }
1158 else
1159 {
1160 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1161 reasonIfUnsupported,
1162 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1163 meanLayerStr, outputTensorStr).data());
1164 }
1165 }
1166
1167 return supported;
narpra0132b90462018-09-13 11:07:48 +01001168}
1169
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001170bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +00001171 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +01001172 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001173 Optional<std::string&> reasonIfUnsupported) const
1174{
Jim Flynne242f2d2019-05-22 14:24:13 +01001175 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001176}
1177
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001178bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1179 const TensorInfo &output,
1180 Optional<std::string &> reasonIfUnsupported) const
1181{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001182 bool supported = true;
1183
1184 std::array<DataType,5> supportedTypes =
1185 {
1186 DataType::Float32,
1187 DataType::Float16,
1188 DataType::QuantisedAsymm8,
1189 DataType::QuantisedSymm16,
1190 DataType::Boolean
1191 };
1192
1193 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1194 "Reference MemCopy: input type not supported");
1195
1196 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1197 "Reference MemCopy: output type not supported");
1198
1199 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1200 "Reference MemCopy: input and output types are mismatched");
1201
1202 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001203}
1204
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001205bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1206 const TensorInfo& input1,
1207 const TensorInfo& output,
1208 Optional<std::string&> reasonIfUnsupported) const
1209{
Sadik Armagan2999a022019-04-09 14:20:12 +01001210 bool supported = true;
1211
Matthew Jackson9bff1442019-09-12 09:08:23 +01001212 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001213 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001214 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001215 DataType::QuantisedAsymm8,
1216 DataType::QuantisedSymm16
1217 };
1218
1219 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1220 "Reference minimum: input 0 is not a supported type.");
1221
1222 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1223 "Reference minimum: input 1 is not a supported type.");
1224
1225 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1226 "Reference minimum: output is not a supported type.");
1227
1228 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1229 "Reference minimum: input 0 and Input 1 types are mismatched");
1230
1231 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1232 "Reference minimum: input and output types are mismatched");
1233
1234 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1235 "Reference minimum: shapes are not suitable for implicit broadcast.");
1236
1237 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001238}
1239
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001240bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1241 const TensorInfo& input1,
1242 const TensorInfo& output,
1243 Optional<std::string&> reasonIfUnsupported) const
1244{
Sadik Armagan2999a022019-04-09 14:20:12 +01001245 bool supported = true;
1246
Matthew Jackson252df3a2019-09-11 09:19:18 +01001247 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001248 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001249 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001250 DataType::QuantisedAsymm8,
1251 DataType::QuantisedSymm16
1252 };
1253
1254 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1255 "Reference multiplication: input 0 is not a supported type.");
1256
1257 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1258 "Reference multiplication: input 1 is not a supported type.");
1259
1260 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1261 "Reference multiplication: output is not a supported type.");
1262
1263 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1264 "Reference multiplication: input 0 and Input 1 types are mismatched");
1265
1266 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1267 "Reference multiplication: input and output types are mismatched");
1268
1269 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1270 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1271
1272 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001273}
1274
1275bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1276 const TensorInfo& output,
1277 const NormalizationDescriptor& descriptor,
1278 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001279{
Nina Drozd661dfa72018-10-02 11:14:17 +01001280 ignore_unused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001281
1282 // Define supported types
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001283 std::array<DataType, 4> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001284 {
1285 DataType::Float16,
1286 DataType::Float32,
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001287 DataType::QuantisedAsymm8,
1288 DataType::QuantisedSymm16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001289 };
1290
1291 bool supported = true;
1292
1293 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1294 "Reference normalization: input type not supported.");
1295
1296 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1297 "Reference normalization: output type not supported.");
1298
1299 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1300 "Reference normalization: input and output shapes have different "
1301 "num total elements.");
1302
1303 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001304}
1305
Derek Lamberti901ea112019-12-10 22:07:09 +00001306bool RefLayerSupport::IsOutputSupported(const TensorInfo& /*output*/,
1307 Optional<std::string&> /*reasonIfUnsupported*/) const
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001308{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001309 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001310}
1311
1312bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1313 const TensorInfo& output,
1314 const PadDescriptor& descriptor,
1315 Optional<std::string&> reasonIfUnsupported) const
1316{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001317 ignore_unused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001318 bool supported = true;
1319
1320 // Define supported output and inputs types.
Matthew Jackson252df3a2019-09-11 09:19:18 +01001321 std::array<DataType,4> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001322 {
1323 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001324 DataType::Float16,
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001325 DataType::QuantisedAsymm8,
1326 DataType::QuantisedSymm16
1327 };
1328
1329 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1330 "Reference pad: input is not a supported type.");
1331
1332 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1333 "Reference pad: output is not a supported type.");
1334
1335 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1336 "Reference pad: input and output types are mismatched.");
1337
1338 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01001339}
1340
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001341bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1342 const TensorInfo& output,
1343 const PermuteDescriptor& descriptor,
1344 Optional<std::string&> reasonIfUnsupported) const
1345{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001346 ignore_unused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001347 bool supported = true;
1348
1349 // Define supported output and inputs types.
1350 std::array<DataType,3> supportedTypes =
1351 {
1352 DataType::Float32,
1353 DataType::QuantisedAsymm8,
1354 DataType::QuantisedSymm16
1355 };
1356
1357 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1358 "Reference permute: input is not a supported type.");
1359
1360 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1361 "Reference permute: output is not a supported type.");
1362
1363 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1364 "Reference permute: input and output types are mismatched.");
1365
1366 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001367}
1368
1369bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1370 const TensorInfo& output,
1371 const Pooling2dDescriptor& descriptor,
1372 Optional<std::string&> reasonIfUnsupported) const
1373{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001374 ignore_unused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001375 bool supported = true;
1376
1377 // Define supported output and inputs types.
Matthew Jackson252df3a2019-09-11 09:19:18 +01001378 std::array<DataType,4> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001379 {
1380 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001381 DataType::Float16,
Teresa Charlin0434df62019-06-06 13:40:35 +01001382 DataType::QuantisedAsymm8,
1383 DataType::QuantisedSymm16
Teresa Charlina3b20472019-06-06 11:12:32 +01001384 };
1385
1386 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1387 "Reference poolind2d: input is not a supported type.");
1388
1389 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1390 "Reference poolind2d: output is not a supported type.");
1391
1392 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1393 "Reference poolind2d: input and output types are mismatched.");
1394
1395 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001396}
1397
Derek Lamberti5f400d62019-03-25 15:41:58 +00001398bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1399 const TensorInfo& output,
1400 Optional<std::string&> reasonIfUnsupported) const
1401{
1402 bool supported = true;
1403
Finn Williamsfd271062019-12-04 14:27:27 +00001404 // Define supported input types.
Sadik Armagan1a816302019-07-29 17:16:40 +01001405 std::array<DataType,1> supportedInputTypes = {
Derek Lamberti5f400d62019-03-25 15:41:58 +00001406 DataType::Float32,
1407 };
1408
1409 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1410 "Reference quantize: input type not supported.");
1411
1412 // Define supported output types.
Finn Williamsfd271062019-12-04 14:27:27 +00001413 std::array<DataType,3> supportedOutputTypes = {
Derek Lamberti5f400d62019-03-25 15:41:58 +00001414 DataType::QuantisedAsymm8,
Finn Williamsfd271062019-12-04 14:27:27 +00001415 DataType::QSymmS8,
Derek Lamberti5f400d62019-03-25 15:41:58 +00001416 DataType::QuantisedSymm16
1417 };
1418 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1419 "Reference quantize: output type not supported.");
1420
1421 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1422 "Reference quantize: input and output shapes have different num total elements.");
1423
1424 return supported;
1425}
1426
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001427bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001428 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001429 Optional<std::string&> reasonIfUnsupported) const
1430{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001431 ignore_unused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001432 // Define supported output types.
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001433 std::array<DataType,5> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001434 {
1435 DataType::Float32,
1436 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001437 DataType::Signed32,
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001438 DataType::QuantisedAsymm8,
1439 DataType::QuantisedSymm16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001440 };
1441 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1442 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001443}
1444
1445bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001446 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001447 Optional<std::string&> reasonIfUnsupported) const
1448{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001449 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001450 std::array<DataType,4> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001451 {
1452 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001453 DataType::Float16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001454 DataType::QuantisedAsymm8,
1455 DataType::QuantisedSymm16
1456 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001457
1458 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1459 "Reference ResizeBilinear: input type not supported");
1460
1461 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1462 "Reference ResizeBilinear: output type not supported");
1463
1464 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1465 "Reference ResizeBilinear: input and output types not matching");
1466
1467 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001468}
1469
Teresa Charlin970f43b2019-07-01 13:51:07 +01001470bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1471 const TensorInfo& output,
1472 const ResizeDescriptor& descriptor,
1473 Optional<std::string&> reasonIfUnsupported) const
1474{
Derek Lamberti901ea112019-12-10 22:07:09 +00001475 boost::ignore_unused(descriptor);
Teresa Charlin970f43b2019-07-01 13:51:07 +01001476 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001477 std::array<DataType,4> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001478 {
1479 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001480 DataType::Float16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001481 DataType::QuantisedAsymm8,
1482 DataType::QuantisedSymm16
1483 };
1484
1485 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1486 "Reference Resize: input type not supported");
1487
1488 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1489 "Reference Resize: output type not supported");
1490
1491 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1492 "Reference Resize: input and output types not matching");
1493
1494 return supported;
1495}
1496
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001497bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1498 const TensorInfo& output,
1499 Optional<std::string&> reasonIfUnsupported) const
1500{
nikraj010421e7f2019-06-14 09:40:34 +01001501 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001502 std::array<DataType,4> supportedTypes =
nikraj010421e7f2019-06-14 09:40:34 +01001503 {
1504 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001505 DataType::Float16,
nikraj0124d73212019-06-14 14:20:40 +01001506 DataType::QuantisedAsymm8,
1507 DataType::QuantisedSymm16
nikraj010421e7f2019-06-14 09:40:34 +01001508 };
1509
1510 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1511 "Reference rsqrt: input type not supported");
1512
1513 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1514 "Reference rsqrt: output type not supported");
1515
1516 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1517 "Reference rsqrt: input and output types not matching");
1518
1519 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1520 "Reference Rsqrt: input and output shapes have different number of total elements");
1521
1522 return supported;
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001523}
1524
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001525bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
1526 const TensorInfo& output,
1527 const SliceDescriptor& descriptor,
1528 Optional<std::string&> reasonIfUnsupported) const
1529{
Derek Lamberti901ea112019-12-10 22:07:09 +00001530 boost::ignore_unused(descriptor);
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001531 bool supported = true;
1532
1533 std::array<DataType, 3> supportedTypes =
1534 {
1535 DataType::Float32,
1536 DataType::QuantisedAsymm8,
1537 DataType::QuantisedSymm16
1538 };
1539
1540 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1541 "Reference Slice: input type not supported");
1542
1543 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1544 "Reference Slice: output type not supported");
1545
1546 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1547 "Reference Slice: input and output types are mismatched");
1548
1549 return supported;
1550}
1551
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001552bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1553 const TensorInfo& output,
1554 const SoftmaxDescriptor& descriptor,
1555 Optional<std::string&> reasonIfUnsupported) const
1556{
Derek Lamberti901ea112019-12-10 22:07:09 +00001557 boost::ignore_unused(descriptor);
nikraj01248683f2019-05-29 16:46:50 +01001558 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001559 std::array<DataType,4> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01001560 {
1561 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001562 DataType::Float16,
nikraj01248683f2019-05-29 16:46:50 +01001563 DataType::QuantisedAsymm8,
1564 DataType::QuantisedSymm16
1565 };
1566
1567 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001568 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001569
1570 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001571 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001572
1573 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001574 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001575
1576 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001577}
1578
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001579bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1580 const TensorInfo& output,
1581 const SpaceToBatchNdDescriptor& descriptor,
1582 Optional<std::string&> reasonIfUnsupported) const
1583{
Derek Lamberti901ea112019-12-10 22:07:09 +00001584 boost::ignore_unused(descriptor);
nikraj01120522a2019-05-31 11:33:07 +01001585 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001586 std::array<DataType,4> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01001587 {
1588 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001589 DataType::Float16,
nikraj01120522a2019-05-31 11:33:07 +01001590 DataType::QuantisedAsymm8,
1591 DataType::QuantisedSymm16
1592 };
1593
1594 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1595 "Reference SpaceToBatchNd: input type not supported");
1596
1597 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1598 "Reference SpaceToBatchNd: output type not supported");
1599
1600 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1601 "Reference SpaceToBatchNd: input and output types are mismatched");
1602
1603 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001604}
1605
Keith Davisa57eccb2019-06-14 17:33:22 +01001606bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01001607 const TensorInfo& output,
1608 const SpaceToDepthDescriptor& descriptor,
1609 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01001610{
1611
1612 ignore_unused(descriptor);
1613 bool supported = true;
1614
Matthew Jackson9bff1442019-09-12 09:08:23 +01001615 std::array<DataType,4> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01001616 {
1617 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001618 DataType::Float16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001619 DataType::QuantisedAsymm8,
James Conroyd2aa85e2019-07-01 17:12:40 +01001620 DataType::QuantisedSymm16
Keith Davisa57eccb2019-06-14 17:33:22 +01001621 };
1622
1623 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1624 "Reference SpaceToDepth: input type not supported");
1625
1626 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1627 "Reference SpaceToDepth: output type not supported");
1628
1629 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1630 "Reference SpaceToDepth: input and output types are mismatched");
1631
1632 return supported;
1633}
1634
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001635bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1636 const ViewsDescriptor& descriptor,
1637 Optional<std::string&> reasonIfUnsupported) const
1638{
1639 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001640 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001641 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001642 {
1643 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001644 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001645 DataType::QuantisedAsymm8,
1646 DataType::QuantisedSymm16
1647 };
1648
1649 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1650 "Reference splitter: input type not supported");
1651
1652 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001653}
1654
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001655bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1656 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1657 const ViewsDescriptor& descriptor,
1658 Optional<std::string&> reasonIfUnsupported) const
1659{
1660 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001661 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001662 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001663 {
1664 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001665 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001666 DataType::QuantisedAsymm8,
1667 DataType::QuantisedSymm16
1668 };
1669
1670 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1671 "Reference splitter: output type not supported");
1672 for (const TensorInfo output : outputs)
1673 {
1674 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1675 "Reference splitter: input type not supported");
1676
1677 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1678 "Reference splitter: input and output types mismatched.");
1679 }
1680
1681 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001682}
1683
Matthew Jackson81e601c2019-07-11 12:07:09 +01001684bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
1685 const TensorInfo& output,
1686 const StackDescriptor& descriptor,
1687 Optional<std::string&> reasonIfUnsupported) const
1688{
1689 ignore_unused(descriptor);
1690
1691 bool supported = true;
Matthew Jacksone69c3992019-09-09 14:31:21 +01001692 std::array<DataType,4> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01001693 {
1694 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01001695 DataType::Float16,
Matthew Jackson81e601c2019-07-11 12:07:09 +01001696 DataType::QuantisedAsymm8,
1697 DataType::QuantisedSymm16
1698 };
1699
1700 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1701 "Reference stack: output type not supported");
1702 for (const TensorInfo* input : inputs)
1703 {
1704 BOOST_ASSERT(input != nullptr);
1705 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
1706 "Reference stack: input type not supported");
1707
1708 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
1709 "Reference stack: input and output types mismatched.");
1710 }
1711
1712 return supported;
1713}
1714
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001715bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1716 const TensorInfo& output,
1717 const StridedSliceDescriptor& descriptor,
1718 Optional<std::string&> reasonIfUnsupported) const
1719{
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001720 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001721 bool supported = true;
1722
1723 std::array<DataType,3> supportedTypes =
1724 {
1725 DataType::Float32,
1726 DataType::QuantisedAsymm8,
1727 DataType::QuantisedSymm16
1728 };
1729
1730 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1731 "Reference StridedSlice: input type not supported");
1732
1733 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1734 "Reference StridedSlice: output type not supported");
1735
1736 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1737 "Reference StridedSlice: input and output types are mismatched");
1738
1739 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001740}
1741
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001742bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1743 const TensorInfo& input1,
1744 const TensorInfo& output,
1745 Optional<std::string&> reasonIfUnsupported) const
1746{
Sadik Armagan2999a022019-04-09 14:20:12 +01001747 bool supported = true;
1748
Matthew Jackson9bff1442019-09-12 09:08:23 +01001749 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001750 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001751 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001752 DataType::QuantisedAsymm8,
1753 DataType::QuantisedSymm16
1754 };
1755
1756 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1757 "Reference subtraction: input 0 is not a supported type.");
1758
1759 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1760 "Reference subtraction: input 1 is not a supported type.");
1761
1762 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1763 "Reference subtraction: output is not a supported type.");
1764
1765 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1766 "Reference subtraction: input 0 and Input 1 types are mismatched");
1767
1768 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1769 "Reference subtraction: input and output types are mismatched");
1770
1771 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1772 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1773
1774 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001775}
1776
Matteo Martincighab9e5252019-06-13 17:27:46 +01001777bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
1778 const TensorInfo& alpha,
1779 const TensorInfo& output,
1780 Optional<std::string&> reasonIfUnsupported) const
1781{
1782 bool supported = true;
1783
Matthew Jackson9bff1442019-09-12 09:08:23 +01001784 std::array<DataType, 4> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01001785 {
1786 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001787 DataType::Float16,
Matteo Martincighab9e5252019-06-13 17:27:46 +01001788 DataType::QuantisedAsymm8,
1789 DataType::QuantisedSymm16
1790 };
1791
1792 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1793 "PReLU: input is not a supported type.");
1794
1795 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
1796 "PReLU: alpha is not a supported type.");
1797
1798 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1799 "PReLU: output is not a supported type.");
1800
1801 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
1802 "PReLU: input, alpha and output types are mismatched");
1803
1804 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
1805 "PReLU: shapes are not suitable for implicit broadcast");
1806
1807 return supported;
1808}
1809
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001810bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
1811 const TensorInfo& output,
1812 const TransposeConvolution2dDescriptor& descriptor,
1813 const TensorInfo& weights,
1814 const Optional<TensorInfo>& biases,
1815 Optional<std::string&> reasonIfUnsupported) const
1816{
Derek Lamberti901ea112019-12-10 22:07:09 +00001817 boost::ignore_unused(descriptor);
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001818 bool supported = true;
1819
Matthew Jackson252df3a2019-09-11 09:19:18 +01001820 std::array<DataType,4> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001821 {
1822 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001823 DataType::Float16,
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001824 DataType::QuantisedAsymm8,
1825 DataType::QuantisedSymm16
1826 };
1827
1828 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1829 "Reference TransposeConvolution2d: input is not a supported type.");
1830
1831 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1832 "Reference TransposeConvolution2d: output is not a supported type.");
1833
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001834 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1835 "Reference TransposeConvolution2d: input and output types mismatched.");
1836
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001837
1838 const DataType inputType = input.GetDataType();
1839 if (inputType == DataType::QuantisedAsymm8)
1840 {
1841 std::array<DataType, 2> supportedWeightTypes =
1842 {
1843 DataType::QuantisedAsymm8,
1844 DataType::QuantizedSymm8PerAxis
1845 };
1846
1847 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1848 "Reference TransposeConvolution2d: weights type not supported for "
1849 "quantized input.");
1850 }
1851 else
1852 {
1853 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1854 "Reference TransposeConvolution2d: weights is not a supported type.");
1855
1856 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1857 "Reference TransposeConvolution2d: input and weights types mismatched.");
1858 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001859
1860 if (biases.has_value())
1861 {
Matthew Jackson252df3a2019-09-11 09:19:18 +01001862 std::array<DataType,3> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01001863 {
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001864 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001865 DataType::Float16,
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001866 DataType::Signed32
1867 };
1868 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1869 "Reference TransposeConvolution2d: biases is not a supported type.");
1870 }
1871
1872 return supported;
1873}
1874
arovir011c7c81b2018-10-08 11:34:28 +01001875} // namespace armnn