blob: 4cc0f7c1c215124f0e355ad98dcebc7e07de37eb [file] [log] [blame]
Matteo Martincigh28dcab62018-10-19 16:40:03 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ClMeanWorkload.hpp"
7
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00008#include <cl/ClTensorHandle.hpp>
9#include <aclCommon/ArmComputeTensorUtils.hpp>
Matteo Martincigh28dcab62018-10-19 16:40:03 +010010
11#include "ClWorkloadUtils.hpp"
12
Matteo Martincigh28dcab62018-10-19 16:40:03 +010013namespace armnn
14{
15using namespace armcomputetensorutils;
16
17arm_compute::Status ClMeanValidate(const TensorInfo& input,
18 const TensorInfo& output,
19 const MeanDescriptor& desc)
20{
21 const arm_compute::TensorInfo aclInputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(input);
22 const arm_compute::TensorInfo aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
23
Matthew Benthamfd899962018-12-31 15:49:42 +000024 arm_compute::Coordinates coords = BuildArmComputeReductionCoordinates(aclInputInfo.num_dimensions(),
25 input.GetNumDimensions(),
26 desc.m_Axis);
Matteo Martincigh28dcab62018-10-19 16:40:03 +010027
28 return arm_compute::CLReduceMean::validate(&aclInputInfo, coords, desc.m_KeepDims, &aclOutputInfo);
29}
30
Sadik Armagane9444752020-12-02 11:28:58 +000031ClMeanWorkload::ClMeanWorkload(const MeanQueueDescriptor& descriptor,
32 const WorkloadInfo& info,
33 const arm_compute::CLCompileContext& clCompileContext)
Matteo Martincigh28dcab62018-10-19 16:40:03 +010034 : BaseWorkload<MeanQueueDescriptor>(descriptor, info)
35{
36 m_Data.ValidateInputsOutputs("ClMeanWorkload", 1, 1);
37
38 arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
39 arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
40
Matthew Benthamfd899962018-12-31 15:49:42 +000041 arm_compute::Coordinates coords = BuildArmComputeReductionCoordinates(input.info()->num_dimensions(),
42 info.m_InputTensorInfos[0].GetNumDimensions(),
43 m_Data.m_Parameters.m_Axis);
Matteo Martincigh28dcab62018-10-19 16:40:03 +010044
Sadik Armagane9444752020-12-02 11:28:58 +000045 m_Layer.configure(clCompileContext, &input, coords, m_Data.m_Parameters.m_KeepDims, &output);
Matteo Martincigh28dcab62018-10-19 16:40:03 +010046}
47
48void ClMeanWorkload::Execute() const
49{
50 ARMNN_SCOPED_PROFILING_EVENT_CL("ClMeanWorkload_Execute");
51 m_Layer.run();
52}
53
54} //namespace armnn