blob: e06d8c51f5198d6e45cca64ea427bea595dfff50 [file] [log] [blame]
Nikhil Raj8599a412018-11-19 14:51:07 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#include "ClMergerWorkload.hpp"
6#include "ClWorkloadUtils.hpp"
7#include <aclCommon/ArmComputeTensorUtils.hpp>
8#include <backendsCommon/CpuTensorHandle.hpp>
9#include <cl/ClTensorHandle.hpp>
10#include <cl/ClLayerSupport.hpp>
11
12#include <boost/polymorphic_pointer_cast.hpp>
13
14namespace armnn
15{
16using namespace armcomputetensorutils;
17
18arm_compute::Status ClMergerWorkloadValidate(const std::vector<const TensorInfo*>& inputs,
19 const TensorInfo& output,
20 const MergerDescriptor& descriptor)
21
22{
23 std::vector<arm_compute::TensorInfo> aclInputs;
24 for (const TensorInfo* input : inputs)
25 {
26 arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(*input, armnn::DataLayout::NCHW);
27 aclInputs.emplace_back(aclInputInfo);
28 }
29 const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
30 arm_compute::DataLayoutDimension aclAxis = arm_compute::DataLayoutDimension::WIDTH;
31
32 std::vector<arm_compute::ITensorInfo*> aclInputPtrs;
33 for (arm_compute::ITensorInfo& input : aclInputs)
34 {
35 aclInputPtrs.emplace_back(&input);
36 }
37
38 return arm_compute::CLConcatenateLayer::validate(aclInputPtrs, &aclOutputInfo, aclAxis);
39
40}
41
42ClMergerWorkload::ClMergerWorkload(const MergerQueueDescriptor& descriptor, const WorkloadInfo& info)
43: BaseWorkload<MergerQueueDescriptor>(descriptor, info)
44{
45 m_Execute = true;
46
47 unsigned int innerAxisOrder = descriptor.m_Parameters.GetNumDimensions() - descriptor.m_Parameters.GetConcatAxis();
48
49 if (innerAxisOrder != 1)
50 {
51 m_Execute = false;
52 return;
53 }
54
55 std::vector<arm_compute::ICLTensor *> aclInputs;
56 arm_compute::DataLayout aclDataLayout = ConvertDataLayout(armnn::DataLayout::NCHW);
57 for (auto input : m_Data.m_Inputs)
58 {
59 arm_compute::ICLTensor& aclInput = boost::polymorphic_pointer_downcast<IClTensorHandle>(input)->GetTensor();
60 aclInput.info()->set_data_layout(aclDataLayout);
61 aclInputs.emplace_back(&aclInput);
62 }
63 arm_compute::ICLTensor& output = boost::polymorphic_pointer_downcast<IClTensorHandle>(
64 m_Data.m_Outputs[0])->GetTensor();
65 output.info()->set_data_layout(aclDataLayout);
66
67 arm_compute::DataLayoutDimension aclAxis = arm_compute::DataLayoutDimension::WIDTH;
68
69 m_Layer.configure(aclInputs, &output, aclAxis);
70
71 m_Layer.prepare();
72
73}
74
75void ClMergerWorkload::Execute() const
76{
77 if (m_Execute)
78 {
79 ARMNN_SCOPED_PROFILING_EVENT_CL("ClMergerWorkload_Execute");
80 m_Layer.run();
81 }
82
83}
84
85} //namespace armnn