blob: aab8c852c1d185604dd55c968e818589232bb7af [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
6#pragma once
7
Colm Donelan0c479742021-12-10 12:43:54 +00008#include <armnn/backends/Workload.hpp>
arovir01a6824102018-08-28 17:40:45 +01009
Matthew Benthamd8777392018-10-08 09:38:55 +010010#include <arm_compute/runtime/IFunction.h>
11#include <arm_compute/core/Error.h>
12#include <arm_compute/runtime/CL/CLTensor.h>
Matthew Bentham14e46692018-09-20 15:35:30 +010013
telsoa01c577f2c2018-08-31 09:22:23 +010014namespace armnn
15{
16
17arm_compute::Status ClDepthwiseConvolutionWorkloadValidate(const TensorInfo& input,
18 const TensorInfo& output,
19 const DepthwiseConvolution2dDescriptor& descriptor,
20 const TensorInfo& weights,
Mike Kelly07810fc2020-11-12 10:58:48 +000021 const Optional<TensorInfo>& biases,
22 const ActivationDescriptor* activationDescriptor = nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +010023
Matthew Benthamd8777392018-10-08 09:38:55 +010024class ClDepthwiseConvolutionWorkload : public BaseWorkload<DepthwiseConvolution2dQueueDescriptor>
telsoa01c577f2c2018-08-31 09:22:23 +010025{
26public:
Matthew Benthamd8777392018-10-08 09:38:55 +010027 using BaseWorkload<DepthwiseConvolution2dQueueDescriptor>::m_Data;
telsoa01c577f2c2018-08-31 09:22:23 +010028
Matthew Benthamd8777392018-10-08 09:38:55 +010029 ClDepthwiseConvolutionWorkload(const DepthwiseConvolution2dQueueDescriptor& descriptor,
Sadik Armagane9444752020-12-02 11:28:58 +000030 const WorkloadInfo& info,
31 const arm_compute::CLCompileContext& clCompileContext);
Matthew Benthamd8777392018-10-08 09:38:55 +010032
33 void Execute() const override;
telsoa01c577f2c2018-08-31 09:22:23 +010034
35protected:
36 std::unique_ptr<arm_compute::IFunction> m_DepthwiseConvolutionLayer;
37
38 std::unique_ptr<arm_compute::CLTensor> m_KernelTensor;
39 std::unique_ptr<arm_compute::CLTensor> m_BiasTensor;
40
41 void FreeUnusedTensors();
42};
43
44} //namespace armnn