blob: 465d45cbae67a3df4373933291990e64add38aa8 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
telsoa014fcda012018-03-09 14:13:49 +00006#include "RefLayerSupport.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01007#include "RefBackendId.hpp"
David Beck3cc9a622018-10-12 10:38:31 +01008
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01009#include <DataLayoutIndexed.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000010#include <InternalTypes.hpp>
11#include <LayerSupportCommon.hpp>
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +010012
telsoa014fcda012018-03-09 14:13:49 +000013#include <armnn/Types.hpp>
Derek Lamberti50db4e82019-03-13 14:16:15 +000014#include <armnn/Descriptors.hpp>
telsoa014fcda012018-03-09 14:13:49 +000015
David Beck111b5d92018-11-12 14:59:37 +000016#include <backendsCommon/BackendRegistry.hpp>
Derek Lambertif674aa02019-08-01 15:56:25 +010017#include <backendsCommon/LayerSupportRules.hpp>
Francis Murtagh46c09d02019-05-28 08:15:28 +010018#include <backendsCommon/test/WorkloadTestUtils.hpp>
David Beck3e9e1152018-10-17 14:17:50 +010019
telsoa014fcda012018-03-09 14:13:49 +000020#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021
Derek Lamberti50db4e82019-03-13 14:16:15 +000022#include <vector>
23#include <algorithm>
24#include <array>
25
telsoa014fcda012018-03-09 14:13:49 +000026using namespace boost;
27
28namespace armnn
29{
30
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010031namespace
32{
33
34template<typename Float32Func, typename Uint8Func, typename ... Params>
35bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
36 DataType dataType,
37 Float32Func floatFuncPtr,
38 Uint8Func uint8FuncPtr,
39 Params&&... params)
40{
41 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
42 dataType,
43 &FalseFunc<Params...>,
44 floatFuncPtr,
45 uint8FuncPtr,
narpra01db2b1602019-01-23 15:23:11 +000046 &FalseFunc<Params...>,
kevmay012b4d88e2019-01-24 14:05:09 +000047 &FalseFunc<Params...>,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +010048 std::forward<Params>(params)...);
49}
50
51} // anonymous namespace
52
James Conroy4d1ff582019-06-10 17:06:39 +010053namespace
54{
55
56std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
57 unsigned int actual,
58 std::string& layerStr,
59 std::string& tensorName)
60{
61 std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
62 " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
63
64 return errorMsg;
65}
66
67} // anonymous namespace
Derek Lamberti50db4e82019-03-13 14:16:15 +000068
Sadik Armagan9199e582019-09-05 17:35:31 +010069bool RefLayerSupport::IsAbsSupported(const TensorInfo& input, const TensorInfo& output,
70 Optional<std::string&> reasonIfUnsupported) const
71{
72 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +010073 std::array<DataType,4> supportedTypes =
Sadik Armagan9199e582019-09-05 17:35:31 +010074 {
75 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +010076 DataType::Float16,
Sadik Armagan9199e582019-09-05 17:35:31 +010077 DataType::QuantisedAsymm8,
78 DataType::QuantisedSymm16
79 };
80
81 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
82 "Reference abs: input type not supported");
83
84 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
85 "Reference abs: output type not supported");
86
87 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
88 "Reference abs: input and output types not matching");
89
90 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
91 "Reference abs: input and output shapes have different number of total elements");
92
93 return supported;
94}
95
arovir011c7c81b2018-10-08 11:34:28 +010096bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
97 const TensorInfo& output,
98 const ActivationDescriptor& descriptor,
99 Optional<std::string&> reasonIfUnsupported) const
100{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000101 bool supported = true;
102
103 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100104 std::array<DataType,4> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000105 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100106 DataType::Float16,
Teresa Charlin18515e22019-04-24 10:17:46 +0100107 DataType::QuantisedAsymm8,
108 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000109 };
110
111 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
112 "Reference activation: input type not supported.");
113
114 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
115 "Reference activation: output type not supported.");
116
117 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
118 "Reference activation: input and output types mismatched.");
119
120 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
121 "Reference activation: input and output shapes are of different rank.");
122
123
124 struct ActivationFunctionSupported : public Rule
125 {
126 ActivationFunctionSupported(const ActivationDescriptor& desc)
127 {
128 switch(desc.m_Function)
129 {
130 case ActivationFunction::Abs:
131 case ActivationFunction::BoundedReLu:
132 case ActivationFunction::LeakyReLu:
133 case ActivationFunction::Linear:
134 case ActivationFunction::ReLu:
135 case ActivationFunction::Sigmoid:
136 case ActivationFunction::SoftReLu:
137 case ActivationFunction::Sqrt:
138 case ActivationFunction::Square:
139 case ActivationFunction::TanH:
140 {
141 m_Res = true;
142 break;
143 }
144 default:
145 {
146 m_Res = false;
147 break;
148 }
149 }
150 }
151 };
152
153 // Function is supported
154 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
155 "Reference activation: function not supported.");
156
157 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100158}
159
160bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
161 const TensorInfo& input1,
162 const TensorInfo& output,
163 Optional<std::string&> reasonIfUnsupported) const
164{
Derek Lamberti50db4e82019-03-13 14:16:15 +0000165 bool supported = true;
166
Matthew Jackson252df3a2019-09-11 09:19:18 +0100167 std::array<DataType,4> supportedTypes = {
Derek Lamberti50db4e82019-03-13 14:16:15 +0000168 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100169 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +0100170 DataType::QuantisedAsymm8,
171 DataType::QuantisedSymm16
Derek Lamberti50db4e82019-03-13 14:16:15 +0000172 };
173
174 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
175 "Reference addition: input 0 is not a supported type.");
176
177 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
178 "Reference addition: input 1 is not a supported type.");
179
180 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
181 "Reference addition: output is not a supported type.");
182
183 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
184 "Reference addition: input 0 and Input 1 types are mismatched");
185
186 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
187 "Reference addition: input and output types are mismatched");
188
189 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
190 "Reference addition: shapes are not suitable for implicit broadcast.");
191
192 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100193}
194
195bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
196 const TensorInfo& output,
197 const TensorInfo& mean,
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100198 const TensorInfo& variance,
arovir011c7c81b2018-10-08 11:34:28 +0100199 const TensorInfo& beta,
200 const TensorInfo& gamma,
201 const BatchNormalizationDescriptor& descriptor,
202 Optional<std::string&> reasonIfUnsupported) const
203{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100204 ignore_unused(descriptor);
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100205
Matthew Jackson9bff1442019-09-12 09:08:23 +0100206 std::array<DataType, 4> supportedTypes =
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100207 {
208 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100209 DataType::Float16,
Matteo Martincighf5507132019-06-04 10:59:47 +0100210 DataType::QuantisedAsymm8,
211 DataType::QuantisedSymm16
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100212 };
213
214 bool supported = true;
215
216 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
217 "Reference batch normalization: input is not a supported type.");
218
219 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
220 "Reference batch normalization: output is not a supported type.");
221
222 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
223 "Reference batch normalization: input and output types are mismatched");
224
225 supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
226 "Reference batch normalization: mean is not a supported type.");
227
228 supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
229 "Reference batch normalization: variance is not a supported type.");
230
231 supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
232 "Reference batch normalization: beta is not a supported type.");
233
234 supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
235 "Reference batch normalization: gamma is not a supported type.");
236
237 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100238}
239
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000240bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
241 const TensorInfo& output,
242 const BatchToSpaceNdDescriptor& descriptor,
243 Optional<std::string&> reasonIfUnsupported) const
244{
245 ignore_unused(descriptor);
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100246
247 bool supported = true;
248
249 std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
250 std::string inputTensorStr = "input";
251 std::string outputTensorStr = "output";
252
253 // Define supported types.
Matthew Jackson9bff1442019-09-12 09:08:23 +0100254 std::array<DataType,4> supportedTypes =
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100255 {
256 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100257 DataType::Float16,
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100258 DataType::QuantisedAsymm8,
259 DataType::QuantisedSymm16
260 };
261
262 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
263 "Reference BatchToSpaceNd: input type not supported.");
264
265 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
266 "Reference BatchToSpaceNd: output type not supported.");
267
268 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
269 "Reference BatchToSpaceNd: input and output types mismatched.");
270
271 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
272 reasonIfUnsupported,
273 CreateIncorrectDimensionsErrorMsg(4,
274 output.GetNumDimensions(),
275 batchToSpaceNdLayerStr,
276 outputTensorStr).data());
277
278 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
279 reasonIfUnsupported,
280 CreateIncorrectDimensionsErrorMsg(4,
281 input.GetNumDimensions(),
282 batchToSpaceNdLayerStr,
283 inputTensorStr).data());
284
285 return supported;
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +0000286}
287
Jim Flynn906f9462019-05-10 13:55:21 +0100288bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
289 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +0100290 const ConcatDescriptor& descriptor,
Jim Flynn906f9462019-05-10 13:55:21 +0100291 Optional<std::string&> reasonIfUnsupported) const
292{
Jim Flynne242f2d2019-05-22 14:24:13 +0100293 ignore_unused(descriptor);
294
295 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +0100296 std::array<DataType,4> supportedTypes =
Jim Flynne242f2d2019-05-22 14:24:13 +0100297 {
298 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100299 DataType::Float16,
Jim Flynne242f2d2019-05-22 14:24:13 +0100300 DataType::QuantisedAsymm8,
301 DataType::QuantisedSymm16
302 };
303
304 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
305 "Reference concatenation: output type not supported");
306 for (const TensorInfo* input : inputs)
307 {
Matthew Jackson81e601c2019-07-11 12:07:09 +0100308 BOOST_ASSERT(input != nullptr);
Jim Flynne242f2d2019-05-22 14:24:13 +0100309 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
310 "Reference concatenation: input type not supported");
311
312 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
313 "Reference concatenation: input and output types mismatched.");
314 }
315
316 return supported;
Jim Flynn906f9462019-05-10 13:55:21 +0100317}
318
arovir011c7c81b2018-10-08 11:34:28 +0100319bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
320 Optional<std::string&> reasonIfUnsupported) const
321{
Jim Flynne242f2d2019-05-22 14:24:13 +0100322 std::array<DataType,4> supportedTypes =
323 {
Nina Drozd58ef2c62019-05-16 12:09:18 +0100324 DataType::Float32,
325 DataType::Signed32,
326 DataType::QuantisedAsymm8,
327 DataType::QuantisedSymm16
328 };
329
330 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
331 "Reference constant: output is not a supported type.");
arovir011c7c81b2018-10-08 11:34:28 +0100332}
333
334bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
335 const TensorInfo& output,
336 Optional<std::string&> reasonIfUnsupported) const
337{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100338 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
339 input.GetDataType(),
340 &TrueFunc<>,
341 &FalseInputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000342 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000343 &FalseFuncI32<>,
344 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100345 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
346 output.GetDataType(),
347 &FalseOutputFuncF16<>,
348 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000349 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000350 &FalseFuncI32<>,
351 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100352}
353
354bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
355 const TensorInfo& output,
356 Optional<std::string&> reasonIfUnsupported) const
357{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100358 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
359 input.GetDataType(),
360 &FalseInputFuncF16<>,
361 &TrueFunc<>,
narpra01db2b1602019-01-23 15:23:11 +0000362 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000363 &FalseFuncI32<>,
364 &FalseFuncU8<>) &&
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100365 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
366 output.GetDataType(),
367 &TrueFunc<>,
368 &FalseOutputFuncF32<>,
narpra01db2b1602019-01-23 15:23:11 +0000369 &FalseFuncU8<>,
kevmay012b4d88e2019-01-24 14:05:09 +0000370 &FalseFuncI32<>,
371 &FalseFuncU8<>));
arovir011c7c81b2018-10-08 11:34:28 +0100372}
373
374bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
375 const TensorInfo& output,
376 const Convolution2dDescriptor& descriptor,
377 const TensorInfo& weights,
378 const Optional<TensorInfo>& biases,
379 Optional<std::string&> reasonIfUnsupported) const
380{
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100381 bool supported = true;
382
383 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100384 std::array<DataType,4> supportedTypes = {
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100385 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100386 DataType::Float16,
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100387 DataType::QuantisedAsymm8,
388 DataType::QuantisedSymm16
389 };
390
391 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100392 "Reference convolution2d: input is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100393
394 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100395 "Reference convolution2d: output is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100396
397 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100398 "Reference convolution2d: weights is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100399
400 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100401 "Reference convolution2d: input and output types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100402
403 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100404 "Reference convolution2d: input and weights types mismatched.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100405
406 if (biases.has_value())
407 {
Matthew Jackson252df3a2019-09-11 09:19:18 +0100408 std::array<DataType,3> biasesSupportedTypes = {
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100409 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100410 DataType::Float16,
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100411 DataType::Signed32
412 };
413 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
Les Belld7f29082019-05-30 09:08:51 +0100414 "Reference convolution2d: biases is not a supported type.");
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100415 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100416 ignore_unused(descriptor);
Mike Kelly2f80f6e2019-05-16 12:41:34 +0100417
418 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100419}
420
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000421bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
422 const TensorInfo& output,
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000423 Optional<std::string&> reasonIfUnsupported) const
424{
Narumol Prangnawarat47cfee92019-07-04 10:29:00 +0100425 bool supported = true;
426
427 std::array<DataType,3> supportedTypes =
428 {
429 DataType::Float32,
430 DataType::QuantisedAsymm8,
431 DataType::QuantisedSymm16
432 };
433
434 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
435 "Reference debug: input type not supported");
436
437 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
438 "Reference debug: output type not supported");
439
440 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
441 "Reference debug: input and output types are mismatched");
442
443 return supported;
Nattapat Chaimanowongcfdcadf2018-12-06 11:54:33 +0000444}
445
arovir011c7c81b2018-10-08 11:34:28 +0100446bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
447 const TensorInfo& output,
448 const DepthwiseConvolution2dDescriptor& descriptor,
449 const TensorInfo& weights,
450 const Optional<TensorInfo>& biases,
451 Optional<std::string&> reasonIfUnsupported) const
452{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100453 bool supported = true;
454
455 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100456 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100457 {
458 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100459 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100460 DataType::QuantisedAsymm8,
461 DataType::QuantisedSymm16
462 };
463
464 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
465 "Reference DepthwiseConvolution2d: input is not a supported type.");
466
467 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
468 "Reference DepthwiseConvolution2d: output is not a supported type.");
469
470 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
471 "Reference DepthwiseConvolution2d: weights is not a supported type.");
472
473 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
474 "Reference DepthwiseConvolution2d: input and output types mismatched.");
475
476 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
477 "Reference DepthwiseConvolution2d: input and weights types mismatched.");
478
479 if (biases.has_value())
480 {
Matthew Jackson252df3a2019-09-11 09:19:18 +0100481 std::array<DataType,3> biasesSupportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100482 {
483 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100484 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100485 DataType::Signed32
486 };
487 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
488 "Reference DepthwiseConvolution2d: biases is not a supported type.");
489 }
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100490 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100491
492 return supported;
493
arovir011c7c81b2018-10-08 11:34:28 +0100494}
495
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000496bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
497 const TensorInfo& output,
498 Optional<std::string&> reasonIfUnsupported) const
499{
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100500 bool supported = true;
501
502 std::array<DataType,2> supportedInputTypes = {
503 DataType::QuantisedAsymm8,
504 DataType::QuantisedSymm16
505 };
506
507 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
508 "Reference dequantize: input type not supported.");
509
Mike Kelly4992c342019-08-14 11:33:11 +0100510 std::array<DataType,1> supportedOutputTypes = {
511 DataType::Float32
Nattapat Chaimanowongafa4e3a2019-04-02 11:41:45 +0100512 };
513
514 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
515 "Reference dequantize: output type not supported.");
516
517 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
518 "Reference dequantize: input and output shapes have different num total elements.");
519
520 return supported;
Nattapat Chaimanowong8a54ac02019-03-29 15:25:04 +0000521}
522
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000523bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
524 const armnn::TensorInfo& input1,
525 const armnn::DetectionPostProcessDescriptor& descriptor,
526 armnn::Optional<std::string&> reasonIfUnsupported) const
527{
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100528 bool supported = true;
529
Mike Kelly4992c342019-08-14 11:33:11 +0100530 std::array<DataType,3> supportedInputTypes =
Aron Virginas-Tara37e1bd2019-06-06 16:08:30 +0100531 {
532 DataType::Float32,
533 DataType::QuantisedAsymm8,
534 DataType::QuantisedSymm16
535 };
536
537 supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
538 "Reference DetectionPostProcess: input 0 is not a supported type.");
539
540 supported &= CheckSupportRule(TypeAnyOf(input1, supportedInputTypes), reasonIfUnsupported,
541 "Reference DetectionPostProcess: input 1 is not a supported type.");
542
543 return supported;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000544}
545
Pablo Tellof0bd6832019-04-26 17:58:13 +0100546bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
547 const TensorInfo& output,
548 const DepthwiseConvolution2dDescriptor& descriptor,
549 const TensorInfo& weights,
550 const Optional<TensorInfo>& biases,
551 Optional<std::string&> reasonIfUnsupported) const
552{
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100553 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
Pablo Tellof0bd6832019-04-26 17:58:13 +0100554}
555
Aron Virginas-Taraece4ed2019-06-14 17:00:09 +0100556bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
arovir011c7c81b2018-10-08 11:34:28 +0100557 const TensorInfo& input1,
558 const TensorInfo& output,
559 Optional<std::string&> reasonIfUnsupported) const
560{
Sadik Armagan2999a022019-04-09 14:20:12 +0100561 bool supported = true;
562
Matthew Jackson9bff1442019-09-12 09:08:23 +0100563 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +0100564 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100565 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +0100566 DataType::QuantisedAsymm8,
567 DataType::QuantisedSymm16
568 };
569
570 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
571 "Reference division: input 0 is not a supported type.");
572
573 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
574 "Reference division: input 1 is not a supported type.");
575
576 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
577 "Reference division: output is not a supported type.");
578
579 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
580 "Reference division: input 0 and Input 1 types are mismatched");
581
582 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
583 "Reference division: input and output types are mismatched");
584
585 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
586 "Reference division: shapes are not suitable for implicit broadcast.");
587
588 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100589}
590
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000591bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
592 const TensorInfo& input1,
593 const TensorInfo& output,
594 Optional<std::string&> reasonIfUnsupported) const
595{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100596 bool supported = true;
597
Matthew Jackson9bff1442019-09-12 09:08:23 +0100598 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100599 {
600 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100601 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100602 DataType::QuantisedAsymm8,
603 DataType::QuantisedSymm16
604 };
605
606 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
607 "Reference equal: input 0 is not a supported type.");
608
609 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
610 "Reference equal: input 1 is not a supported type.");
611
612 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
613 "Reference equal: input 0 and Input 1 types are mismatched");
614
615 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
616 "Reference equal: shapes are not suitable for implicit broadcast.");
617
618 return supported;
FrancisMurtagh30cdfca2018-12-18 12:57:35 +0000619}
620
arovir011c7c81b2018-10-08 11:34:28 +0100621bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
622 const FakeQuantizationDescriptor& descriptor,
623 Optional<std::string&> reasonIfUnsupported) const
624{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100625 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100626 bool supported = true;
627
628 std::array<DataType,1> supportedTypes =
629 {
630 DataType::Float32
631 };
632
633 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
634 "Reference fake quantization: input type not supported.");
635
636 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100637}
638
639bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
640 const TensorInfo& output,
641 Optional<std::string&> reasonIfUnsupported) const
642{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100643 ignore_unused(output);
James Conroy83735b12019-05-30 16:36:59 +0100644 bool supported = true;
645
Matthew Jackson9bff1442019-09-12 09:08:23 +0100646 std::array<DataType,3> supportedTypes =
James Conroy83735b12019-05-30 16:36:59 +0100647 {
James Conroyb40d7102019-06-04 12:32:09 +0100648 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100649 DataType::Float16,
James Conroyb40d7102019-06-04 12:32:09 +0100650 DataType::QuantisedSymm16
James Conroy83735b12019-05-30 16:36:59 +0100651 };
652
653 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
654 "Reference Floor: input type not supported.");
655
656 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
657 "Reference Floor: output type not supported.");
658
659 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100660}
661
662bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
663 const TensorInfo& output,
664 const TensorInfo& weights,
665 const TensorInfo& biases,
666 const FullyConnectedDescriptor& descriptor,
667 Optional<std::string&> reasonIfUnsupported) const
668{
Francis Murtagh46c09d02019-05-28 08:15:28 +0100669 bool supported = true;
670
671 // Define supported types.
Matthew Jackson252df3a2019-09-11 09:19:18 +0100672 std::array<DataType,4> supportedTypes =
Francis Murtagh46c09d02019-05-28 08:15:28 +0100673 {
674 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100675 DataType::Float16,
Francis Murtagh46c09d02019-05-28 08:15:28 +0100676 DataType::QuantisedAsymm8,
677 DataType::QuantisedSymm16
678 };
679
680 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
681 "Reference Fully Connected: input type not supported.");
682
683 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
684 "Reference Fully Connected: output type not supported.");
685
686 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
687 "Reference Fully Connected: input and output types mismatched.");
688
689 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
690 "Reference Fully Connected: weights type not supported.");
691
692 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
693 "Reference Fully Connected: input and weight types mismatched.");
694
695 if (descriptor.m_BiasEnabled)
696 {
697 // Defined supported types for bias
Matthew Jackson252df3a2019-09-11 09:19:18 +0100698 std::array<DataType, 3>
Francis Murtagh46c09d02019-05-28 08:15:28 +0100699 supportedBiasTypes =
700 {
701 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100702 DataType::Float16,
Francis Murtagh46c09d02019-05-28 08:15:28 +0100703 DataType::Signed32
704 };
705
706 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
707 "Reference Fully Connected: bias type not supported.");
708
709 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
710 "Reference Fully Connected: bias and weight types mismatch.");
711
712 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
713 "Reference Fully Connected: bias type inferred from weights is incompatible.");
714
715 }
716
717 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100718}
719
narpra014951d842019-01-18 16:53:53 +0000720bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
721 const armnn::TensorInfo& input1,
722 const armnn::TensorInfo& output,
723 armnn::Optional<std::string&> reasonIfUnsupported) const
724{
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100725 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +0100726 std::array<DataType,4> supportedTypes =
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100727 {
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100728 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100729 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100730 DataType::QuantisedAsymm8,
731 DataType::QuantisedSymm16
Ellen Norris-Thompsone0dbedf2019-06-24 09:23:38 +0100732 };
733
734 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
735 "Reference Gather: input type not supported");
736
737 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
738 "Reference Gather: output type not supported");
739
740 supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
741 "Reference Gather: indices (input1) type not supported");
742
743 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
744 "Reference Gather: input and output types not matching");
745
746 return supported;
narpra014951d842019-01-18 16:53:53 +0000747}
748
FrancisMurtagh878f0232018-12-19 10:56:15 +0000749bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
750 const TensorInfo& input1,
751 const TensorInfo& output,
752 Optional<std::string&> reasonIfUnsupported) const
753{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100754 bool supported = true;
755
Matthew Jackson9bff1442019-09-12 09:08:23 +0100756 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100757 {
758 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100759 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +0100760 DataType::QuantisedAsymm8,
761 DataType::QuantisedSymm16
762 };
763
764 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
765 "Reference greater: input 0 is not a supported type.");
766
767 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
768 "Reference greater: input 1 is not a supported type.");
769
770 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
771 "Reference greater: input 0 and Input 1 types are mismatched");
772
773 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
774 "Reference greater: shapes are not suitable for implicit broadcast.");
775
776 return supported;
FrancisMurtagh878f0232018-12-19 10:56:15 +0000777}
778
arovir011c7c81b2018-10-08 11:34:28 +0100779bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
780 Optional<std::string&> reasonIfUnsupported) const
781{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +0100782 return true;
arovir011c7c81b2018-10-08 11:34:28 +0100783}
784
785bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
786 const TensorInfo& output,
787 const L2NormalizationDescriptor& descriptor,
788 Optional<std::string&> reasonIfUnsupported) const
789{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100790 ignore_unused(descriptor);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100791 // Define supported types
Matthew Jackson252df3a2019-09-11 09:19:18 +0100792 std::array<DataType, 4> supportedTypes =
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100793 {
794 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100795 DataType::Float16,
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100796 DataType::QuantisedAsymm8,
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100797 DataType::QuantisedSymm16
798 };
799
800 bool supported = true;
801
802 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
803 "Reference L2normalization: input type not supported.");
804
805 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
806 "Reference L2normalization: output type not supported.");
807
808 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
809 "Reference L2normalization: input and output types mismatched.");
810
811 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
812 "Reference L2normalization: input and output shapes have different "
813 "num total elements.");
814
815 return supported;
arovir011c7c81b2018-10-08 11:34:28 +0100816}
817
818bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
819 const TensorInfo& outputStateIn,
820 const TensorInfo& cellStateIn,
821 const TensorInfo& scratchBuffer,
822 const TensorInfo& outputStateOut,
823 const TensorInfo& cellStateOut,
824 const TensorInfo& output,
825 const LstmDescriptor& descriptor,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100826 const LstmInputParamsInfo& paramsInfo,
827 Optional<std::string&> reasonIfUnsupported) const
arovir011c7c81b2018-10-08 11:34:28 +0100828{
telsoa01c577f2c2018-08-31 09:22:23 +0100829 ignore_unused(descriptor);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100830 ignore_unused(paramsInfo);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100831
832 bool supported = true;
833
834 std::array<DataType,2> supportedTypes = {
Conor Kennedyb9971c92019-05-07 07:14:23 +0100835 DataType::Float32,
836 DataType::QuantisedSymm16
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100837 };
838
Jan Eilersd01a83c2019-07-03 18:20:40 +0100839 // check inputs and outputs
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100840 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
841 "Reference Lstm: input is not a supported type.");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100842 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
843 "Reference Lstm: input and outputStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100844 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
845 "Reference Lstm: input and cellStateIn types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100846 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
847 "Reference Lstm: input and scratchBuffer types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100848 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
849 "Reference Lstm: input and outputStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100850 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
851 "Reference Lstm: input and cellStateOut types are mismatched");
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100852 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
853 "Reference Lstm: input and output types are mismatched");
Jan Eilersd01a83c2019-07-03 18:20:40 +0100854 // check layer parameters
Francis Murtaghbb590b42019-08-14 09:51:36 +0100855 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100856 "Reference Lstm: input and InputToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100857 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100858 "Reference Lstm: input and InputToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100859 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100860 "Reference Lstm: input and InputToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100861 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100862 "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100863 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToCellWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100864 "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100865 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100866 "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100867 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100868 "Reference Lstm: input and ForgetGateBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100869 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100870 "Reference Lstm: input and CellBias types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100871 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100872 "Reference Lstm: input and OutputGateBias types are mismatched");
873 if (!descriptor.m_CifgEnabled)
874 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100875 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputToInputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100876 "Reference Lstm: input and InputToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100877 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetRecurrentToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +0100878 reasonIfUnsupported,
879 "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100880 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100881 "Reference Lstm: input and InputGateBias types are mismatched");
882 if (descriptor.m_PeepholeEnabled)
883 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100884 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToInputWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +0100885 reasonIfUnsupported,
886 "Reference Lstm: input and CellToInputWeights types are mismatched");
887 }
888 }
889 if (descriptor.m_PeepholeEnabled)
890 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100891 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToForgetWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100892 "Reference Lstm: input and CellToForgetWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100893 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellToOutputWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100894 "Reference Lstm: input and CellToOutputWeights types are mismatched");
895 }
896 if (descriptor.m_ProjectionEnabled)
897 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100898 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionWeights()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100899 "Reference Lstm: input and mProjectionWeights types are mismatched");
900 if (paramsInfo.m_ProjectionBias != nullptr)
901 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100902 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
Jan Eilersd01a83c2019-07-03 18:20:40 +0100903 "Reference Lstm: input and ProjectionBias types are mismatched");
904 }
905 }
906 if (descriptor.m_LayerNormEnabled)
907 {
908 if (!descriptor.m_CifgEnabled)
909 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100910 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +0100911 reasonIfUnsupported,
912 "Reference Lstm: input and InputLayerNormWeights types are mismatched");
913 }
Francis Murtaghbb590b42019-08-14 09:51:36 +0100914 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +0100915 reasonIfUnsupported,
916 "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100917 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +0100918 reasonIfUnsupported,
919 "Reference Lstm: input and CellLayerNormWeights types are mismatched");
Francis Murtaghbb590b42019-08-14 09:51:36 +0100920 supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputLayerNormWeights()),
Jan Eilersd01a83c2019-07-03 18:20:40 +0100921 reasonIfUnsupported,
922 "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
923 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100924
925 return supported;
telsoa01c577f2c2018-08-31 09:22:23 +0100926}
927
saoste012df12b32018-11-28 16:57:20 +0000928bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
929 const TensorInfo& input1,
930 const TensorInfo& output,
931 Optional<std::string&> reasonIfUnsupported) const
932{
Sadik Armagan2999a022019-04-09 14:20:12 +0100933 bool supported = true;
934
Matthew Jackson9bff1442019-09-12 09:08:23 +0100935 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +0100936 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100937 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +0100938 DataType::QuantisedAsymm8,
939 DataType::QuantisedSymm16
940 };
941
942 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
943 "Reference maximum: input 0 is not a supported type.");
944
945 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
946 "Reference maximum: input 1 is not a supported type.");
947
948 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
949 "Reference maximum: output is not a supported type.");
950
951 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
952 "Reference maximum: input 0 and Input 1 types are mismatched");
953
954 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
955 "Reference maximum: input and output types are mismatched");
956
957 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
958 "Reference maximum: shapes are not suitable for implicit broadcast.");
959
960 return supported;
saoste012df12b32018-11-28 16:57:20 +0000961}
962
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +0100963bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
964 const TensorInfo& output,
965 const MeanDescriptor& descriptor,
966 Optional<std::string&> reasonIfUnsupported) const
narpra0132b90462018-09-13 11:07:48 +0100967{
James Conroy4d1ff582019-06-10 17:06:39 +0100968 bool supported = true;
969 std::string meanLayerStr = "Mean";
970 std::string outputTensorStr = "output";
971
Matthew Jackson252df3a2019-09-11 09:19:18 +0100972 std::array<DataType,4> supportedTypes =
James Conroy4d1ff582019-06-10 17:06:39 +0100973 {
974 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +0100975 DataType::Float16,
James Conroyb80775f2019-06-11 11:25:30 +0100976 DataType::QuantisedAsymm8,
977 DataType::QuantisedSymm16
James Conroy4d1ff582019-06-10 17:06:39 +0100978 };
979
980 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
981 "Reference Mean: input type not supported.");
982
983 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
984 "Reference Mean: input and output types are mismatched");
985
986 if (descriptor.m_KeepDims)
987 {
988 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
989 reasonIfUnsupported,
990 CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
991 output.GetNumDimensions(),
992 meanLayerStr, outputTensorStr).data());
993 }
994 else if (descriptor.m_Axis.empty())
995 {
996 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
997 reasonIfUnsupported,
998 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
999 meanLayerStr, outputTensorStr).data());
1000 }
1001 else
1002 {
1003 auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
1004
1005 if (outputDim > 0)
1006 {
1007 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1008 reasonIfUnsupported,
1009 CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1010 meanLayerStr, outputTensorStr).data());
1011 }
1012 else
1013 {
1014 supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1015 reasonIfUnsupported,
1016 CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1017 meanLayerStr, outputTensorStr).data());
1018 }
1019 }
1020
1021 return supported;
narpra0132b90462018-09-13 11:07:48 +01001022}
1023
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001024bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +00001025 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +01001026 const MergerDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001027 Optional<std::string&> reasonIfUnsupported) const
1028{
Jim Flynne242f2d2019-05-22 14:24:13 +01001029 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001030}
1031
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001032bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1033 const TensorInfo &output,
1034 Optional<std::string &> reasonIfUnsupported) const
1035{
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001036 bool supported = true;
1037
1038 std::array<DataType,5> supportedTypes =
1039 {
1040 DataType::Float32,
1041 DataType::Float16,
1042 DataType::QuantisedAsymm8,
1043 DataType::QuantisedSymm16,
1044 DataType::Boolean
1045 };
1046
1047 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1048 "Reference MemCopy: input type not supported");
1049
1050 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1051 "Reference MemCopy: output type not supported");
1052
1053 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1054 "Reference MemCopy: input and output types are mismatched");
1055
1056 return supported;
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001057}
1058
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001059bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1060 const TensorInfo& input1,
1061 const TensorInfo& output,
1062 Optional<std::string&> reasonIfUnsupported) const
1063{
Sadik Armagan2999a022019-04-09 14:20:12 +01001064 bool supported = true;
1065
Matthew Jackson9bff1442019-09-12 09:08:23 +01001066 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001067 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001068 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001069 DataType::QuantisedAsymm8,
1070 DataType::QuantisedSymm16
1071 };
1072
1073 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1074 "Reference minimum: input 0 is not a supported type.");
1075
1076 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1077 "Reference minimum: input 1 is not a supported type.");
1078
1079 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1080 "Reference minimum: output is not a supported type.");
1081
1082 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1083 "Reference minimum: input 0 and Input 1 types are mismatched");
1084
1085 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1086 "Reference minimum: input and output types are mismatched");
1087
1088 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1089 "Reference minimum: shapes are not suitable for implicit broadcast.");
1090
1091 return supported;
Éanna Ó Catháin20e58802018-12-04 10:29:06 +00001092}
1093
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001094bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1095 const TensorInfo& input1,
1096 const TensorInfo& output,
1097 Optional<std::string&> reasonIfUnsupported) const
1098{
Sadik Armagan2999a022019-04-09 14:20:12 +01001099 bool supported = true;
1100
Matthew Jackson252df3a2019-09-11 09:19:18 +01001101 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001102 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001103 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001104 DataType::QuantisedAsymm8,
1105 DataType::QuantisedSymm16
1106 };
1107
1108 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1109 "Reference multiplication: input 0 is not a supported type.");
1110
1111 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1112 "Reference multiplication: input 1 is not a supported type.");
1113
1114 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1115 "Reference multiplication: output is not a supported type.");
1116
1117 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1118 "Reference multiplication: input 0 and Input 1 types are mismatched");
1119
1120 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1121 "Reference multiplication: input and output types are mismatched");
1122
1123 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1124 "Reference multiplication: shapes are not suitable for implicit broadcast.");
1125
1126 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001127}
1128
1129bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1130 const TensorInfo& output,
1131 const NormalizationDescriptor& descriptor,
1132 Optional<std::string&> reasonIfUnsupported) const
Nina Drozd661dfa72018-10-02 11:14:17 +01001133{
Nina Drozd661dfa72018-10-02 11:14:17 +01001134 ignore_unused(descriptor);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001135
1136 // Define supported types
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001137 std::array<DataType, 4> supportedTypes =
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001138 {
1139 DataType::Float16,
1140 DataType::Float32,
Matteo Martincigh6aeb7712019-06-05 17:23:29 +01001141 DataType::QuantisedAsymm8,
1142 DataType::QuantisedSymm16
Matteo Martincigh2fc70c52019-06-05 14:12:48 +01001143 };
1144
1145 bool supported = true;
1146
1147 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1148 "Reference normalization: input type not supported.");
1149
1150 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1151 "Reference normalization: output type not supported.");
1152
1153 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1154 "Reference normalization: input and output shapes have different "
1155 "num total elements.");
1156
1157 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001158}
1159
1160bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
1161 Optional<std::string&> reasonIfUnsupported) const
1162{
Narumol Prangnawaratb6441e42019-06-04 11:22:00 +01001163 return true;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001164}
1165
1166bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1167 const TensorInfo& output,
1168 const PadDescriptor& descriptor,
1169 Optional<std::string&> reasonIfUnsupported) const
1170{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001171 ignore_unused(descriptor);
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001172 bool supported = true;
1173
1174 // Define supported output and inputs types.
Matthew Jackson252df3a2019-09-11 09:19:18 +01001175 std::array<DataType,4> supportedTypes =
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001176 {
1177 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001178 DataType::Float16,
Narumol Prangnawarate6eaf662019-07-08 08:57:17 +01001179 DataType::QuantisedAsymm8,
1180 DataType::QuantisedSymm16
1181 };
1182
1183 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1184 "Reference pad: input is not a supported type.");
1185
1186 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1187 "Reference pad: output is not a supported type.");
1188
1189 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1190 "Reference pad: input and output types are mismatched.");
1191
1192 return supported;
Nina Drozd661dfa72018-10-02 11:14:17 +01001193}
1194
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001195bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1196 const TensorInfo& output,
1197 const PermuteDescriptor& descriptor,
1198 Optional<std::string&> reasonIfUnsupported) const
1199{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001200 ignore_unused(descriptor);
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +01001201 bool supported = true;
1202
1203 // Define supported output and inputs types.
1204 std::array<DataType,3> supportedTypes =
1205 {
1206 DataType::Float32,
1207 DataType::QuantisedAsymm8,
1208 DataType::QuantisedSymm16
1209 };
1210
1211 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1212 "Reference permute: input is not a supported type.");
1213
1214 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1215 "Reference permute: output is not a supported type.");
1216
1217 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1218 "Reference permute: input and output types are mismatched.");
1219
1220 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001221}
1222
1223bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1224 const TensorInfo& output,
1225 const Pooling2dDescriptor& descriptor,
1226 Optional<std::string&> reasonIfUnsupported) const
1227{
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001228 ignore_unused(descriptor);
Teresa Charlina3b20472019-06-06 11:12:32 +01001229 bool supported = true;
1230
1231 // Define supported output and inputs types.
Matthew Jackson252df3a2019-09-11 09:19:18 +01001232 std::array<DataType,4> supportedTypes =
Teresa Charlina3b20472019-06-06 11:12:32 +01001233 {
1234 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001235 DataType::Float16,
Teresa Charlin0434df62019-06-06 13:40:35 +01001236 DataType::QuantisedAsymm8,
1237 DataType::QuantisedSymm16
Teresa Charlina3b20472019-06-06 11:12:32 +01001238 };
1239
1240 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1241 "Reference poolind2d: input is not a supported type.");
1242
1243 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1244 "Reference poolind2d: output is not a supported type.");
1245
1246 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1247 "Reference poolind2d: input and output types are mismatched.");
1248
1249 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001250}
1251
Derek Lamberti5f400d62019-03-25 15:41:58 +00001252bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1253 const TensorInfo& output,
1254 Optional<std::string&> reasonIfUnsupported) const
1255{
1256 bool supported = true;
1257
1258 // Define supported output types.
Sadik Armagan1a816302019-07-29 17:16:40 +01001259 std::array<DataType,1> supportedInputTypes = {
Derek Lamberti5f400d62019-03-25 15:41:58 +00001260 DataType::Float32,
1261 };
1262
1263 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1264 "Reference quantize: input type not supported.");
1265
1266 // Define supported output types.
1267 std::array<DataType,2> supportedOutputTypes = {
1268 DataType::QuantisedAsymm8,
1269 DataType::QuantisedSymm16
1270 };
1271 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1272 "Reference quantize: output type not supported.");
1273
1274 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1275 "Reference quantize: input and output shapes have different num total elements.");
1276
1277 return supported;
1278}
1279
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001280bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001281 const ReshapeDescriptor& descriptor,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001282 Optional<std::string&> reasonIfUnsupported) const
1283{
Matteo Martincigh992d6dc2019-01-10 17:34:20 +00001284 ignore_unused(descriptor);
Nina Drozd2f2778f2019-05-27 10:37:05 +01001285 // Define supported output types.
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001286 std::array<DataType,4> supportedOutputTypes =
Nina Drozd2f2778f2019-05-27 10:37:05 +01001287 {
1288 DataType::Float32,
1289 DataType::Float16,
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001290 DataType::QuantisedAsymm8,
1291 DataType::QuantisedSymm16
Nina Drozd2f2778f2019-05-27 10:37:05 +01001292 };
1293 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1294 "Reference reshape: input type not supported.");
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001295}
1296
1297bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
Sadik Armaganc625f002018-12-17 11:32:16 +00001298 const TensorInfo& output,
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001299 Optional<std::string&> reasonIfUnsupported) const
1300{
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001301 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001302 std::array<DataType,4> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001303 {
1304 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001305 DataType::Float16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001306 DataType::QuantisedAsymm8,
1307 DataType::QuantisedSymm16
1308 };
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +01001309
1310 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1311 "Reference ResizeBilinear: input type not supported");
1312
1313 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1314 "Reference ResizeBilinear: output type not supported");
1315
1316 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1317 "Reference ResizeBilinear: input and output types not matching");
1318
1319 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001320}
1321
Teresa Charlin970f43b2019-07-01 13:51:07 +01001322bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1323 const TensorInfo& output,
1324 const ResizeDescriptor& descriptor,
1325 Optional<std::string&> reasonIfUnsupported) const
1326{
1327 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001328 std::array<DataType,4> supportedTypes =
Teresa Charlin970f43b2019-07-01 13:51:07 +01001329 {
1330 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001331 DataType::Float16,
Teresa Charlin970f43b2019-07-01 13:51:07 +01001332 DataType::QuantisedAsymm8,
1333 DataType::QuantisedSymm16
1334 };
1335
1336 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1337 "Reference Resize: input type not supported");
1338
1339 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1340 "Reference Resize: output type not supported");
1341
1342 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1343 "Reference Resize: input and output types not matching");
1344
1345 return supported;
1346}
1347
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001348bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1349 const TensorInfo& output,
1350 Optional<std::string&> reasonIfUnsupported) const
1351{
nikraj010421e7f2019-06-14 09:40:34 +01001352 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001353 std::array<DataType,4> supportedTypes =
nikraj010421e7f2019-06-14 09:40:34 +01001354 {
1355 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001356 DataType::Float16,
nikraj0124d73212019-06-14 14:20:40 +01001357 DataType::QuantisedAsymm8,
1358 DataType::QuantisedSymm16
nikraj010421e7f2019-06-14 09:40:34 +01001359 };
1360
1361 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1362 "Reference rsqrt: input type not supported");
1363
1364 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1365 "Reference rsqrt: output type not supported");
1366
1367 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1368 "Reference rsqrt: input and output types not matching");
1369
1370 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1371 "Reference Rsqrt: input and output shapes have different number of total elements");
1372
1373 return supported;
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +00001374}
1375
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001376bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1377 const TensorInfo& output,
1378 const SoftmaxDescriptor& descriptor,
1379 Optional<std::string&> reasonIfUnsupported) const
1380{
1381 ignore_unused(output);
nikraj01248683f2019-05-29 16:46:50 +01001382 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001383 std::array<DataType,4> supportedTypes =
nikraj01248683f2019-05-29 16:46:50 +01001384 {
1385 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001386 DataType::Float16,
nikraj01248683f2019-05-29 16:46:50 +01001387 DataType::QuantisedAsymm8,
1388 DataType::QuantisedSymm16
1389 };
1390
1391 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1392 "Reference concatenation: output type not supported");
1393
1394 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1395 "Reference concatenation: input type not supported");
1396
1397 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1398 "Reference concatenation: input type not supported");
1399
1400 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001401}
1402
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001403bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1404 const TensorInfo& output,
1405 const SpaceToBatchNdDescriptor& descriptor,
1406 Optional<std::string&> reasonIfUnsupported) const
1407{
1408 ignore_unused(output);
nikraj01120522a2019-05-31 11:33:07 +01001409 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001410 std::array<DataType,4> supportedTypes =
nikraj01120522a2019-05-31 11:33:07 +01001411 {
1412 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001413 DataType::Float16,
nikraj01120522a2019-05-31 11:33:07 +01001414 DataType::QuantisedAsymm8,
1415 DataType::QuantisedSymm16
1416 };
1417
1418 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1419 "Reference SpaceToBatchNd: input type not supported");
1420
1421 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1422 "Reference SpaceToBatchNd: output type not supported");
1423
1424 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1425 "Reference SpaceToBatchNd: input and output types are mismatched");
1426
1427 return supported;
Nattapat Chaimanowong3ea76d52018-11-09 14:10:38 +00001428}
1429
Keith Davisa57eccb2019-06-14 17:33:22 +01001430bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
Keith Davis51910332019-06-26 15:28:43 +01001431 const TensorInfo& output,
1432 const SpaceToDepthDescriptor& descriptor,
1433 Optional<std::string&> reasonIfUnsupported) const
Keith Davisa57eccb2019-06-14 17:33:22 +01001434{
1435
1436 ignore_unused(descriptor);
1437 bool supported = true;
1438
Matthew Jackson9bff1442019-09-12 09:08:23 +01001439 std::array<DataType,4> supportedTypes =
Keith Davisa57eccb2019-06-14 17:33:22 +01001440 {
1441 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001442 DataType::Float16,
Keith Davisa57eccb2019-06-14 17:33:22 +01001443 DataType::QuantisedAsymm8,
James Conroyd2aa85e2019-07-01 17:12:40 +01001444 DataType::QuantisedSymm16
Keith Davisa57eccb2019-06-14 17:33:22 +01001445 };
1446
1447 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1448 "Reference SpaceToDepth: input type not supported");
1449
1450 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1451 "Reference SpaceToDepth: output type not supported");
1452
1453 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1454 "Reference SpaceToDepth: input and output types are mismatched");
1455
1456 return supported;
1457}
1458
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001459bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1460 const ViewsDescriptor& descriptor,
1461 Optional<std::string&> reasonIfUnsupported) const
1462{
1463 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001464 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001465 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001466 {
1467 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001468 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001469 DataType::QuantisedAsymm8,
1470 DataType::QuantisedSymm16
1471 };
1472
1473 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1474 "Reference splitter: input type not supported");
1475
1476 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001477}
1478
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001479bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1480 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1481 const ViewsDescriptor& descriptor,
1482 Optional<std::string&> reasonIfUnsupported) const
1483{
1484 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001485 bool supported = true;
Matthew Jackson9bff1442019-09-12 09:08:23 +01001486 std::array<DataType,4> supportedTypes =
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001487 {
1488 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001489 DataType::Float16,
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001490 DataType::QuantisedAsymm8,
1491 DataType::QuantisedSymm16
1492 };
1493
1494 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1495 "Reference splitter: output type not supported");
1496 for (const TensorInfo output : outputs)
1497 {
1498 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1499 "Reference splitter: input type not supported");
1500
1501 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1502 "Reference splitter: input and output types mismatched.");
1503 }
1504
1505 return supported;
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +01001506}
1507
Matthew Jackson81e601c2019-07-11 12:07:09 +01001508bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
1509 const TensorInfo& output,
1510 const StackDescriptor& descriptor,
1511 Optional<std::string&> reasonIfUnsupported) const
1512{
1513 ignore_unused(descriptor);
1514
1515 bool supported = true;
Matthew Jacksone69c3992019-09-09 14:31:21 +01001516 std::array<DataType,4> supportedTypes =
Matthew Jackson81e601c2019-07-11 12:07:09 +01001517 {
1518 DataType::Float32,
Matthew Jacksone69c3992019-09-09 14:31:21 +01001519 DataType::Float16,
Matthew Jackson81e601c2019-07-11 12:07:09 +01001520 DataType::QuantisedAsymm8,
1521 DataType::QuantisedSymm16
1522 };
1523
1524 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1525 "Reference stack: output type not supported");
1526 for (const TensorInfo* input : inputs)
1527 {
1528 BOOST_ASSERT(input != nullptr);
1529 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
1530 "Reference stack: input type not supported");
1531
1532 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
1533 "Reference stack: input and output types mismatched.");
1534 }
1535
1536 return supported;
1537}
1538
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001539bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1540 const TensorInfo& output,
1541 const StridedSliceDescriptor& descriptor,
1542 Optional<std::string&> reasonIfUnsupported) const
1543{
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001544 ignore_unused(descriptor);
Narumol Prangnawaratf9ac3fd2019-07-03 14:55:57 +01001545 bool supported = true;
1546
1547 std::array<DataType,3> supportedTypes =
1548 {
1549 DataType::Float32,
1550 DataType::QuantisedAsymm8,
1551 DataType::QuantisedSymm16
1552 };
1553
1554 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1555 "Reference StridedSlice: input type not supported");
1556
1557 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1558 "Reference StridedSlice: output type not supported");
1559
1560 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1561 "Reference StridedSlice: input and output types are mismatched");
1562
1563 return supported;
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001564}
1565
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001566bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1567 const TensorInfo& input1,
1568 const TensorInfo& output,
1569 Optional<std::string&> reasonIfUnsupported) const
1570{
Sadik Armagan2999a022019-04-09 14:20:12 +01001571 bool supported = true;
1572
Matthew Jackson9bff1442019-09-12 09:08:23 +01001573 std::array<DataType,4> supportedTypes = {
Sadik Armagan2999a022019-04-09 14:20:12 +01001574 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001575 DataType::Float16,
Sadik Armagan2999a022019-04-09 14:20:12 +01001576 DataType::QuantisedAsymm8,
1577 DataType::QuantisedSymm16
1578 };
1579
1580 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1581 "Reference subtraction: input 0 is not a supported type.");
1582
1583 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1584 "Reference subtraction: input 1 is not a supported type.");
1585
1586 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1587 "Reference subtraction: output is not a supported type.");
1588
1589 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1590 "Reference subtraction: input 0 and Input 1 types are mismatched");
1591
1592 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1593 "Reference subtraction: input and output types are mismatched");
1594
1595 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1596 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1597
1598 return supported;
Aron Virginas-Tarb5acbb72018-10-15 11:11:51 +01001599}
1600
Matteo Martincighab9e5252019-06-13 17:27:46 +01001601bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
1602 const TensorInfo& alpha,
1603 const TensorInfo& output,
1604 Optional<std::string&> reasonIfUnsupported) const
1605{
1606 bool supported = true;
1607
Matthew Jackson9bff1442019-09-12 09:08:23 +01001608 std::array<DataType, 4> supportedTypes
Matteo Martincighab9e5252019-06-13 17:27:46 +01001609 {
1610 DataType::Float32,
Matthew Jackson9bff1442019-09-12 09:08:23 +01001611 DataType::Float16,
Matteo Martincighab9e5252019-06-13 17:27:46 +01001612 DataType::QuantisedAsymm8,
1613 DataType::QuantisedSymm16
1614 };
1615
1616 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1617 "PReLU: input is not a supported type.");
1618
1619 supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
1620 "PReLU: alpha is not a supported type.");
1621
1622 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1623 "PReLU: output is not a supported type.");
1624
1625 supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
1626 "PReLU: input, alpha and output types are mismatched");
1627
1628 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
1629 "PReLU: shapes are not suitable for implicit broadcast");
1630
1631 return supported;
1632}
1633
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001634bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
1635 const TensorInfo& output,
1636 const TransposeConvolution2dDescriptor& descriptor,
1637 const TensorInfo& weights,
1638 const Optional<TensorInfo>& biases,
1639 Optional<std::string&> reasonIfUnsupported) const
1640{
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001641 bool supported = true;
1642
Matthew Jackson252df3a2019-09-11 09:19:18 +01001643 std::array<DataType,4> supportedTypes =
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001644 {
1645 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001646 DataType::Float16,
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001647 DataType::QuantisedAsymm8,
1648 DataType::QuantisedSymm16
1649 };
1650
1651 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1652 "Reference TransposeConvolution2d: input is not a supported type.");
1653
1654 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1655 "Reference TransposeConvolution2d: output is not a supported type.");
1656
1657 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1658 "Reference TransposeConvolution2d: weights is not a supported type.");
1659
1660 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1661 "Reference TransposeConvolution2d: input and output types mismatched.");
1662
1663 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1664 "Reference TransposeConvolution2d: input and weights types mismatched.");
1665
1666 if (biases.has_value())
1667 {
Matthew Jackson252df3a2019-09-11 09:19:18 +01001668 std::array<DataType,3> biasesSupportedTypes =
Aron Virginas-Tar651aafe2019-08-05 11:52:05 +01001669 {
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001670 DataType::Float32,
Matthew Jackson252df3a2019-09-11 09:19:18 +01001671 DataType::Float16,
Aron Virginas-Tar98180ef2019-06-26 15:02:47 +01001672 DataType::Signed32
1673 };
1674 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1675 "Reference TransposeConvolution2d: biases is not a supported type.");
1676 }
1677
1678 return supported;
1679}
1680
arovir011c7c81b2018-10-08 11:34:28 +01001681} // namespace armnn