blob: 27aec8ecddd520205c00dc8e7d18e7a013b30e04 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
6#pragma once
7
David Beckac42efd2018-09-26 17:41:13 +01008#include <backends/Workload.hpp>
arovir01a6824102018-08-28 17:40:45 +01009
Matthew Bentham14e46692018-09-20 15:35:30 +010010#include <arm_compute/runtime/CL/CLFunctions.h>
11
telsoa01c577f2c2018-08-31 09:22:23 +010012namespace armnn
13{
14
15arm_compute::Status ClDepthwiseConvolutionWorkloadValidate(const TensorInfo& input,
16 const TensorInfo& output,
17 const DepthwiseConvolution2dDescriptor& descriptor,
18 const TensorInfo& weights,
David Beck5eec11d2018-10-04 15:43:17 +010019 const Optional<TensorInfo>& biases);
telsoa01c577f2c2018-08-31 09:22:23 +010020
21template<armnn::DataType... dataTypes>
22class ClDepthwiseConvolutionBaseWorkload : public TypedWorkload<DepthwiseConvolution2dQueueDescriptor, dataTypes...>
23{
24public:
25 using TypedWorkload<DepthwiseConvolution2dQueueDescriptor, dataTypes...>::m_Data;
26
27 ClDepthwiseConvolutionBaseWorkload(const DepthwiseConvolution2dQueueDescriptor& descriptor,
28 const WorkloadInfo& info);
29
30protected:
31 std::unique_ptr<arm_compute::IFunction> m_DepthwiseConvolutionLayer;
32
33 std::unique_ptr<arm_compute::CLTensor> m_KernelTensor;
34 std::unique_ptr<arm_compute::CLTensor> m_BiasTensor;
35
36 void FreeUnusedTensors();
37};
38
39} //namespace armnn