blob: dd58e002bebb523e4aa099f74bec5c66adab2d78 [file] [log] [blame]
arovir014424b0a2018-10-04 10:46:04 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ClBackend.hpp"
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +01007#include "ClBackendContext.hpp"
David Beck3e9e1152018-10-17 14:17:50 +01008#include "ClBackendId.hpp"
Sadik Armagan045f6be2020-09-10 13:37:32 +01009#include "ClBackendModelContext.hpp"
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +010010#include "ClImportTensorHandleFactory.hpp"
David Beck111b5d92018-11-12 14:59:37 +000011#include "ClLayerSupport.hpp"
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010012#include "ClTensorHandleFactory.hpp"
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +010013#include "ClWorkloadFactory.hpp"
arovir01a0944792018-10-11 15:00:58 +010014
Matteo Martincighc601aa62019-10-29 15:03:22 +000015#include <armnn/BackendRegistry.hpp>
Mike Kelly07810fc2020-11-12 10:58:48 +000016#include <armnn/Descriptors.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +000017
Mike Kelly07810fc2020-11-12 10:58:48 +000018#include <aclCommon/ArmComputeSubgraphUtils.hpp>
19#include <aclCommon/ArmComputeUtils.hpp>
Aron Virginas-Tar56055192018-11-12 18:10:43 +000020#include <aclCommon/BaseMemoryManager.hpp>
21
Matteo Martincighe5b8eb92019-11-28 15:45:42 +000022#include <armnn/backends/IBackendContext.hpp>
23#include <armnn/backends/IMemoryManager.hpp>
Jan Eilers3c9e0452020-04-10 13:00:44 +010024#include <armnn/utility/PolymorphicDowncast.hpp>
25
Mike Kelly07810fc2020-11-12 10:58:48 +000026#include "workloads/ClAdditionWorkload.hpp"
27#include "workloads/ClBatchNormalizationFloatWorkload.hpp"
28#include "workloads/ClConvolution2dWorkload.hpp"
29#include "workloads/ClDepthwiseConvolutionWorkload.hpp"
Teresa Charline11e63d2021-04-21 12:56:45 +010030#include "workloads/ClDivisionWorkload.hpp"
Mike Kelly07810fc2020-11-12 10:58:48 +000031#include "workloads/ClFullyConnectedWorkload.hpp"
32#include "workloads/ClMultiplicationWorkload.hpp"
Matthew Sloyan5fc0fd62021-05-03 12:22:03 +010033#include "workloads/ClReduceWorkload.hpp"
Mike Kelly07810fc2020-11-12 10:58:48 +000034#include "workloads/ClSubtractionWorkload.hpp"
35
David Beck263e3492018-11-09 14:46:40 +000036#include <Optimizer.hpp>
arovir014424b0a2018-10-04 10:46:04 +010037
Mike Kelly07810fc2020-11-12 10:58:48 +000038#include <arm_compute/core/Types.h>
Aron Virginas-Tar56055192018-11-12 18:10:43 +000039#include <arm_compute/runtime/CL/CLBufferAllocator.h>
40
arovir014424b0a2018-10-04 10:46:04 +010041namespace armnn
42{
43
David Beck3cc9a622018-10-12 10:38:31 +010044const BackendId& ClBackend::GetIdStatic()
arovir014424b0a2018-10-04 10:46:04 +010045{
David Beck3e9e1152018-10-17 14:17:50 +010046 static const BackendId s_Id{ClBackendId()};
arovir014424b0a2018-10-04 10:46:04 +010047 return s_Id;
48}
49
Aron Virginas-Tar56055192018-11-12 18:10:43 +000050IBackendInternal::IMemoryManagerUniquePtr ClBackend::CreateMemoryManager() const
arovir014424b0a2018-10-04 10:46:04 +010051{
Jan Eilersc1c872f2021-07-22 13:17:04 +010052 if (m_UsingCustomAllocator)
53 {
54 return std::make_unique<ClMemoryManager>(m_CustomAllocator);
55 }
Aron Virginas-Tar56055192018-11-12 18:10:43 +000056 return std::make_unique<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
57}
58
59IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory(
60 const IBackendInternal::IMemoryManagerSharedPtr& memoryManager) const
61{
62 return std::make_unique<ClWorkloadFactory>(
Jan Eilers3c9e0452020-04-10 13:00:44 +010063 PolymorphicPointerDowncast<ClMemoryManager>(memoryManager));
arovir014424b0a2018-10-04 10:46:04 +010064}
65
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010066IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory(
Sadik Armagan04a72972020-09-14 15:44:18 +010067 const IBackendInternal::IMemoryManagerSharedPtr& memoryManager, const ModelOptions& modelOptions) const
68{
69 return std::make_unique<ClWorkloadFactory>(
70 PolymorphicPointerDowncast<ClMemoryManager>(memoryManager), CreateBackendSpecificModelContext(modelOptions));
71}
72
73IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory(
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010074 TensorHandleFactoryRegistry& registry) const
75{
Jan Eilersc1c872f2021-07-22 13:17:04 +010076 std::shared_ptr<ClMemoryManager> memoryManager;
77 if (m_UsingCustomAllocator)
78 {
79 memoryManager = std::make_shared<ClMemoryManager>(m_CustomAllocator);
80 }
81 else
82 {
83 memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
84 }
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010085
86 registry.RegisterMemoryManager(memoryManager);
Narumol Prangnawarat549cb7a2020-07-10 17:50:53 +010087 registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(memoryManager));
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +010088 registry.RegisterFactory(std::make_unique<ClImportTensorHandleFactory>(
89 static_cast<MemorySourceFlags>(MemorySource::Malloc), static_cast<MemorySourceFlags>(MemorySource::Malloc)));
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010090
91 return std::make_unique<ClWorkloadFactory>(
Jan Eilers3c9e0452020-04-10 13:00:44 +010092 PolymorphicPointerDowncast<ClMemoryManager>(memoryManager));
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010093}
94
Sadik Armagan04a72972020-09-14 15:44:18 +010095IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory(
96 TensorHandleFactoryRegistry& registry, const ModelOptions& modelOptions) const
97{
Jan Eilersc1c872f2021-07-22 13:17:04 +010098 std::shared_ptr<ClMemoryManager> memoryManager;
99 if (m_UsingCustomAllocator)
100 {
101 memoryManager = std::make_shared<ClMemoryManager>(m_CustomAllocator);
102 }
103 else
104 {
105 memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
106 }
Sadik Armagan04a72972020-09-14 15:44:18 +0100107
108 registry.RegisterMemoryManager(memoryManager);
109 registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(memoryManager));
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +0100110 registry.RegisterFactory(std::make_unique<ClImportTensorHandleFactory>(
111 static_cast<MemorySourceFlags>(MemorySource::Malloc), static_cast<MemorySourceFlags>(MemorySource::Malloc)));
112
113 return std::make_unique<ClWorkloadFactory>(
114 PolymorphicPointerDowncast<ClMemoryManager>(memoryManager), CreateBackendSpecificModelContext(modelOptions));
115}
116
117IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory(
118 TensorHandleFactoryRegistry& registry,
119 const ModelOptions& modelOptions,
120 MemorySourceFlags inputFlags,
121 MemorySourceFlags outputFlags) const
122{
Jan Eilersc1c872f2021-07-22 13:17:04 +0100123 std::shared_ptr<ClMemoryManager> memoryManager;
124 if (m_UsingCustomAllocator)
125 {
126 memoryManager = std::make_shared<ClMemoryManager>(m_CustomAllocator);
127 }
128 else
129 {
130 memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
131 }
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +0100132
133 registry.RegisterMemoryManager(memoryManager);
134 registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(memoryManager));
135 registry.RegisterFactory(std::make_unique<ClImportTensorHandleFactory>(inputFlags, outputFlags));
Sadik Armagan04a72972020-09-14 15:44:18 +0100136
137 return std::make_unique<ClWorkloadFactory>(
138 PolymorphicPointerDowncast<ClMemoryManager>(memoryManager), CreateBackendSpecificModelContext(modelOptions));
139}
140
Jan Eilerse9f0f0f2019-08-16 10:28:37 +0100141std::vector<ITensorHandleFactory::FactoryId> ClBackend::GetHandleFactoryPreferences() const
142{
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +0100143 return std::vector<ITensorHandleFactory::FactoryId> {ClTensorHandleFactory::GetIdStatic(),
144 ClImportTensorHandleFactory::GetIdStatic()};
Jan Eilerse9f0f0f2019-08-16 10:28:37 +0100145}
146
147void ClBackend::RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry)
148{
Jan Eilersc1c872f2021-07-22 13:17:04 +0100149 std::shared_ptr<ClMemoryManager> memoryManager;
150 if (m_UsingCustomAllocator)
151 {
152 memoryManager = std::make_shared<ClMemoryManager>(m_CustomAllocator);
153 }
154 else
155 {
156 memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
157 }
Jan Eilerse9f0f0f2019-08-16 10:28:37 +0100158
Jan Eilersc1c872f2021-07-22 13:17:04 +0100159 registry.RegisterMemoryManager(memoryManager);
160 registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(memoryManager));
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +0100161 registry.RegisterFactory(std::make_unique<ClImportTensorHandleFactory>(
162 static_cast<MemorySourceFlags>(MemorySource::Malloc), static_cast<MemorySourceFlags>(MemorySource::Malloc)));
163}
164
165void ClBackend::RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry,
166 MemorySourceFlags inputFlags,
167 MemorySourceFlags outputFlags)
168{
Jan Eilersc1c872f2021-07-22 13:17:04 +0100169 std::shared_ptr<ClMemoryManager> memoryManager;
170 if (m_UsingCustomAllocator)
171 {
172 memoryManager = std::make_shared<ClMemoryManager>(m_CustomAllocator);
173 }
174 else
175 {
176 memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
177 }
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +0100178
Jan Eilersc1c872f2021-07-22 13:17:04 +0100179 registry.RegisterMemoryManager(memoryManager);
180 registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(memoryManager));
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +0100181 registry.RegisterFactory(std::make_unique<ClImportTensorHandleFactory>(inputFlags, outputFlags));
Jan Eilerse9f0f0f2019-08-16 10:28:37 +0100182}
183
Sadik Armagan045f6be2020-09-10 13:37:32 +0100184IBackendInternal::IBackendContextPtr ClBackend::CreateBackendContext(const IRuntime::CreationOptions& options) const
David Beck1b61be52018-11-08 09:19:14 +0000185{
186 return IBackendContextPtr{new ClBackendContext{options}};
187}
188
Colm Donelane49755b2020-01-29 15:22:43 +0000189IBackendInternal::IBackendProfilingContextPtr ClBackend::CreateBackendProfilingContext(
Colm Donelan1aff3932020-02-05 17:48:59 +0000190 const IRuntime::CreationOptions&, IBackendProfilingPtr&)
Colm Donelane49755b2020-01-29 15:22:43 +0000191{
192 return IBackendProfilingContextPtr{};
193}
194
Sadik Armagan045f6be2020-09-10 13:37:32 +0100195IBackendInternal::IBackendSpecificModelContextPtr ClBackend::CreateBackendSpecificModelContext(
196 const ModelOptions& modelOptions) const
197{
198 return IBackendSpecificModelContextPtr{new ClBackendModelContext{modelOptions}};
199}
200
David Beck111b5d92018-11-12 14:59:37 +0000201IBackendInternal::ILayerSupportSharedPtr ClBackend::GetLayerSupport() const
202{
Sadik Armagan045f6be2020-09-10 13:37:32 +0100203 static ILayerSupportSharedPtr layerSupport
204 {
205 new ClLayerSupport(IBackendInternal::IBackendSpecificModelContextPtr{})
206 };
207 return layerSupport;
208}
209
210IBackendInternal::ILayerSupportSharedPtr ClBackend::GetLayerSupport(const ModelOptions& modelOptions) const
211{
212 static ILayerSupportSharedPtr layerSupport
213 {
214 new ClLayerSupport(CreateBackendSpecificModelContext(modelOptions))
215 };
David Beck111b5d92018-11-12 14:59:37 +0000216 return layerSupport;
217}
218
Mike Kelly07810fc2020-11-12 10:58:48 +0000219OptimizationViews ClBackend::OptimizeSubgraphView(const SubgraphView& subgraph,
220 const ModelOptions& modelOptions) const
Matteo Martincighadddddb2019-01-24 14:06:23 +0000221{
Matteo Martincighc3ba50e2019-05-22 14:28:16 +0100222 OptimizationViews optimizationViews;
Matteo Martincighadddddb2019-01-24 14:06:23 +0000223
Mike Kelly07810fc2020-11-12 10:58:48 +0000224 auto it = subgraph.end();
225 bool isFastMathEnabled = false;
Mike Kelly1ac690a2020-11-17 11:41:38 +0000226 std::map<LayerGuid, Layer*> untouched;
Mike Kelly07810fc2020-11-12 10:58:48 +0000227
Mike Kelly1ac690a2020-11-17 11:41:38 +0000228 while (it != subgraph.begin())
229 {
230 --it;
231 Layer& base = **it;
232 untouched.insert({base.GetGuid(), &base});
233 }
234
235 it = subgraph.end();
Mike Kelly07810fc2020-11-12 10:58:48 +0000236#if defined(ARMCOMPUTECL_ENABLED)
237 IBackendInternal::IBackendSpecificModelContextPtr modelContextPtr = CreateBackendSpecificModelContext(modelOptions);
238
239 if (modelContextPtr)
240 {
241 auto clModelOptions = dynamic_cast<ClBackendModelContext*>(modelContextPtr.get());
242 if (clModelOptions)
243 {
244 isFastMathEnabled = clModelOptions->IsFastMathEnabled();
245 }
246 }
247#endif
Mike Kelly07810fc2020-11-12 10:58:48 +0000248 while (it != subgraph.begin())
249 {
250 --it;
251 Layer& base = **it;
252
Matthew Sloyan5fc0fd62021-05-03 12:22:03 +0100253 // Fuse activation into previous layer if supported by backend
Mike Kelly07810fc2020-11-12 10:58:48 +0000254 if ((base.GetType() == LayerType::DepthwiseConvolution2d || base.GetType() == LayerType::Convolution2d
255 || base.GetType() == LayerType::BatchNormalization || base.GetType() == LayerType::FullyConnected
256 || base.GetType() == LayerType::Addition || base.GetType() == LayerType::Multiplication
Matthew Sloyanae123062021-05-07 14:18:01 +0000257 || base.GetType() == LayerType::Subtraction || base.GetType() == LayerType::Division)
Mike Kelly07810fc2020-11-12 10:58:48 +0000258 && (base.GetAdditionalInformation<ActivationDescriptor>() == nullptr))
259 {
260 for (auto output = base.BeginOutputSlots(); output != base.EndOutputSlots(); ++output)
261 {
262 if (output->GetNumConnections() == 1)
263 {
264 for (auto&& childInput : output->GetConnections())
265 {
Teresa Charlind672f5d2021-01-18 18:07:57 +0000266 if ((childInput->GetOwningLayer().GetType() == LayerType::Activation) &&
267 (checkDataTypeInputandOutput(childInput->GetOwningLayer())))
Mike Kelly07810fc2020-11-12 10:58:48 +0000268 {
269 Layer& child = childInput->GetOwningLayer();
270
271 auto* activationLayer = PolymorphicDowncast<ActivationLayer*>(&child);
272
273 const std::string name = std::string("fused-") + child.GetName() + std::string("-into-") +
274 base.GetName();
275
276 // Get params from activation layer
277 ActivationDescriptor activationDesc = activationLayer->GetParameters();
278
279 if (base.GetType() == LayerType::Convolution2d)
280 {
281 Convolution2dLayer* baseLayer = PolymorphicDowncast<Convolution2dLayer*>(&base);
282
283 Optional<TensorInfo> biases;
284
285 if (baseLayer->GetParameters().m_BiasEnabled)
286 {
Mike Kelly1ac690a2020-11-17 11:41:38 +0000287 biases = baseLayer->m_Bias->GetTensorInfo();
Mike Kelly07810fc2020-11-12 10:58:48 +0000288 }
289
290 arm_compute::Status status = ClConvolution2dWorkloadValidate(
291 baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
292 activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
293 baseLayer->GetParameters(),
294 baseLayer->m_Weight->GetTensorInfo(),
295 biases,
296 isFastMathEnabled,
297 &activationDesc);
298
299 if (status)
300 {
301 FuseLayerWithWeightsAndBiases<Convolution2dLayer>(optimizationViews,
302 baseLayer,
303 activationLayer,
304 activationDesc,
305 name);
Mike Kelly1ac690a2020-11-17 11:41:38 +0000306 untouched.erase(baseLayer->GetGuid());
307 untouched.erase(activationLayer->GetGuid());
Mike Kelly07810fc2020-11-12 10:58:48 +0000308 }
309 }
310 else if (base.GetType() == LayerType::DepthwiseConvolution2d)
311 {
312 DepthwiseConvolution2dLayer* baseLayer =
313 PolymorphicDowncast<DepthwiseConvolution2dLayer*>(&base);
314
315 Optional<TensorInfo> biases;
316
317 if (baseLayer->GetParameters().m_BiasEnabled)
318 {
Mike Kelly1ac690a2020-11-17 11:41:38 +0000319 biases = baseLayer->m_Bias->GetTensorInfo();
Mike Kelly07810fc2020-11-12 10:58:48 +0000320 }
321
322 arm_compute::Status status = ClDepthwiseConvolutionWorkloadValidate(
323 baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
324 activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
325 baseLayer->GetParameters(),
326 baseLayer->m_Weight->GetTensorInfo(),
327 biases,
328 &activationDesc);
329
330 if (status)
331 {
332 FuseLayerWithWeightsAndBiases<DepthwiseConvolution2dLayer>(optimizationViews,
333 baseLayer,
334 activationLayer,
335 activationDesc,
336 name);
Mike Kelly1ac690a2020-11-17 11:41:38 +0000337 untouched.erase(baseLayer->GetGuid());
338 untouched.erase(activationLayer->GetGuid());
Mike Kelly07810fc2020-11-12 10:58:48 +0000339 }
340 }
341 else if (base.GetType() == LayerType::FullyConnected)
342 {
343 FullyConnectedLayer* baseLayer = PolymorphicDowncast<FullyConnectedLayer*>(&base);
344
345 arm_compute::Status status = ClFullyConnectedWorkloadValidate(
346 baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
347 activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
348 baseLayer->m_Weight->GetTensorInfo(),
349 baseLayer->m_Bias->GetTensorInfo(),
350 baseLayer->GetParameters(),
351 &activationDesc);
352
353 if (status)
354 {
355 FuseLayerWithWeightsAndBiases<FullyConnectedLayer>(optimizationViews,
356 baseLayer,
357 activationLayer,
358 activationDesc,
359 name);
Mike Kelly1ac690a2020-11-17 11:41:38 +0000360 untouched.erase(baseLayer->GetGuid());
361 untouched.erase(activationLayer->GetGuid());
Mike Kelly07810fc2020-11-12 10:58:48 +0000362 }
363 }
364 else if (base.GetType() == LayerType::BatchNormalization)
365 {
366 BatchNormalizationLayer* baseLayer =
367 PolymorphicDowncast<BatchNormalizationLayer*>(&base);
368
369 arm_compute::Status status = ClBatchNormalizationValidate(
370 baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
371 activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
372 baseLayer->m_Mean->GetTensorInfo(),
373 baseLayer->m_Variance->GetTensorInfo(),
374 baseLayer->m_Beta->GetTensorInfo(),
375 baseLayer->m_Gamma->GetTensorInfo(),
376 baseLayer->GetParameters(),
377 &activationDesc);
378
379 if (status)
380 {
381 BatchNormalizationLayer* replacementLayer =
382 FuseLayerWithParameters<BatchNormalizationLayer>(optimizationViews,
383 baseLayer,
384 activationLayer,
385 activationDesc,
386 name);
387
388 replacementLayer->m_Beta = std::move(baseLayer->m_Beta);
389 replacementLayer->m_Gamma = std::move(baseLayer->m_Gamma);
390 replacementLayer->m_Mean = std::move(baseLayer->m_Mean);
391 replacementLayer->m_Variance = std::move(baseLayer->m_Variance);
Mike Kelly1ac690a2020-11-17 11:41:38 +0000392 untouched.erase(baseLayer->GetGuid());
393 untouched.erase(activationLayer->GetGuid());
Mike Kelly07810fc2020-11-12 10:58:48 +0000394 }
395 }
396 else if (base.GetType() == LayerType::Addition)
397 {
398 AdditionLayer* baseLayer = PolymorphicDowncast<AdditionLayer*>(&base);
399
400 arm_compute::Status status = ClAdditionValidate(
401 baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
402 baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
403 activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
404 &activationDesc);
405
406 if (status)
407 {
408 FuseLayerWithoutParameters<AdditionLayer>(optimizationViews,
409 baseLayer,
410 activationLayer,
411 activationDesc,
412 name);
Mike Kelly1ac690a2020-11-17 11:41:38 +0000413 untouched.erase(baseLayer->GetGuid());
414 untouched.erase(activationLayer->GetGuid());
Mike Kelly07810fc2020-11-12 10:58:48 +0000415 }
416 }
417 else if (base.GetType() == LayerType::Division)
418 {
419 DivisionLayer* baseLayer = PolymorphicDowncast<DivisionLayer*>(&base);
420
421 arm_compute::Status status = ClDivisionWorkloadValidate(
422 baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
423 baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
424 activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
425 &activationDesc);
426
427 if (status)
428 {
429 FuseLayerWithoutParameters<DivisionLayer>(optimizationViews,
430 baseLayer,
431 activationLayer,
432 activationDesc,
433 name);
Mike Kelly1ac690a2020-11-17 11:41:38 +0000434 untouched.erase(baseLayer->GetGuid());
435 untouched.erase(activationLayer->GetGuid());
Mike Kelly07810fc2020-11-12 10:58:48 +0000436 }
437 }
438 else if (base.GetType() == LayerType::Multiplication)
439 {
440 MultiplicationLayer* baseLayer = PolymorphicDowncast<MultiplicationLayer*>(&base);
441
442 arm_compute::Status status = ClMultiplicationWorkloadValidate(
443 baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
444 baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
445 activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
446 &activationDesc);
447
448 if (status)
449 {
450 FuseLayerWithoutParameters<MultiplicationLayer>(optimizationViews,
451 baseLayer,
452 activationLayer,
453 activationDesc,
454 name);
Mike Kelly1ac690a2020-11-17 11:41:38 +0000455 untouched.erase(baseLayer->GetGuid());
456 untouched.erase(activationLayer->GetGuid());
Mike Kelly07810fc2020-11-12 10:58:48 +0000457 }
458 }
459 else if (base.GetType() == LayerType::Subtraction)
460 {
461 SubtractionLayer* baseLayer = PolymorphicDowncast<SubtractionLayer*>(&base);
462
463 arm_compute::Status status = ClSubtractionValidate(
464 baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
465 baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
466 activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
467 &activationDesc);
468
469 if (status)
470 {
471 FuseLayerWithoutParameters<SubtractionLayer>(optimizationViews,
472 baseLayer,
473 activationLayer,
474 activationDesc,
475 name);
Mike Kelly1ac690a2020-11-17 11:41:38 +0000476 untouched.erase(baseLayer->GetGuid());
477 untouched.erase(activationLayer->GetGuid());
Mike Kelly07810fc2020-11-12 10:58:48 +0000478 }
479 }
480 }
481 }
482 }
483 }
484 }
Matthew Sloyan5fc0fd62021-05-03 12:22:03 +0100485
486 // Separate reduce layer with multiple axes into multiple reduce layers with 1 axis.
487 if (base.GetType() == LayerType::Reduce)
488 {
489 ReduceLayer* baseLayer = PolymorphicDowncast<ReduceLayer*>(&base);
490 ReduceDescriptor reduceDescriptor = baseLayer->GetParameters();
491
492 if (!reduceDescriptor.m_vAxis.empty() && reduceDescriptor.m_vAxis.size() > 1)
493 {
494 // Add new layers to the graph and connect them.
495 std::vector<Layer*> layers = ChainReduceLayers<ReduceLayer>(optimizationViews,
496 baseLayer,
497 reduceDescriptor);
498
499 // Replace existing baselayer with new subgraph.
500 ReplaceLayers<ReduceLayer>(optimizationViews, baseLayer, layers);
501 untouched.erase(baseLayer->GetGuid());
502 }
503 }
Mike Kelly07810fc2020-11-12 10:58:48 +0000504 }
Mike Kelly1ac690a2020-11-17 11:41:38 +0000505
Mike Kelly07810fc2020-11-12 10:58:48 +0000506 if (optimizationViews.GetSubstitutions().empty())
507 {
508 optimizationViews.AddUntouchedSubgraph(SubgraphView(subgraph));
509 }
Mike Kelly1ac690a2020-11-17 11:41:38 +0000510 else
511 {
512 ReportUntouchedLayers(optimizationViews, untouched);
513 }
Matteo Martincighc3ba50e2019-05-22 14:28:16 +0100514
515 return optimizationViews;
Matteo Martincighadddddb2019-01-24 14:06:23 +0000516}
517
David Beck9efb57d2018-11-05 13:40:33 +0000518} // namespace armnn