blob: 470b6a883d5f30cb7b09b4679730d909675c82cc [file] [log] [blame]
Matteo Martincigh28dcab62018-10-19 16:40:03 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ClMeanWorkload.hpp"
7
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00008#include <cl/ClTensorHandle.hpp>
9#include <aclCommon/ArmComputeTensorUtils.hpp>
Matteo Martincigh28dcab62018-10-19 16:40:03 +010010
11#include "ClWorkloadUtils.hpp"
12
Matteo Martincigh28dcab62018-10-19 16:40:03 +010013namespace armnn
14{
15using namespace armcomputetensorutils;
16
17arm_compute::Status ClMeanValidate(const TensorInfo& input,
18 const TensorInfo& output,
19 const MeanDescriptor& desc)
20{
21 const arm_compute::TensorInfo aclInputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(input);
22 const arm_compute::TensorInfo aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
23
Matthew Benthamfd899962018-12-31 15:49:42 +000024 arm_compute::Coordinates coords = BuildArmComputeReductionCoordinates(aclInputInfo.num_dimensions(),
25 input.GetNumDimensions(),
26 desc.m_Axis);
Matteo Martincigh28dcab62018-10-19 16:40:03 +010027
28 return arm_compute::CLReduceMean::validate(&aclInputInfo, coords, desc.m_KeepDims, &aclOutputInfo);
29}
30
31ClMeanWorkload::ClMeanWorkload(const MeanQueueDescriptor& descriptor, const WorkloadInfo& info)
32 : BaseWorkload<MeanQueueDescriptor>(descriptor, info)
33{
34 m_Data.ValidateInputsOutputs("ClMeanWorkload", 1, 1);
35
36 arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
37 arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
38
Matthew Benthamfd899962018-12-31 15:49:42 +000039 arm_compute::Coordinates coords = BuildArmComputeReductionCoordinates(input.info()->num_dimensions(),
40 info.m_InputTensorInfos[0].GetNumDimensions(),
41 m_Data.m_Parameters.m_Axis);
Matteo Martincigh28dcab62018-10-19 16:40:03 +010042
43 m_Layer.configure(&input, coords, m_Data.m_Parameters.m_KeepDims, &output);
44}
45
46void ClMeanWorkload::Execute() const
47{
48 ARMNN_SCOPED_PROFILING_EVENT_CL("ClMeanWorkload_Execute");
49 m_Layer.run();
50}
51
52} //namespace armnn