blob: fe97cb10661fdae1a2159768247fa6348940618a [file] [log] [blame]
Mike Kelly9b398322019-05-22 17:21:49 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RefConvolution2dWorkload.hpp"
7
8#include "ConvImpl.hpp"
9#include "RefWorkloadUtils.hpp"
10
11#include "Profiling.hpp"
12
13namespace armnn
14{
Keith Davisb4dd5cc2022-04-07 11:32:00 +010015RefConvolution2dWorkload::RefConvolution2dWorkload(const Convolution2dQueueDescriptor& descriptor,
16 const WorkloadInfo& info)
Finn Williams73c547d2022-02-15 20:47:34 +000017 : RefBaseWorkload<Convolution2dQueueDescriptor>(descriptor, info)
Mike Kelly9b398322019-05-22 17:21:49 +010018{
Keith Davis554fa092021-07-20 11:25:22 +010019 WorkloadInfo detailsInfo;
20 detailsInfo.m_InputTensorInfos = info.m_InputTensorInfos;
21 detailsInfo.m_OutputTensorInfos = info.m_OutputTensorInfos;
Keith Davis554fa092021-07-20 11:25:22 +010022
23 // Report Profiling Details
Keith Davisf4874862021-08-09 16:49:18 +010024 ARMNN_REPORT_PROFILING_WORKLOAD_DESC("RefConvolution2dWorkload_Construct",
Keith Davis5a64f222021-08-04 10:35:20 +010025 descriptor.m_Parameters,
26 detailsInfo,
27 this->GetGuid());
Keith Davisb4dd5cc2022-04-07 11:32:00 +010028}
Keith Davis554fa092021-07-20 11:25:22 +010029
Keith Davisb4dd5cc2022-04-07 11:32:00 +010030void RefConvolution2dWorkload::PostAllocationConfigure()
31{
32 PostAllocationConfigure(m_Data.m_Inputs, m_Data.m_Outputs);
33}
Matthew Bentham4cefc412019-06-18 16:14:34 +010034
Keith Davisb4dd5cc2022-04-07 11:32:00 +010035void RefConvolution2dWorkload::PostAllocationConfigure(std::vector<ITensorHandle*> inputs,
36 std::vector<ITensorHandle*> outputs)
37{
38 const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
39 ARMNN_ASSERT(inputInfo.GetNumDimensions() > 1);
40 m_InputShape = inputInfo.GetShape();
41
42 const TensorInfo& rFilterInfo = GetTensorInfo(inputs[1]);
43 ARMNN_ASSERT(inputInfo.GetNumDimensions() > 1);
Mike Kelly9b398322019-05-22 17:21:49 +010044 m_FilterShape = rFilterInfo.GetShape();
Keith Davisb4dd5cc2022-04-07 11:32:00 +010045 m_FilterDecoder = MakeDecoder<float>(rFilterInfo);
Mike Kelly9b398322019-05-22 17:21:49 +010046
Keith Davisb4dd5cc2022-04-07 11:32:00 +010047 if (m_Data.m_Parameters.m_BiasEnabled)
Mike Kelly9b398322019-05-22 17:21:49 +010048 {
Keith Davisb4dd5cc2022-04-07 11:32:00 +010049 const TensorInfo& biasInfo = GetTensorInfo(inputs[2]);
50 m_BiasDecoder = MakeDecoder<float>(biasInfo);
Mike Kelly9b398322019-05-22 17:21:49 +010051 }
Keith Davisb4dd5cc2022-04-07 11:32:00 +010052
53 const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
54 m_OutputShape = outputInfo.GetShape();
Mike Kelly9b398322019-05-22 17:21:49 +010055}
56
Finn Williamsb8181f72021-04-07 10:23:21 +010057void RefConvolution2dWorkload::Execute() const
Mike Kelly9b398322019-05-22 17:21:49 +010058{
Finn Williamsb8181f72021-04-07 10:23:21 +010059 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
Mike Kelly9b398322019-05-22 17:21:49 +010060}
61
Keith Davis554fa092021-07-20 11:25:22 +010062void RefConvolution2dWorkload::ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor)
Finn Williamsb8181f72021-04-07 10:23:21 +010063{
Keith Davisb4dd5cc2022-04-07 11:32:00 +010064 PostAllocationConfigure(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs);
65
Finn Williamsb8181f72021-04-07 10:23:21 +010066 Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs);
67}
68
Keith Davis554fa092021-07-20 11:25:22 +010069void RefConvolution2dWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
70{
Keith Davis5a64f222021-08-04 10:35:20 +010071 ARMNN_SCOPED_PROFILING_EVENT_GUID(Compute::CpuRef, "RefConvolution2dWorkload_Execute", this->GetGuid());
Mike Kelly9b398322019-05-22 17:21:49 +010072
Finn Williamsb8181f72021-04-07 10:23:21 +010073 std::unique_ptr<Decoder<float>> inputDecoder = MakeDecoder<float>(GetTensorInfo(inputs[0]), inputs[0]->Map());
74 std::unique_ptr<Encoder<float>> outputEncoder = MakeEncoder<float>(GetTensorInfo(outputs[0]), outputs[0]->Map());
Matthew Benthamc394a6d2019-06-24 12:51:25 +010075
Keith Davisb4dd5cc2022-04-07 11:32:00 +010076 m_FilterDecoder->Reset(inputs[1]->Map());
77 if (m_Data.m_Parameters.m_BiasEnabled)
78 {
79 m_BiasDecoder->Reset(inputs[2]->Map());
80 }
Finn Williamsb8181f72021-04-07 10:23:21 +010081
Keith Davisb4dd5cc2022-04-07 11:32:00 +010082 Convolve(m_InputShape, *inputDecoder, m_OutputShape, *outputEncoder, m_FilterShape,
Mike Kelly9b398322019-05-22 17:21:49 +010083 *m_FilterDecoder, m_Data.m_Parameters.m_BiasEnabled, m_BiasDecoder.get(),
84 m_Data.m_Parameters.m_DataLayout, m_Data.m_Parameters.m_PadTop, m_Data.m_Parameters.m_PadLeft,
85 m_Data.m_Parameters.m_StrideX, m_Data.m_Parameters.m_StrideY,
86 m_Data.m_Parameters.m_DilationX, m_Data.m_Parameters.m_DilationY);
87}
88
Keith Davisb4dd5cc2022-04-07 11:32:00 +010089} //namespace armnn