blob: 610acb91fb99e8a9da19c61e5ca034a325c66395 [file] [log] [blame]
Nikhil Raj8599a412018-11-19 14:51:07 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#include "ClMergerWorkload.hpp"
6#include "ClWorkloadUtils.hpp"
7#include <aclCommon/ArmComputeTensorUtils.hpp>
8#include <backendsCommon/CpuTensorHandle.hpp>
9#include <cl/ClTensorHandle.hpp>
10#include <cl/ClLayerSupport.hpp>
11
Derek Lamberti0790dce2019-04-15 18:37:35 +010012#include <arm_compute/core/Types.h>
13
Nikhil Raj8599a412018-11-19 14:51:07 +000014#include <boost/polymorphic_pointer_cast.hpp>
15
16namespace armnn
17{
18using namespace armcomputetensorutils;
19
Derek Lamberti0790dce2019-04-15 18:37:35 +010020namespace
21{
22size_t CalcAxis(const MergerDescriptor& desc)
23{
24 return (desc.GetNumDimensions() - desc.GetConcatAxis()) - 1;
25}
26} //namespace
27
Nikhil Raj8599a412018-11-19 14:51:07 +000028arm_compute::Status ClMergerWorkloadValidate(const std::vector<const TensorInfo*>& inputs,
29 const TensorInfo& output,
30 const MergerDescriptor& descriptor)
Nikhil Raj8599a412018-11-19 14:51:07 +000031{
32 std::vector<arm_compute::TensorInfo> aclInputs;
33 for (const TensorInfo* input : inputs)
34 {
35 arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(*input, armnn::DataLayout::NCHW);
36 aclInputs.emplace_back(aclInputInfo);
37 }
38 const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
Nikhil Raj8599a412018-11-19 14:51:07 +000039 std::vector<arm_compute::ITensorInfo*> aclInputPtrs;
40 for (arm_compute::ITensorInfo& input : aclInputs)
41 {
42 aclInputPtrs.emplace_back(&input);
43 }
44
Derek Lamberti0790dce2019-04-15 18:37:35 +010045 size_t aclAxis = CalcAxis(descriptor);
Nikhil Raj8599a412018-11-19 14:51:07 +000046 return arm_compute::CLConcatenateLayer::validate(aclInputPtrs, &aclOutputInfo, aclAxis);
Nikhil Raj8599a412018-11-19 14:51:07 +000047}
48
49ClMergerWorkload::ClMergerWorkload(const MergerQueueDescriptor& descriptor, const WorkloadInfo& info)
50: BaseWorkload<MergerQueueDescriptor>(descriptor, info)
51{
Derek Lamberti0790dce2019-04-15 18:37:35 +010052 bool allInputsAreSubtensors = true;
Nikhil Raj8599a412018-11-19 14:51:07 +000053
Derek Lamberti0790dce2019-04-15 18:37:35 +010054 // Check that all inputs are sub-tensors
55 for (auto input : descriptor.m_Inputs)
Nikhil Raj8599a412018-11-19 14:51:07 +000056 {
Derek Lamberti0790dce2019-04-15 18:37:35 +010057 if (!input->GetParent())
58 {
59 // Non sub-tensor input found so we need to execute the merger function
60 allInputsAreSubtensors = false;
61 break;
62 }
63 }
64
65 if (allInputsAreSubtensors)
66 {
67 // Can skip configuring the merger function since it's not executed
Nikhil Raj8599a412018-11-19 14:51:07 +000068 return;
69 }
70
71 std::vector<arm_compute::ICLTensor *> aclInputs;
Nikhil Raj8599a412018-11-19 14:51:07 +000072 for (auto input : m_Data.m_Inputs)
73 {
74 arm_compute::ICLTensor& aclInput = boost::polymorphic_pointer_downcast<IClTensorHandle>(input)->GetTensor();
Nikhil Raj8599a412018-11-19 14:51:07 +000075 aclInputs.emplace_back(&aclInput);
76 }
77 arm_compute::ICLTensor& output = boost::polymorphic_pointer_downcast<IClTensorHandle>(
78 m_Data.m_Outputs[0])->GetTensor();
Nikhil Raj8599a412018-11-19 14:51:07 +000079
Derek Lamberti0790dce2019-04-15 18:37:35 +010080 // Create the layer function
81 m_Layer.reset(new arm_compute::CLConcatenateLayer());
Nikhil Raj8599a412018-11-19 14:51:07 +000082
Derek Lamberti0790dce2019-04-15 18:37:35 +010083 // Configure input and output tensors
84 size_t aclAxis = CalcAxis(descriptor.m_Parameters);
85 m_Layer->configure(aclInputs, &output, aclAxis);
Nikhil Raj8599a412018-11-19 14:51:07 +000086
Derek Lamberti0790dce2019-04-15 18:37:35 +010087 // Prepare
88 m_Layer->prepare();
Nikhil Raj8599a412018-11-19 14:51:07 +000089}
90
91void ClMergerWorkload::Execute() const
92{
Derek Lamberti0790dce2019-04-15 18:37:35 +010093 if (m_Layer)
Nikhil Raj8599a412018-11-19 14:51:07 +000094 {
95 ARMNN_SCOPED_PROFILING_EVENT_CL("ClMergerWorkload_Execute");
Derek Lamberti0790dce2019-04-15 18:37:35 +010096 m_Layer->run();
Nikhil Raj8599a412018-11-19 14:51:07 +000097 }
Nikhil Raj8599a412018-11-19 14:51:07 +000098}
99
100} //namespace armnn