blob: ef0cc8c363c0390523654f970244d9311a7480ce [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01009#include <DataLayoutIndexed.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000010#include <InternalTypes.hpp>
11#include <LayerSupportCommon.hpp>
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +010012
telsoa014fcda012018-03-09 14:13:49 +000013#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000014#include <armnn/Descriptors.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000015#include <armnn/BackendRegistry.hpp>
telsoa014fcda012018-03-09 14:13:49 +000016
Derek Lambertif674aa02019-08-01 15:56:25 +010017#include <backendsCommon/LayerSupportRules.hpp>
Francis Murtagh46c09d02019-05-28 08:15:28 +010018#include <backendsCommon/test/WorkloadTestUtils.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010019
telsoa014fcda012018-03-09 14:13:49 +000020#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021
Derek Lamberti50db4e82019-03-13 14:16:15 +000022#include <vector>
23#include <algorithm>
24#include <array>
25
telsoa014fcda012018-03-09 14:13:49 +000026using namespace boost;
27
28namespace armnn
29{
30
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010031namespace
32{
33
34template<typename Float32Func, typename Uint8Func, typename ... Params>
35bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
36 DataType dataType,
37 Float32Func floatFuncPtr,
38 Uint8Func uint8FuncPtr,
39 Params&&... params)
40{
41 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
42 dataType,
43 &FalseFunc<Params...>,
44 floatFuncPtr,
45 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000046 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000047 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010048 std::forward<Params>(params)...);
49}
50
51} // anonymous namespace
52
James Conroy4d1ff582019-06-10 17:06:39 +010053namespace
54{
55
56std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
57 unsigned int actual,
58 std::string& layerStr,
59 std::string& tensorName)
60{
61 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
62 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
63
64 return errorMsg;
65}
66
67} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000068
Sadik Armagan9199e582019-09-05 17:35:31 +010069bool RefLayerSupport::IsAbsSupported(const TensorInfo& input, const TensorInfo& output,
70 Optional<std::string&> reasonIfUnsupported) const
71{
72 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +010073 std::array<DataType,4> supportedTypes =
Sadik Armagan9199e582019-09-05 17:35:31 +010074 {
75 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +010076 DataType::Float16,
Sadik Armagan9199e582019-09-05 17:35:31 +010077 DataType::QuantisedAsymm8,
78 DataType::QuantisedSymm16
79 };
80
81 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
82 "Reference abs: input type not supported");
83
84 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
85 "Reference abs: output type not supported");
86
87 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
88 "Reference abs: input and output types not matching");
89
90 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
91 "Reference abs: input and output shapes have different number of total elements");
92
93 return supported;
94}
95
arovir011c7c81b2018-10-08 11:34:28 +010096bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
97 const TensorInfo& output,
98 const ActivationDescriptor& descriptor,
99 Optional<std::string&> reasonIfUnsupported) const
100{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000101 bool supported = true;
102
103 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100104 std::array<DataType,4> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000105 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100106 DataType::Float16,
Teresa Charlin18515e22019-04-24 10:17:46 +0100107 DataType::QuantisedAsymm8,
108 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000109 };
110
111 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
112 "Reference activation: input type not supported.");
113
114 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
115 "Reference activation: output type not supported.");
116
117 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
118 "Reference activation: input and output types mismatched.");
119
120 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
121 "Reference activation: input and output shapes are of different rank.");
122
123
124 struct ActivationFunctionSupported : public Rule
125 {
126 ActivationFunctionSupported(const ActivationDescriptor& desc)
127 {
128 switch(desc.m_Function)
129 {
130 case ActivationFunction::Abs:
131 case ActivationFunction::BoundedReLu:
132 case ActivationFunction::LeakyReLu:
133 case ActivationFunction::Linear:
134 case ActivationFunction::ReLu:
135 case ActivationFunction::Sigmoid:
136 case ActivationFunction::SoftReLu:
137 case ActivationFunction::Sqrt:
138 case ActivationFunction::Square:
139 case ActivationFunction::TanH:
140 {
141 m_Res = true;
142 break;
143 }
144 default:
145 {
146 m_Res = false;
147 break;
148 }
149 }
150 }
151 };
152
153 // Function is supported
154 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
155 "Reference activation: function not supported.");
156
157 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100158}
159
160bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
161 const TensorInfo& input1,
162 const TensorInfo& output,
163 Optional<std::string&> reasonIfUnsupported) const
164{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000165 bool supported = true;
166
Matthew Jackson252df3a2019-09-11 09:19:18 +0100167 std::array<DataType,4> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000168 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100169 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +0100170 DataType::QuantisedAsymm8,
171 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000172 };
173
174 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
175 "Reference addition: input 0 is not a supported type.");
176
177 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
178 "Reference addition: input 1 is not a supported type.");
179
180 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
181 "Reference addition: output is not a supported type.");
182
183 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
184 "Reference addition: input 0 and Input 1 types are mismatched");
185
186 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
187 "Reference addition: input and output types are mismatched");
188
189 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
190 "Reference addition: shapes are not suitable for implicit broadcast.");
191
192 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100193}
194
Nikhil Raj68c2c902019-09-19 11:21:11 +0100195bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
196 const armnn::ArgMinMaxDescriptor &descriptor,
197 armnn::Optional<std::string &> reasonIfUnsupported) const
198{
199 ignore_unused(descriptor);
200
Francis Murtagh1939df52019-11-13 15:21:09 +0000201 std::array<DataType, 4> supportedTypes =
Nikhil Raj68c2c902019-09-19 11:21:11 +0100202 {
203 DataType::Float32,
204 DataType::QuantisedAsymm8,
Francis Murtagh1939df52019-11-13 15:21:09 +0000205 DataType::QuantisedSymm16,
206 DataType::Signed32
Nikhil Raj68c2c902019-09-19 11:21:11 +0100207 };
208
209 bool supported = true;
210
211 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
212 "Reference ArgMinMax: input is not a supported type.");
213 supported &= CheckSupportRule(TypeIs(output, DataType::Signed32), reasonIfUnsupported,
214 "Reference ArgMinMax: output type not supported");
215
216 return supported;
217}
218
arovir011c7c81b2018-10-08 11:34:28 +0100219bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
220 const TensorInfo& output,
221 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100222 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100223 const TensorInfo& beta,
224 const TensorInfo& gamma,
225 const BatchNormalizationDescriptor& descriptor,
226 Optional<std::string&> reasonIfUnsupported) const
227{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100228 ignore_unused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100229
Matthew Jackson9bff1442019-09-12 09:08:23 +0100230 std::array<DataType, 4> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100231 {
232 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100233 DataType::Float16,
Matteo Martincighf5507132019-06-04 10:59:47 +0100234 DataType::QuantisedAsymm8,
235 DataType::QuantisedSymm16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100236 };
237
238 bool supported = true;
239
240 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
241 "Reference batch normalization: input is not a supported type.");
242
243 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
244 "Reference batch normalization: output is not a supported type.");
245
246 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
247 "Reference batch normalization: input and output types are mismatched");
248
249 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
250 "Reference batch normalization: mean is not a supported type.");
251
252 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
253 "Reference batch normalization: variance is not a supported type.");
254
255 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
256 "Reference batch normalization: beta is not a supported type.");
257
258 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
259 "Reference batch normalization: gamma is not a supported type.");
260
261 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100262}
263
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000264bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
265 const TensorInfo& output,
266 const BatchToSpaceNdDescriptor& descriptor,
267 Optional<std::string&> reasonIfUnsupported) const
268{
269 ignore_unused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100270
271 bool supported = true;
272
273 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
274 std::string inputTensorStr = "input";
275 std::string outputTensorStr = "output";
276
277 // Define supported types.
Matthew Jackson9bff1442019-09-12 09:08:23 +0100278 std::array<DataType,4> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100279 {
280 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100281 DataType::Float16,
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100282 DataType::QuantisedAsymm8,
283 DataType::QuantisedSymm16
284 };
285
286 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
287 "Reference BatchToSpaceNd: input type not supported.");
288
289 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
290 "Reference BatchToSpaceNd: output type not supported.");
291
292 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
293 "Reference BatchToSpaceNd: input and output types mismatched.");
294
295 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
296 reasonIfUnsupported,
297 CreateIncorrectDimensionsErrorMsg(4,
298 output.GetNumDimensions(),
299 batchToSpaceNdLayerStr,
300 outputTensorStr).data());
301
302 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
303 reasonIfUnsupported,
304 CreateIncorrectDimensionsErrorMsg(4,
305 input.GetNumDimensions(),
306 batchToSpaceNdLayerStr,
307 inputTensorStr).data());
308
309 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000310}
311
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100312bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
313 const TensorInfo& input1,
314 const TensorInfo& output,
315 const ComparisonDescriptor& descriptor,
316 Optional<std::string&> reasonIfUnsupported) const
317{
318 boost::ignore_unused(descriptor);
319
320 std::array<DataType, 4> supportedInputTypes =
321 {
322 DataType::Float32,
323 DataType::Float16,
324 DataType::QuantisedAsymm8,
325 DataType::QuantisedSymm16
326 };
327
328 bool supported = true;
329 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
330 "Reference comparison: input 0 is not a supported type");
331
332 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
333 "Reference comparison: input 0 and Input 1 types are mismatched");
334
335 supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
336 "Reference comparison: output is not of type Boolean");
337
338 return supported;
339}
340
Jim Flynn906f9462019-05-10 13:55:21 +0100341bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
342 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100343 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100344 Optional<std::string&> reasonIfUnsupported) const
345{
Jim Flynne242f2d2019-05-22 14:24:13 +0100346 ignore_unused(descriptor);
347
348 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +0100349 std::array<DataType,4> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100350 {
351 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100352 DataType::Float16,
Jim Flynne242f2d2019-05-22 14:24:13 +0100353 DataType::QuantisedAsymm8,
354 DataType::QuantisedSymm16
355 };
356
357 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
358 "Reference concatenation: output type not supported");
359 for (const TensorInfo* input : inputs)
360 {
Matthew Jackson81e601c2019-07-11 12:07:09 +0100361 BOOST_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100362 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
363 "Reference concatenation: input type not supported");
364
365 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
366 "Reference concatenation: input and output types mismatched.");
367 }
368
369 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100370}
371
arovir011c7c81b2018-10-08 11:34:28 +0100372bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
373 Optional<std::string&> reasonIfUnsupported) const
374{
Jim Flynne242f2d2019-05-22 14:24:13 +0100375 std::array<DataType,4> supportedTypes =
376 {
Nina Drozd58ef2c62019-05-16 12:09:18 +0100377 DataType::Float32,
378 DataType::Signed32,
379 DataType::QuantisedAsymm8,
380 DataType::QuantisedSymm16
381 };
382
383 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
384 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100385}
386
387bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
388 const TensorInfo& output,
389 Optional<std::string&> reasonIfUnsupported) const
390{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100391 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
392 input.GetDataType(),
393 &TrueFunc<>,
394 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000395 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000396 &FalseFuncI32<>,
397 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100398 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
399 output.GetDataType(),
400 &FalseOutputFuncF16<>,
401 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000402 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000403 &FalseFuncI32<>,
404 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100405}
406
407bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
408 const TensorInfo& output,
409 Optional<std::string&> reasonIfUnsupported) const
410{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100411 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
412 input.GetDataType(),
413 &FalseInputFuncF16<>,
414 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000415 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000416 &FalseFuncI32<>,
417 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100418 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
419 output.GetDataType(),
420 &TrueFunc<>,
421 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000422 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000423 &FalseFuncI32<>,
424 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100425}
426
427bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
428 const TensorInfo& output,
429 const Convolution2dDescriptor& descriptor,
430 const TensorInfo& weights,
431 const Optional<TensorInfo>& biases,
432 Optional<std::string&> reasonIfUnsupported) const
433{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100434 bool supported = true;
435
436 // Define supported types.
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000437 std::array<DataType,4> supportedTypes =
438 {
439 DataType::Float32,
440 DataType::Float16,
441 DataType::QuantisedAsymm8,
442 DataType::QuantisedSymm16
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100443 };
444
445 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100446 "Reference convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100447
448 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100449 "Reference convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100450
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100451 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100452 "Reference convolution2d: input and output types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100453
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000454 const DataType inputType = input.GetDataType();
455 if (inputType == DataType::QuantisedAsymm8)
456 {
457 std::array<DataType, 2> supportedWeightTypes =
458 {
459 DataType::QuantisedAsymm8,
460 DataType::QuantizedSymm8PerAxis
461 };
462
463 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
464 "Reference convolution2d: weights type not supported for quantized input.");
465 }
466 else
467 {
468 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
469 "Reference convolution2d: weights is not a supported type.");
470
471 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
472 "Reference convolution2d: input and weights types mismatched.");
473 }
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100474
475 if (biases.has_value())
476 {
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000477 std::array<DataType,3> biasesSupportedTypes =
478 {
479 DataType::Float32,
480 DataType::Float16,
481 DataType::Signed32
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100482 };
Aron Virginas-Tar5edc8812019-11-05 18:00:21 +0000483
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100484 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100485 "Reference convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100486 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100487 ignore_unused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100488
489 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100490}
491
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000492bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
493 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000494 Optional<std::string&> reasonIfUnsupported) const
495{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100496 bool supported = true;
497
498 std::array<DataType,3> supportedTypes =
499 {
500 DataType::Float32,
501 DataType::QuantisedAsymm8,
502 DataType::QuantisedSymm16
503 };
504
505 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
506 "Reference debug: input type not supported");
507
508 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
509 "Reference debug: output type not supported");
510
511 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
512 "Reference debug: input and output types are mismatched");
513
514 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000515}
516
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100517bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
518 const TensorInfo& output,
519 const DepthToSpaceDescriptor& descriptor,
520 Optional<std::string&> reasonIfUnsupported) const
521{
522 ignore_unused(descriptor);
523 bool supported = true;
524
525 std::array<DataType,4> supportedTypes =
526 {
527 DataType::Float32,
528 DataType::Float16,
529 DataType::QuantisedAsymm8,
530 DataType::QuantisedSymm16
531 };
532
533 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
534 "Reference DepthToSpace: input type not supported");
535
536 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
537 "Reference DepthToSpace: output type not supported");
538
539 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
540 "Reference DepthToSpace: input and output types are mismatched");
541
542 return supported;
543}
544
arovir011c7c81b2018-10-08 11:34:28 +0100545bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
546 const TensorInfo& output,
547 const DepthwiseConvolution2dDescriptor& descriptor,
548 const TensorInfo& weights,
549 const Optional<TensorInfo>& biases,
550 Optional<std::string&> reasonIfUnsupported) const
551{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100552 bool supported = true;
553
554 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100555 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100556 {
557 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100558 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100559 DataType::QuantisedAsymm8,
560 DataType::QuantisedSymm16
561 };
562
563 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
564 "Reference DepthwiseConvolution2d: input is not a supported type.");
565
566 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
567 "Reference DepthwiseConvolution2d: output is not a supported type.");
568
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100569 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
570 "Reference DepthwiseConvolution2d: input and output types mismatched.");
571
Teresa Charlind8df0262019-11-11 12:28:15 +0000572 const DataType inputType = input.GetDataType();
573 if (inputType == DataType::QuantisedAsymm8)
574 {
575 std::array<DataType, 2> supportedWeightTypes =
576 {
577 DataType::QuantisedAsymm8,
578 DataType::QuantizedSymm8PerAxis
579 };
580
581 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
582 "Reference convolution2d: weights type not supported for quantized input.");
583 }
584 else
585 {
586 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
587 "Reference DepthwiseConvolution2d: weights is not a supported type.");
588
589 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
590 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
591 }
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100592
593 if (biases.has_value())
594 {
Matthew Jackson252df3a2019-09-11 09:19:18 +0100595 std::array<DataType,3> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100596 {
597 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100598 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100599 DataType::Signed32
600 };
601 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
602 "Reference DepthwiseConvolution2d: biases is not a supported type.");
603 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100604 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100605
606 return supported;
607
arovir011c7c81b2018-10-08 11:34:28 +0100608}
609
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000610bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
611 const TensorInfo& output,
612 Optional<std::string&> reasonIfUnsupported) const
613{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100614 bool supported = true;
615
616 std::array<DataType,2> supportedInputTypes = {
617 DataType::QuantisedAsymm8,
618 DataType::QuantisedSymm16
619 };
620
621 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
622 "Reference dequantize: input type not supported.");
623
Jan Eilersf7107932019-11-01 11:09:36 +0000624 std::array<DataType,2> supportedOutputTypes = {
625 DataType::Float32,
626 DataType::Float16
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100627 };
628
629 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
630 "Reference dequantize: output type not supported.");
631
632 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
633 "Reference dequantize: input and output shapes have different num total elements.");
634
635 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000636}
637
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000638bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
639 const armnn::TensorInfo& input1,
640 const armnn::DetectionPostProcessDescriptor& descriptor,
641 armnn::Optional<std::string&> reasonIfUnsupported) const
642{
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100643 bool supported = true;
644
Mike Kelly4992c342019-08-14 11:33:11 +0100645 std::array<DataType,3> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100646 {
647 DataType::Float32,
648 DataType::QuantisedAsymm8,
649 DataType::QuantisedSymm16
650 };
651
652 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
653 "Reference DetectionPostProcess: input 0 is not a supported type.");
654
655 supported &= CheckSupportRule(TypeAnyOf(input1, supportedInputTypes), reasonIfUnsupported,
656 "Reference DetectionPostProcess: input 1 is not a supported type.");
657
658 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000659}
660
Pablo Tellof0bd6832019-04-26 17:58:13 +0100661bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
662 const TensorInfo& output,
663 const DepthwiseConvolution2dDescriptor& descriptor,
664 const TensorInfo& weights,
665 const Optional<TensorInfo>& biases,
666 Optional<std::string&> reasonIfUnsupported) const
667{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100668 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100669}
670
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100671bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100672 const TensorInfo& input1,
673 const TensorInfo& output,
674 Optional<std::string&> reasonIfUnsupported) const
675{
Sadik Armagan2999a022019-04-09 14:20:12 +0100676 bool supported = true;
677
Matthew Jackson9bff1442019-09-12 09:08:23 +0100678 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +0100679 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100680 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +0100681 DataType::QuantisedAsymm8,
682 DataType::QuantisedSymm16
683 };
684
685 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
686 "Reference division: input 0 is not a supported type.");
687
688 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
689 "Reference division: input 1 is not a supported type.");
690
691 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
692 "Reference division: output is not a supported type.");
693
694 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
695 "Reference division: input 0 and Input 1 types are mismatched");
696
697 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
698 "Reference division: input and output types are mismatched");
699
700 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
701 "Reference division: shapes are not suitable for implicit broadcast.");
702
703 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100704}
705
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000706bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
707 const TensorInfo& input1,
708 const TensorInfo& output,
709 Optional<std::string&> reasonIfUnsupported) const
710{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100711 return IsComparisonSupported(input0,
712 input1,
713 output,
714 ComparisonDescriptor(ComparisonOperation::Equal),
715 reasonIfUnsupported);
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000716}
717
arovir011c7c81b2018-10-08 11:34:28 +0100718bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
719 const FakeQuantizationDescriptor& descriptor,
720 Optional<std::string&> reasonIfUnsupported) const
721{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100722 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100723 bool supported = true;
724
725 std::array<DataType,1> supportedTypes =
726 {
727 DataType::Float32
728 };
729
730 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
731 "Reference fake quantization: input type not supported.");
732
733 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100734}
735
736bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
737 const TensorInfo& output,
738 Optional<std::string&> reasonIfUnsupported) const
739{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100740 ignore_unused(output);
James Conroy83735b12019-05-30 16:36:59 +0100741 bool supported = true;
742
Matthew Jackson9bff1442019-09-12 09:08:23 +0100743 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100744 {
James Conroyb40d7102019-06-04 12:32:09 +0100745 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100746 DataType::Float16,
James Conroyb40d7102019-06-04 12:32:09 +0100747 DataType::QuantisedSymm16
James Conroy83735b12019-05-30 16:36:59 +0100748 };
749
750 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
751 "Reference Floor: input type not supported.");
752
753 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
754 "Reference Floor: output type not supported.");
755
756 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100757}
758
759bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
760 const TensorInfo& output,
761 const TensorInfo& weights,
762 const TensorInfo& biases,
763 const FullyConnectedDescriptor& descriptor,
764 Optional<std::string&> reasonIfUnsupported) const
765{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100766 bool supported = true;
767
768 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100769 std::array<DataType,4> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +0100770 {
771 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100772 DataType::Float16,
Francis Murtagh46c09d02019-05-28 08:15:28 +0100773 DataType::QuantisedAsymm8,
774 DataType::QuantisedSymm16
775 };
776
777 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
778 "Reference Fully Connected: input type not supported.");
779
780 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
781 "Reference Fully Connected: output type not supported.");
782
783 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
784 "Reference Fully Connected: input and output types mismatched.");
785
786 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
787 "Reference Fully Connected: weights type not supported.");
788
789 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
790 "Reference Fully Connected: input and weight types mismatched.");
791
792 if (descriptor.m_BiasEnabled)
793 {
794 // Defined supported types for bias
Matthew Jackson252df3a2019-09-11 09:19:18 +0100795 std::array<DataType, 3>
Francis Murtagh46c09d02019-05-28 08:15:28 +0100796 supportedBiasTypes =
797 {
798 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100799 DataType::Float16,
Francis Murtagh46c09d02019-05-28 08:15:28 +0100800 DataType::Signed32
801 };
802
803 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
804 "Reference Fully Connected: bias type not supported.");
805
806 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
807 "Reference Fully Connected: bias and weight types mismatch.");
808
809 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
810 "Reference Fully Connected: bias type inferred from weights is incompatible.");
811
812 }
813
814 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100815}
816
narpra014951d842019-01-18 16:53:53 +0000817bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
818 const armnn::TensorInfo& input1,
819 const armnn::TensorInfo& output,
820 armnn::Optional<std::string&> reasonIfUnsupported) const
821{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100822 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +0100823 std::array<DataType,4> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100824 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100825 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100826 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100827 DataType::QuantisedAsymm8,
828 DataType::QuantisedSymm16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100829 };
830
831 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
832 "Reference Gather: input type not supported");
833
834 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
835 "Reference Gather: output type not supported");
836
837 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
838 "Reference Gather: indices (input1) type not supported");
839
840 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
841 "Reference Gather: input and output types not matching");
842
843 return supported;
narpra014951d842019-01-18 16:53:53 +0000844}
845
FrancisMurtagh878f0232018-12-19 10:56:15 +0000846bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
847 const TensorInfo& input1,
848 const TensorInfo& output,
849 Optional<std::string&> reasonIfUnsupported) const
850{
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100851 return IsComparisonSupported(input0,
852 input1,
853 output,
854 ComparisonDescriptor(ComparisonOperation::Greater),
855 reasonIfUnsupported);
FrancisMurtagh878f0232018-12-19 10:56:15 +0000856}
857
arovir011c7c81b2018-10-08 11:34:28 +0100858bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
859 Optional<std::string&> reasonIfUnsupported) const
860{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +0100861 return true;
arovir011c7c81b2018-10-08 11:34:28 +0100862}
863
Kevin May09ca49c2019-10-09 12:37:34 +0100864bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
865 const TensorInfo& output,
866 const InstanceNormalizationDescriptor& descriptor,
867 Optional<std::string&> reasonIfUnsupported) const
868{
869 ignore_unused(descriptor);
870 // Define supported types
871 std::array<DataType, 4> supportedTypes =
872 {
873 DataType::Float32,
874 DataType::Float16
875 };
876
877 bool supported = true;
878
879 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
880 "Reference Instance Normalization: input type not supported.");
881
882 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
883 "Reference Instance Normalization: output type not supported.");
884
885 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
886 "Reference Instance Normalization: input and output types mismatched.");
887
888 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
889 "Reference Instance Normalization: input and output shapes have different "
890 "num total elements.");
891
892 return supported;
893}
894
arovir011c7c81b2018-10-08 11:34:28 +0100895bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
896 const TensorInfo& output,
897 const L2NormalizationDescriptor& descriptor,
898 Optional<std::string&> reasonIfUnsupported) const
899{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100900 ignore_unused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100901 // Define supported types
Matthew Jackson252df3a2019-09-11 09:19:18 +0100902 std::array<DataType, 4> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100903 {
904 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100905 DataType::Float16,
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100906 DataType::QuantisedAsymm8,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100907 DataType::QuantisedSymm16
908 };
909
910 bool supported = true;
911
912 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
913 "Reference L2normalization: input type not supported.");
914
915 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
916 "Reference L2normalization: output type not supported.");
917
918 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
919 "Reference L2normalization: input and output types mismatched.");
920
921 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
922 "Reference L2normalization: input and output shapes have different "
923 "num total elements.");
924
925 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100926}
927
Aron Virginas-Tare662a942019-10-14 15:12:00 +0100928bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
929 const TensorInfo& output,
930 const LogSoftmaxDescriptor& descriptor,
931 Optional<std::string&> reasonIfUnsupported) const
932{
933 ignore_unused(descriptor);
934
935 std::array<DataType, 2> supportedTypes =
936 {
937 DataType::Float32,
938 DataType::Float16
939 };
940
941 bool supported = true;
942 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
943 "Reference LogSoftmax: input type not supported");
944
945 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
946 "Reference LogSoftmax: output type not supported");
947
948 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
949 "Reference LogSoftmax: input and output types do not match");
950
951 return supported;
952}
953
arovir011c7c81b2018-10-08 11:34:28 +0100954bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
955 const TensorInfo& outputStateIn,
956 const TensorInfo& cellStateIn,
957 const TensorInfo& scratchBuffer,
958 const TensorInfo& outputStateOut,
959 const TensorInfo& cellStateOut,
960 const TensorInfo& output,
961 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100962 const LstmInputParamsInfo& paramsInfo,
963 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +0100964{
telsoa01c577f2c2018-08-31 09:22:23 +0100965 ignore_unused(descriptor);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100966 ignore_unused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100967
968 bool supported = true;
969
970 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +0100971 DataType::Float32,
972 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100973 };
974
Jan Eilersd01a83c2019-07-03 18:20:40 +0100975 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100976 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
977 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100978 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
979 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100980 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
981 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100982 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
983 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100984 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
985 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100986 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
987 "Reference Lstm: input and cellStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100988 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
989 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +0100990 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +0100991 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100992 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100993 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100994 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100995 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100996 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100997 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100998 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100999 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001000 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001001 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001002 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001003 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001004 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001005 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001006 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001007 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001008 "Reference Lstm: input and OutputGateBias types are mismatched");
1009 if (!descriptor.m_CifgEnabled)
1010 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001011 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001012 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001013 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001014 reasonIfUnsupported,
1015 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001016 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001017 "Reference Lstm: input and InputGateBias types are mismatched");
1018 if (descriptor.m_PeepholeEnabled)
1019 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001020 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001021 reasonIfUnsupported,
1022 "Reference Lstm: input and CellToInputWeights types are mismatched");
1023 }
1024 }
1025 if (descriptor.m_PeepholeEnabled)
1026 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001027 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001028 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001029 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001030 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1031 }
1032 if (descriptor.m_ProjectionEnabled)
1033 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001034 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001035 "Reference Lstm: input and mProjectionWeights types are mismatched");
1036 if (paramsInfo.m_ProjectionBias != nullptr)
1037 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001038 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001039 "Reference Lstm: input and ProjectionBias types are mismatched");
1040 }
1041 }
1042 if (descriptor.m_LayerNormEnabled)
1043 {
1044 if (!descriptor.m_CifgEnabled)
1045 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001046 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001047 reasonIfUnsupported,
1048 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1049 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001050 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001051 reasonIfUnsupported,
1052 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001053 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001054 reasonIfUnsupported,
1055 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001056 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001057 reasonIfUnsupported,
1058 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1059 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001060
1061 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001062}
1063
saoste012df12b32018-11-28 16:57:20 +00001064bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1065 const TensorInfo& input1,
1066 const TensorInfo& output,
1067 Optional<std::string&> reasonIfUnsupported) const
1068{
Sadik Armagan2999a022019-04-09 14:20:12 +01001069 bool supported = true;
1070
Matthew Jackson9bff1442019-09-12 09:08:23 +01001071 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001072 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001073 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001074 DataType::QuantisedAsymm8,
1075 DataType::QuantisedSymm16
1076 };
1077
1078 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1079 "Reference maximum: input 0 is not a supported type.");
1080
1081 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1082 "Reference maximum: input 1 is not a supported type.");
1083
1084 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1085 "Reference maximum: output is not a supported type.");
1086
1087 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1088 "Reference maximum: input 0 and Input 1 types are mismatched");
1089
1090 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1091 "Reference maximum: input and output types are mismatched");
1092
1093 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1094 "Reference maximum: shapes are not suitable for implicit broadcast.");
1095
1096 return supported;
saoste012df12b32018-11-28 16:57:20 +00001097}
1098
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001099bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1100 const TensorInfo& output,
1101 const MeanDescriptor& descriptor,
1102 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001103{
James Conroy4d1ff582019-06-10 17:06:39 +01001104 bool supported = true;
1105 std::string meanLayerStr = "Mean";
1106 std::string outputTensorStr = "output";
1107
Matthew Jackson252df3a2019-09-11 09:19:18 +01001108 std::array<DataType,4> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001109 {
1110 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001111 DataType::Float16,
James Conroyb80775f2019-06-11 11:25:30 +01001112 DataType::QuantisedAsymm8,
1113 DataType::QuantisedSymm16
James Conroy4d1ff582019-06-10 17:06:39 +01001114 };
1115
1116 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1117 "Reference Mean: input type not supported.");
1118
1119 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1120 "Reference Mean: input and output types are mismatched");
1121
1122 if (descriptor.m_KeepDims)
1123 {
1124 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1125 reasonIfUnsupported,
1126 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1127 output.GetNumDimensions(),
1128 meanLayerStr, outputTensorStr).data());
1129 }
1130 else if (descriptor.m_Axis.empty())
1131 {
1132 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1133 reasonIfUnsupported,
1134 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1135 meanLayerStr, outputTensorStr).data());
1136 }
1137 else
1138 {
1139 auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
1140
1141 if (outputDim > 0)
1142 {
1143 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1144 reasonIfUnsupported,
1145 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1146 meanLayerStr, outputTensorStr).data());
1147 }
1148 else
1149 {
1150 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1151 reasonIfUnsupported,
1152 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1153 meanLayerStr, outputTensorStr).data());
1154 }
1155 }
1156
1157 return supported;
narpra0132b90462018-09-13 11:07:48 +01001158}
1159
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001160bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +00001161 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +01001162 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001163 Optional<std::string&> reasonIfUnsupported) const
1164{
Jim Flynne242f2d2019-05-22 14:24:13 +01001165 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001166}
1167
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001168bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1169 const TensorInfo &output,
1170 Optional<std::string &> reasonIfUnsupported) const
1171{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001172 bool supported = true;
1173
1174 std::array<DataType,5> supportedTypes =
1175 {
1176 DataType::Float32,
1177 DataType::Float16,
1178 DataType::QuantisedAsymm8,
1179 DataType::QuantisedSymm16,
1180 DataType::Boolean
1181 };
1182
1183 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1184 "Reference MemCopy: input type not supported");
1185
1186 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1187 "Reference MemCopy: output type not supported");
1188
1189 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1190 "Reference MemCopy: input and output types are mismatched");
1191
1192 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001193}
1194
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001195bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1196 const TensorInfo& input1,
1197 const TensorInfo& output,
1198 Optional<std::string&> reasonIfUnsupported) const
1199{
Sadik Armagan2999a022019-04-09 14:20:12 +01001200 bool supported = true;
1201
Matthew Jackson9bff1442019-09-12 09:08:23 +01001202 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001203 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001204 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001205 DataType::QuantisedAsymm8,
1206 DataType::QuantisedSymm16
1207 };
1208
1209 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1210 "Reference minimum: input 0 is not a supported type.");
1211
1212 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1213 "Reference minimum: input 1 is not a supported type.");
1214
1215 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1216 "Reference minimum: output is not a supported type.");
1217
1218 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1219 "Reference minimum: input 0 and Input 1 types are mismatched");
1220
1221 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1222 "Reference minimum: input and output types are mismatched");
1223
1224 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1225 "Reference minimum: shapes are not suitable for implicit broadcast.");
1226
1227 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001228}
1229
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001230bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1231 const TensorInfo& input1,
1232 const TensorInfo& output,
1233 Optional<std::string&> reasonIfUnsupported) const
1234{
Sadik Armagan2999a022019-04-09 14:20:12 +01001235 bool supported = true;
1236
Matthew Jackson252df3a2019-09-11 09:19:18 +01001237 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001238 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001239 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001240 DataType::QuantisedAsymm8,
1241 DataType::QuantisedSymm16
1242 };
1243
1244 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1245 "Reference multiplication: input 0 is not a supported type.");
1246
1247 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1248 "Reference multiplication: input 1 is not a supported type.");
1249
1250 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1251 "Reference multiplication: output is not a supported type.");
1252
1253 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1254 "Reference multiplication: input 0 and Input 1 types are mismatched");
1255
1256 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1257 "Reference multiplication: input and output types are mismatched");
1258
1259 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1260 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1261
1262 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001263}
1264
1265bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1266 const TensorInfo& output,
1267 const NormalizationDescriptor& descriptor,
1268 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001269{
Nina Drozd661dfa72018-10-02 11:14:17 +01001270 ignore_unused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001271
1272 // Define supported types
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001273 std::array<DataType, 4> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001274 {
1275 DataType::Float16,
1276 DataType::Float32,
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001277 DataType::QuantisedAsymm8,
1278 DataType::QuantisedSymm16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001279 };
1280
1281 bool supported = true;
1282
1283 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1284 "Reference normalization: input type not supported.");
1285
1286 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1287 "Reference normalization: output type not supported.");
1288
1289 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1290 "Reference normalization: input and output shapes have different "
1291 "num total elements.");
1292
1293 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001294}
1295
1296bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
1297 Optional<std::string&> reasonIfUnsupported) const
1298{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001299 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001300}
1301
1302bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1303 const TensorInfo& output,
1304 const PadDescriptor& descriptor,
1305 Optional<std::string&> reasonIfUnsupported) const
1306{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001307 ignore_unused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001308 bool supported = true;
1309
1310 // Define supported output and inputs types.
Matthew Jackson252df3a2019-09-11 09:19:18 +01001311 std::array<DataType,4> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001312 {
1313 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001314 DataType::Float16,
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001315 DataType::QuantisedAsymm8,
1316 DataType::QuantisedSymm16
1317 };
1318
1319 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1320 "Reference pad: input is not a supported type.");
1321
1322 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1323 "Reference pad: output is not a supported type.");
1324
1325 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1326 "Reference pad: input and output types are mismatched.");
1327
1328 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01001329}
1330
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001331bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1332 const TensorInfo& output,
1333 const PermuteDescriptor& descriptor,
1334 Optional<std::string&> reasonIfUnsupported) const
1335{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001336 ignore_unused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001337 bool supported = true;
1338
1339 // Define supported output and inputs types.
1340 std::array<DataType,3> supportedTypes =
1341 {
1342 DataType::Float32,
1343 DataType::QuantisedAsymm8,
1344 DataType::QuantisedSymm16
1345 };
1346
1347 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1348 "Reference permute: input is not a supported type.");
1349
1350 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1351 "Reference permute: output is not a supported type.");
1352
1353 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1354 "Reference permute: input and output types are mismatched.");
1355
1356 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001357}
1358
1359bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1360 const TensorInfo& output,
1361 const Pooling2dDescriptor& descriptor,
1362 Optional<std::string&> reasonIfUnsupported) const
1363{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001364 ignore_unused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001365 bool supported = true;
1366
1367 // Define supported output and inputs types.
Matthew Jackson252df3a2019-09-11 09:19:18 +01001368 std::array<DataType,4> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001369 {
1370 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001371 DataType::Float16,
Teresa Charlin0434df62019-06-06 13:40:35 +01001372 DataType::QuantisedAsymm8,
1373 DataType::QuantisedSymm16
Teresa Charlina3b20472019-06-06 11:12:32 +01001374 };
1375
1376 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1377 "Reference poolind2d: input is not a supported type.");
1378
1379 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1380 "Reference poolind2d: output is not a supported type.");
1381
1382 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1383 "Reference poolind2d: input and output types are mismatched.");
1384
1385 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001386}
1387
Derek Lamberti5f400d62019-03-25 15:41:58 +00001388bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1389 const TensorInfo& output,
1390 Optional<std::string&> reasonIfUnsupported) const
1391{
1392 bool supported = true;
1393
1394 // Define supported output types.
Sadik Armagan1a816302019-07-29 17:16:40 +01001395 std::array<DataType,1> supportedInputTypes = {
Derek Lamberti5f400d62019-03-25 15:41:58 +00001396 DataType::Float32,
1397 };
1398
1399 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1400 "Reference quantize: input type not supported.");
1401
1402 // Define supported output types.
1403 std::array<DataType,2> supportedOutputTypes = {
1404 DataType::QuantisedAsymm8,
1405 DataType::QuantisedSymm16
1406 };
1407 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1408 "Reference quantize: output type not supported.");
1409
1410 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1411 "Reference quantize: input and output shapes have different num total elements.");
1412
1413 return supported;
1414}
1415
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001416bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001417 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001418 Optional<std::string&> reasonIfUnsupported) const
1419{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001420 ignore_unused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001421 // Define supported output types.
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001422 std::array<DataType,5> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001423 {
1424 DataType::Float32,
1425 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001426 DataType::Signed32,
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001427 DataType::QuantisedAsymm8,
1428 DataType::QuantisedSymm16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001429 };
1430 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1431 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001432}
1433
1434bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001435 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001436 Optional<std::string&> reasonIfUnsupported) const
1437{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001438 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001439 std::array<DataType,4> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001440 {
1441 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001442 DataType::Float16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001443 DataType::QuantisedAsymm8,
1444 DataType::QuantisedSymm16
1445 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001446
1447 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1448 "Reference ResizeBilinear: input type not supported");
1449
1450 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1451 "Reference ResizeBilinear: output type not supported");
1452
1453 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1454 "Reference ResizeBilinear: input and output types not matching");
1455
1456 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001457}
1458
Teresa Charlin970f43b2019-07-01 13:51:07 +01001459bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1460 const TensorInfo& output,
1461 const ResizeDescriptor& descriptor,
1462 Optional<std::string&> reasonIfUnsupported) const
1463{
1464 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001465 std::array<DataType,4> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001466 {
1467 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001468 DataType::Float16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001469 DataType::QuantisedAsymm8,
1470 DataType::QuantisedSymm16
1471 };
1472
1473 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1474 "Reference Resize: input type not supported");
1475
1476 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1477 "Reference Resize: output type not supported");
1478
1479 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1480 "Reference Resize: input and output types not matching");
1481
1482 return supported;
1483}
1484
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001485bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1486 const TensorInfo& output,
1487 Optional<std::string&> reasonIfUnsupported) const
1488{
nikraj010421e7f2019-06-14 09:40:34 +01001489 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001490 std::array<DataType,4> supportedTypes =
nikraj010421e7f2019-06-14 09:40:34 +01001491 {
1492 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001493 DataType::Float16,
nikraj0124d73212019-06-14 14:20:40 +01001494 DataType::QuantisedAsymm8,
1495 DataType::QuantisedSymm16
nikraj010421e7f2019-06-14 09:40:34 +01001496 };
1497
1498 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1499 "Reference rsqrt: input type not supported");
1500
1501 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1502 "Reference rsqrt: output type not supported");
1503
1504 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1505 "Reference rsqrt: input and output types not matching");
1506
1507 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1508 "Reference Rsqrt: input and output shapes have different number of total elements");
1509
1510 return supported;
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001511}
1512
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001513bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
1514 const TensorInfo& output,
1515 const SliceDescriptor& descriptor,
1516 Optional<std::string&> reasonIfUnsupported) const
1517{
1518 ignore_unused(descriptor);
1519 bool supported = true;
1520
1521 std::array<DataType, 3> supportedTypes =
1522 {
1523 DataType::Float32,
1524 DataType::QuantisedAsymm8,
1525 DataType::QuantisedSymm16
1526 };
1527
1528 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1529 "Reference Slice: input type not supported");
1530
1531 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1532 "Reference Slice: output type not supported");
1533
1534 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1535 "Reference Slice: input and output types are mismatched");
1536
1537 return supported;
1538}
1539
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001540bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1541 const TensorInfo& output,
1542 const SoftmaxDescriptor& descriptor,
1543 Optional<std::string&> reasonIfUnsupported) const
1544{
1545 ignore_unused(output);
nikraj01248683f2019-05-29 16:46:50 +01001546 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001547 std::array<DataType,4> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01001548 {
1549 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001550 DataType::Float16,
nikraj01248683f2019-05-29 16:46:50 +01001551 DataType::QuantisedAsymm8,
1552 DataType::QuantisedSymm16
1553 };
1554
1555 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001556 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001557
1558 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001559 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001560
1561 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001562 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001563
1564 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001565}
1566
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001567bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1568 const TensorInfo& output,
1569 const SpaceToBatchNdDescriptor& descriptor,
1570 Optional<std::string&> reasonIfUnsupported) const
1571{
1572 ignore_unused(output);
nikraj01120522a2019-05-31 11:33:07 +01001573 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001574 std::array<DataType,4> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01001575 {
1576 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001577 DataType::Float16,
nikraj01120522a2019-05-31 11:33:07 +01001578 DataType::QuantisedAsymm8,
1579 DataType::QuantisedSymm16
1580 };
1581
1582 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1583 "Reference SpaceToBatchNd: input type not supported");
1584
1585 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1586 "Reference SpaceToBatchNd: output type not supported");
1587
1588 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1589 "Reference SpaceToBatchNd: input and output types are mismatched");
1590
1591 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001592}
1593
Keith Davisa57eccb2019-06-14 17:33:22 +01001594bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01001595 const TensorInfo& output,
1596 const SpaceToDepthDescriptor& descriptor,
1597 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01001598{
1599
1600 ignore_unused(descriptor);
1601 bool supported = true;
1602
Matthew Jackson9bff1442019-09-12 09:08:23 +01001603 std::array<DataType,4> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01001604 {
1605 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001606 DataType::Float16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001607 DataType::QuantisedAsymm8,
James Conroyd2aa85e2019-07-01 17:12:40 +01001608 DataType::QuantisedSymm16
Keith Davisa57eccb2019-06-14 17:33:22 +01001609 };
1610
1611 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1612 "Reference SpaceToDepth: input type not supported");
1613
1614 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1615 "Reference SpaceToDepth: output type not supported");
1616
1617 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1618 "Reference SpaceToDepth: input and output types are mismatched");
1619
1620 return supported;
1621}
1622
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001623bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1624 const ViewsDescriptor& descriptor,
1625 Optional<std::string&> reasonIfUnsupported) const
1626{
1627 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001628 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001629 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001630 {
1631 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001632 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001633 DataType::QuantisedAsymm8,
1634 DataType::QuantisedSymm16
1635 };
1636
1637 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1638 "Reference splitter: input type not supported");
1639
1640 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001641}
1642
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001643bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1644 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1645 const ViewsDescriptor& descriptor,
1646 Optional<std::string&> reasonIfUnsupported) const
1647{
1648 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001649 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001650 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001651 {
1652 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001653 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001654 DataType::QuantisedAsymm8,
1655 DataType::QuantisedSymm16
1656 };
1657
1658 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1659 "Reference splitter: output type not supported");
1660 for (const TensorInfo output : outputs)
1661 {
1662 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1663 "Reference splitter: input type not supported");
1664
1665 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1666 "Reference splitter: input and output types mismatched.");
1667 }
1668
1669 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001670}
1671
Matthew Jackson81e601c2019-07-11 12:07:09 +01001672bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
1673 const TensorInfo& output,
1674 const StackDescriptor& descriptor,
1675 Optional<std::string&> reasonIfUnsupported) const
1676{
1677 ignore_unused(descriptor);
1678
1679 bool supported = true;
Matthew Jacksone69c3992019-09-09 14:31:21 +01001680 std::array<DataType,4> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01001681 {
1682 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01001683 DataType::Float16,
Matthew Jackson81e601c2019-07-11 12:07:09 +01001684 DataType::QuantisedAsymm8,
1685 DataType::QuantisedSymm16
1686 };
1687
1688 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1689 "Reference stack: output type not supported");
1690 for (const TensorInfo* input : inputs)
1691 {
1692 BOOST_ASSERT(input != nullptr);
1693 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
1694 "Reference stack: input type not supported");
1695
1696 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
1697 "Reference stack: input and output types mismatched.");
1698 }
1699
1700 return supported;
1701}
1702
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001703bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1704 const TensorInfo& output,
1705 const StridedSliceDescriptor& descriptor,
1706 Optional<std::string&> reasonIfUnsupported) const
1707{
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001708 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001709 bool supported = true;
1710
1711 std::array<DataType,3> supportedTypes =
1712 {
1713 DataType::Float32,
1714 DataType::QuantisedAsymm8,
1715 DataType::QuantisedSymm16
1716 };
1717
1718 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1719 "Reference StridedSlice: input type not supported");
1720
1721 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1722 "Reference StridedSlice: output type not supported");
1723
1724 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1725 "Reference StridedSlice: input and output types are mismatched");
1726
1727 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001728}
1729
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001730bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1731 const TensorInfo& input1,
1732 const TensorInfo& output,
1733 Optional<std::string&> reasonIfUnsupported) const
1734{
Sadik Armagan2999a022019-04-09 14:20:12 +01001735 bool supported = true;
1736
Matthew Jackson9bff1442019-09-12 09:08:23 +01001737 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001738 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001739 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001740 DataType::QuantisedAsymm8,
1741 DataType::QuantisedSymm16
1742 };
1743
1744 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1745 "Reference subtraction: input 0 is not a supported type.");
1746
1747 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1748 "Reference subtraction: input 1 is not a supported type.");
1749
1750 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1751 "Reference subtraction: output is not a supported type.");
1752
1753 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1754 "Reference subtraction: input 0 and Input 1 types are mismatched");
1755
1756 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1757 "Reference subtraction: input and output types are mismatched");
1758
1759 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1760 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1761
1762 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001763}
1764
Matteo Martincighab9e5252019-06-13 17:27:46 +01001765bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
1766 const TensorInfo& alpha,
1767 const TensorInfo& output,
1768 Optional<std::string&> reasonIfUnsupported) const
1769{
1770 bool supported = true;
1771
Matthew Jackson9bff1442019-09-12 09:08:23 +01001772 std::array<DataType, 4> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01001773 {
1774 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001775 DataType::Float16,
Matteo Martincighab9e5252019-06-13 17:27:46 +01001776 DataType::QuantisedAsymm8,
1777 DataType::QuantisedSymm16
1778 };
1779
1780 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1781 "PReLU: input is not a supported type.");
1782
1783 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
1784 "PReLU: alpha is not a supported type.");
1785
1786 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1787 "PReLU: output is not a supported type.");
1788
1789 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
1790 "PReLU: input, alpha and output types are mismatched");
1791
1792 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
1793 "PReLU: shapes are not suitable for implicit broadcast");
1794
1795 return supported;
1796}
1797
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001798bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
1799 const TensorInfo& output,
1800 const TransposeConvolution2dDescriptor& descriptor,
1801 const TensorInfo& weights,
1802 const Optional<TensorInfo>& biases,
1803 Optional<std::string&> reasonIfUnsupported) const
1804{
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001805 bool supported = true;
1806
Matthew Jackson252df3a2019-09-11 09:19:18 +01001807 std::array<DataType,4> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001808 {
1809 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001810 DataType::Float16,
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001811 DataType::QuantisedAsymm8,
1812 DataType::QuantisedSymm16
1813 };
1814
1815 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1816 "Reference TransposeConvolution2d: input is not a supported type.");
1817
1818 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1819 "Reference TransposeConvolution2d: output is not a supported type.");
1820
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001821 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1822 "Reference TransposeConvolution2d: input and output types mismatched.");
1823
Aron Virginas-Tar94d3b932019-11-11 12:54:47 +00001824
1825 const DataType inputType = input.GetDataType();
1826 if (inputType == DataType::QuantisedAsymm8)
1827 {
1828 std::array<DataType, 2> supportedWeightTypes =
1829 {
1830 DataType::QuantisedAsymm8,
1831 DataType::QuantizedSymm8PerAxis
1832 };
1833
1834 supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
1835 "Reference TransposeConvolution2d: weights type not supported for "
1836 "quantized input.");
1837 }
1838 else
1839 {
1840 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1841 "Reference TransposeConvolution2d: weights is not a supported type.");
1842
1843 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1844 "Reference TransposeConvolution2d: input and weights types mismatched.");
1845 }
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001846
1847 if (biases.has_value())
1848 {
Matthew Jackson252df3a2019-09-11 09:19:18 +01001849 std::array<DataType,3> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01001850 {
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001851 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001852 DataType::Float16,
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001853 DataType::Signed32
1854 };
1855 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1856 "Reference TransposeConvolution2d: biases is not a supported type.");
1857 }
1858
1859 return supported;
1860}
1861
arovir011c7c81b2018-10-08 11:34:28 +01001862} // namespace armnn