blob: fa99e7f54d51a1237137822ab96d3339f8560e63 [file] [log] [blame]
Aron Virginas-Tar94c4fef2019-11-25 15:37:08 +00001//
2// Copyright © 2019 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ClSliceWorkload.hpp"
7
8#include "ClWorkloadUtils.hpp"
9
10#include <aclCommon/ArmComputeTensorUtils.hpp>
11
12#include <cl/ClTensorHandle.hpp>
13
14#include <boost/cast.hpp>
15
16namespace armnn
17{
18
19arm_compute::Status ClSliceWorkloadValidate(const TensorInfo& input,
20 const TensorInfo& output,
21 const SliceDescriptor& descriptor)
22{
23 const arm_compute::TensorInfo aclInput = armcomputetensorutils::BuildArmComputeTensorInfo(input);
24 const arm_compute::TensorInfo aclOutput = armcomputetensorutils::BuildArmComputeTensorInfo(output);
25
26 arm_compute::Coordinates starts;
27 arm_compute::Coordinates ends;
28
29 std::tie(starts, ends) = SetClSliceData(descriptor.m_Begin, descriptor.m_Size);
30
31 return arm_compute::CLSlice::validate(&aclInput, &aclOutput, starts, ends);
32}
33
34ClSliceWorkload::ClSliceWorkload(const SliceQueueDescriptor& descriptor, const WorkloadInfo& info)
35 : BaseWorkload<SliceQueueDescriptor>(descriptor, info)
36{
37 m_Data.ValidateInputsOutputs("ClSliceWorkload", 1, 1);
38
39 arm_compute::ICLTensor& input = boost::polymorphic_downcast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
40 arm_compute::ICLTensor& output = boost::polymorphic_downcast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
41
42 arm_compute::Coordinates starts;
43 arm_compute::Coordinates ends;
44
45 std::tie(starts, ends) = SetClSliceData(m_Data.m_Parameters.m_Begin, m_Data.m_Parameters.m_Size);
46
47 m_SliceFunction.configure(&input, &output, starts, ends);
48}
49
50void ClSliceWorkload::Execute() const
51{
52 ARMNN_SCOPED_PROFILING_EVENT_CL("ClSliceWorkload_Execute");
53 RunClFunction(m_SliceFunction, CHECK_LOCATION());
54}
55
56} // namespace armnn