blob: 9d5cde30b63eea5575111f65b0bb893af27487ba [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
6#pragma once
7
David Beckac42efd2018-09-26 17:41:13 +01008#include <backends/Workload.hpp>
arovir01a6824102018-08-28 17:40:45 +01009#include <boost/optional.hpp>
10
Matthew Bentham14e46692018-09-20 15:35:30 +010011#include <arm_compute/runtime/CL/CLFunctions.h>
12
telsoa01c577f2c2018-08-31 09:22:23 +010013namespace armnn
14{
15
16arm_compute::Status ClDepthwiseConvolutionWorkloadValidate(const TensorInfo& input,
17 const TensorInfo& output,
18 const DepthwiseConvolution2dDescriptor& descriptor,
19 const TensorInfo& weights,
arovir01a6824102018-08-28 17:40:45 +010020 const boost::optional<TensorInfo>& biases);
telsoa01c577f2c2018-08-31 09:22:23 +010021
22template<armnn::DataType... dataTypes>
23class ClDepthwiseConvolutionBaseWorkload : public TypedWorkload<DepthwiseConvolution2dQueueDescriptor, dataTypes...>
24{
25public:
26 using TypedWorkload<DepthwiseConvolution2dQueueDescriptor, dataTypes...>::m_Data;
27
28 ClDepthwiseConvolutionBaseWorkload(const DepthwiseConvolution2dQueueDescriptor& descriptor,
29 const WorkloadInfo& info);
30
31protected:
32 std::unique_ptr<arm_compute::IFunction> m_DepthwiseConvolutionLayer;
33
34 std::unique_ptr<arm_compute::CLTensor> m_KernelTensor;
35 std::unique_ptr<arm_compute::CLTensor> m_BiasTensor;
36
37 void FreeUnusedTensors();
38};
39
40} //namespace armnn