blob: 068487039b871cc75f359830ec27e15710665d33 [file] [log] [blame]
Teresa Charlin9ad2e5b2020-04-10 22:34:48 +01001//
2// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ClGatherWorkload.hpp"
7#include "ClWorkloadUtils.hpp"
8#include <aclCommon/ArmComputeUtils.hpp>
9#include <cl/ClTensorHandle.hpp>
10
11using namespace armnn::armcomputetensorutils;
12
13namespace armnn
14{
15arm_compute::Status ClGatherWorkloadValidate(const TensorInfo& input,
16 const TensorInfo& indices,
17 const TensorInfo& output)
18{
19 const arm_compute::TensorInfo aclInput = BuildArmComputeTensorInfo(input);
20 const arm_compute::TensorInfo aclIndices = BuildArmComputeTensorInfo(indices);
21 const arm_compute::TensorInfo aclOutput = BuildArmComputeTensorInfo(output);
22
23 int aclAxis = ComputeAclAxis(0, input);
24
25 return arm_compute::CLGather::validate(&aclInput, &aclIndices, &aclOutput, aclAxis);
26}
27
28ClGatherWorkload::ClGatherWorkload(const GatherQueueDescriptor& descriptor,
29 const WorkloadInfo& info)
30 : BaseWorkload<GatherQueueDescriptor>(descriptor, info)
31{
32 m_Data.ValidateInputsOutputs("ClGatherWorkload", 1, 1);
33
34 arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
35 arm_compute::ICLTensor& indices = static_cast<IClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
36 arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
37
38 int aclAxis = ComputeAclAxis(0, info.m_InputTensorInfos[0]);
39
40 m_Layer.configure(&input, &indices, &output, aclAxis);
41};
42
43void ClGatherWorkload::Execute() const
44{
45 ARMNN_SCOPED_PROFILING_EVENT_CL("ClGatherWorkload_Execute");
46 RunClFunction(m_Layer, CHECK_LOCATION());
47}
48} // namespace armnn