blob: 5a84d8ac78fd583c0bdfd3b1d38d2ab7460d422a [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01009#include <DataLayoutIndexed.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000010#include <InternalTypes.hpp>
11#include <LayerSupportCommon.hpp>
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +010012
telsoa014fcda012018-03-09 14:13:49 +000013#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000014#include <armnn/Descriptors.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000015#include <armnn/BackendRegistry.hpp>
telsoa014fcda012018-03-09 14:13:49 +000016
Derek Lambertif674aa02019-08-01 15:56:25 +010017#include <backendsCommon/LayerSupportRules.hpp>
Francis Murtagh46c09d02019-05-28 08:15:28 +010018#include <backendsCommon/test/WorkloadTestUtils.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010019
telsoa014fcda012018-03-09 14:13:49 +000020#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021
Derek Lamberti50db4e82019-03-13 14:16:15 +000022#include <vector>
23#include <algorithm>
24#include <array>
25
telsoa014fcda012018-03-09 14:13:49 +000026using namespace boost;
27
28namespace armnn
29{
30
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010031namespace
32{
33
34template<typename Float32Func, typename Uint8Func, typename ... Params>
35bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
36 DataType dataType,
37 Float32Func floatFuncPtr,
38 Uint8Func uint8FuncPtr,
39 Params&&... params)
40{
41 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
42 dataType,
43 &FalseFunc<Params...>,
44 floatFuncPtr,
45 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000046 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000047 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010048 std::forward<Params>(params)...);
49}
50
51} // anonymous namespace
52
James Conroy4d1ff582019-06-10 17:06:39 +010053namespace
54{
55
56std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
57 unsigned int actual,
58 std::string& layerStr,
59 std::string& tensorName)
60{
61 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
62 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
63
64 return errorMsg;
65}
66
67} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000068
Sadik Armagan9199e582019-09-05 17:35:31 +010069bool RefLayerSupport::IsAbsSupported(const TensorInfo& input, const TensorInfo& output,
70 Optional<std::string&> reasonIfUnsupported) const
71{
72 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +010073 std::array<DataType,4> supportedTypes =
Sadik Armagan9199e582019-09-05 17:35:31 +010074 {
75 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +010076 DataType::Float16,
Sadik Armagan9199e582019-09-05 17:35:31 +010077 DataType::QuantisedAsymm8,
78 DataType::QuantisedSymm16
79 };
80
81 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
82 "Reference abs: input type not supported");
83
84 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
85 "Reference abs: output type not supported");
86
87 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
88 "Reference abs: input and output types not matching");
89
90 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
91 "Reference abs: input and output shapes have different number of total elements");
92
93 return supported;
94}
95
arovir011c7c81b2018-10-08 11:34:28 +010096bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
97 const TensorInfo& output,
98 const ActivationDescriptor& descriptor,
99 Optional<std::string&> reasonIfUnsupported) const
100{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000101 bool supported = true;
102
103 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100104 std::array<DataType,4> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000105 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100106 DataType::Float16,
Teresa Charlin18515e22019-04-24 10:17:46 +0100107 DataType::QuantisedAsymm8,
108 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000109 };
110
111 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
112 "Reference activation: input type not supported.");
113
114 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
115 "Reference activation: output type not supported.");
116
117 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
118 "Reference activation: input and output types mismatched.");
119
120 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
121 "Reference activation: input and output shapes are of different rank.");
122
123
124 struct ActivationFunctionSupported : public Rule
125 {
126 ActivationFunctionSupported(const ActivationDescriptor& desc)
127 {
128 switch(desc.m_Function)
129 {
130 case ActivationFunction::Abs:
131 case ActivationFunction::BoundedReLu:
132 case ActivationFunction::LeakyReLu:
133 case ActivationFunction::Linear:
134 case ActivationFunction::ReLu:
135 case ActivationFunction::Sigmoid:
136 case ActivationFunction::SoftReLu:
137 case ActivationFunction::Sqrt:
138 case ActivationFunction::Square:
139 case ActivationFunction::TanH:
140 {
141 m_Res = true;
142 break;
143 }
144 default:
145 {
146 m_Res = false;
147 break;
148 }
149 }
150 }
151 };
152
153 // Function is supported
154 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
155 "Reference activation: function not supported.");
156
157 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100158}
159
160bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
161 const TensorInfo& input1,
162 const TensorInfo& output,
163 Optional<std::string&> reasonIfUnsupported) const
164{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000165 bool supported = true;
166
Matthew Jackson252df3a2019-09-11 09:19:18 +0100167 std::array<DataType,4> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000168 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100169 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +0100170 DataType::QuantisedAsymm8,
171 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000172 };
173
174 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
175 "Reference addition: input 0 is not a supported type.");
176
177 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
178 "Reference addition: input 1 is not a supported type.");
179
180 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
181 "Reference addition: output is not a supported type.");
182
183 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
184 "Reference addition: input 0 and Input 1 types are mismatched");
185
186 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
187 "Reference addition: input and output types are mismatched");
188
189 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
190 "Reference addition: shapes are not suitable for implicit broadcast.");
191
192 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100193}
194
Nikhil Raj68c2c902019-09-19 11:21:11 +0100195bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
196 const armnn::ArgMinMaxDescriptor &descriptor,
197 armnn::Optional<std::string &> reasonIfUnsupported) const
198{
199 ignore_unused(descriptor);
200
Francis Murtagh1939df52019-11-13 15:21:09 +0000201 std::array<DataType, 4> supportedTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100202 {
203 DataType::Float32,
204 DataType::QuantisedAsymm8,
Francis Murtagh1939df52019-11-13 15:21:09 +0000205 DataType::QuantisedSymm16,
206 DataType::Signed32
Nikhil Raj68c2c902019-09-19 11:21:11 +0100207 };
208
209 bool supported = true;
210
211 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
212 "Reference ArgMinMax: input is not a supported type.");
213 supported &= CheckSupportRule(TypeIs(output, DataType::Signed32), reasonIfUnsupported,
214 "Reference ArgMinMax: output type not supported");
215
216 return supported;
217}
218
arovir011c7c81b2018-10-08 11:34:28 +0100219bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
220 const TensorInfo& output,
221 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100222 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100223 const TensorInfo& beta,
224 const TensorInfo& gamma,
225 const BatchNormalizationDescriptor& descriptor,
226 Optional<std::string&> reasonIfUnsupported) const
227{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100228 ignore_unused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100229
Matthew Jackson9bff1442019-09-12 09:08:23 +0100230 std::array<DataType, 4> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100231 {
232 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100233 DataType::Float16,
Matteo Martincighf5507132019-06-04 10:59:47 +0100234 DataType::QuantisedAsymm8,
235 DataType::QuantisedSymm16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100236 };
237
238 bool supported = true;
239
240 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
241 "Reference batch normalization: input is not a supported type.");
242
243 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
244 "Reference batch normalization: output is not a supported type.");
245
246 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
247 "Reference batch normalization: input and output types are mismatched");
248
249 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
250 "Reference batch normalization: mean is not a supported type.");
251
252 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
253 "Reference batch normalization: variance is not a supported type.");
254
255 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
256 "Reference batch normalization: beta is not a supported type.");
257
258 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
259 "Reference batch normalization: gamma is not a supported type.");
260
261 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100262}
263
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000264bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
265 const TensorInfo& output,
266 const BatchToSpaceNdDescriptor& descriptor,
267 Optional<std::string&> reasonIfUnsupported) const
268{
269 ignore_unused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100270
271 bool supported = true;
272
273 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
274 std::string inputTensorStr = "input";
275 std::string outputTensorStr = "output";
276
277 // Define supported types.
Matthew Jackson9bff1442019-09-12 09:08:23 +0100278 std::array<DataType,4> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100279 {
280 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100281 DataType::Float16,
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100282 DataType::QuantisedAsymm8,
283 DataType::QuantisedSymm16
284 };
285
286 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
287 "Reference BatchToSpaceNd: input type not supported.");
288
289 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
290 "Reference BatchToSpaceNd: output type not supported.");
291
292 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
293 "Reference BatchToSpaceNd: input and output types mismatched.");
294
295 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
296 reasonIfUnsupported,
297 CreateIncorrectDimensionsErrorMsg(4,
298 output.GetNumDimensions(),
299 batchToSpaceNdLayerStr,
300 outputTensorStr).data());
301
302 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
303 reasonIfUnsupported,
304 CreateIncorrectDimensionsErrorMsg(4,
305 input.GetNumDimensions(),
306 batchToSpaceNdLayerStr,
307 inputTensorStr).data());
308
309 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000310}
311
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100312bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
313 const TensorInfo& input1,
314 const TensorInfo& output,
315 const ComparisonDescriptor& descriptor,
316 Optional<std::string&> reasonIfUnsupported) const
317{
318 boost::ignore_unused(descriptor);
319
320 std::array<DataType, 4> supportedInputTypes =
321 {
322 DataType::Float32,
323 DataType::Float16,
324 DataType::QuantisedAsymm8,
325 DataType::QuantisedSymm16
326 };
327
328 bool supported = true;
329 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
330 "Reference comparison: input 0 is not a supported type");
331
332 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
333 "Reference comparison: input 0 and Input 1 types are mismatched");
334
335 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
336 "Reference comparison: output is not of type Boolean");
337
338 return supported;
339}
340
Jim Flynn906f9462019-05-10 13:55:21 +0100341bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
342 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100343 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100344 Optional<std::string&> reasonIfUnsupported) const
345{
Jim Flynne242f2d2019-05-22 14:24:13 +0100346 ignore_unused(descriptor);
347
348 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +0100349 std::array<DataType,4> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100350 {
351 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100352 DataType::Float16,
Jim Flynne242f2d2019-05-22 14:24:13 +0100353 DataType::QuantisedAsymm8,
354 DataType::QuantisedSymm16
355 };
356
357 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
358 "Reference concatenation: output type not supported");
359 for (const TensorInfo* input : inputs)
360 {
Matthew Jackson81e601c2019-07-11 12:07:09 +0100361 BOOST_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100362 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
363 "Reference concatenation: input type not supported");
364
365 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
366 "Reference concatenation: input and output types mismatched.");
367 }
368
369 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100370}
371
arovir011c7c81b2018-10-08 11:34:28 +0100372bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
373 Optional<std::string&> reasonIfUnsupported) const
374{
Jim Flynne242f2d2019-05-22 14:24:13 +0100375 std::array<DataType,4> supportedTypes =
376 {
Nina Drozd58ef2c62019-05-16 12:09:18 +0100377 DataType::Float32,
378 DataType::Signed32,
379 DataType::QuantisedAsymm8,
380 DataType::QuantisedSymm16
381 };
382
383 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
384 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100385}
386
387bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
388 const TensorInfo& output,
389 Optional<std::string&> reasonIfUnsupported) const
390{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100391 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
392 input.GetDataType(),
393 &TrueFunc<>,
394 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000395 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000396 &FalseFuncI32<>,
397 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100398 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
399 output.GetDataType(),
400 &FalseOutputFuncF16<>,
401 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000402 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000403 &FalseFuncI32<>,
404 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100405}
406
407bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
408 const TensorInfo& output,
409 Optional<std::string&> reasonIfUnsupported) const
410{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100411 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
412 input.GetDataType(),
413 &FalseInputFuncF16<>,
414 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000415 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000416 &FalseFuncI32<>,
417 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100418 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
419 output.GetDataType(),
420 &TrueFunc<>,
421 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000422 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000423 &FalseFuncI32<>,
424 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100425}
426
427bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
428 const TensorInfo& output,
429 const Convolution2dDescriptor& descriptor,
430 const TensorInfo& weights,
431 const Optional<TensorInfo>& biases,
432 Optional<std::string&> reasonIfUnsupported) const
433{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100434 bool supported = true;
435
436 // Define supported types.
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000437 std::array<DataType,4> supportedTypes =
438 {
439 DataType::Float32,
440 DataType::Float16,
441 DataType::QuantisedAsymm8,
442 DataType::QuantisedSymm16
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100443 };
444
445 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100446 "Reference convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100447
448 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100449 "Reference convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100450
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100451 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100452 "Reference convolution2d: input and output types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100453
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000454 const DataType inputType = input.GetDataType();
455 if (inputType == DataType::QuantisedAsymm8)
456 {
457 std::array<DataType, 2> supportedWeightTypes =
458 {
459 DataType::QuantisedAsymm8,
460 DataType::QuantizedSymm8PerAxis
461 };
462
463 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
464 "Reference convolution2d: weights type not supported for quantized input.");
465 }
466 else
467 {
468 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
469 "Reference convolution2d: weights is not a supported type.");
470
471 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
472 "Reference convolution2d: input and weights types mismatched.");
473 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100474
475 if (biases.has_value())
476 {
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000477 std::array<DataType,3> biasesSupportedTypes =
478 {
479 DataType::Float32,
480 DataType::Float16,
481 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100482 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000483
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100484 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100485 "Reference convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100486 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100487 ignore_unused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100488
489 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100490}
491
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000492bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
493 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000494 Optional<std::string&> reasonIfUnsupported) const
495{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100496 bool supported = true;
497
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +0000498 std::array<DataType, 4> supportedTypes =
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100499 {
Aron Virginas-Tardb1a2832019-11-12 16:15:11 +0000500 DataType::Float16,
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100501 DataType::Float32,
502 DataType::QuantisedAsymm8,
503 DataType::QuantisedSymm16
504 };
505
506 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
507 "Reference debug: input type not supported");
508
509 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
510 "Reference debug: output type not supported");
511
512 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
513 "Reference debug: input and output types are mismatched");
514
515 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000516}
517
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100518bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
519 const TensorInfo& output,
520 const DepthToSpaceDescriptor& descriptor,
521 Optional<std::string&> reasonIfUnsupported) const
522{
523 ignore_unused(descriptor);
524 bool supported = true;
525
526 std::array<DataType,4> supportedTypes =
527 {
528 DataType::Float32,
529 DataType::Float16,
530 DataType::QuantisedAsymm8,
531 DataType::QuantisedSymm16
532 };
533
534 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
535 "Reference DepthToSpace: input type not supported");
536
537 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
538 "Reference DepthToSpace: output type not supported");
539
540 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
541 "Reference DepthToSpace: input and output types are mismatched");
542
543 return supported;
544}
545
arovir011c7c81b2018-10-08 11:34:28 +0100546bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
547 const TensorInfo& output,
548 const DepthwiseConvolution2dDescriptor& descriptor,
549 const TensorInfo& weights,
550 const Optional<TensorInfo>& biases,
551 Optional<std::string&> reasonIfUnsupported) const
552{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100553 bool supported = true;
554
555 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100556 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100557 {
558 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100559 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100560 DataType::QuantisedAsymm8,
561 DataType::QuantisedSymm16
562 };
563
564 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
565 "Reference DepthwiseConvolution2d: input is not a supported type.");
566
567 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
568 "Reference DepthwiseConvolution2d: output is not a supported type.");
569
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100570 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
571 "Reference DepthwiseConvolution2d: input and output types mismatched.");
572
Teresa Charlind8df0262019-11-11 12:28:15 +0000573 const DataType inputType = input.GetDataType();
574 if (inputType == DataType::QuantisedAsymm8)
575 {
576 std::array<DataType, 2> supportedWeightTypes =
577 {
578 DataType::QuantisedAsymm8,
579 DataType::QuantizedSymm8PerAxis
580 };
581
582 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
583 "Reference convolution2d: weights type not supported for quantized input.");
584 }
585 else
586 {
587 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
588 "Reference DepthwiseConvolution2d: weights is not a supported type.");
589
590 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
591 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
592 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100593
594 if (biases.has_value())
595 {
Matthew Jackson252df3a2019-09-11 09:19:18 +0100596 std::array<DataType,3> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100597 {
598 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100599 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100600 DataType::Signed32
601 };
602 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
603 "Reference DepthwiseConvolution2d: biases is not a supported type.");
604 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100605 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100606
607 return supported;
608
arovir011c7c81b2018-10-08 11:34:28 +0100609}
610
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000611bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
612 const TensorInfo& output,
613 Optional<std::string&> reasonIfUnsupported) const
614{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100615 bool supported = true;
616
617 std::array<DataType,2> supportedInputTypes = {
618 DataType::QuantisedAsymm8,
619 DataType::QuantisedSymm16
620 };
621
622 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
623 "Reference dequantize: input type not supported.");
624
Jan Eilersf7107932019-11-01 11:09:36 +0000625 std::array<DataType,2> supportedOutputTypes = {
626 DataType::Float32,
627 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100628 };
629
630 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
631 "Reference dequantize: output type not supported.");
632
633 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
634 "Reference dequantize: input and output shapes have different num total elements.");
635
636 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000637}
638
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000639bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
640 const armnn::TensorInfo& input1,
641 const armnn::DetectionPostProcessDescriptor& descriptor,
642 armnn::Optional<std::string&> reasonIfUnsupported) const
643{
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100644 bool supported = true;
645
Mike Kelly4992c342019-08-14 11:33:11 +0100646 std::array<DataType,3> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100647 {
648 DataType::Float32,
649 DataType::QuantisedAsymm8,
650 DataType::QuantisedSymm16
651 };
652
653 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
654 "Reference DetectionPostProcess: input 0 is not a supported type.");
655
656 supported &= CheckSupportRule(TypeAnyOf(input1, supportedInputTypes), reasonIfUnsupported,
657 "Reference DetectionPostProcess: input 1 is not a supported type.");
658
659 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000660}
661
Pablo Tellof0bd6832019-04-26 17:58:13 +0100662bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
663 const TensorInfo& output,
664 const DepthwiseConvolution2dDescriptor& descriptor,
665 const TensorInfo& weights,
666 const Optional<TensorInfo>& biases,
667 Optional<std::string&> reasonIfUnsupported) const
668{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100669 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100670}
671
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100672bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100673 const TensorInfo& input1,
674 const TensorInfo& output,
675 Optional<std::string&> reasonIfUnsupported) const
676{
Sadik Armagan2999a022019-04-09 14:20:12 +0100677 bool supported = true;
678
Matthew Jackson9bff1442019-09-12 09:08:23 +0100679 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +0100680 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100681 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +0100682 DataType::QuantisedAsymm8,
683 DataType::QuantisedSymm16
684 };
685
686 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
687 "Reference division: input 0 is not a supported type.");
688
689 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
690 "Reference division: input 1 is not a supported type.");
691
692 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
693 "Reference division: output is not a supported type.");
694
695 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
696 "Reference division: input 0 and Input 1 types are mismatched");
697
698 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
699 "Reference division: input and output types are mismatched");
700
701 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
702 "Reference division: shapes are not suitable for implicit broadcast.");
703
704 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100705}
706
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000707bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
708 const TensorInfo& input1,
709 const TensorInfo& output,
710 Optional<std::string&> reasonIfUnsupported) const
711{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100712 return IsComparisonSupported(input0,
713 input1,
714 output,
715 ComparisonDescriptor(ComparisonOperation::Equal),
716 reasonIfUnsupported);
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000717}
718
arovir011c7c81b2018-10-08 11:34:28 +0100719bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
720 const FakeQuantizationDescriptor& descriptor,
721 Optional<std::string&> reasonIfUnsupported) const
722{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100723 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100724 bool supported = true;
725
726 std::array<DataType,1> supportedTypes =
727 {
728 DataType::Float32
729 };
730
731 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
732 "Reference fake quantization: input type not supported.");
733
734 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100735}
736
737bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
738 const TensorInfo& output,
739 Optional<std::string&> reasonIfUnsupported) const
740{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100741 ignore_unused(output);
James Conroy83735b12019-05-30 16:36:59 +0100742 bool supported = true;
743
Matthew Jackson9bff1442019-09-12 09:08:23 +0100744 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100745 {
James Conroyb40d7102019-06-04 12:32:09 +0100746 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100747 DataType::Float16,
James Conroyb40d7102019-06-04 12:32:09 +0100748 DataType::QuantisedSymm16
James Conroy83735b12019-05-30 16:36:59 +0100749 };
750
751 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
752 "Reference Floor: input type not supported.");
753
754 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
755 "Reference Floor: output type not supported.");
756
757 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100758}
759
760bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
761 const TensorInfo& output,
762 const TensorInfo& weights,
763 const TensorInfo& biases,
764 const FullyConnectedDescriptor& descriptor,
765 Optional<std::string&> reasonIfUnsupported) const
766{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100767 bool supported = true;
768
769 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100770 std::array<DataType,4> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +0100771 {
772 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100773 DataType::Float16,
Francis Murtagh46c09d02019-05-28 08:15:28 +0100774 DataType::QuantisedAsymm8,
775 DataType::QuantisedSymm16
776 };
777
778 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
779 "Reference Fully Connected: input type not supported.");
780
781 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
782 "Reference Fully Connected: output type not supported.");
783
784 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
785 "Reference Fully Connected: input and output types mismatched.");
786
787 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
788 "Reference Fully Connected: weights type not supported.");
789
790 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
791 "Reference Fully Connected: input and weight types mismatched.");
792
793 if (descriptor.m_BiasEnabled)
794 {
795 // Defined supported types for bias
Matthew Jackson252df3a2019-09-11 09:19:18 +0100796 std::array<DataType, 3>
Francis Murtagh46c09d02019-05-28 08:15:28 +0100797 supportedBiasTypes =
798 {
799 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100800 DataType::Float16,
Francis Murtagh46c09d02019-05-28 08:15:28 +0100801 DataType::Signed32
802 };
803
804 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
805 "Reference Fully Connected: bias type not supported.");
806
807 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
808 "Reference Fully Connected: bias and weight types mismatch.");
809
810 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
811 "Reference Fully Connected: bias type inferred from weights is incompatible.");
812
813 }
814
815 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100816}
817
narpra014951d842019-01-18 16:53:53 +0000818bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
819 const armnn::TensorInfo& input1,
820 const armnn::TensorInfo& output,
821 armnn::Optional<std::string&> reasonIfUnsupported) const
822{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100823 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +0100824 std::array<DataType,4> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100825 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100826 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100827 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100828 DataType::QuantisedAsymm8,
829 DataType::QuantisedSymm16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100830 };
831
832 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
833 "Reference Gather: input type not supported");
834
835 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
836 "Reference Gather: output type not supported");
837
838 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
839 "Reference Gather: indices (input1) type not supported");
840
841 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
842 "Reference Gather: input and output types not matching");
843
844 return supported;
narpra014951d842019-01-18 16:53:53 +0000845}
846
FrancisMurtagh878f0232018-12-19 10:56:15 +0000847bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
848 const TensorInfo& input1,
849 const TensorInfo& output,
850 Optional<std::string&> reasonIfUnsupported) const
851{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100852 return IsComparisonSupported(input0,
853 input1,
854 output,
855 ComparisonDescriptor(ComparisonOperation::Greater),
856 reasonIfUnsupported);
FrancisMurtagh878f0232018-12-19 10:56:15 +0000857}
858
arovir011c7c81b2018-10-08 11:34:28 +0100859bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
860 Optional<std::string&> reasonIfUnsupported) const
861{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +0100862 return true;
arovir011c7c81b2018-10-08 11:34:28 +0100863}
864
Kevin May09ca49c2019-10-09 12:37:34 +0100865bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
866 const TensorInfo& output,
867 const InstanceNormalizationDescriptor& descriptor,
868 Optional<std::string&> reasonIfUnsupported) const
869{
870 ignore_unused(descriptor);
871 // Define supported types
872 std::array<DataType, 4> supportedTypes =
873 {
874 DataType::Float32,
875 DataType::Float16
876 };
877
878 bool supported = true;
879
880 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
881 "Reference Instance Normalization: input type not supported.");
882
883 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
884 "Reference Instance Normalization: output type not supported.");
885
886 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
887 "Reference Instance Normalization: input and output types mismatched.");
888
889 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
890 "Reference Instance Normalization: input and output shapes have different "
891 "num total elements.");
892
893 return supported;
894}
895
arovir011c7c81b2018-10-08 11:34:28 +0100896bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
897 const TensorInfo& output,
898 const L2NormalizationDescriptor& descriptor,
899 Optional<std::string&> reasonIfUnsupported) const
900{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100901 ignore_unused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100902 // Define supported types
Matthew Jackson252df3a2019-09-11 09:19:18 +0100903 std::array<DataType, 4> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100904 {
905 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100906 DataType::Float16,
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100907 DataType::QuantisedAsymm8,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100908 DataType::QuantisedSymm16
909 };
910
911 bool supported = true;
912
913 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
914 "Reference L2normalization: input type not supported.");
915
916 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
917 "Reference L2normalization: output type not supported.");
918
919 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
920 "Reference L2normalization: input and output types mismatched.");
921
922 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
923 "Reference L2normalization: input and output shapes have different "
924 "num total elements.");
925
926 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100927}
928
Aron Virginas-Tare662a942019-10-14 15:12:00 +0100929bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
930 const TensorInfo& output,
931 const LogSoftmaxDescriptor& descriptor,
932 Optional<std::string&> reasonIfUnsupported) const
933{
934 ignore_unused(descriptor);
935
936 std::array<DataType, 2> supportedTypes =
937 {
938 DataType::Float32,
939 DataType::Float16
940 };
941
942 bool supported = true;
943 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
944 "Reference LogSoftmax: input type not supported");
945
946 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
947 "Reference LogSoftmax: output type not supported");
948
949 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
950 "Reference LogSoftmax: input and output types do not match");
951
952 return supported;
953}
954
arovir011c7c81b2018-10-08 11:34:28 +0100955bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
956 const TensorInfo& outputStateIn,
957 const TensorInfo& cellStateIn,
958 const TensorInfo& scratchBuffer,
959 const TensorInfo& outputStateOut,
960 const TensorInfo& cellStateOut,
961 const TensorInfo& output,
962 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100963 const LstmInputParamsInfo& paramsInfo,
964 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +0100965{
telsoa01c577f2c2018-08-31 09:22:23 +0100966 ignore_unused(descriptor);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100967 ignore_unused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100968
969 bool supported = true;
970
971 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +0100972 DataType::Float32,
973 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100974 };
975
Jan Eilersd01a83c2019-07-03 18:20:40 +0100976 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100977 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
978 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100979 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
980 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100981 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
982 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100983 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
984 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100985 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
986 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100987 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
988 "Reference Lstm: input and cellStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100989 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
990 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +0100991 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +0100992 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100993 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100994 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100995 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100996 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100997 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100998 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100999 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001000 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001001 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001002 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001003 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001004 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001005 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001006 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001007 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001008 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001009 "Reference Lstm: input and OutputGateBias types are mismatched");
1010 if (!descriptor.m_CifgEnabled)
1011 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001012 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001013 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001014 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001015 reasonIfUnsupported,
1016 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001017 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001018 "Reference Lstm: input and InputGateBias types are mismatched");
1019 if (descriptor.m_PeepholeEnabled)
1020 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001021 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001022 reasonIfUnsupported,
1023 "Reference Lstm: input and CellToInputWeights types are mismatched");
1024 }
1025 }
1026 if (descriptor.m_PeepholeEnabled)
1027 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001028 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001029 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001030 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001031 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1032 }
1033 if (descriptor.m_ProjectionEnabled)
1034 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001035 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001036 "Reference Lstm: input and mProjectionWeights types are mismatched");
1037 if (paramsInfo.m_ProjectionBias != nullptr)
1038 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001039 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001040 "Reference Lstm: input and ProjectionBias types are mismatched");
1041 }
1042 }
1043 if (descriptor.m_LayerNormEnabled)
1044 {
1045 if (!descriptor.m_CifgEnabled)
1046 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001047 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001048 reasonIfUnsupported,
1049 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1050 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001051 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001052 reasonIfUnsupported,
1053 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001054 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001055 reasonIfUnsupported,
1056 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001057 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001058 reasonIfUnsupported,
1059 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1060 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001061
1062 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001063}
1064
saoste012df12b32018-11-28 16:57:20 +00001065bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1066 const TensorInfo& input1,
1067 const TensorInfo& output,
1068 Optional<std::string&> reasonIfUnsupported) const
1069{
Sadik Armagan2999a022019-04-09 14:20:12 +01001070 bool supported = true;
1071
Matthew Jackson9bff1442019-09-12 09:08:23 +01001072 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001073 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001074 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001075 DataType::QuantisedAsymm8,
1076 DataType::QuantisedSymm16
1077 };
1078
1079 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1080 "Reference maximum: input 0 is not a supported type.");
1081
1082 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1083 "Reference maximum: input 1 is not a supported type.");
1084
1085 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1086 "Reference maximum: output is not a supported type.");
1087
1088 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1089 "Reference maximum: input 0 and Input 1 types are mismatched");
1090
1091 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1092 "Reference maximum: input and output types are mismatched");
1093
1094 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1095 "Reference maximum: shapes are not suitable for implicit broadcast.");
1096
1097 return supported;
saoste012df12b32018-11-28 16:57:20 +00001098}
1099
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001100bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1101 const TensorInfo& output,
1102 const MeanDescriptor& descriptor,
1103 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001104{
James Conroy4d1ff582019-06-10 17:06:39 +01001105 bool supported = true;
1106 std::string meanLayerStr = "Mean";
1107 std::string outputTensorStr = "output";
1108
Matthew Jackson252df3a2019-09-11 09:19:18 +01001109 std::array<DataType,4> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001110 {
1111 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001112 DataType::Float16,
James Conroyb80775f2019-06-11 11:25:30 +01001113 DataType::QuantisedAsymm8,
1114 DataType::QuantisedSymm16
James Conroy4d1ff582019-06-10 17:06:39 +01001115 };
1116
1117 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1118 "Reference Mean: input type not supported.");
1119
1120 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1121 "Reference Mean: input and output types are mismatched");
1122
1123 if (descriptor.m_KeepDims)
1124 {
1125 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1126 reasonIfUnsupported,
1127 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1128 output.GetNumDimensions(),
1129 meanLayerStr, outputTensorStr).data());
1130 }
1131 else if (descriptor.m_Axis.empty())
1132 {
1133 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1134 reasonIfUnsupported,
1135 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1136 meanLayerStr, outputTensorStr).data());
1137 }
1138 else
1139 {
1140 auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
1141
1142 if (outputDim > 0)
1143 {
1144 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1145 reasonIfUnsupported,
1146 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1147 meanLayerStr, outputTensorStr).data());
1148 }
1149 else
1150 {
1151 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1152 reasonIfUnsupported,
1153 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1154 meanLayerStr, outputTensorStr).data());
1155 }
1156 }
1157
1158 return supported;
narpra0132b90462018-09-13 11:07:48 +01001159}
1160
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001161bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +00001162 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +01001163 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001164 Optional<std::string&> reasonIfUnsupported) const
1165{
Jim Flynne242f2d2019-05-22 14:24:13 +01001166 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001167}
1168
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001169bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1170 const TensorInfo &output,
1171 Optional<std::string &> reasonIfUnsupported) const
1172{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001173 bool supported = true;
1174
1175 std::array<DataType,5> supportedTypes =
1176 {
1177 DataType::Float32,
1178 DataType::Float16,
1179 DataType::QuantisedAsymm8,
1180 DataType::QuantisedSymm16,
1181 DataType::Boolean
1182 };
1183
1184 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1185 "Reference MemCopy: input type not supported");
1186
1187 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1188 "Reference MemCopy: output type not supported");
1189
1190 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1191 "Reference MemCopy: input and output types are mismatched");
1192
1193 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001194}
1195
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001196bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1197 const TensorInfo& input1,
1198 const TensorInfo& output,
1199 Optional<std::string&> reasonIfUnsupported) const
1200{
Sadik Armagan2999a022019-04-09 14:20:12 +01001201 bool supported = true;
1202
Matthew Jackson9bff1442019-09-12 09:08:23 +01001203 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001204 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001205 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001206 DataType::QuantisedAsymm8,
1207 DataType::QuantisedSymm16
1208 };
1209
1210 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1211 "Reference minimum: input 0 is not a supported type.");
1212
1213 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1214 "Reference minimum: input 1 is not a supported type.");
1215
1216 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1217 "Reference minimum: output is not a supported type.");
1218
1219 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1220 "Reference minimum: input 0 and Input 1 types are mismatched");
1221
1222 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1223 "Reference minimum: input and output types are mismatched");
1224
1225 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1226 "Reference minimum: shapes are not suitable for implicit broadcast.");
1227
1228 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001229}
1230
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001231bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1232 const TensorInfo& input1,
1233 const TensorInfo& output,
1234 Optional<std::string&> reasonIfUnsupported) const
1235{
Sadik Armagan2999a022019-04-09 14:20:12 +01001236 bool supported = true;
1237
Matthew Jackson252df3a2019-09-11 09:19:18 +01001238 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001239 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001240 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001241 DataType::QuantisedAsymm8,
1242 DataType::QuantisedSymm16
1243 };
1244
1245 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1246 "Reference multiplication: input 0 is not a supported type.");
1247
1248 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1249 "Reference multiplication: input 1 is not a supported type.");
1250
1251 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1252 "Reference multiplication: output is not a supported type.");
1253
1254 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1255 "Reference multiplication: input 0 and Input 1 types are mismatched");
1256
1257 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1258 "Reference multiplication: input and output types are mismatched");
1259
1260 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1261 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1262
1263 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001264}
1265
1266bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1267 const TensorInfo& output,
1268 const NormalizationDescriptor& descriptor,
1269 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001270{
Nina Drozd661dfa72018-10-02 11:14:17 +01001271 ignore_unused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001272
1273 // Define supported types
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001274 std::array<DataType, 4> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001275 {
1276 DataType::Float16,
1277 DataType::Float32,
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001278 DataType::QuantisedAsymm8,
1279 DataType::QuantisedSymm16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001280 };
1281
1282 bool supported = true;
1283
1284 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1285 "Reference normalization: input type not supported.");
1286
1287 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1288 "Reference normalization: output type not supported.");
1289
1290 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1291 "Reference normalization: input and output shapes have different "
1292 "num total elements.");
1293
1294 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001295}
1296
1297bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
1298 Optional<std::string&> reasonIfUnsupported) const
1299{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001300 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001301}
1302
1303bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1304 const TensorInfo& output,
1305 const PadDescriptor& descriptor,
1306 Optional<std::string&> reasonIfUnsupported) const
1307{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001308 ignore_unused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001309 bool supported = true;
1310
1311 // Define supported output and inputs types.
Matthew Jackson252df3a2019-09-11 09:19:18 +01001312 std::array<DataType,4> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001313 {
1314 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001315 DataType::Float16,
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001316 DataType::QuantisedAsymm8,
1317 DataType::QuantisedSymm16
1318 };
1319
1320 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1321 "Reference pad: input is not a supported type.");
1322
1323 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1324 "Reference pad: output is not a supported type.");
1325
1326 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1327 "Reference pad: input and output types are mismatched.");
1328
1329 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01001330}
1331
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001332bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1333 const TensorInfo& output,
1334 const PermuteDescriptor& descriptor,
1335 Optional<std::string&> reasonIfUnsupported) const
1336{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001337 ignore_unused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001338 bool supported = true;
1339
1340 // Define supported output and inputs types.
1341 std::array<DataType,3> supportedTypes =
1342 {
1343 DataType::Float32,
1344 DataType::QuantisedAsymm8,
1345 DataType::QuantisedSymm16
1346 };
1347
1348 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1349 "Reference permute: input is not a supported type.");
1350
1351 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1352 "Reference permute: output is not a supported type.");
1353
1354 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1355 "Reference permute: input and output types are mismatched.");
1356
1357 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001358}
1359
1360bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1361 const TensorInfo& output,
1362 const Pooling2dDescriptor& descriptor,
1363 Optional<std::string&> reasonIfUnsupported) const
1364{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001365 ignore_unused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001366 bool supported = true;
1367
1368 // Define supported output and inputs types.
Matthew Jackson252df3a2019-09-11 09:19:18 +01001369 std::array<DataType,4> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001370 {
1371 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001372 DataType::Float16,
Teresa Charlin0434df62019-06-06 13:40:35 +01001373 DataType::QuantisedAsymm8,
1374 DataType::QuantisedSymm16
Teresa Charlina3b20472019-06-06 11:12:32 +01001375 };
1376
1377 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1378 "Reference poolind2d: input is not a supported type.");
1379
1380 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1381 "Reference poolind2d: output is not a supported type.");
1382
1383 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1384 "Reference poolind2d: input and output types are mismatched.");
1385
1386 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001387}
1388
Derek Lamberti5f400d62019-03-25 15:41:58 +00001389bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1390 const TensorInfo& output,
1391 Optional<std::string&> reasonIfUnsupported) const
1392{
1393 bool supported = true;
1394
1395 // Define supported output types.
Sadik Armagan1a816302019-07-29 17:16:40 +01001396 std::array<DataType,1> supportedInputTypes = {
Derek Lamberti5f400d62019-03-25 15:41:58 +00001397 DataType::Float32,
1398 };
1399
1400 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1401 "Reference quantize: input type not supported.");
1402
1403 // Define supported output types.
1404 std::array<DataType,2> supportedOutputTypes = {
1405 DataType::QuantisedAsymm8,
1406 DataType::QuantisedSymm16
1407 };
1408 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1409 "Reference quantize: output type not supported.");
1410
1411 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1412 "Reference quantize: input and output shapes have different num total elements.");
1413
1414 return supported;
1415}
1416
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001417bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001418 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001419 Optional<std::string&> reasonIfUnsupported) const
1420{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001421 ignore_unused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001422 // Define supported output types.
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001423 std::array<DataType,5> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001424 {
1425 DataType::Float32,
1426 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001427 DataType::Signed32,
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001428 DataType::QuantisedAsymm8,
1429 DataType::QuantisedSymm16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001430 };
1431 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1432 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001433}
1434
1435bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001436 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001437 Optional<std::string&> reasonIfUnsupported) const
1438{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001439 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001440 std::array<DataType,4> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001441 {
1442 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001443 DataType::Float16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001444 DataType::QuantisedAsymm8,
1445 DataType::QuantisedSymm16
1446 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001447
1448 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1449 "Reference ResizeBilinear: input type not supported");
1450
1451 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1452 "Reference ResizeBilinear: output type not supported");
1453
1454 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1455 "Reference ResizeBilinear: input and output types not matching");
1456
1457 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001458}
1459
Teresa Charlin970f43b2019-07-01 13:51:07 +01001460bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1461 const TensorInfo& output,
1462 const ResizeDescriptor& descriptor,
1463 Optional<std::string&> reasonIfUnsupported) const
1464{
1465 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001466 std::array<DataType,4> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001467 {
1468 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001469 DataType::Float16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001470 DataType::QuantisedAsymm8,
1471 DataType::QuantisedSymm16
1472 };
1473
1474 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1475 "Reference Resize: input type not supported");
1476
1477 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1478 "Reference Resize: output type not supported");
1479
1480 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1481 "Reference Resize: input and output types not matching");
1482
1483 return supported;
1484}
1485
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001486bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1487 const TensorInfo& output,
1488 Optional<std::string&> reasonIfUnsupported) const
1489{
nikraj010421e7f2019-06-14 09:40:34 +01001490 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001491 std::array<DataType,4> supportedTypes =
nikraj010421e7f2019-06-14 09:40:34 +01001492 {
1493 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001494 DataType::Float16,
nikraj0124d73212019-06-14 14:20:40 +01001495 DataType::QuantisedAsymm8,
1496 DataType::QuantisedSymm16
nikraj010421e7f2019-06-14 09:40:34 +01001497 };
1498
1499 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1500 "Reference rsqrt: input type not supported");
1501
1502 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1503 "Reference rsqrt: output type not supported");
1504
1505 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1506 "Reference rsqrt: input and output types not matching");
1507
1508 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1509 "Reference Rsqrt: input and output shapes have different number of total elements");
1510
1511 return supported;
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001512}
1513
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001514bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
1515 const TensorInfo& output,
1516 const SliceDescriptor& descriptor,
1517 Optional<std::string&> reasonIfUnsupported) const
1518{
1519 ignore_unused(descriptor);
1520 bool supported = true;
1521
1522 std::array<DataType, 3> supportedTypes =
1523 {
1524 DataType::Float32,
1525 DataType::QuantisedAsymm8,
1526 DataType::QuantisedSymm16
1527 };
1528
1529 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1530 "Reference Slice: input type not supported");
1531
1532 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1533 "Reference Slice: output type not supported");
1534
1535 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1536 "Reference Slice: input and output types are mismatched");
1537
1538 return supported;
1539}
1540
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001541bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1542 const TensorInfo& output,
1543 const SoftmaxDescriptor& descriptor,
1544 Optional<std::string&> reasonIfUnsupported) const
1545{
1546 ignore_unused(output);
nikraj01248683f2019-05-29 16:46:50 +01001547 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001548 std::array<DataType,4> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01001549 {
1550 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001551 DataType::Float16,
nikraj01248683f2019-05-29 16:46:50 +01001552 DataType::QuantisedAsymm8,
1553 DataType::QuantisedSymm16
1554 };
1555
1556 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001557 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001558
1559 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001560 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001561
1562 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001563 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001564
1565 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001566}
1567
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001568bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1569 const TensorInfo& output,
1570 const SpaceToBatchNdDescriptor& descriptor,
1571 Optional<std::string&> reasonIfUnsupported) const
1572{
1573 ignore_unused(output);
nikraj01120522a2019-05-31 11:33:07 +01001574 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001575 std::array<DataType,4> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01001576 {
1577 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001578 DataType::Float16,
nikraj01120522a2019-05-31 11:33:07 +01001579 DataType::QuantisedAsymm8,
1580 DataType::QuantisedSymm16
1581 };
1582
1583 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1584 "Reference SpaceToBatchNd: input type not supported");
1585
1586 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1587 "Reference SpaceToBatchNd: output type not supported");
1588
1589 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1590 "Reference SpaceToBatchNd: input and output types are mismatched");
1591
1592 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001593}
1594
Keith Davisa57eccb2019-06-14 17:33:22 +01001595bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01001596 const TensorInfo& output,
1597 const SpaceToDepthDescriptor& descriptor,
1598 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01001599{
1600
1601 ignore_unused(descriptor);
1602 bool supported = true;
1603
Matthew Jackson9bff1442019-09-12 09:08:23 +01001604 std::array<DataType,4> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01001605 {
1606 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001607 DataType::Float16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001608 DataType::QuantisedAsymm8,
James Conroyd2aa85e2019-07-01 17:12:40 +01001609 DataType::QuantisedSymm16
Keith Davisa57eccb2019-06-14 17:33:22 +01001610 };
1611
1612 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1613 "Reference SpaceToDepth: input type not supported");
1614
1615 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1616 "Reference SpaceToDepth: output type not supported");
1617
1618 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1619 "Reference SpaceToDepth: input and output types are mismatched");
1620
1621 return supported;
1622}
1623
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001624bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1625 const ViewsDescriptor& descriptor,
1626 Optional<std::string&> reasonIfUnsupported) const
1627{
1628 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001629 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001630 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001631 {
1632 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001633 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001634 DataType::QuantisedAsymm8,
1635 DataType::QuantisedSymm16
1636 };
1637
1638 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1639 "Reference splitter: input type not supported");
1640
1641 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001642}
1643
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001644bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1645 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1646 const ViewsDescriptor& descriptor,
1647 Optional<std::string&> reasonIfUnsupported) const
1648{
1649 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001650 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001651 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001652 {
1653 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001654 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001655 DataType::QuantisedAsymm8,
1656 DataType::QuantisedSymm16
1657 };
1658
1659 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1660 "Reference splitter: output type not supported");
1661 for (const TensorInfo output : outputs)
1662 {
1663 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1664 "Reference splitter: input type not supported");
1665
1666 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1667 "Reference splitter: input and output types mismatched.");
1668 }
1669
1670 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001671}
1672
Matthew Jackson81e601c2019-07-11 12:07:09 +01001673bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
1674 const TensorInfo& output,
1675 const StackDescriptor& descriptor,
1676 Optional<std::string&> reasonIfUnsupported) const
1677{
1678 ignore_unused(descriptor);
1679
1680 bool supported = true;
Matthew Jacksone69c3992019-09-09 14:31:21 +01001681 std::array<DataType,4> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01001682 {
1683 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01001684 DataType::Float16,
Matthew Jackson81e601c2019-07-11 12:07:09 +01001685 DataType::QuantisedAsymm8,
1686 DataType::QuantisedSymm16
1687 };
1688
1689 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1690 "Reference stack: output type not supported");
1691 for (const TensorInfo* input : inputs)
1692 {
1693 BOOST_ASSERT(input != nullptr);
1694 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
1695 "Reference stack: input type not supported");
1696
1697 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
1698 "Reference stack: input and output types mismatched.");
1699 }
1700
1701 return supported;
1702}
1703
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001704bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1705 const TensorInfo& output,
1706 const StridedSliceDescriptor& descriptor,
1707 Optional<std::string&> reasonIfUnsupported) const
1708{
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001709 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001710 bool supported = true;
1711
1712 std::array<DataType,3> supportedTypes =
1713 {
1714 DataType::Float32,
1715 DataType::QuantisedAsymm8,
1716 DataType::QuantisedSymm16
1717 };
1718
1719 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1720 "Reference StridedSlice: input type not supported");
1721
1722 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1723 "Reference StridedSlice: output type not supported");
1724
1725 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1726 "Reference StridedSlice: input and output types are mismatched");
1727
1728 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001729}
1730
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001731bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1732 const TensorInfo& input1,
1733 const TensorInfo& output,
1734 Optional<std::string&> reasonIfUnsupported) const
1735{
Sadik Armagan2999a022019-04-09 14:20:12 +01001736 bool supported = true;
1737
Matthew Jackson9bff1442019-09-12 09:08:23 +01001738 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001739 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001740 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001741 DataType::QuantisedAsymm8,
1742 DataType::QuantisedSymm16
1743 };
1744
1745 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1746 "Reference subtraction: input 0 is not a supported type.");
1747
1748 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1749 "Reference subtraction: input 1 is not a supported type.");
1750
1751 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1752 "Reference subtraction: output is not a supported type.");
1753
1754 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1755 "Reference subtraction: input 0 and Input 1 types are mismatched");
1756
1757 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1758 "Reference subtraction: input and output types are mismatched");
1759
1760 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1761 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1762
1763 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001764}
1765
Matteo Martincighab9e5252019-06-13 17:27:46 +01001766bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
1767 const TensorInfo& alpha,
1768 const TensorInfo& output,
1769 Optional<std::string&> reasonIfUnsupported) const
1770{
1771 bool supported = true;
1772
Matthew Jackson9bff1442019-09-12 09:08:23 +01001773 std::array<DataType, 4> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01001774 {
1775 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001776 DataType::Float16,
Matteo Martincighab9e5252019-06-13 17:27:46 +01001777 DataType::QuantisedAsymm8,
1778 DataType::QuantisedSymm16
1779 };
1780
1781 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1782 "PReLU: input is not a supported type.");
1783
1784 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
1785 "PReLU: alpha is not a supported type.");
1786
1787 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1788 "PReLU: output is not a supported type.");
1789
1790 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
1791 "PReLU: input, alpha and output types are mismatched");
1792
1793 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
1794 "PReLU: shapes are not suitable for implicit broadcast");
1795
1796 return supported;
1797}
1798
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001799bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
1800 const TensorInfo& output,
1801 const TransposeConvolution2dDescriptor& descriptor,
1802 const TensorInfo& weights,
1803 const Optional<TensorInfo>& biases,
1804 Optional<std::string&> reasonIfUnsupported) const
1805{
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001806 bool supported = true;
1807
Matthew Jackson252df3a2019-09-11 09:19:18 +01001808 std::array<DataType,4> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001809 {
1810 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001811 DataType::Float16,
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001812 DataType::QuantisedAsymm8,
1813 DataType::QuantisedSymm16
1814 };
1815
1816 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1817 "Reference TransposeConvolution2d: input is not a supported type.");
1818
1819 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1820 "Reference TransposeConvolution2d: output is not a supported type.");
1821
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001822 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1823 "Reference TransposeConvolution2d: input and output types mismatched.");
1824
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001825
1826 const DataType inputType = input.GetDataType();
1827 if (inputType == DataType::QuantisedAsymm8)
1828 {
1829 std::array<DataType, 2> supportedWeightTypes =
1830 {
1831 DataType::QuantisedAsymm8,
1832 DataType::QuantizedSymm8PerAxis
1833 };
1834
1835 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1836 "Reference TransposeConvolution2d: weights type not supported for "
1837 "quantized input.");
1838 }
1839 else
1840 {
1841 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1842 "Reference TransposeConvolution2d: weights is not a supported type.");
1843
1844 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1845 "Reference TransposeConvolution2d: input and weights types mismatched.");
1846 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001847
1848 if (biases.has_value())
1849 {
Matthew Jackson252df3a2019-09-11 09:19:18 +01001850 std::array<DataType,3> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01001851 {
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001852 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001853 DataType::Float16,
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001854 DataType::Signed32
1855 };
1856 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1857 "Reference TransposeConvolution2d: biases is not a supported type.");
1858 }
1859
1860 return supported;
1861}
1862
arovir011c7c81b2018-10-08 11:34:28 +01001863} // namespace armnn