blob: b6bdf23ffacc72bada1c7411fe14560418ab0ead [file] [log] [blame]
Mike Kelly9b398322019-05-22 17:21:49 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <backendsCommon/Workload.hpp>
9#include <backendsCommon/WorkloadData.hpp>
10#include "Decoders.hpp"
11#include "Encoders.hpp"
12
13namespace armnn
14{
15
16class RefConvolution2dWorkload : public BaseWorkload<Convolution2dQueueDescriptor>
17{
18public:
19 explicit RefConvolution2dWorkload(const Convolution2dQueueDescriptor& descriptor,
20 const WorkloadInfo& info);
21
22 void PostAllocationConfigure() override;
23
24 virtual void Execute() const override;
25
26private:
27 std::unique_ptr<ScopedCpuTensorHandle> m_Weight;
28 std::unique_ptr<ScopedCpuTensorHandle> m_Bias;
29
30 std::unique_ptr<Decoder<float>> m_InputDecoder;
31 std::unique_ptr<Encoder<float>> m_OutputEncoder;
32 std::unique_ptr<Decoder<float>> m_FilterDecoder;
33 std::unique_ptr<Decoder<float>> m_BiasDecoder;
34
35 TensorShape m_InputShape;
36 TensorShape m_OutputShape;
37 TensorShape m_FilterShape;
38};
39
40} //namespace armnn
41