blob: 6d7037f66045ed0a0169c2e91e6e34db1de5c31b [file] [log] [blame]
Ruomei Yan495852f2019-05-23 11:37:33 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#include <backendsCommon/Workload.hpp>
6#include <backendsCommon/WorkloadData.hpp>
7#include "Decoders.hpp"
8#include "Encoders.hpp"
9
10#include <armnn/TypesUtils.hpp>
11
12namespace armnn
13{
14
15class RefDepthwiseConvolution2dWorkload : public BaseWorkload<DepthwiseConvolution2dQueueDescriptor> {
16public:
17 explicit RefDepthwiseConvolution2dWorkload(const DepthwiseConvolution2dQueueDescriptor &descriptor,
18 const WorkloadInfo &info);
19
20 void PostAllocationConfigure() override;
21
22 virtual void Execute() const override;
23
24private:
25
26 std::unique_ptr <ScopedCpuTensorHandle> m_Weight;
27 std::unique_ptr <ScopedCpuTensorHandle> m_Bias;
28
29 std::unique_ptr <Decoder<float>> m_InputDecoder;
30 std::unique_ptr <Encoder<float>> m_OutputEncoder;
31 std::unique_ptr <Decoder<float>> m_FilterDecoder;
32 std::unique_ptr <Decoder<float>> m_BiasDecoder;
33
34 TensorShape m_InputShape;
35 TensorShape m_OutputShape;
36 TensorShape m_FilterShape;
37};
38
39} //namespace armnn