blob: c4a4f7f59394b4236ea4fe86e011a7cc7b3aa7a5 [file] [log] [blame]
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00001//
Mike Kelly7cbe7812023-07-25 17:37:33 +01002// Copyright © 2018-2023 Arm Ltd and Contributors. All rights reserved.
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00003// SPDX-License-Identifier: MIT
4//
5
6#include "RefStridedSliceWorkload.hpp"
Matteo Martincighe851b3d2019-05-28 14:31:20 +01007#include "RefWorkloadUtils.hpp"
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +00008#include "StridedSlice.hpp"
9
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +000010namespace armnn
11{
12
Matteo Martincighe851b3d2019-05-28 14:31:20 +010013RefStridedSliceWorkload::RefStridedSliceWorkload(const StridedSliceQueueDescriptor& descriptor,
14 const WorkloadInfo& info)
Finn Williams73c547d2022-02-15 20:47:34 +000015 : RefBaseWorkload(descriptor, info)
Matteo Martincighe851b3d2019-05-28 14:31:20 +010016{}
17
18void RefStridedSliceWorkload::Execute() const
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +000019{
Finn Williamsb8181f72021-04-07 10:23:21 +010020 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
Nattapat Chaimanowong1216b582018-11-23 15:33:41 +000021}
22
Matthew Sloyan2d213a72022-06-30 17:13:04 +010023void RefStridedSliceWorkload::ExecuteAsync(ExecutionData& executionData)
Mike Kelly386ff1a2021-03-29 15:04:50 +010024{
Matthew Sloyan2d213a72022-06-30 17:13:04 +010025 WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
26 Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
Finn Williamsb8181f72021-04-07 10:23:21 +010027}
Mike Kelly386ff1a2021-03-29 15:04:50 +010028
Finn Williamsb8181f72021-04-07 10:23:21 +010029void RefStridedSliceWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
30{
Mike Kelly7cbe7812023-07-25 17:37:33 +010031 ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID("RefStridedSliceWorkload_Execute");
Finn Williamsb8181f72021-04-07 10:23:21 +010032
33 const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
34 const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
Mike Kelly386ff1a2021-03-29 15:04:50 +010035
36 DataType inputDataType = inputInfo.GetDataType();
37 DataType outputDataType = outputInfo.GetDataType();
38
39 ARMNN_ASSERT(inputDataType == outputDataType);
40 IgnoreUnused(outputDataType);
41
42 StridedSlice(inputInfo,
43 m_Data.m_Parameters,
Finn Williamsb8181f72021-04-07 10:23:21 +010044 inputs[0]->Map(),
45 outputs[0]->Map(),
Mike Kelly386ff1a2021-03-29 15:04:50 +010046 GetDataTypeSize(inputDataType));
47}
48
Matteo Martincighe851b3d2019-05-28 14:31:20 +010049} // namespace armnn