blob: 8abb16cccac44ac6116954f973c96ff720ffef70 [file] [log] [blame]
arovir014424b0a2018-10-04 10:46:04 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ClBackend.hpp"
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +01007#include "ClBackendContext.hpp"
David Monahan6642b8a2021-11-04 16:31:46 +00008#include "ClBackendDefaultAllocator.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01009#include "ClBackendId.hpp"
Sadik Armagan045f6be2020-09-10 13:37:32 +010010#include "ClBackendModelContext.hpp"
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +010011#include "ClImportTensorHandleFactory.hpp"
David Beck111b5d92018-11-12 14:59:37 +000012#include "ClLayerSupport.hpp"
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010013#include "ClTensorHandleFactory.hpp"
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +010014#include "ClWorkloadFactory.hpp"
arovir01a0944792018-10-11 15:00:58 +010015
Matteo Martincighc601aa62019-10-29 15:03:22 +000016#include <armnn/BackendRegistry.hpp>
Mike Kelly07810fc2020-11-12 10:58:48 +000017#include <armnn/Descriptors.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000018
Mike Kelly07810fc2020-11-12 10:58:48 +000019#include <aclCommon/ArmComputeSubgraphUtils.hpp>
20#include <aclCommon/ArmComputeUtils.hpp>
Aron Virginas-Tar56055192018-11-12 18:10:43 +000021#include <aclCommon/BaseMemoryManager.hpp>
22
Matteo Martincighe5b8eb92019-11-28 15:45:42 +000023#include <armnn/backends/IBackendContext.hpp>
24#include <armnn/backends/IMemoryManager.hpp>
Jan Eilers3c9e0452020-04-10 13:00:44 +010025#include <armnn/utility/PolymorphicDowncast.hpp>
26
Mike Kelly07810fc2020-11-12 10:58:48 +000027#include "workloads/ClAdditionWorkload.hpp"
28#include "workloads/ClBatchNormalizationFloatWorkload.hpp"
29#include "workloads/ClConvolution2dWorkload.hpp"
30#include "workloads/ClDepthwiseConvolutionWorkload.hpp"
Teresa Charline11e63d2021-04-21 12:56:45 +010031#include "workloads/ClDivisionWorkload.hpp"
Mike Kelly07810fc2020-11-12 10:58:48 +000032#include "workloads/ClFullyConnectedWorkload.hpp"
33#include "workloads/ClMultiplicationWorkload.hpp"
Matthew Sloyan5fc0fd62021-05-03 12:22:03 +010034#include "workloads/ClReduceWorkload.hpp"
Mike Kelly07810fc2020-11-12 10:58:48 +000035#include "workloads/ClSubtractionWorkload.hpp"
36
David Beck263e3492018-11-09 14:46:40 +000037#include <Optimizer.hpp>
arovir014424b0a2018-10-04 10:46:04 +010038
Mike Kelly07810fc2020-11-12 10:58:48 +000039#include <arm_compute/core/Types.h>
Aron Virginas-Tar56055192018-11-12 18:10:43 +000040#include <arm_compute/runtime/CL/CLBufferAllocator.h>
41
arovir014424b0a2018-10-04 10:46:04 +010042namespace armnn
43{
44
David Beck3cc9a622018-10-12 10:38:31 +010045const BackendId& ClBackend::GetIdStatic()
arovir014424b0a2018-10-04 10:46:04 +010046{
David Beck3e9e1152018-10-17 14:17:50 +010047 static const BackendId s_Id{ClBackendId()};
arovir014424b0a2018-10-04 10:46:04 +010048 return s_Id;
49}
50
Aron Virginas-Tar56055192018-11-12 18:10:43 +000051IBackendInternal::IMemoryManagerUniquePtr ClBackend::CreateMemoryManager() const
arovir014424b0a2018-10-04 10:46:04 +010052{
Jan Eilersc1c872f2021-07-22 13:17:04 +010053 if (m_UsingCustomAllocator)
54 {
55 return std::make_unique<ClMemoryManager>(m_CustomAllocator);
56 }
Aron Virginas-Tar56055192018-11-12 18:10:43 +000057 return std::make_unique<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
58}
59
60IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory(
61 const IBackendInternal::IMemoryManagerSharedPtr& memoryManager) const
62{
63 return std::make_unique<ClWorkloadFactory>(
Jan Eilers3c9e0452020-04-10 13:00:44 +010064 PolymorphicPointerDowncast<ClMemoryManager>(memoryManager));
arovir014424b0a2018-10-04 10:46:04 +010065}
66
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010067IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory(
Sadik Armagan04a72972020-09-14 15:44:18 +010068 const IBackendInternal::IMemoryManagerSharedPtr& memoryManager, const ModelOptions& modelOptions) const
69{
70 return std::make_unique<ClWorkloadFactory>(
71 PolymorphicPointerDowncast<ClMemoryManager>(memoryManager), CreateBackendSpecificModelContext(modelOptions));
72}
73
74IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory(
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010075 TensorHandleFactoryRegistry& registry) const
76{
Jan Eilersc1c872f2021-07-22 13:17:04 +010077 std::shared_ptr<ClMemoryManager> memoryManager;
78 if (m_UsingCustomAllocator)
79 {
80 memoryManager = std::make_shared<ClMemoryManager>(m_CustomAllocator);
81 }
82 else
83 {
84 memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
85 }
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010086
Narumol Prangnawaratd12b4072022-01-17 18:03:14 +000087 std::unique_ptr<ITensorHandleFactory> factory = std::make_unique<ClTensorHandleFactory>(memoryManager);
88 std::unique_ptr<ITensorHandleFactory> importFactory = std::make_unique<ClImportTensorHandleFactory>(
89 static_cast<MemorySourceFlags>(MemorySource::Malloc), static_cast<MemorySourceFlags>(MemorySource::Malloc));
90
91 registry.RegisterCopyAndImportFactoryPair(factory->GetId(), importFactory->GetId());
92 registry.RegisterCopyAndImportFactoryPair(importFactory->GetId(), factory->GetId());
93
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010094 registry.RegisterMemoryManager(memoryManager);
Narumol Prangnawaratd12b4072022-01-17 18:03:14 +000095 registry.RegisterFactory(std::move(factory));
96 registry.RegisterFactory(std::move(importFactory));
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010097
98 return std::make_unique<ClWorkloadFactory>(
Jan Eilers3c9e0452020-04-10 13:00:44 +010099 PolymorphicPointerDowncast<ClMemoryManager>(memoryManager));
Jan Eilerse9f0f0f2019-08-16 10:28:37 +0100100}
101
Sadik Armagan04a72972020-09-14 15:44:18 +0100102IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory(
103 TensorHandleFactoryRegistry& registry, const ModelOptions& modelOptions) const
104{
Jan Eilersc1c872f2021-07-22 13:17:04 +0100105 std::shared_ptr<ClMemoryManager> memoryManager;
106 if (m_UsingCustomAllocator)
107 {
108 memoryManager = std::make_shared<ClMemoryManager>(m_CustomAllocator);
109 }
110 else
111 {
112 memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
113 }
Sadik Armagan04a72972020-09-14 15:44:18 +0100114
Narumol Prangnawaratd12b4072022-01-17 18:03:14 +0000115 std::unique_ptr<ITensorHandleFactory> factory = std::make_unique<ClTensorHandleFactory>(memoryManager);
116 std::unique_ptr<ITensorHandleFactory> importFactory = std::make_unique<ClImportTensorHandleFactory>(
117 static_cast<MemorySourceFlags>(MemorySource::Malloc), static_cast<MemorySourceFlags>(MemorySource::Malloc));
118
119 registry.RegisterCopyAndImportFactoryPair(factory->GetId(), importFactory->GetId());
120 registry.RegisterCopyAndImportFactoryPair(importFactory->GetId(), factory->GetId());
121
Sadik Armagan04a72972020-09-14 15:44:18 +0100122 registry.RegisterMemoryManager(memoryManager);
Narumol Prangnawaratd12b4072022-01-17 18:03:14 +0000123 registry.RegisterFactory(std::move(factory));
124 registry.RegisterFactory(std::move(importFactory));
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +0100125
126 return std::make_unique<ClWorkloadFactory>(
127 PolymorphicPointerDowncast<ClMemoryManager>(memoryManager), CreateBackendSpecificModelContext(modelOptions));
128}
129
130IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory(
131 TensorHandleFactoryRegistry& registry,
132 const ModelOptions& modelOptions,
133 MemorySourceFlags inputFlags,
134 MemorySourceFlags outputFlags) const
135{
Jan Eilersc1c872f2021-07-22 13:17:04 +0100136 std::shared_ptr<ClMemoryManager> memoryManager;
137 if (m_UsingCustomAllocator)
138 {
139 memoryManager = std::make_shared<ClMemoryManager>(m_CustomAllocator);
140 }
141 else
142 {
143 memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
144 }
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +0100145
Narumol Prangnawaratd12b4072022-01-17 18:03:14 +0000146 std::unique_ptr<ITensorHandleFactory> factory = std::make_unique<ClTensorHandleFactory>(memoryManager);
147 std::unique_ptr<ITensorHandleFactory> importFactory = std::make_unique<ClImportTensorHandleFactory>(
148 inputFlags, outputFlags);
149
150 registry.RegisterCopyAndImportFactoryPair(factory->GetId(), importFactory->GetId());
151 registry.RegisterCopyAndImportFactoryPair(importFactory->GetId(), factory->GetId());
152
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +0100153 registry.RegisterMemoryManager(memoryManager);
Narumol Prangnawaratd12b4072022-01-17 18:03:14 +0000154 registry.RegisterFactory(std::move(factory));
155 registry.RegisterFactory(std::move(importFactory));
Sadik Armagan04a72972020-09-14 15:44:18 +0100156
157 return std::make_unique<ClWorkloadFactory>(
158 PolymorphicPointerDowncast<ClMemoryManager>(memoryManager), CreateBackendSpecificModelContext(modelOptions));
159}
160
Jan Eilerse9f0f0f2019-08-16 10:28:37 +0100161std::vector<ITensorHandleFactory::FactoryId> ClBackend::GetHandleFactoryPreferences() const
162{
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +0100163 return std::vector<ITensorHandleFactory::FactoryId> {ClTensorHandleFactory::GetIdStatic(),
164 ClImportTensorHandleFactory::GetIdStatic()};
Jan Eilerse9f0f0f2019-08-16 10:28:37 +0100165}
166
167void ClBackend::RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry)
168{
Jan Eilersc1c872f2021-07-22 13:17:04 +0100169 std::shared_ptr<ClMemoryManager> memoryManager;
170 if (m_UsingCustomAllocator)
171 {
172 memoryManager = std::make_shared<ClMemoryManager>(m_CustomAllocator);
173 }
174 else
175 {
176 memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
177 }
Jan Eilerse9f0f0f2019-08-16 10:28:37 +0100178
Narumol Prangnawaratd12b4072022-01-17 18:03:14 +0000179 std::unique_ptr<ITensorHandleFactory> factory = std::make_unique<ClTensorHandleFactory>(memoryManager);
180 std::unique_ptr<ITensorHandleFactory> importFactory = std::make_unique<ClImportTensorHandleFactory>(
181 static_cast<MemorySourceFlags>(MemorySource::Malloc), static_cast<MemorySourceFlags>(MemorySource::Malloc));
182
183 registry.RegisterCopyAndImportFactoryPair(factory->GetId(), importFactory->GetId());
184 registry.RegisterCopyAndImportFactoryPair(importFactory->GetId(), factory->GetId());
185
Jan Eilersc1c872f2021-07-22 13:17:04 +0100186 registry.RegisterMemoryManager(memoryManager);
Narumol Prangnawaratd12b4072022-01-17 18:03:14 +0000187 registry.RegisterFactory(std::move(factory));
188 registry.RegisterFactory(std::move(importFactory));
189
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +0100190}
191
192void ClBackend::RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry,
193 MemorySourceFlags inputFlags,
194 MemorySourceFlags outputFlags)
195{
Jan Eilersc1c872f2021-07-22 13:17:04 +0100196 std::shared_ptr<ClMemoryManager> memoryManager;
197 if (m_UsingCustomAllocator)
198 {
199 memoryManager = std::make_shared<ClMemoryManager>(m_CustomAllocator);
200 }
201 else
202 {
203 memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
204 }
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +0100205
Narumol Prangnawaratd12b4072022-01-17 18:03:14 +0000206 std::unique_ptr<ITensorHandleFactory> factory = std::make_unique<ClTensorHandleFactory>(memoryManager);
207 std::unique_ptr<ITensorHandleFactory> importFactory = std::make_unique<ClImportTensorHandleFactory>(
208 inputFlags, outputFlags);
209
210 registry.RegisterCopyAndImportFactoryPair(factory->GetId(), importFactory->GetId());
211 registry.RegisterCopyAndImportFactoryPair(importFactory->GetId(), factory->GetId());
212
Jan Eilersc1c872f2021-07-22 13:17:04 +0100213 registry.RegisterMemoryManager(memoryManager);
Narumol Prangnawaratd12b4072022-01-17 18:03:14 +0000214 registry.RegisterFactory(std::move(factory));
215 registry.RegisterFactory(std::move(importFactory));
Jan Eilerse9f0f0f2019-08-16 10:28:37 +0100216}
217
Sadik Armagan045f6be2020-09-10 13:37:32 +0100218IBackendInternal::IBackendContextPtr ClBackend::CreateBackendContext(const IRuntime::CreationOptions& options) const
David Beck1b61be52018-11-08 09:19:14 +0000219{
220 return IBackendContextPtr{new ClBackendContext{options}};
221}
222
Colm Donelane49755b2020-01-29 15:22:43 +0000223IBackendInternal::IBackendProfilingContextPtr ClBackend::CreateBackendProfilingContext(
Colm Donelan1aff3932020-02-05 17:48:59 +0000224 const IRuntime::CreationOptions&, IBackendProfilingPtr&)
Colm Donelane49755b2020-01-29 15:22:43 +0000225{
226 return IBackendProfilingContextPtr{};
227}
228
Sadik Armagan045f6be2020-09-10 13:37:32 +0100229IBackendInternal::IBackendSpecificModelContextPtr ClBackend::CreateBackendSpecificModelContext(
230 const ModelOptions& modelOptions) const
231{
232 return IBackendSpecificModelContextPtr{new ClBackendModelContext{modelOptions}};
233}
234
David Beck111b5d92018-11-12 14:59:37 +0000235IBackendInternal::ILayerSupportSharedPtr ClBackend::GetLayerSupport() const
236{
Sadik Armagan045f6be2020-09-10 13:37:32 +0100237 static ILayerSupportSharedPtr layerSupport
238 {
239 new ClLayerSupport(IBackendInternal::IBackendSpecificModelContextPtr{})
240 };
241 return layerSupport;
242}
243
244IBackendInternal::ILayerSupportSharedPtr ClBackend::GetLayerSupport(const ModelOptions& modelOptions) const
245{
246 static ILayerSupportSharedPtr layerSupport
247 {
248 new ClLayerSupport(CreateBackendSpecificModelContext(modelOptions))
249 };
David Beck111b5d92018-11-12 14:59:37 +0000250 return layerSupport;
251}
252
David Monahan6642b8a2021-11-04 16:31:46 +0000253std::unique_ptr<ICustomAllocator> ClBackend::GetDefaultAllocator() const
254{
255 return std::make_unique<ClBackendDefaultAllocator>();
256}
257
Mike Kelly07810fc2020-11-12 10:58:48 +0000258OptimizationViews ClBackend::OptimizeSubgraphView(const SubgraphView& subgraph,
259 const ModelOptions& modelOptions) const
Matteo Martincighadddddb2019-01-24 14:06:23 +0000260{
Matteo Martincighc3ba50e2019-05-22 14:28:16 +0100261 OptimizationViews optimizationViews;
Matteo Martincighadddddb2019-01-24 14:06:23 +0000262
Francis Murtagh56ccf682021-12-13 18:48:12 +0000263 auto it = subgraph.endIConnectable();
Mike Kelly07810fc2020-11-12 10:58:48 +0000264 bool isFastMathEnabled = false;
Mike Kelly1ac690a2020-11-17 11:41:38 +0000265 std::map<LayerGuid, Layer*> untouched;
Mike Kelly07810fc2020-11-12 10:58:48 +0000266
Francis Murtagh56ccf682021-12-13 18:48:12 +0000267 while (it != subgraph.beginIConnectable())
Mike Kelly1ac690a2020-11-17 11:41:38 +0000268 {
269 --it;
Francis Murtagh56ccf682021-12-13 18:48:12 +0000270 Layer& base = *(PolymorphicDowncast<Layer*>(*it));
Mike Kelly1ac690a2020-11-17 11:41:38 +0000271 untouched.insert({base.GetGuid(), &base});
272 }
273
Francis Murtagh56ccf682021-12-13 18:48:12 +0000274 it = subgraph.endIConnectable();
Mike Kelly07810fc2020-11-12 10:58:48 +0000275#if defined(ARMCOMPUTECL_ENABLED)
276 IBackendInternal::IBackendSpecificModelContextPtr modelContextPtr = CreateBackendSpecificModelContext(modelOptions);
277
278 if (modelContextPtr)
279 {
280 auto clModelOptions = dynamic_cast<ClBackendModelContext*>(modelContextPtr.get());
281 if (clModelOptions)
282 {
283 isFastMathEnabled = clModelOptions->IsFastMathEnabled();
284 }
285 }
286#endif
Francis Murtagh56ccf682021-12-13 18:48:12 +0000287 while (it != subgraph.beginIConnectable())
Mike Kelly07810fc2020-11-12 10:58:48 +0000288 {
289 --it;
Francis Murtagh56ccf682021-12-13 18:48:12 +0000290 Layer& base = *(PolymorphicDowncast<Layer*>(*it));
Mike Kelly07810fc2020-11-12 10:58:48 +0000291
Matthew Sloyan5fc0fd62021-05-03 12:22:03 +0100292 // Fuse activation into previous layer if supported by backend
Mike Kelly07810fc2020-11-12 10:58:48 +0000293 if ((base.GetType() == LayerType::DepthwiseConvolution2d || base.GetType() == LayerType::Convolution2d
294 || base.GetType() == LayerType::BatchNormalization || base.GetType() == LayerType::FullyConnected
295 || base.GetType() == LayerType::Addition || base.GetType() == LayerType::Multiplication
Matthew Sloyanae123062021-05-07 14:18:01 +0000296 || base.GetType() == LayerType::Subtraction || base.GetType() == LayerType::Division)
Mike Kelly07810fc2020-11-12 10:58:48 +0000297 && (base.GetAdditionalInformation<ActivationDescriptor>() == nullptr))
298 {
299 for (auto output = base.BeginOutputSlots(); output != base.EndOutputSlots(); ++output)
300 {
301 if (output->GetNumConnections() == 1)
302 {
303 for (auto&& childInput : output->GetConnections())
304 {
Teresa Charlind672f5d2021-01-18 18:07:57 +0000305 if ((childInput->GetOwningLayer().GetType() == LayerType::Activation) &&
306 (checkDataTypeInputandOutput(childInput->GetOwningLayer())))
Mike Kelly07810fc2020-11-12 10:58:48 +0000307 {
308 Layer& child = childInput->GetOwningLayer();
309
310 auto* activationLayer = PolymorphicDowncast<ActivationLayer*>(&child);
311
312 const std::string name = std::string("fused-") + child.GetName() + std::string("-into-") +
313 base.GetName();
314
315 // Get params from activation layer
316 ActivationDescriptor activationDesc = activationLayer->GetParameters();
317
318 if (base.GetType() == LayerType::Convolution2d)
319 {
320 Convolution2dLayer* baseLayer = PolymorphicDowncast<Convolution2dLayer*>(&base);
321
322 Optional<TensorInfo> biases;
323
324 if (baseLayer->GetParameters().m_BiasEnabled)
325 {
Mike Kelly1ac690a2020-11-17 11:41:38 +0000326 biases = baseLayer->m_Bias->GetTensorInfo();
Mike Kelly07810fc2020-11-12 10:58:48 +0000327 }
328
329 arm_compute::Status status = ClConvolution2dWorkloadValidate(
330 baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
331 activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
332 baseLayer->GetParameters(),
333 baseLayer->m_Weight->GetTensorInfo(),
334 biases,
335 isFastMathEnabled,
336 &activationDesc);
337
338 if (status)
339 {
Cathal Corbettcbfd7182021-12-15 17:12:59 +0000340 FuseConvolution2dLayer<Convolution2dLayer>(optimizationViews,
341 baseLayer,
342 activationLayer,
343 activationDesc,
344 name);
Mike Kelly1ac690a2020-11-17 11:41:38 +0000345 untouched.erase(baseLayer->GetGuid());
346 untouched.erase(activationLayer->GetGuid());
Mike Kelly07810fc2020-11-12 10:58:48 +0000347 }
348 }
349 else if (base.GetType() == LayerType::DepthwiseConvolution2d)
350 {
351 DepthwiseConvolution2dLayer* baseLayer =
352 PolymorphicDowncast<DepthwiseConvolution2dLayer*>(&base);
353
354 Optional<TensorInfo> biases;
355
356 if (baseLayer->GetParameters().m_BiasEnabled)
357 {
Mike Kelly1ac690a2020-11-17 11:41:38 +0000358 biases = baseLayer->m_Bias->GetTensorInfo();
Mike Kelly07810fc2020-11-12 10:58:48 +0000359 }
360
361 arm_compute::Status status = ClDepthwiseConvolutionWorkloadValidate(
362 baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
363 activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
364 baseLayer->GetParameters(),
365 baseLayer->m_Weight->GetTensorInfo(),
366 biases,
367 &activationDesc);
368
369 if (status)
370 {
Cathal Corbettcbfd7182021-12-15 17:12:59 +0000371 FuseDepthwiseConvolution2dLayer<DepthwiseConvolution2dLayer>(optimizationViews,
372 baseLayer,
373 activationLayer,
374 activationDesc,
375 name);
Mike Kelly1ac690a2020-11-17 11:41:38 +0000376 untouched.erase(baseLayer->GetGuid());
377 untouched.erase(activationLayer->GetGuid());
Mike Kelly07810fc2020-11-12 10:58:48 +0000378 }
379 }
380 else if (base.GetType() == LayerType::FullyConnected)
381 {
382 FullyConnectedLayer* baseLayer = PolymorphicDowncast<FullyConnectedLayer*>(&base);
383
384 arm_compute::Status status = ClFullyConnectedWorkloadValidate(
385 baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
386 activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
387 baseLayer->m_Weight->GetTensorInfo(),
388 baseLayer->m_Bias->GetTensorInfo(),
389 baseLayer->GetParameters(),
390 &activationDesc);
391
392 if (status)
393 {
Cathal Corbettcbfd7182021-12-15 17:12:59 +0000394 FuseFullyConnectedLayer<FullyConnectedLayer>(optimizationViews,
395 baseLayer,
396 activationLayer,
397 activationDesc,
398 name);
Mike Kelly1ac690a2020-11-17 11:41:38 +0000399 untouched.erase(baseLayer->GetGuid());
400 untouched.erase(activationLayer->GetGuid());
Mike Kelly07810fc2020-11-12 10:58:48 +0000401 }
402 }
403 else if (base.GetType() == LayerType::BatchNormalization)
404 {
405 BatchNormalizationLayer* baseLayer =
406 PolymorphicDowncast<BatchNormalizationLayer*>(&base);
407
408 arm_compute::Status status = ClBatchNormalizationValidate(
409 baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
410 activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
411 baseLayer->m_Mean->GetTensorInfo(),
412 baseLayer->m_Variance->GetTensorInfo(),
413 baseLayer->m_Beta->GetTensorInfo(),
414 baseLayer->m_Gamma->GetTensorInfo(),
415 baseLayer->GetParameters(),
416 &activationDesc);
417
418 if (status)
419 {
420 BatchNormalizationLayer* replacementLayer =
Cathal Corbettcbfd7182021-12-15 17:12:59 +0000421 FuseBatchNormalizationLayer<BatchNormalizationLayer>(optimizationViews,
Mike Kelly07810fc2020-11-12 10:58:48 +0000422 baseLayer,
423 activationLayer,
424 activationDesc,
425 name);
426
427 replacementLayer->m_Beta = std::move(baseLayer->m_Beta);
428 replacementLayer->m_Gamma = std::move(baseLayer->m_Gamma);
429 replacementLayer->m_Mean = std::move(baseLayer->m_Mean);
430 replacementLayer->m_Variance = std::move(baseLayer->m_Variance);
Mike Kelly1ac690a2020-11-17 11:41:38 +0000431 untouched.erase(baseLayer->GetGuid());
432 untouched.erase(activationLayer->GetGuid());
Mike Kelly07810fc2020-11-12 10:58:48 +0000433 }
434 }
435 else if (base.GetType() == LayerType::Addition)
436 {
437 AdditionLayer* baseLayer = PolymorphicDowncast<AdditionLayer*>(&base);
438
439 arm_compute::Status status = ClAdditionValidate(
440 baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
441 baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
442 activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
443 &activationDesc);
444
445 if (status)
446 {
Cathal Corbettcbfd7182021-12-15 17:12:59 +0000447 FuseAdditionLayer<AdditionLayer>(optimizationViews,
448 baseLayer,
449 activationLayer,
450 activationDesc,
451 name);
Mike Kelly1ac690a2020-11-17 11:41:38 +0000452 untouched.erase(baseLayer->GetGuid());
453 untouched.erase(activationLayer->GetGuid());
Mike Kelly07810fc2020-11-12 10:58:48 +0000454 }
455 }
456 else if (base.GetType() == LayerType::Division)
457 {
458 DivisionLayer* baseLayer = PolymorphicDowncast<DivisionLayer*>(&base);
459
460 arm_compute::Status status = ClDivisionWorkloadValidate(
461 baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
462 baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
463 activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
464 &activationDesc);
465
466 if (status)
467 {
Cathal Corbettcbfd7182021-12-15 17:12:59 +0000468 FuseDivisionLayer<DivisionLayer>(optimizationViews,
469 baseLayer,
470 activationLayer,
471 activationDesc,
472 name);
Mike Kelly1ac690a2020-11-17 11:41:38 +0000473 untouched.erase(baseLayer->GetGuid());
474 untouched.erase(activationLayer->GetGuid());
Mike Kelly07810fc2020-11-12 10:58:48 +0000475 }
476 }
477 else if (base.GetType() == LayerType::Multiplication)
478 {
479 MultiplicationLayer* baseLayer = PolymorphicDowncast<MultiplicationLayer*>(&base);
480
481 arm_compute::Status status = ClMultiplicationWorkloadValidate(
482 baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
483 baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
484 activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
485 &activationDesc);
486
487 if (status)
488 {
Cathal Corbettcbfd7182021-12-15 17:12:59 +0000489 FuseMultiplicationLayer<MultiplicationLayer>(optimizationViews,
490 baseLayer,
491 activationLayer,
492 activationDesc,
493 name);
Mike Kelly1ac690a2020-11-17 11:41:38 +0000494 untouched.erase(baseLayer->GetGuid());
495 untouched.erase(activationLayer->GetGuid());
Mike Kelly07810fc2020-11-12 10:58:48 +0000496 }
497 }
498 else if (base.GetType() == LayerType::Subtraction)
499 {
500 SubtractionLayer* baseLayer = PolymorphicDowncast<SubtractionLayer*>(&base);
501
502 arm_compute::Status status = ClSubtractionValidate(
503 baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
504 baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
505 activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
506 &activationDesc);
507
508 if (status)
509 {
Cathal Corbettcbfd7182021-12-15 17:12:59 +0000510 FuseSubtractionLayer<SubtractionLayer>(optimizationViews,
511 baseLayer,
512 activationLayer,
513 activationDesc,
514 name);
Mike Kelly1ac690a2020-11-17 11:41:38 +0000515 untouched.erase(baseLayer->GetGuid());
516 untouched.erase(activationLayer->GetGuid());
Mike Kelly07810fc2020-11-12 10:58:48 +0000517 }
518 }
519 }
520 }
521 }
522 }
523 }
Matthew Sloyan5fc0fd62021-05-03 12:22:03 +0100524
525 // Separate reduce layer with multiple axes into multiple reduce layers with 1 axis.
526 if (base.GetType() == LayerType::Reduce)
527 {
528 ReduceLayer* baseLayer = PolymorphicDowncast<ReduceLayer*>(&base);
529 ReduceDescriptor reduceDescriptor = baseLayer->GetParameters();
530
531 if (!reduceDescriptor.m_vAxis.empty() && reduceDescriptor.m_vAxis.size() > 1)
532 {
533 // Add new layers to the graph and connect them.
Francis Murtagh56ccf682021-12-13 18:48:12 +0000534 std::vector<IConnectableLayer*> layers = ChainReduceLayers<ReduceLayer>(optimizationViews,
535 baseLayer,
536 reduceDescriptor);
Matthew Sloyan5fc0fd62021-05-03 12:22:03 +0100537
538 // Replace existing baselayer with new subgraph.
539 ReplaceLayers<ReduceLayer>(optimizationViews, baseLayer, layers);
540 untouched.erase(baseLayer->GetGuid());
541 }
542 }
Mike Kelly07810fc2020-11-12 10:58:48 +0000543 }
Mike Kelly1ac690a2020-11-17 11:41:38 +0000544
Mike Kelly07810fc2020-11-12 10:58:48 +0000545 if (optimizationViews.GetSubstitutions().empty())
546 {
547 optimizationViews.AddUntouchedSubgraph(SubgraphView(subgraph));
548 }
Mike Kelly1ac690a2020-11-17 11:41:38 +0000549 else
550 {
551 ReportUntouchedLayers(optimizationViews, untouched);
552 }
Matteo Martincighc3ba50e2019-05-22 14:28:16 +0100553
554 return optimizationViews;
Matteo Martincighadddddb2019-01-24 14:06:23 +0000555}
556
David Beck9efb57d2018-11-05 13:40:33 +0000557} // namespace armnn