blob: cde9f50d38734c79cf694ed27063f5bd72811bc5 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
6#pragma once
7
David Beckac42efd2018-09-26 17:41:13 +01008#include <backends/Workload.hpp>
arovir01a6824102018-08-28 17:40:45 +01009
Matthew Benthamd8777392018-10-08 09:38:55 +010010#include <arm_compute/runtime/IFunction.h>
11#include <arm_compute/core/Error.h>
12#include <arm_compute/runtime/CL/CLTensor.h>
Matthew Bentham14e46692018-09-20 15:35:30 +010013
telsoa01c577f2c2018-08-31 09:22:23 +010014namespace armnn
15{
16
17arm_compute::Status ClDepthwiseConvolutionWorkloadValidate(const TensorInfo& input,
18 const TensorInfo& output,
19 const DepthwiseConvolution2dDescriptor& descriptor,
20 const TensorInfo& weights,
David Beck5eec11d2018-10-04 15:43:17 +010021 const Optional<TensorInfo>& biases);
telsoa01c577f2c2018-08-31 09:22:23 +010022
Matthew Benthamd8777392018-10-08 09:38:55 +010023class ClDepthwiseConvolutionWorkload : public BaseWorkload<DepthwiseConvolution2dQueueDescriptor>
telsoa01c577f2c2018-08-31 09:22:23 +010024{
25public:
Matthew Benthamd8777392018-10-08 09:38:55 +010026 using BaseWorkload<DepthwiseConvolution2dQueueDescriptor>::m_Data;
telsoa01c577f2c2018-08-31 09:22:23 +010027
Matthew Benthamd8777392018-10-08 09:38:55 +010028 ClDepthwiseConvolutionWorkload(const DepthwiseConvolution2dQueueDescriptor& descriptor,
29 const WorkloadInfo& info);
30
31 void Execute() const override;
telsoa01c577f2c2018-08-31 09:22:23 +010032
33protected:
34 std::unique_ptr<arm_compute::IFunction> m_DepthwiseConvolutionLayer;
35
36 std::unique_ptr<arm_compute::CLTensor> m_KernelTensor;
37 std::unique_ptr<arm_compute::CLTensor> m_BiasTensor;
38
39 void FreeUnusedTensors();
40};
41
42} //namespace armnn