blob: a9ab23732509f8118a906e74df30b0c706d5ca30 [file] [log] [blame]
arovir014424b0a2018-10-04 10:46:04 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ClBackend.hpp"
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +01007#include "ClBackendContext.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01008#include "ClBackendId.hpp"
Sadik Armagan045f6be2020-09-10 13:37:32 +01009#include "ClBackendModelContext.hpp"
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +010010#include "ClImportTensorHandleFactory.hpp"
David Beck111b5d92018-11-12 14:59:37 +000011#include "ClLayerSupport.hpp"
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010012#include "ClTensorHandleFactory.hpp"
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +010013#include "ClWorkloadFactory.hpp"
arovir01a0944792018-10-11 15:00:58 +010014
Matteo Martincighc601aa62019-10-29 15:03:22 +000015#include <armnn/BackendRegistry.hpp>
Mike Kelly07810fc2020-11-12 10:58:48 +000016#include <armnn/Descriptors.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000017
Mike Kelly07810fc2020-11-12 10:58:48 +000018#include <aclCommon/ArmComputeSubgraphUtils.hpp>
19#include <aclCommon/ArmComputeUtils.hpp>
Aron Virginas-Tar56055192018-11-12 18:10:43 +000020#include <aclCommon/BaseMemoryManager.hpp>
21
Matteo Martincighe5b8eb92019-11-28 15:45:42 +000022#include <armnn/backends/IBackendContext.hpp>
23#include <armnn/backends/IMemoryManager.hpp>
Jan Eilers3c9e0452020-04-10 13:00:44 +010024#include <armnn/utility/PolymorphicDowncast.hpp>
25
Mike Kelly07810fc2020-11-12 10:58:48 +000026#include "workloads/ClAdditionWorkload.hpp"
27#include "workloads/ClBatchNormalizationFloatWorkload.hpp"
28#include "workloads/ClConvolution2dWorkload.hpp"
29#include "workloads/ClDepthwiseConvolutionWorkload.hpp"
Teresa Charline11e63d2021-04-21 12:56:45 +010030#include "workloads/ClDivisionWorkload.hpp"
Mike Kelly07810fc2020-11-12 10:58:48 +000031#include "workloads/ClFullyConnectedWorkload.hpp"
32#include "workloads/ClMultiplicationWorkload.hpp"
Matthew Sloyan5fc0fd62021-05-03 12:22:03 +010033#include "workloads/ClReduceWorkload.hpp"
Mike Kelly07810fc2020-11-12 10:58:48 +000034#include "workloads/ClSubtractionWorkload.hpp"
35
David Beck263e3492018-11-09 14:46:40 +000036#include <Optimizer.hpp>
arovir014424b0a2018-10-04 10:46:04 +010037
Mike Kelly07810fc2020-11-12 10:58:48 +000038#include <arm_compute/core/Types.h>
Aron Virginas-Tar56055192018-11-12 18:10:43 +000039#include <arm_compute/runtime/CL/CLBufferAllocator.h>
40
arovir014424b0a2018-10-04 10:46:04 +010041namespace armnn
42{
43
David Beck3cc9a622018-10-12 10:38:31 +010044const BackendId& ClBackend::GetIdStatic()
arovir014424b0a2018-10-04 10:46:04 +010045{
David Beck3e9e1152018-10-17 14:17:50 +010046 static const BackendId s_Id{ClBackendId()};
arovir014424b0a2018-10-04 10:46:04 +010047 return s_Id;
48}
49
Aron Virginas-Tar56055192018-11-12 18:10:43 +000050IBackendInternal::IMemoryManagerUniquePtr ClBackend::CreateMemoryManager() const
arovir014424b0a2018-10-04 10:46:04 +010051{
Aron Virginas-Tar56055192018-11-12 18:10:43 +000052 return std::make_unique<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
53}
54
55IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory(
56 const IBackendInternal::IMemoryManagerSharedPtr& memoryManager) const
57{
58 return std::make_unique<ClWorkloadFactory>(
Jan Eilers3c9e0452020-04-10 13:00:44 +010059 PolymorphicPointerDowncast<ClMemoryManager>(memoryManager));
arovir014424b0a2018-10-04 10:46:04 +010060}
61
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010062IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory(
Sadik Armagan04a72972020-09-14 15:44:18 +010063 const IBackendInternal::IMemoryManagerSharedPtr& memoryManager, const ModelOptions& modelOptions) const
64{
65 return std::make_unique<ClWorkloadFactory>(
66 PolymorphicPointerDowncast<ClMemoryManager>(memoryManager), CreateBackendSpecificModelContext(modelOptions));
67}
68
69IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory(
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010070 TensorHandleFactoryRegistry& registry) const
71{
72 auto memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
73
74 registry.RegisterMemoryManager(memoryManager);
Narumol Prangnawarat549cb7a2020-07-10 17:50:53 +010075 registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(memoryManager));
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +010076 registry.RegisterFactory(std::make_unique<ClImportTensorHandleFactory>(
77 static_cast<MemorySourceFlags>(MemorySource::Malloc), static_cast<MemorySourceFlags>(MemorySource::Malloc)));
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010078
79 return std::make_unique<ClWorkloadFactory>(
Jan Eilers3c9e0452020-04-10 13:00:44 +010080 PolymorphicPointerDowncast<ClMemoryManager>(memoryManager));
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010081}
82
Sadik Armagan04a72972020-09-14 15:44:18 +010083IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory(
84 TensorHandleFactoryRegistry& registry, const ModelOptions& modelOptions) const
85{
86 auto memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
87
88 registry.RegisterMemoryManager(memoryManager);
89 registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(memoryManager));
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +010090 registry.RegisterFactory(std::make_unique<ClImportTensorHandleFactory>(
91 static_cast<MemorySourceFlags>(MemorySource::Malloc), static_cast<MemorySourceFlags>(MemorySource::Malloc)));
92
93 return std::make_unique<ClWorkloadFactory>(
94 PolymorphicPointerDowncast<ClMemoryManager>(memoryManager), CreateBackendSpecificModelContext(modelOptions));
95}
96
97IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory(
98 TensorHandleFactoryRegistry& registry,
99 const ModelOptions& modelOptions,
100 MemorySourceFlags inputFlags,
101 MemorySourceFlags outputFlags) const
102{
103 auto memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
104
105 registry.RegisterMemoryManager(memoryManager);
106 registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(memoryManager));
107 registry.RegisterFactory(std::make_unique<ClImportTensorHandleFactory>(inputFlags, outputFlags));
Sadik Armagan04a72972020-09-14 15:44:18 +0100108
109 return std::make_unique<ClWorkloadFactory>(
110 PolymorphicPointerDowncast<ClMemoryManager>(memoryManager), CreateBackendSpecificModelContext(modelOptions));
111}
112
Jan Eilerse9f0f0f2019-08-16 10:28:37 +0100113std::vector<ITensorHandleFactory::FactoryId> ClBackend::GetHandleFactoryPreferences() const
114{
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +0100115 return std::vector<ITensorHandleFactory::FactoryId> {ClTensorHandleFactory::GetIdStatic(),
116 ClImportTensorHandleFactory::GetIdStatic()};
Jan Eilerse9f0f0f2019-08-16 10:28:37 +0100117}
118
119void ClBackend::RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry)
120{
121 auto mgr = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
122
123 registry.RegisterMemoryManager(mgr);
124 registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(mgr));
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +0100125 registry.RegisterFactory(std::make_unique<ClImportTensorHandleFactory>(
126 static_cast<MemorySourceFlags>(MemorySource::Malloc), static_cast<MemorySourceFlags>(MemorySource::Malloc)));
127}
128
129void ClBackend::RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry,
130 MemorySourceFlags inputFlags,
131 MemorySourceFlags outputFlags)
132{
133 auto mgr = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
134
135 registry.RegisterMemoryManager(mgr);
136 registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(mgr));
137 registry.RegisterFactory(std::make_unique<ClImportTensorHandleFactory>(inputFlags, outputFlags));
Jan Eilerse9f0f0f2019-08-16 10:28:37 +0100138}
139
Sadik Armagan045f6be2020-09-10 13:37:32 +0100140IBackendInternal::IBackendContextPtr ClBackend::CreateBackendContext(const IRuntime::CreationOptions& options) const
David Beck1b61be52018-11-08 09:19:14 +0000141{
142 return IBackendContextPtr{new ClBackendContext{options}};
143}
144
Colm Donelane49755b2020-01-29 15:22:43 +0000145IBackendInternal::IBackendProfilingContextPtr ClBackend::CreateBackendProfilingContext(
Colm Donelan1aff3932020-02-05 17:48:59 +0000146 const IRuntime::CreationOptions&, IBackendProfilingPtr&)
Colm Donelane49755b2020-01-29 15:22:43 +0000147{
148 return IBackendProfilingContextPtr{};
149}
150
David Beck263e3492018-11-09 14:46:40 +0000151IBackendInternal::Optimizations ClBackend::GetOptimizations() const
152{
153 return Optimizations{};
154}
David Beck1b61be52018-11-08 09:19:14 +0000155
Sadik Armagan045f6be2020-09-10 13:37:32 +0100156IBackendInternal::IBackendSpecificModelContextPtr ClBackend::CreateBackendSpecificModelContext(
157 const ModelOptions& modelOptions) const
158{
159 return IBackendSpecificModelContextPtr{new ClBackendModelContext{modelOptions}};
160}
161
David Beck111b5d92018-11-12 14:59:37 +0000162IBackendInternal::ILayerSupportSharedPtr ClBackend::GetLayerSupport() const
163{
Sadik Armagan045f6be2020-09-10 13:37:32 +0100164 static ILayerSupportSharedPtr layerSupport
165 {
166 new ClLayerSupport(IBackendInternal::IBackendSpecificModelContextPtr{})
167 };
168 return layerSupport;
169}
170
171IBackendInternal::ILayerSupportSharedPtr ClBackend::GetLayerSupport(const ModelOptions& modelOptions) const
172{
173 static ILayerSupportSharedPtr layerSupport
174 {
175 new ClLayerSupport(CreateBackendSpecificModelContext(modelOptions))
176 };
David Beck111b5d92018-11-12 14:59:37 +0000177 return layerSupport;
178}
179
Sadik Armaganaede8ca2021-03-31 16:12:13 +0100180bool ClBackend::HasCapability(BackendCapability capabilityClass) const
181{
182 auto search = gpuAccCapabilities.find(capabilityClass);
183 if (search != gpuAccCapabilities.end())
184 {
185 return true;
186 }
187 return false;
188}
189
Mike Kelly07810fc2020-11-12 10:58:48 +0000190OptimizationViews ClBackend::OptimizeSubgraphView(const SubgraphView& subgraph,
191 const ModelOptions& modelOptions) const
Matteo Martincighadddddb2019-01-24 14:06:23 +0000192{
Matteo Martincighc3ba50e2019-05-22 14:28:16 +0100193 OptimizationViews optimizationViews;
Matteo Martincighadddddb2019-01-24 14:06:23 +0000194
Mike Kelly07810fc2020-11-12 10:58:48 +0000195 auto it = subgraph.end();
196 bool isFastMathEnabled = false;
Mike Kelly1ac690a2020-11-17 11:41:38 +0000197 std::map<LayerGuid, Layer*> untouched;
Mike Kelly07810fc2020-11-12 10:58:48 +0000198
Mike Kelly1ac690a2020-11-17 11:41:38 +0000199 while (it != subgraph.begin())
200 {
201 --it;
202 Layer& base = **it;
203 untouched.insert({base.GetGuid(), &base});
204 }
205
206 it = subgraph.end();
Mike Kelly07810fc2020-11-12 10:58:48 +0000207#if defined(ARMCOMPUTECL_ENABLED)
208 IBackendInternal::IBackendSpecificModelContextPtr modelContextPtr = CreateBackendSpecificModelContext(modelOptions);
209
210 if (modelContextPtr)
211 {
212 auto clModelOptions = dynamic_cast<ClBackendModelContext*>(modelContextPtr.get());
213 if (clModelOptions)
214 {
215 isFastMathEnabled = clModelOptions->IsFastMathEnabled();
216 }
217 }
218#endif
Mike Kelly07810fc2020-11-12 10:58:48 +0000219 while (it != subgraph.begin())
220 {
221 --it;
222 Layer& base = **it;
223
Matthew Sloyan5fc0fd62021-05-03 12:22:03 +0100224 // Fuse activation into previous layer if supported by backend
Mike Kelly07810fc2020-11-12 10:58:48 +0000225 if ((base.GetType() == LayerType::DepthwiseConvolution2d || base.GetType() == LayerType::Convolution2d
226 || base.GetType() == LayerType::BatchNormalization || base.GetType() == LayerType::FullyConnected
227 || base.GetType() == LayerType::Addition || base.GetType() == LayerType::Multiplication
Matthew Sloyanae123062021-05-07 14:18:01 +0000228 || base.GetType() == LayerType::Subtraction || base.GetType() == LayerType::Division)
Mike Kelly07810fc2020-11-12 10:58:48 +0000229 && (base.GetAdditionalInformation<ActivationDescriptor>() == nullptr))
230 {
231 for (auto output = base.BeginOutputSlots(); output != base.EndOutputSlots(); ++output)
232 {
233 if (output->GetNumConnections() == 1)
234 {
235 for (auto&& childInput : output->GetConnections())
236 {
Teresa Charlind672f5d2021-01-18 18:07:57 +0000237 if ((childInput->GetOwningLayer().GetType() == LayerType::Activation) &&
238 (checkDataTypeInputandOutput(childInput->GetOwningLayer())))
Mike Kelly07810fc2020-11-12 10:58:48 +0000239 {
240 Layer& child = childInput->GetOwningLayer();
241
242 auto* activationLayer = PolymorphicDowncast<ActivationLayer*>(&child);
243
244 const std::string name = std::string("fused-") + child.GetName() + std::string("-into-") +
245 base.GetName();
246
247 // Get params from activation layer
248 ActivationDescriptor activationDesc = activationLayer->GetParameters();
249
250 if (base.GetType() == LayerType::Convolution2d)
251 {
252 Convolution2dLayer* baseLayer = PolymorphicDowncast<Convolution2dLayer*>(&base);
253
254 Optional<TensorInfo> biases;
255
256 if (baseLayer->GetParameters().m_BiasEnabled)
257 {
Mike Kelly1ac690a2020-11-17 11:41:38 +0000258 biases = baseLayer->m_Bias->GetTensorInfo();
Mike Kelly07810fc2020-11-12 10:58:48 +0000259 }
260
261 arm_compute::Status status = ClConvolution2dWorkloadValidate(
262 baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
263 activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
264 baseLayer->GetParameters(),
265 baseLayer->m_Weight->GetTensorInfo(),
266 biases,
267 isFastMathEnabled,
268 &activationDesc);
269
270 if (status)
271 {
272 FuseLayerWithWeightsAndBiases<Convolution2dLayer>(optimizationViews,
273 baseLayer,
274 activationLayer,
275 activationDesc,
276 name);
Mike Kelly1ac690a2020-11-17 11:41:38 +0000277 untouched.erase(baseLayer->GetGuid());
278 untouched.erase(activationLayer->GetGuid());
Mike Kelly07810fc2020-11-12 10:58:48 +0000279 }
280 }
281 else if (base.GetType() == LayerType::DepthwiseConvolution2d)
282 {
283 DepthwiseConvolution2dLayer* baseLayer =
284 PolymorphicDowncast<DepthwiseConvolution2dLayer*>(&base);
285
286 Optional<TensorInfo> biases;
287
288 if (baseLayer->GetParameters().m_BiasEnabled)
289 {
Mike Kelly1ac690a2020-11-17 11:41:38 +0000290 biases = baseLayer->m_Bias->GetTensorInfo();
Mike Kelly07810fc2020-11-12 10:58:48 +0000291 }
292
293 arm_compute::Status status = ClDepthwiseConvolutionWorkloadValidate(
294 baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
295 activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
296 baseLayer->GetParameters(),
297 baseLayer->m_Weight->GetTensorInfo(),
298 biases,
299 &activationDesc);
300
301 if (status)
302 {
303 FuseLayerWithWeightsAndBiases<DepthwiseConvolution2dLayer>(optimizationViews,
304 baseLayer,
305 activationLayer,
306 activationDesc,
307 name);
Mike Kelly1ac690a2020-11-17 11:41:38 +0000308 untouched.erase(baseLayer->GetGuid());
309 untouched.erase(activationLayer->GetGuid());
Mike Kelly07810fc2020-11-12 10:58:48 +0000310 }
311 }
312 else if (base.GetType() == LayerType::FullyConnected)
313 {
314 FullyConnectedLayer* baseLayer = PolymorphicDowncast<FullyConnectedLayer*>(&base);
315
316 arm_compute::Status status = ClFullyConnectedWorkloadValidate(
317 baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
318 activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
319 baseLayer->m_Weight->GetTensorInfo(),
320 baseLayer->m_Bias->GetTensorInfo(),
321 baseLayer->GetParameters(),
322 &activationDesc);
323
324 if (status)
325 {
326 FuseLayerWithWeightsAndBiases<FullyConnectedLayer>(optimizationViews,
327 baseLayer,
328 activationLayer,
329 activationDesc,
330 name);
Mike Kelly1ac690a2020-11-17 11:41:38 +0000331 untouched.erase(baseLayer->GetGuid());
332 untouched.erase(activationLayer->GetGuid());
Mike Kelly07810fc2020-11-12 10:58:48 +0000333 }
334 }
335 else if (base.GetType() == LayerType::BatchNormalization)
336 {
337 BatchNormalizationLayer* baseLayer =
338 PolymorphicDowncast<BatchNormalizationLayer*>(&base);
339
340 arm_compute::Status status = ClBatchNormalizationValidate(
341 baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
342 activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
343 baseLayer->m_Mean->GetTensorInfo(),
344 baseLayer->m_Variance->GetTensorInfo(),
345 baseLayer->m_Beta->GetTensorInfo(),
346 baseLayer->m_Gamma->GetTensorInfo(),
347 baseLayer->GetParameters(),
348 &activationDesc);
349
350 if (status)
351 {
352 BatchNormalizationLayer* replacementLayer =
353 FuseLayerWithParameters<BatchNormalizationLayer>(optimizationViews,
354 baseLayer,
355 activationLayer,
356 activationDesc,
357 name);
358
359 replacementLayer->m_Beta = std::move(baseLayer->m_Beta);
360 replacementLayer->m_Gamma = std::move(baseLayer->m_Gamma);
361 replacementLayer->m_Mean = std::move(baseLayer->m_Mean);
362 replacementLayer->m_Variance = std::move(baseLayer->m_Variance);
Mike Kelly1ac690a2020-11-17 11:41:38 +0000363 untouched.erase(baseLayer->GetGuid());
364 untouched.erase(activationLayer->GetGuid());
Mike Kelly07810fc2020-11-12 10:58:48 +0000365 }
366 }
367 else if (base.GetType() == LayerType::Addition)
368 {
369 AdditionLayer* baseLayer = PolymorphicDowncast<AdditionLayer*>(&base);
370
371 arm_compute::Status status = ClAdditionValidate(
372 baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
373 baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
374 activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
375 &activationDesc);
376
377 if (status)
378 {
379 FuseLayerWithoutParameters<AdditionLayer>(optimizationViews,
380 baseLayer,
381 activationLayer,
382 activationDesc,
383 name);
Mike Kelly1ac690a2020-11-17 11:41:38 +0000384 untouched.erase(baseLayer->GetGuid());
385 untouched.erase(activationLayer->GetGuid());
Mike Kelly07810fc2020-11-12 10:58:48 +0000386 }
387 }
388 else if (base.GetType() == LayerType::Division)
389 {
390 DivisionLayer* baseLayer = PolymorphicDowncast<DivisionLayer*>(&base);
391
392 arm_compute::Status status = ClDivisionWorkloadValidate(
393 baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
394 baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
395 activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
396 &activationDesc);
397
398 if (status)
399 {
400 FuseLayerWithoutParameters<DivisionLayer>(optimizationViews,
401 baseLayer,
402 activationLayer,
403 activationDesc,
404 name);
Mike Kelly1ac690a2020-11-17 11:41:38 +0000405 untouched.erase(baseLayer->GetGuid());
406 untouched.erase(activationLayer->GetGuid());
Mike Kelly07810fc2020-11-12 10:58:48 +0000407 }
408 }
409 else if (base.GetType() == LayerType::Multiplication)
410 {
411 MultiplicationLayer* baseLayer = PolymorphicDowncast<MultiplicationLayer*>(&base);
412
413 arm_compute::Status status = ClMultiplicationWorkloadValidate(
414 baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
415 baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
416 activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
417 &activationDesc);
418
419 if (status)
420 {
421 FuseLayerWithoutParameters<MultiplicationLayer>(optimizationViews,
422 baseLayer,
423 activationLayer,
424 activationDesc,
425 name);
Mike Kelly1ac690a2020-11-17 11:41:38 +0000426 untouched.erase(baseLayer->GetGuid());
427 untouched.erase(activationLayer->GetGuid());
Mike Kelly07810fc2020-11-12 10:58:48 +0000428 }
429 }
430 else if (base.GetType() == LayerType::Subtraction)
431 {
432 SubtractionLayer* baseLayer = PolymorphicDowncast<SubtractionLayer*>(&base);
433
434 arm_compute::Status status = ClSubtractionValidate(
435 baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
436 baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
437 activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
438 &activationDesc);
439
440 if (status)
441 {
442 FuseLayerWithoutParameters<SubtractionLayer>(optimizationViews,
443 baseLayer,
444 activationLayer,
445 activationDesc,
446 name);
Mike Kelly1ac690a2020-11-17 11:41:38 +0000447 untouched.erase(baseLayer->GetGuid());
448 untouched.erase(activationLayer->GetGuid());
Mike Kelly07810fc2020-11-12 10:58:48 +0000449 }
450 }
451 }
452 }
453 }
454 }
455 }
Matthew Sloyan5fc0fd62021-05-03 12:22:03 +0100456
457 // Separate reduce layer with multiple axes into multiple reduce layers with 1 axis.
458 if (base.GetType() == LayerType::Reduce)
459 {
460 ReduceLayer* baseLayer = PolymorphicDowncast<ReduceLayer*>(&base);
461 ReduceDescriptor reduceDescriptor = baseLayer->GetParameters();
462
463 if (!reduceDescriptor.m_vAxis.empty() && reduceDescriptor.m_vAxis.size() > 1)
464 {
465 // Add new layers to the graph and connect them.
466 std::vector<Layer*> layers = ChainReduceLayers<ReduceLayer>(optimizationViews,
467 baseLayer,
468 reduceDescriptor);
469
470 // Replace existing baselayer with new subgraph.
471 ReplaceLayers<ReduceLayer>(optimizationViews, baseLayer, layers);
472 untouched.erase(baseLayer->GetGuid());
473 }
474 }
Mike Kelly07810fc2020-11-12 10:58:48 +0000475 }
Mike Kelly1ac690a2020-11-17 11:41:38 +0000476
Mike Kelly07810fc2020-11-12 10:58:48 +0000477 if (optimizationViews.GetSubstitutions().empty())
478 {
479 optimizationViews.AddUntouchedSubgraph(SubgraphView(subgraph));
480 }
Mike Kelly1ac690a2020-11-17 11:41:38 +0000481 else
482 {
483 ReportUntouchedLayers(optimizationViews, untouched);
484 }
Matteo Martincighc3ba50e2019-05-22 14:28:16 +0100485
486 return optimizationViews;
Matteo Martincighadddddb2019-01-24 14:06:23 +0000487}
488
David Beck9efb57d2018-11-05 13:40:33 +0000489} // namespace armnn