blob: e51fa34233485760fb0944498dd43724ab04e3d8 [file] [log] [blame]
keidav01d74dc912018-12-10 18:16:07 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ClStridedSliceWorkload.hpp"
7
8#include "ClWorkloadUtils.hpp"
9
10#include <aclCommon/ArmComputeUtils.hpp>
11#include <aclCommon/ArmComputeTensorUtils.hpp>
12
13#include <backendsCommon/CpuTensorHandle.hpp>
14
15#include <cl/ClLayerSupport.hpp>
16#include <cl/ClTensorHandle.hpp>
17#include <cl/ClLayerSupport.hpp>
18
19namespace armnn
20{
21
22using namespace armcomputetensorutils;
23
24arm_compute::Status ClStridedSliceWorkloadValidate(const TensorInfo& input,
25 const TensorInfo& output,
26 const StridedSliceDescriptor& descriptor)
27{
28 const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input, descriptor.m_DataLayout);
29 const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output, descriptor.m_DataLayout);
30
31 arm_compute::Coordinates starts;
32 arm_compute::Coordinates ends;
33 arm_compute::Coordinates strides;
34
35 std::tie(starts, ends, strides) = SetClStridedSliceData(descriptor.m_Begin, descriptor.m_End, descriptor.m_Stride);
36
37 int32_t begin_mask = descriptor.m_BeginMask;
38 int32_t end_mask = descriptor.m_EndMask;
39 int32_t shrink_axis_mask = descriptor.m_ShrinkAxisMask;
40
41 return arm_compute::CLStridedSlice::validate(&aclInputInfo,
42 &aclOutputInfo,
43 starts,
44 ends,
45 strides,
46 begin_mask,
47 end_mask,
48 shrink_axis_mask);
49}
50
51ClStridedSliceWorkload::ClStridedSliceWorkload(const StridedSliceQueueDescriptor& descriptor,
52 const WorkloadInfo& info)
53 : BaseWorkload<StridedSliceQueueDescriptor>(descriptor, info)
54{
55 m_Data.ValidateInputsOutputs("ClStridedSliceWorkload", 1, 1);
56
57 arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
58 arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
59
60 arm_compute::Coordinates starts;
61 arm_compute::Coordinates ends;
62 arm_compute::Coordinates strides;
63
64 std::tie(starts, ends, strides) = SetClStridedSliceData(m_Data.m_Parameters.m_Begin,
65 m_Data.m_Parameters.m_End,
66 m_Data.m_Parameters.m_Stride);
67
68 int32_t begin_mask = m_Data.m_Parameters.m_BeginMask;
69 int32_t end_mask = m_Data.m_Parameters.m_EndMask;
70 int32_t shrink_axis_mask = m_Data.m_Parameters.m_ShrinkAxisMask;
71
72 arm_compute::DataLayout aclDataLayout = ConvertDataLayout(m_Data.m_Parameters.m_DataLayout);
73 input.info()->set_data_layout(aclDataLayout);
74 output.info()->set_data_layout(aclDataLayout);
75
76 m_StridedSliceLayer.configure(&input,
77 &output,
78 starts,
79 ends,
80 strides,
81 begin_mask,
82 end_mask,
83 shrink_axis_mask);
84}
85
86void ClStridedSliceWorkload::Execute() const
87{
88 ARMNN_SCOPED_PROFILING_EVENT_CL("ClStridedSliceWorkload_Execute");
89 RunClFunction(m_StridedSliceLayer, CHECK_LOCATION());
90}
91
92} //namespace armnn