blob: 9c2472800487d7b02e0613563f357033ca7f4246 [file] [log] [blame]
FinnWilliamsArm1fa19192019-08-02 17:26:31 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "NeonStridedSliceWorkload.hpp"
7
8#include "NeonWorkloadUtils.hpp"
9#include <neon/NeonTensorHandle.hpp>
10#include <aclCommon/ArmComputeUtils.hpp>
11#include <aclCommon/ArmComputeTensorUtils.hpp>
12
13
14namespace armnn
15{
16
17arm_compute::Status NeonStridedSliceWorkloadValidate(const TensorInfo& input,
18 const TensorInfo& output,
19 const StridedSliceDescriptor& descriptor)
20{
21 const arm_compute::TensorInfo aclInput = armcomputetensorutils::BuildArmComputeTensorInfo(input);
22 const arm_compute::TensorInfo aclOutput = armcomputetensorutils::BuildArmComputeTensorInfo(output);
23
24 arm_compute::Coordinates starts;
25 arm_compute::Coordinates ends;
26 arm_compute::Coordinates strides;
27
28 std::tie(starts, ends, strides) = SetNeonStridedSliceData(descriptor.m_Begin,
29 descriptor.m_End,
30 descriptor.m_Stride);
31
32 int32_t begin_mask = descriptor.m_BeginMask;
33 int32_t end_mask = descriptor.m_EndMask;
34 int32_t shrink_axis_mask = descriptor.m_ShrinkAxisMask;
35
36 return arm_compute::NEStridedSlice::validate(&aclInput,
37 &aclOutput,
38 starts,
39 ends,
40 strides,
41 begin_mask,
42 end_mask,
43 shrink_axis_mask);
44}
45
46NeonStridedSliceWorkload::NeonStridedSliceWorkload(const StridedSliceQueueDescriptor& descriptor,
47 const WorkloadInfo& info)
48 : BaseWorkload<StridedSliceQueueDescriptor>(descriptor, info)
49{
50 m_Data.ValidateInputsOutputs("NeonStridedSliceWorkload", 1, 1);
51
52 arm_compute::ITensor& input = boost::polymorphic_downcast<IAclTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
53 arm_compute::ITensor& output = boost::polymorphic_downcast<IAclTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
54
55 arm_compute::Coordinates starts;
56 arm_compute::Coordinates ends;
57 arm_compute::Coordinates strides;
58
59 std::tie(starts, ends, strides) = SetNeonStridedSliceData(m_Data.m_Parameters.m_Begin,
60 m_Data.m_Parameters.m_End,
61 m_Data.m_Parameters.m_Stride);
62
63 int32_t begin_mask = m_Data.m_Parameters.m_BeginMask;
64 int32_t end_mask = m_Data.m_Parameters.m_EndMask;
65 int32_t shrink_axis_mask = m_Data.m_Parameters.m_ShrinkAxisMask;
66
67 arm_compute::DataLayout aclDataLayout = ConvertDataLayout(m_Data.m_Parameters.m_DataLayout);
68 input.info()->set_data_layout(aclDataLayout);
69 output.info()->set_data_layout(aclDataLayout);
70
71 auto layer = std::make_unique<arm_compute::NEStridedSlice>();
72
73 layer->configure(&input,
74 &output,
75 starts,
76 ends,
77 strides,
78 begin_mask,
79 end_mask,
80 shrink_axis_mask);
81 m_Layer.reset(layer.release());
82}
83
84void NeonStridedSliceWorkload::Execute() const
85{
86 ARMNN_SCOPED_PROFILING_EVENT_NEON("NeonStridedSliceWorkload_Execute");
87 m_Layer->run();
88}
89
90} //namespace armnn