blob: 9342b29f475c02691a557a90da07e6a356dad496 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01009#include <DataLayoutIndexed.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000010#include <InternalTypes.hpp>
11#include <LayerSupportCommon.hpp>
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +010012
telsoa014fcda012018-03-09 14:13:49 +000013#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000014#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000015
David Beck111b5d92018-11-12 14:59:37 +000016#include <backendsCommon/BackendRegistry.hpp>
Derek Lambertif674aa02019-08-01 15:56:25 +010017#include <backendsCommon/LayerSupportRules.hpp>
Francis Murtagh46c09d02019-05-28 08:15:28 +010018#include <backendsCommon/test/WorkloadTestUtils.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010019
telsoa014fcda012018-03-09 14:13:49 +000020#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021
Derek Lamberti50db4e82019-03-13 14:16:15 +000022#include <vector>
23#include <algorithm>
24#include <array>
25
telsoa014fcda012018-03-09 14:13:49 +000026using namespace boost;
27
28namespace armnn
29{
30
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010031namespace
32{
33
34template<typename Float32Func, typename Uint8Func, typename ... Params>
35bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
36 DataType dataType,
37 Float32Func floatFuncPtr,
38 Uint8Func uint8FuncPtr,
39 Params&&... params)
40{
41 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
42 dataType,
43 &FalseFunc<Params...>,
44 floatFuncPtr,
45 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000046 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000047 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010048 std::forward<Params>(params)...);
49}
50
51} // anonymous namespace
52
James Conroy4d1ff582019-06-10 17:06:39 +010053namespace
54{
55
56std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
57 unsigned int actual,
58 std::string& layerStr,
59 std::string& tensorName)
60{
61 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
62 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
63
64 return errorMsg;
65}
66
67} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000068
Sadik Armagan9199e582019-09-05 17:35:31 +010069bool RefLayerSupport::IsAbsSupported(const TensorInfo& input, const TensorInfo& output,
70 Optional<std::string&> reasonIfUnsupported) const
71{
72 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +010073 std::array<DataType,4> supportedTypes =
Sadik Armagan9199e582019-09-05 17:35:31 +010074 {
75 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +010076 DataType::Float16,
Sadik Armagan9199e582019-09-05 17:35:31 +010077 DataType::QuantisedAsymm8,
78 DataType::QuantisedSymm16
79 };
80
81 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
82 "Reference abs: input type not supported");
83
84 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
85 "Reference abs: output type not supported");
86
87 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
88 "Reference abs: input and output types not matching");
89
90 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
91 "Reference abs: input and output shapes have different number of total elements");
92
93 return supported;
94}
95
arovir011c7c81b2018-10-08 11:34:28 +010096bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
97 const TensorInfo& output,
98 const ActivationDescriptor& descriptor,
99 Optional<std::string&> reasonIfUnsupported) const
100{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000101 bool supported = true;
102
103 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100104 std::array<DataType,4> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000105 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100106 DataType::Float16,
Teresa Charlin18515e22019-04-24 10:17:46 +0100107 DataType::QuantisedAsymm8,
108 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000109 };
110
111 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
112 "Reference activation: input type not supported.");
113
114 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
115 "Reference activation: output type not supported.");
116
117 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
118 "Reference activation: input and output types mismatched.");
119
120 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
121 "Reference activation: input and output shapes are of different rank.");
122
123
124 struct ActivationFunctionSupported : public Rule
125 {
126 ActivationFunctionSupported(const ActivationDescriptor& desc)
127 {
128 switch(desc.m_Function)
129 {
130 case ActivationFunction::Abs:
131 case ActivationFunction::BoundedReLu:
132 case ActivationFunction::LeakyReLu:
133 case ActivationFunction::Linear:
134 case ActivationFunction::ReLu:
135 case ActivationFunction::Sigmoid:
136 case ActivationFunction::SoftReLu:
137 case ActivationFunction::Sqrt:
138 case ActivationFunction::Square:
139 case ActivationFunction::TanH:
140 {
141 m_Res = true;
142 break;
143 }
144 default:
145 {
146 m_Res = false;
147 break;
148 }
149 }
150 }
151 };
152
153 // Function is supported
154 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
155 "Reference activation: function not supported.");
156
157 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100158}
159
160bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
161 const TensorInfo& input1,
162 const TensorInfo& output,
163 Optional<std::string&> reasonIfUnsupported) const
164{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000165 bool supported = true;
166
Matthew Jackson252df3a2019-09-11 09:19:18 +0100167 std::array<DataType,4> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000168 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100169 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +0100170 DataType::QuantisedAsymm8,
171 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000172 };
173
174 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
175 "Reference addition: input 0 is not a supported type.");
176
177 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
178 "Reference addition: input 1 is not a supported type.");
179
180 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
181 "Reference addition: output is not a supported type.");
182
183 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
184 "Reference addition: input 0 and Input 1 types are mismatched");
185
186 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
187 "Reference addition: input and output types are mismatched");
188
189 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
190 "Reference addition: shapes are not suitable for implicit broadcast.");
191
192 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100193}
194
Nikhil Raj68c2c902019-09-19 11:21:11 +0100195bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const armnn::TensorInfo &output,
196 const armnn::ArgMinMaxDescriptor &descriptor,
197 armnn::Optional<std::string &> reasonIfUnsupported) const
198{
199 ignore_unused(descriptor);
200
201 std::array<DataType, 3> supportedTypes =
202 {
203 DataType::Float32,
204 DataType::QuantisedAsymm8,
205 DataType::QuantisedSymm16
206 };
207
208 bool supported = true;
209
210 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
211 "Reference ArgMinMax: input is not a supported type.");
212 supported &= CheckSupportRule(TypeIs(output, DataType::Signed32), reasonIfUnsupported,
213 "Reference ArgMinMax: output type not supported");
214
215 return supported;
216}
217
arovir011c7c81b2018-10-08 11:34:28 +0100218bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
219 const TensorInfo& output,
220 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100221 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100222 const TensorInfo& beta,
223 const TensorInfo& gamma,
224 const BatchNormalizationDescriptor& descriptor,
225 Optional<std::string&> reasonIfUnsupported) const
226{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100227 ignore_unused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100228
Matthew Jackson9bff1442019-09-12 09:08:23 +0100229 std::array<DataType, 4> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100230 {
231 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100232 DataType::Float16,
Matteo Martincighf5507132019-06-04 10:59:47 +0100233 DataType::QuantisedAsymm8,
234 DataType::QuantisedSymm16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100235 };
236
237 bool supported = true;
238
239 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
240 "Reference batch normalization: input is not a supported type.");
241
242 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
243 "Reference batch normalization: output is not a supported type.");
244
245 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
246 "Reference batch normalization: input and output types are mismatched");
247
248 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
249 "Reference batch normalization: mean is not a supported type.");
250
251 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
252 "Reference batch normalization: variance is not a supported type.");
253
254 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
255 "Reference batch normalization: beta is not a supported type.");
256
257 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
258 "Reference batch normalization: gamma is not a supported type.");
259
260 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100261}
262
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000263bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
264 const TensorInfo& output,
265 const BatchToSpaceNdDescriptor& descriptor,
266 Optional<std::string&> reasonIfUnsupported) const
267{
268 ignore_unused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100269
270 bool supported = true;
271
272 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
273 std::string inputTensorStr = "input";
274 std::string outputTensorStr = "output";
275
276 // Define supported types.
Matthew Jackson9bff1442019-09-12 09:08:23 +0100277 std::array<DataType,4> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100278 {
279 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100280 DataType::Float16,
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100281 DataType::QuantisedAsymm8,
282 DataType::QuantisedSymm16
283 };
284
285 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
286 "Reference BatchToSpaceNd: input type not supported.");
287
288 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
289 "Reference BatchToSpaceNd: output type not supported.");
290
291 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
292 "Reference BatchToSpaceNd: input and output types mismatched.");
293
294 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
295 reasonIfUnsupported,
296 CreateIncorrectDimensionsErrorMsg(4,
297 output.GetNumDimensions(),
298 batchToSpaceNdLayerStr,
299 outputTensorStr).data());
300
301 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
302 reasonIfUnsupported,
303 CreateIncorrectDimensionsErrorMsg(4,
304 input.GetNumDimensions(),
305 batchToSpaceNdLayerStr,
306 inputTensorStr).data());
307
308 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000309}
310
Jim Flynn906f9462019-05-10 13:55:21 +0100311bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
312 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100313 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100314 Optional<std::string&> reasonIfUnsupported) const
315{
Jim Flynne242f2d2019-05-22 14:24:13 +0100316 ignore_unused(descriptor);
317
318 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +0100319 std::array<DataType,4> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100320 {
321 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100322 DataType::Float16,
Jim Flynne242f2d2019-05-22 14:24:13 +0100323 DataType::QuantisedAsymm8,
324 DataType::QuantisedSymm16
325 };
326
327 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
328 "Reference concatenation: output type not supported");
329 for (const TensorInfo* input : inputs)
330 {
Matthew Jackson81e601c2019-07-11 12:07:09 +0100331 BOOST_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100332 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
333 "Reference concatenation: input type not supported");
334
335 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
336 "Reference concatenation: input and output types mismatched.");
337 }
338
339 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100340}
341
arovir011c7c81b2018-10-08 11:34:28 +0100342bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
343 Optional<std::string&> reasonIfUnsupported) const
344{
Jim Flynne242f2d2019-05-22 14:24:13 +0100345 std::array<DataType,4> supportedTypes =
346 {
Nina Drozd58ef2c62019-05-16 12:09:18 +0100347 DataType::Float32,
348 DataType::Signed32,
349 DataType::QuantisedAsymm8,
350 DataType::QuantisedSymm16
351 };
352
353 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
354 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100355}
356
357bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
358 const TensorInfo& output,
359 Optional<std::string&> reasonIfUnsupported) const
360{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100361 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
362 input.GetDataType(),
363 &TrueFunc<>,
364 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000365 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000366 &FalseFuncI32<>,
367 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100368 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
369 output.GetDataType(),
370 &FalseOutputFuncF16<>,
371 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000372 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000373 &FalseFuncI32<>,
374 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100375}
376
377bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
378 const TensorInfo& output,
379 Optional<std::string&> reasonIfUnsupported) const
380{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100381 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
382 input.GetDataType(),
383 &FalseInputFuncF16<>,
384 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000385 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000386 &FalseFuncI32<>,
387 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100388 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
389 output.GetDataType(),
390 &TrueFunc<>,
391 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000392 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000393 &FalseFuncI32<>,
394 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100395}
396
397bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
398 const TensorInfo& output,
399 const Convolution2dDescriptor& descriptor,
400 const TensorInfo& weights,
401 const Optional<TensorInfo>& biases,
402 Optional<std::string&> reasonIfUnsupported) const
403{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100404 bool supported = true;
405
406 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100407 std::array<DataType,4> supportedTypes = {
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100408 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100409 DataType::Float16,
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100410 DataType::QuantisedAsymm8,
411 DataType::QuantisedSymm16
412 };
413
414 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100415 "Reference convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100416
417 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100418 "Reference convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100419
420 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100421 "Reference convolution2d: weights is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100422
423 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100424 "Reference convolution2d: input and output types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100425
426 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100427 "Reference convolution2d: input and weights types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100428
429 if (biases.has_value())
430 {
Matthew Jackson252df3a2019-09-11 09:19:18 +0100431 std::array<DataType,3> biasesSupportedTypes = {
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100432 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100433 DataType::Float16,
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100434 DataType::Signed32
435 };
436 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100437 "Reference convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100438 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100439 ignore_unused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100440
441 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100442}
443
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000444bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
445 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000446 Optional<std::string&> reasonIfUnsupported) const
447{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100448 bool supported = true;
449
450 std::array<DataType,3> supportedTypes =
451 {
452 DataType::Float32,
453 DataType::QuantisedAsymm8,
454 DataType::QuantisedSymm16
455 };
456
457 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
458 "Reference debug: input type not supported");
459
460 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
461 "Reference debug: output type not supported");
462
463 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
464 "Reference debug: input and output types are mismatched");
465
466 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000467}
468
Aron Virginas-Tar73f66422019-09-23 19:11:59 +0100469bool RefLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
470 const TensorInfo& output,
471 const DepthToSpaceDescriptor& descriptor,
472 Optional<std::string&> reasonIfUnsupported) const
473{
474 ignore_unused(descriptor);
475 bool supported = true;
476
477 std::array<DataType,4> supportedTypes =
478 {
479 DataType::Float32,
480 DataType::Float16,
481 DataType::QuantisedAsymm8,
482 DataType::QuantisedSymm16
483 };
484
485 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
486 "Reference DepthToSpace: input type not supported");
487
488 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
489 "Reference DepthToSpace: output type not supported");
490
491 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
492 "Reference DepthToSpace: input and output types are mismatched");
493
494 return supported;
495}
496
arovir011c7c81b2018-10-08 11:34:28 +0100497bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
498 const TensorInfo& output,
499 const DepthwiseConvolution2dDescriptor& descriptor,
500 const TensorInfo& weights,
501 const Optional<TensorInfo>& biases,
502 Optional<std::string&> reasonIfUnsupported) const
503{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100504 bool supported = true;
505
506 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100507 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100508 {
509 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100510 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100511 DataType::QuantisedAsymm8,
512 DataType::QuantisedSymm16
513 };
514
515 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
516 "Reference DepthwiseConvolution2d: input is not a supported type.");
517
518 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
519 "Reference DepthwiseConvolution2d: output is not a supported type.");
520
521 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
522 "Reference DepthwiseConvolution2d: weights is not a supported type.");
523
524 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
525 "Reference DepthwiseConvolution2d: input and output types mismatched.");
526
527 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
528 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
529
530 if (biases.has_value())
531 {
Matthew Jackson252df3a2019-09-11 09:19:18 +0100532 std::array<DataType,3> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100533 {
534 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100535 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100536 DataType::Signed32
537 };
538 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
539 "Reference DepthwiseConvolution2d: biases is not a supported type.");
540 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100541 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100542
543 return supported;
544
arovir011c7c81b2018-10-08 11:34:28 +0100545}
546
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000547bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
548 const TensorInfo& output,
549 Optional<std::string&> reasonIfUnsupported) const
550{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100551 bool supported = true;
552
553 std::array<DataType,2> supportedInputTypes = {
554 DataType::QuantisedAsymm8,
555 DataType::QuantisedSymm16
556 };
557
558 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
559 "Reference dequantize: input type not supported.");
560
Mike Kelly4992c342019-08-14 11:33:11 +0100561 std::array<DataType,1> supportedOutputTypes = {
562 DataType::Float32
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100563 };
564
565 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
566 "Reference dequantize: output type not supported.");
567
568 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
569 "Reference dequantize: input and output shapes have different num total elements.");
570
571 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000572}
573
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000574bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
575 const armnn::TensorInfo& input1,
576 const armnn::DetectionPostProcessDescriptor& descriptor,
577 armnn::Optional<std::string&> reasonIfUnsupported) const
578{
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100579 bool supported = true;
580
Mike Kelly4992c342019-08-14 11:33:11 +0100581 std::array<DataType,3> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100582 {
583 DataType::Float32,
584 DataType::QuantisedAsymm8,
585 DataType::QuantisedSymm16
586 };
587
588 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
589 "Reference DetectionPostProcess: input 0 is not a supported type.");
590
591 supported &= CheckSupportRule(TypeAnyOf(input1, supportedInputTypes), reasonIfUnsupported,
592 "Reference DetectionPostProcess: input 1 is not a supported type.");
593
594 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000595}
596
Pablo Tellof0bd6832019-04-26 17:58:13 +0100597bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
598 const TensorInfo& output,
599 const DepthwiseConvolution2dDescriptor& descriptor,
600 const TensorInfo& weights,
601 const Optional<TensorInfo>& biases,
602 Optional<std::string&> reasonIfUnsupported) const
603{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100604 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100605}
606
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100607bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100608 const TensorInfo& input1,
609 const TensorInfo& output,
610 Optional<std::string&> reasonIfUnsupported) const
611{
Sadik Armagan2999a022019-04-09 14:20:12 +0100612 bool supported = true;
613
Matthew Jackson9bff1442019-09-12 09:08:23 +0100614 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +0100615 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100616 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +0100617 DataType::QuantisedAsymm8,
618 DataType::QuantisedSymm16
619 };
620
621 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
622 "Reference division: input 0 is not a supported type.");
623
624 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
625 "Reference division: input 1 is not a supported type.");
626
627 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
628 "Reference division: output is not a supported type.");
629
630 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
631 "Reference division: input 0 and Input 1 types are mismatched");
632
633 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
634 "Reference division: input and output types are mismatched");
635
636 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
637 "Reference division: shapes are not suitable for implicit broadcast.");
638
639 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100640}
641
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000642bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
643 const TensorInfo& input1,
644 const TensorInfo& output,
645 Optional<std::string&> reasonIfUnsupported) const
646{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100647 bool supported = true;
648
Matthew Jackson9bff1442019-09-12 09:08:23 +0100649 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100650 {
651 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100652 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100653 DataType::QuantisedAsymm8,
654 DataType::QuantisedSymm16
655 };
656
657 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
658 "Reference equal: input 0 is not a supported type.");
659
660 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
661 "Reference equal: input 1 is not a supported type.");
662
663 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
664 "Reference equal: input 0 and Input 1 types are mismatched");
665
666 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
667 "Reference equal: shapes are not suitable for implicit broadcast.");
668
669 return supported;
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000670}
671
arovir011c7c81b2018-10-08 11:34:28 +0100672bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
673 const FakeQuantizationDescriptor& descriptor,
674 Optional<std::string&> reasonIfUnsupported) const
675{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100676 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100677 bool supported = true;
678
679 std::array<DataType,1> supportedTypes =
680 {
681 DataType::Float32
682 };
683
684 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
685 "Reference fake quantization: input type not supported.");
686
687 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100688}
689
690bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
691 const TensorInfo& output,
692 Optional<std::string&> reasonIfUnsupported) const
693{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100694 ignore_unused(output);
James Conroy83735b12019-05-30 16:36:59 +0100695 bool supported = true;
696
Matthew Jackson9bff1442019-09-12 09:08:23 +0100697 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100698 {
James Conroyb40d7102019-06-04 12:32:09 +0100699 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100700 DataType::Float16,
James Conroyb40d7102019-06-04 12:32:09 +0100701 DataType::QuantisedSymm16
James Conroy83735b12019-05-30 16:36:59 +0100702 };
703
704 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
705 "Reference Floor: input type not supported.");
706
707 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
708 "Reference Floor: output type not supported.");
709
710 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100711}
712
713bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
714 const TensorInfo& output,
715 const TensorInfo& weights,
716 const TensorInfo& biases,
717 const FullyConnectedDescriptor& descriptor,
718 Optional<std::string&> reasonIfUnsupported) const
719{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100720 bool supported = true;
721
722 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100723 std::array<DataType,4> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +0100724 {
725 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100726 DataType::Float16,
Francis Murtagh46c09d02019-05-28 08:15:28 +0100727 DataType::QuantisedAsymm8,
728 DataType::QuantisedSymm16
729 };
730
731 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
732 "Reference Fully Connected: input type not supported.");
733
734 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
735 "Reference Fully Connected: output type not supported.");
736
737 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
738 "Reference Fully Connected: input and output types mismatched.");
739
740 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
741 "Reference Fully Connected: weights type not supported.");
742
743 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
744 "Reference Fully Connected: input and weight types mismatched.");
745
746 if (descriptor.m_BiasEnabled)
747 {
748 // Defined supported types for bias
Matthew Jackson252df3a2019-09-11 09:19:18 +0100749 std::array<DataType, 3>
Francis Murtagh46c09d02019-05-28 08:15:28 +0100750 supportedBiasTypes =
751 {
752 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100753 DataType::Float16,
Francis Murtagh46c09d02019-05-28 08:15:28 +0100754 DataType::Signed32
755 };
756
757 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
758 "Reference Fully Connected: bias type not supported.");
759
760 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
761 "Reference Fully Connected: bias and weight types mismatch.");
762
763 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
764 "Reference Fully Connected: bias type inferred from weights is incompatible.");
765
766 }
767
768 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100769}
770
narpra014951d842019-01-18 16:53:53 +0000771bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
772 const armnn::TensorInfo& input1,
773 const armnn::TensorInfo& output,
774 armnn::Optional<std::string&> reasonIfUnsupported) const
775{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100776 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +0100777 std::array<DataType,4> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100778 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100779 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100780 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100781 DataType::QuantisedAsymm8,
782 DataType::QuantisedSymm16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100783 };
784
785 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
786 "Reference Gather: input type not supported");
787
788 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
789 "Reference Gather: output type not supported");
790
791 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
792 "Reference Gather: indices (input1) type not supported");
793
794 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
795 "Reference Gather: input and output types not matching");
796
797 return supported;
narpra014951d842019-01-18 16:53:53 +0000798}
799
FrancisMurtagh878f0232018-12-19 10:56:15 +0000800bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
801 const TensorInfo& input1,
802 const TensorInfo& output,
803 Optional<std::string&> reasonIfUnsupported) const
804{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100805 bool supported = true;
806
Matthew Jackson9bff1442019-09-12 09:08:23 +0100807 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100808 {
809 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100810 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100811 DataType::QuantisedAsymm8,
812 DataType::QuantisedSymm16
813 };
814
815 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
816 "Reference greater: input 0 is not a supported type.");
817
818 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
819 "Reference greater: input 1 is not a supported type.");
820
821 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
822 "Reference greater: input 0 and Input 1 types are mismatched");
823
824 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
825 "Reference greater: shapes are not suitable for implicit broadcast.");
826
827 return supported;
FrancisMurtagh878f0232018-12-19 10:56:15 +0000828}
829
arovir011c7c81b2018-10-08 11:34:28 +0100830bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
831 Optional<std::string&> reasonIfUnsupported) const
832{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +0100833 return true;
arovir011c7c81b2018-10-08 11:34:28 +0100834}
835
Kevin May09ca49c2019-10-09 12:37:34 +0100836bool RefLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
837 const TensorInfo& output,
838 const InstanceNormalizationDescriptor& descriptor,
839 Optional<std::string&> reasonIfUnsupported) const
840{
841 ignore_unused(descriptor);
842 // Define supported types
843 std::array<DataType, 4> supportedTypes =
844 {
845 DataType::Float32,
846 DataType::Float16
847 };
848
849 bool supported = true;
850
851 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
852 "Reference Instance Normalization: input type not supported.");
853
854 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
855 "Reference Instance Normalization: output type not supported.");
856
857 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
858 "Reference Instance Normalization: input and output types mismatched.");
859
860 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
861 "Reference Instance Normalization: input and output shapes have different "
862 "num total elements.");
863
864 return supported;
865}
866
arovir011c7c81b2018-10-08 11:34:28 +0100867bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
868 const TensorInfo& output,
869 const L2NormalizationDescriptor& descriptor,
870 Optional<std::string&> reasonIfUnsupported) const
871{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100872 ignore_unused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100873 // Define supported types
Matthew Jackson252df3a2019-09-11 09:19:18 +0100874 std::array<DataType, 4> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100875 {
876 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100877 DataType::Float16,
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100878 DataType::QuantisedAsymm8,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100879 DataType::QuantisedSymm16
880 };
881
882 bool supported = true;
883
884 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
885 "Reference L2normalization: input type not supported.");
886
887 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
888 "Reference L2normalization: output type not supported.");
889
890 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
891 "Reference L2normalization: input and output types mismatched.");
892
893 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
894 "Reference L2normalization: input and output shapes have different "
895 "num total elements.");
896
897 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100898}
899
Aron Virginas-Tare662a942019-10-14 15:12:00 +0100900bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
901 const TensorInfo& output,
902 const LogSoftmaxDescriptor& descriptor,
903 Optional<std::string&> reasonIfUnsupported) const
904{
905 ignore_unused(descriptor);
906
907 std::array<DataType, 2> supportedTypes =
908 {
909 DataType::Float32,
910 DataType::Float16
911 };
912
913 bool supported = true;
914 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
915 "Reference LogSoftmax: input type not supported");
916
917 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
918 "Reference LogSoftmax: output type not supported");
919
920 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
921 "Reference LogSoftmax: input and output types do not match");
922
923 return supported;
924}
925
arovir011c7c81b2018-10-08 11:34:28 +0100926bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
927 const TensorInfo& outputStateIn,
928 const TensorInfo& cellStateIn,
929 const TensorInfo& scratchBuffer,
930 const TensorInfo& outputStateOut,
931 const TensorInfo& cellStateOut,
932 const TensorInfo& output,
933 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100934 const LstmInputParamsInfo& paramsInfo,
935 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +0100936{
telsoa01c577f2c2018-08-31 09:22:23 +0100937 ignore_unused(descriptor);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100938 ignore_unused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100939
940 bool supported = true;
941
942 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +0100943 DataType::Float32,
944 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100945 };
946
Jan Eilersd01a83c2019-07-03 18:20:40 +0100947 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100948 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
949 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100950 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
951 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100952 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
953 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100954 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
955 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100956 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
957 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100958 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
959 "Reference Lstm: input and cellStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100960 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
961 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +0100962 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +0100963 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100964 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100965 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100966 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100967 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100968 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100969 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100970 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100971 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100972 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100973 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100974 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100975 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100976 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100977 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100978 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100979 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100980 "Reference Lstm: input and OutputGateBias types are mismatched");
981 if (!descriptor.m_CifgEnabled)
982 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100983 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100984 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100985 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +0100986 reasonIfUnsupported,
987 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100988 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100989 "Reference Lstm: input and InputGateBias types are mismatched");
990 if (descriptor.m_PeepholeEnabled)
991 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100992 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +0100993 reasonIfUnsupported,
994 "Reference Lstm: input and CellToInputWeights types are mismatched");
995 }
996 }
997 if (descriptor.m_PeepholeEnabled)
998 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100999 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001000 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001001 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001002 "Reference Lstm: input and CellToOutputWeights types are mismatched");
1003 }
1004 if (descriptor.m_ProjectionEnabled)
1005 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001006 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001007 "Reference Lstm: input and mProjectionWeights types are mismatched");
1008 if (paramsInfo.m_ProjectionBias != nullptr)
1009 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001010 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +01001011 "Reference Lstm: input and ProjectionBias types are mismatched");
1012 }
1013 }
1014 if (descriptor.m_LayerNormEnabled)
1015 {
1016 if (!descriptor.m_CifgEnabled)
1017 {
Francis Murtaghbb590b42019-08-14 09:51:36 +01001018 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001019 reasonIfUnsupported,
1020 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1021 }
Francis Murtaghbb590b42019-08-14 09:51:36 +01001022 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001023 reasonIfUnsupported,
1024 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001025 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001026 reasonIfUnsupported,
1027 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +01001028 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +01001029 reasonIfUnsupported,
1030 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1031 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001032
1033 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +01001034}
1035
saoste012df12b32018-11-28 16:57:20 +00001036bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1037 const TensorInfo& input1,
1038 const TensorInfo& output,
1039 Optional<std::string&> reasonIfUnsupported) const
1040{
Sadik Armagan2999a022019-04-09 14:20:12 +01001041 bool supported = true;
1042
Matthew Jackson9bff1442019-09-12 09:08:23 +01001043 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001044 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001045 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001046 DataType::QuantisedAsymm8,
1047 DataType::QuantisedSymm16
1048 };
1049
1050 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1051 "Reference maximum: input 0 is not a supported type.");
1052
1053 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1054 "Reference maximum: input 1 is not a supported type.");
1055
1056 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1057 "Reference maximum: output is not a supported type.");
1058
1059 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1060 "Reference maximum: input 0 and Input 1 types are mismatched");
1061
1062 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1063 "Reference maximum: input and output types are mismatched");
1064
1065 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1066 "Reference maximum: shapes are not suitable for implicit broadcast.");
1067
1068 return supported;
saoste012df12b32018-11-28 16:57:20 +00001069}
1070
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001071bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1072 const TensorInfo& output,
1073 const MeanDescriptor& descriptor,
1074 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +01001075{
James Conroy4d1ff582019-06-10 17:06:39 +01001076 bool supported = true;
1077 std::string meanLayerStr = "Mean";
1078 std::string outputTensorStr = "output";
1079
Matthew Jackson252df3a2019-09-11 09:19:18 +01001080 std::array<DataType,4> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +01001081 {
1082 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001083 DataType::Float16,
James Conroyb80775f2019-06-11 11:25:30 +01001084 DataType::QuantisedAsymm8,
1085 DataType::QuantisedSymm16
James Conroy4d1ff582019-06-10 17:06:39 +01001086 };
1087
1088 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1089 "Reference Mean: input type not supported.");
1090
1091 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1092 "Reference Mean: input and output types are mismatched");
1093
1094 if (descriptor.m_KeepDims)
1095 {
1096 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1097 reasonIfUnsupported,
1098 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1099 output.GetNumDimensions(),
1100 meanLayerStr, outputTensorStr).data());
1101 }
1102 else if (descriptor.m_Axis.empty())
1103 {
1104 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1105 reasonIfUnsupported,
1106 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1107 meanLayerStr, outputTensorStr).data());
1108 }
1109 else
1110 {
1111 auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
1112
1113 if (outputDim > 0)
1114 {
1115 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1116 reasonIfUnsupported,
1117 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1118 meanLayerStr, outputTensorStr).data());
1119 }
1120 else
1121 {
1122 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1123 reasonIfUnsupported,
1124 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1125 meanLayerStr, outputTensorStr).data());
1126 }
1127 }
1128
1129 return supported;
narpra0132b90462018-09-13 11:07:48 +01001130}
1131
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001132bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +00001133 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +01001134 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001135 Optional<std::string&> reasonIfUnsupported) const
1136{
Jim Flynne242f2d2019-05-22 14:24:13 +01001137 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001138}
1139
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001140bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1141 const TensorInfo &output,
1142 Optional<std::string &> reasonIfUnsupported) const
1143{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001144 bool supported = true;
1145
1146 std::array<DataType,5> supportedTypes =
1147 {
1148 DataType::Float32,
1149 DataType::Float16,
1150 DataType::QuantisedAsymm8,
1151 DataType::QuantisedSymm16,
1152 DataType::Boolean
1153 };
1154
1155 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1156 "Reference MemCopy: input type not supported");
1157
1158 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1159 "Reference MemCopy: output type not supported");
1160
1161 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1162 "Reference MemCopy: input and output types are mismatched");
1163
1164 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001165}
1166
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001167bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1168 const TensorInfo& input1,
1169 const TensorInfo& output,
1170 Optional<std::string&> reasonIfUnsupported) const
1171{
Sadik Armagan2999a022019-04-09 14:20:12 +01001172 bool supported = true;
1173
Matthew Jackson9bff1442019-09-12 09:08:23 +01001174 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001175 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001176 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001177 DataType::QuantisedAsymm8,
1178 DataType::QuantisedSymm16
1179 };
1180
1181 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1182 "Reference minimum: input 0 is not a supported type.");
1183
1184 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1185 "Reference minimum: input 1 is not a supported type.");
1186
1187 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1188 "Reference minimum: output is not a supported type.");
1189
1190 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1191 "Reference minimum: input 0 and Input 1 types are mismatched");
1192
1193 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1194 "Reference minimum: input and output types are mismatched");
1195
1196 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1197 "Reference minimum: shapes are not suitable for implicit broadcast.");
1198
1199 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001200}
1201
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001202bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1203 const TensorInfo& input1,
1204 const TensorInfo& output,
1205 Optional<std::string&> reasonIfUnsupported) const
1206{
Sadik Armagan2999a022019-04-09 14:20:12 +01001207 bool supported = true;
1208
Matthew Jackson252df3a2019-09-11 09:19:18 +01001209 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001210 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001211 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001212 DataType::QuantisedAsymm8,
1213 DataType::QuantisedSymm16
1214 };
1215
1216 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1217 "Reference multiplication: input 0 is not a supported type.");
1218
1219 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1220 "Reference multiplication: input 1 is not a supported type.");
1221
1222 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1223 "Reference multiplication: output is not a supported type.");
1224
1225 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1226 "Reference multiplication: input 0 and Input 1 types are mismatched");
1227
1228 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1229 "Reference multiplication: input and output types are mismatched");
1230
1231 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1232 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1233
1234 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001235}
1236
1237bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1238 const TensorInfo& output,
1239 const NormalizationDescriptor& descriptor,
1240 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001241{
Nina Drozd661dfa72018-10-02 11:14:17 +01001242 ignore_unused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001243
1244 // Define supported types
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001245 std::array<DataType, 4> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001246 {
1247 DataType::Float16,
1248 DataType::Float32,
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001249 DataType::QuantisedAsymm8,
1250 DataType::QuantisedSymm16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001251 };
1252
1253 bool supported = true;
1254
1255 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1256 "Reference normalization: input type not supported.");
1257
1258 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1259 "Reference normalization: output type not supported.");
1260
1261 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1262 "Reference normalization: input and output shapes have different "
1263 "num total elements.");
1264
1265 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001266}
1267
1268bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
1269 Optional<std::string&> reasonIfUnsupported) const
1270{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001271 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001272}
1273
1274bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1275 const TensorInfo& output,
1276 const PadDescriptor& descriptor,
1277 Optional<std::string&> reasonIfUnsupported) const
1278{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001279 ignore_unused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001280 bool supported = true;
1281
1282 // Define supported output and inputs types.
Matthew Jackson252df3a2019-09-11 09:19:18 +01001283 std::array<DataType,4> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001284 {
1285 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001286 DataType::Float16,
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001287 DataType::QuantisedAsymm8,
1288 DataType::QuantisedSymm16
1289 };
1290
1291 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1292 "Reference pad: input is not a supported type.");
1293
1294 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1295 "Reference pad: output is not a supported type.");
1296
1297 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1298 "Reference pad: input and output types are mismatched.");
1299
1300 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01001301}
1302
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001303bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1304 const TensorInfo& output,
1305 const PermuteDescriptor& descriptor,
1306 Optional<std::string&> reasonIfUnsupported) const
1307{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001308 ignore_unused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001309 bool supported = true;
1310
1311 // Define supported output and inputs types.
1312 std::array<DataType,3> supportedTypes =
1313 {
1314 DataType::Float32,
1315 DataType::QuantisedAsymm8,
1316 DataType::QuantisedSymm16
1317 };
1318
1319 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1320 "Reference permute: input is not a supported type.");
1321
1322 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1323 "Reference permute: output is not a supported type.");
1324
1325 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1326 "Reference permute: input and output types are mismatched.");
1327
1328 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001329}
1330
1331bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1332 const TensorInfo& output,
1333 const Pooling2dDescriptor& descriptor,
1334 Optional<std::string&> reasonIfUnsupported) const
1335{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001336 ignore_unused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001337 bool supported = true;
1338
1339 // Define supported output and inputs types.
Matthew Jackson252df3a2019-09-11 09:19:18 +01001340 std::array<DataType,4> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001341 {
1342 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001343 DataType::Float16,
Teresa Charlin0434df62019-06-06 13:40:35 +01001344 DataType::QuantisedAsymm8,
1345 DataType::QuantisedSymm16
Teresa Charlina3b20472019-06-06 11:12:32 +01001346 };
1347
1348 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1349 "Reference poolind2d: input is not a supported type.");
1350
1351 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1352 "Reference poolind2d: output is not a supported type.");
1353
1354 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1355 "Reference poolind2d: input and output types are mismatched.");
1356
1357 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001358}
1359
Derek Lamberti5f400d62019-03-25 15:41:58 +00001360bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1361 const TensorInfo& output,
1362 Optional<std::string&> reasonIfUnsupported) const
1363{
1364 bool supported = true;
1365
1366 // Define supported output types.
Sadik Armagan1a816302019-07-29 17:16:40 +01001367 std::array<DataType,1> supportedInputTypes = {
Derek Lamberti5f400d62019-03-25 15:41:58 +00001368 DataType::Float32,
1369 };
1370
1371 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1372 "Reference quantize: input type not supported.");
1373
1374 // Define supported output types.
1375 std::array<DataType,2> supportedOutputTypes = {
1376 DataType::QuantisedAsymm8,
1377 DataType::QuantisedSymm16
1378 };
1379 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1380 "Reference quantize: output type not supported.");
1381
1382 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1383 "Reference quantize: input and output shapes have different num total elements.");
1384
1385 return supported;
1386}
1387
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001388bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001389 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001390 Optional<std::string&> reasonIfUnsupported) const
1391{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001392 ignore_unused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001393 // Define supported output types.
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001394 std::array<DataType,5> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001395 {
1396 DataType::Float32,
1397 DataType::Float16,
Narumol Prangnawarat0718ee92019-09-13 16:53:38 +01001398 DataType::Signed32,
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001399 DataType::QuantisedAsymm8,
1400 DataType::QuantisedSymm16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001401 };
1402 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1403 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001404}
1405
1406bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001407 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001408 Optional<std::string&> reasonIfUnsupported) const
1409{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001410 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001411 std::array<DataType,4> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001412 {
1413 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001414 DataType::Float16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001415 DataType::QuantisedAsymm8,
1416 DataType::QuantisedSymm16
1417 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001418
1419 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1420 "Reference ResizeBilinear: input type not supported");
1421
1422 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1423 "Reference ResizeBilinear: output type not supported");
1424
1425 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1426 "Reference ResizeBilinear: input and output types not matching");
1427
1428 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001429}
1430
Teresa Charlin970f43b2019-07-01 13:51:07 +01001431bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1432 const TensorInfo& output,
1433 const ResizeDescriptor& descriptor,
1434 Optional<std::string&> reasonIfUnsupported) const
1435{
1436 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001437 std::array<DataType,4> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001438 {
1439 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001440 DataType::Float16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001441 DataType::QuantisedAsymm8,
1442 DataType::QuantisedSymm16
1443 };
1444
1445 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1446 "Reference Resize: input type not supported");
1447
1448 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1449 "Reference Resize: output type not supported");
1450
1451 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1452 "Reference Resize: input and output types not matching");
1453
1454 return supported;
1455}
1456
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001457bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1458 const TensorInfo& output,
1459 Optional<std::string&> reasonIfUnsupported) const
1460{
nikraj010421e7f2019-06-14 09:40:34 +01001461 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001462 std::array<DataType,4> supportedTypes =
nikraj010421e7f2019-06-14 09:40:34 +01001463 {
1464 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001465 DataType::Float16,
nikraj0124d73212019-06-14 14:20:40 +01001466 DataType::QuantisedAsymm8,
1467 DataType::QuantisedSymm16
nikraj010421e7f2019-06-14 09:40:34 +01001468 };
1469
1470 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1471 "Reference rsqrt: input type not supported");
1472
1473 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1474 "Reference rsqrt: output type not supported");
1475
1476 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1477 "Reference rsqrt: input and output types not matching");
1478
1479 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1480 "Reference Rsqrt: input and output shapes have different number of total elements");
1481
1482 return supported;
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001483}
1484
Aron Virginas-Tar92b9f872019-09-17 17:27:04 +01001485bool RefLayerSupport::IsSliceSupported(const TensorInfo& input,
1486 const TensorInfo& output,
1487 const SliceDescriptor& descriptor,
1488 Optional<std::string&> reasonIfUnsupported) const
1489{
1490 ignore_unused(descriptor);
1491 bool supported = true;
1492
1493 std::array<DataType, 3> supportedTypes =
1494 {
1495 DataType::Float32,
1496 DataType::QuantisedAsymm8,
1497 DataType::QuantisedSymm16
1498 };
1499
1500 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1501 "Reference Slice: input type not supported");
1502
1503 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1504 "Reference Slice: output type not supported");
1505
1506 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1507 "Reference Slice: input and output types are mismatched");
1508
1509 return supported;
1510}
1511
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001512bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1513 const TensorInfo& output,
1514 const SoftmaxDescriptor& descriptor,
1515 Optional<std::string&> reasonIfUnsupported) const
1516{
1517 ignore_unused(output);
nikraj01248683f2019-05-29 16:46:50 +01001518 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001519 std::array<DataType,4> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01001520 {
1521 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001522 DataType::Float16,
nikraj01248683f2019-05-29 16:46:50 +01001523 DataType::QuantisedAsymm8,
1524 DataType::QuantisedSymm16
1525 };
1526
1527 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001528 "Reference Softmax: output type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001529
1530 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001531 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001532
1533 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Aron Virginas-Tare662a942019-10-14 15:12:00 +01001534 "Reference Softmax: input type not supported");
nikraj01248683f2019-05-29 16:46:50 +01001535
1536 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001537}
1538
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001539bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1540 const TensorInfo& output,
1541 const SpaceToBatchNdDescriptor& descriptor,
1542 Optional<std::string&> reasonIfUnsupported) const
1543{
1544 ignore_unused(output);
nikraj01120522a2019-05-31 11:33:07 +01001545 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001546 std::array<DataType,4> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01001547 {
1548 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001549 DataType::Float16,
nikraj01120522a2019-05-31 11:33:07 +01001550 DataType::QuantisedAsymm8,
1551 DataType::QuantisedSymm16
1552 };
1553
1554 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1555 "Reference SpaceToBatchNd: input type not supported");
1556
1557 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1558 "Reference SpaceToBatchNd: output type not supported");
1559
1560 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1561 "Reference SpaceToBatchNd: input and output types are mismatched");
1562
1563 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001564}
1565
Keith Davisa57eccb2019-06-14 17:33:22 +01001566bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01001567 const TensorInfo& output,
1568 const SpaceToDepthDescriptor& descriptor,
1569 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01001570{
1571
1572 ignore_unused(descriptor);
1573 bool supported = true;
1574
Matthew Jackson9bff1442019-09-12 09:08:23 +01001575 std::array<DataType,4> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01001576 {
1577 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001578 DataType::Float16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001579 DataType::QuantisedAsymm8,
James Conroyd2aa85e2019-07-01 17:12:40 +01001580 DataType::QuantisedSymm16
Keith Davisa57eccb2019-06-14 17:33:22 +01001581 };
1582
1583 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1584 "Reference SpaceToDepth: input type not supported");
1585
1586 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1587 "Reference SpaceToDepth: output type not supported");
1588
1589 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1590 "Reference SpaceToDepth: input and output types are mismatched");
1591
1592 return supported;
1593}
1594
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001595bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1596 const ViewsDescriptor& descriptor,
1597 Optional<std::string&> reasonIfUnsupported) const
1598{
1599 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001600 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001601 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001602 {
1603 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001604 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001605 DataType::QuantisedAsymm8,
1606 DataType::QuantisedSymm16
1607 };
1608
1609 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1610 "Reference splitter: input type not supported");
1611
1612 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001613}
1614
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001615bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1616 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1617 const ViewsDescriptor& descriptor,
1618 Optional<std::string&> reasonIfUnsupported) const
1619{
1620 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001621 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001622 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001623 {
1624 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001625 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001626 DataType::QuantisedAsymm8,
1627 DataType::QuantisedSymm16
1628 };
1629
1630 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1631 "Reference splitter: output type not supported");
1632 for (const TensorInfo output : outputs)
1633 {
1634 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1635 "Reference splitter: input type not supported");
1636
1637 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1638 "Reference splitter: input and output types mismatched.");
1639 }
1640
1641 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001642}
1643
Matthew Jackson81e601c2019-07-11 12:07:09 +01001644bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
1645 const TensorInfo& output,
1646 const StackDescriptor& descriptor,
1647 Optional<std::string&> reasonIfUnsupported) const
1648{
1649 ignore_unused(descriptor);
1650
1651 bool supported = true;
Matthew Jacksone69c3992019-09-09 14:31:21 +01001652 std::array<DataType,4> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01001653 {
1654 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01001655 DataType::Float16,
Matthew Jackson81e601c2019-07-11 12:07:09 +01001656 DataType::QuantisedAsymm8,
1657 DataType::QuantisedSymm16
1658 };
1659
1660 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1661 "Reference stack: output type not supported");
1662 for (const TensorInfo* input : inputs)
1663 {
1664 BOOST_ASSERT(input != nullptr);
1665 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
1666 "Reference stack: input type not supported");
1667
1668 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
1669 "Reference stack: input and output types mismatched.");
1670 }
1671
1672 return supported;
1673}
1674
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001675bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1676 const TensorInfo& output,
1677 const StridedSliceDescriptor& descriptor,
1678 Optional<std::string&> reasonIfUnsupported) const
1679{
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001680 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001681 bool supported = true;
1682
1683 std::array<DataType,3> supportedTypes =
1684 {
1685 DataType::Float32,
1686 DataType::QuantisedAsymm8,
1687 DataType::QuantisedSymm16
1688 };
1689
1690 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1691 "Reference StridedSlice: input type not supported");
1692
1693 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1694 "Reference StridedSlice: output type not supported");
1695
1696 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1697 "Reference StridedSlice: input and output types are mismatched");
1698
1699 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001700}
1701
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001702bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1703 const TensorInfo& input1,
1704 const TensorInfo& output,
1705 Optional<std::string&> reasonIfUnsupported) const
1706{
Sadik Armagan2999a022019-04-09 14:20:12 +01001707 bool supported = true;
1708
Matthew Jackson9bff1442019-09-12 09:08:23 +01001709 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001710 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001711 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001712 DataType::QuantisedAsymm8,
1713 DataType::QuantisedSymm16
1714 };
1715
1716 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1717 "Reference subtraction: input 0 is not a supported type.");
1718
1719 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1720 "Reference subtraction: input 1 is not a supported type.");
1721
1722 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1723 "Reference subtraction: output is not a supported type.");
1724
1725 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1726 "Reference subtraction: input 0 and Input 1 types are mismatched");
1727
1728 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1729 "Reference subtraction: input and output types are mismatched");
1730
1731 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1732 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1733
1734 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001735}
1736
Matteo Martincighab9e5252019-06-13 17:27:46 +01001737bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
1738 const TensorInfo& alpha,
1739 const TensorInfo& output,
1740 Optional<std::string&> reasonIfUnsupported) const
1741{
1742 bool supported = true;
1743
Matthew Jackson9bff1442019-09-12 09:08:23 +01001744 std::array<DataType, 4> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01001745 {
1746 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001747 DataType::Float16,
Matteo Martincighab9e5252019-06-13 17:27:46 +01001748 DataType::QuantisedAsymm8,
1749 DataType::QuantisedSymm16
1750 };
1751
1752 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1753 "PReLU: input is not a supported type.");
1754
1755 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
1756 "PReLU: alpha is not a supported type.");
1757
1758 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1759 "PReLU: output is not a supported type.");
1760
1761 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
1762 "PReLU: input, alpha and output types are mismatched");
1763
1764 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
1765 "PReLU: shapes are not suitable for implicit broadcast");
1766
1767 return supported;
1768}
1769
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001770bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
1771 const TensorInfo& output,
1772 const TransposeConvolution2dDescriptor& descriptor,
1773 const TensorInfo& weights,
1774 const Optional<TensorInfo>& biases,
1775 Optional<std::string&> reasonIfUnsupported) const
1776{
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001777 bool supported = true;
1778
Matthew Jackson252df3a2019-09-11 09:19:18 +01001779 std::array<DataType,4> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001780 {
1781 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001782 DataType::Float16,
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001783 DataType::QuantisedAsymm8,
1784 DataType::QuantisedSymm16
1785 };
1786
1787 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1788 "Reference TransposeConvolution2d: input is not a supported type.");
1789
1790 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1791 "Reference TransposeConvolution2d: output is not a supported type.");
1792
1793 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1794 "Reference TransposeConvolution2d: weights is not a supported type.");
1795
1796 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1797 "Reference TransposeConvolution2d: input and output types mismatched.");
1798
1799 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1800 "Reference TransposeConvolution2d: input and weights types mismatched.");
1801
1802 if (biases.has_value())
1803 {
Matthew Jackson252df3a2019-09-11 09:19:18 +01001804 std::array<DataType,3> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01001805 {
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001806 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001807 DataType::Float16,
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001808 DataType::Signed32
1809 };
1810 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1811 "Reference TransposeConvolution2d: biases is not a supported type.");
1812 }
1813
1814 return supported;
1815}
1816
arovir011c7c81b2018-10-08 11:34:28 +01001817} // namespace armnn