blob: 7293c830acd15443e16d50df09625917995411e8 [file] [log] [blame]
Matthew Benthamd8067922018-10-03 17:18:04 +01001//
Teresa Charlin588cbdf2022-01-19 15:55:37 +00002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
Matthew Benthamd8067922018-10-03 17:18:04 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/Tensor.hpp>
9#include <armnn/Descriptors.hpp>
10
Teresa Charlin588cbdf2022-01-19 15:55:37 +000011#include "ClBaseWorkload.hpp"
Matthew Benthamd8067922018-10-03 17:18:04 +010012
13#include <arm_compute/runtime/CL/functions/CLConvolutionLayer.h>
14#include <arm_compute/runtime/MemoryManagerOnDemand.h>
15
Narumol Prangnawarate2af6f42022-01-28 17:59:18 +000016#include <cl/ICLTensorProxy.hpp>
17
Matthew Benthamd8067922018-10-03 17:18:04 +010018#include <memory>
19
20namespace armnn
21{
22
23arm_compute::Status ClConvolution2dWorkloadValidate(const TensorInfo& input,
24 const TensorInfo& output,
25 const Convolution2dDescriptor& descriptor,
26 const TensorInfo& weights,
Sadik Armagan045f6be2020-09-10 13:37:32 +010027 const Optional<TensorInfo>& biases,
Mike Kelly07810fc2020-11-12 10:58:48 +000028 bool isFastMathEnabled = false,
29 const ActivationDescriptor* activationDescriptor = nullptr);
Matthew Benthamd8067922018-10-03 17:18:04 +010030
Teresa Charlin588cbdf2022-01-19 15:55:37 +000031class ClConvolution2dWorkload : public ClBaseWorkload<Convolution2dQueueDescriptor>
Matthew Benthamd8067922018-10-03 17:18:04 +010032{
33public:
Sadik Armagan04a72972020-09-14 15:44:18 +010034 ClConvolution2dWorkload(const Convolution2dQueueDescriptor& descriptor,
35 const WorkloadInfo& info,
36 std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager,
Sadik Armagane9444752020-12-02 11:28:58 +000037 const arm_compute::CLCompileContext& clCompileContext,
Sadik Armagan04a72972020-09-14 15:44:18 +010038 const bool isFastMathEnabled = false);
Matthew Benthamd8067922018-10-03 17:18:04 +010039 void Execute() const override;
40
Sadik Armagan04a72972020-09-14 15:44:18 +010041 arm_compute::ConvolutionMethod GetConvolutionMethod() const;
42
David Monahan3826ab62022-02-21 12:26:16 +000043 bool SupportsTensorHandleReplacement() const override
44 {
45 // NCHW DataLayout on ACL still uses paddding for alignment on the Conv2d workload so importing is unreliable.
46 if (m_Data.m_Parameters.m_DataLayout == DataLayout::NCHW)
47 {
48 return false;
49 }
50 else
51 {
52 return true;
53 }
54 }
55
Finn Williams73c547d2022-02-15 20:47:34 +000056
Narumol Prangnawarate2af6f42022-01-28 17:59:18 +000057protected:
58 void Reconfigure() override;
59
Matthew Benthamd8067922018-10-03 17:18:04 +010060private:
61 mutable arm_compute::CLConvolutionLayer m_ConvolutionLayer;
62
Sadik Armagan04a72972020-09-14 15:44:18 +010063 arm_compute::ConvolutionMethod m_ConvolutionMethod;
64
Narumol Prangnawarate2af6f42022-01-28 17:59:18 +000065 std::unique_ptr<ICLTensorProxy> m_InputProxy;
66 std::unique_ptr<ICLTensorProxy> m_OutputProxy;
Matthew Benthamd8067922018-10-03 17:18:04 +010067};
68
69} //namespace armnn
70