blob: 3ba698ec4d194bd155ef20e20fdccd0a07de3741 [file] [log] [blame]
Matthew Jacksond5166102019-07-31 14:06:28 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#include "ClStackWorkload.hpp"
6#include "ClWorkloadUtils.hpp"
7#include <aclCommon/ArmComputeTensorUtils.hpp>
8#include <backendsCommon/CpuTensorHandle.hpp>
9#include <cl/ClTensorHandle.hpp>
10#include <cl/ClLayerSupport.hpp>
11
12#include <arm_compute/core/Types.h>
13
14#include <boost/numeric/conversion/cast.hpp>
15#include <boost/polymorphic_pointer_cast.hpp>
16
17namespace armnn
18{
19using namespace armcomputetensorutils;
20
21namespace
22{
23int CalcAxis(const unsigned int axis, const unsigned int inputDimensions)
24{
25 const int intAxis = boost::numeric_cast<int>(axis);
26 return boost::numeric_cast<int>(inputDimensions) - intAxis;
27}
28} //namespace
29
30arm_compute::Status ClStackWorkloadValidate(const std::vector<const TensorInfo*>& inputs,
31 const TensorInfo& output,
32 const StackDescriptor& descriptor)
33{
34 std::vector<arm_compute::ITensorInfo*> aclInputPtrs;
35 arm_compute::TensorInfo aclInputInfo;
36 for (const TensorInfo* input : inputs)
37 {
38 aclInputInfo = BuildArmComputeTensorInfo(*input);
39 aclInputPtrs.emplace_back(&aclInputInfo);
40 }
41 const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
42
43 int aclAxis = CalcAxis(descriptor.m_Axis, descriptor.m_InputShape.GetNumDimensions());
44
45 return arm_compute::CLStackLayer::validate(aclInputPtrs, aclAxis, &aclOutputInfo);
46}
47
48ClStackWorkload::ClStackWorkload(const StackQueueDescriptor& descriptor, const WorkloadInfo& info)
49: BaseWorkload<StackQueueDescriptor>(descriptor, info)
50{
51 std::vector<arm_compute::ICLTensor*> aclInputs;
52 for (auto input : m_Data.m_Inputs)
53 {
54 arm_compute::ICLTensor& aclInput = boost::polymorphic_pointer_downcast<IClTensorHandle>(input)->GetTensor();
55 aclInputs.emplace_back(&aclInput);
56 }
57 arm_compute::ICLTensor& output = boost::polymorphic_pointer_downcast<IClTensorHandle>(
58 m_Data.m_Outputs[0])->GetTensor();
59
60 m_Layer.reset(new arm_compute::CLStackLayer());
61 int aclAxis = CalcAxis(descriptor.m_Parameters.m_Axis, descriptor.m_Parameters.m_InputShape.GetNumDimensions());
62 m_Layer->configure(aclInputs, aclAxis, &output);
63}
64
65void ClStackWorkload::Execute() const
66{
67 if (m_Layer)
68 {
69 ARMNN_SCOPED_PROFILING_EVENT_CL("ClStackWorkload_Execute");
70 m_Layer->run();
71 }
72}
73
74} //namespace armnn