blob: 4a9f68798c63a30124277685052a7ddab0ec1ec4 [file] [log] [blame]
Nikhil Raj8599a412018-11-19 14:51:07 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
Jim Flynn39d487d2019-05-17 15:44:36 +01006#include "NeonConcatWorkload.hpp"
Matthew Benthamd80a7122019-01-08 17:52:37 +00007
8#include "NeonWorkloadUtils.hpp"
9
Nikhil Raj8599a412018-11-19 14:51:07 +000010#include <aclCommon/ArmComputeTensorUtils.hpp>
11#include <backendsCommon/CpuTensorHandle.hpp>
12#include <neon/NeonTensorHandle.hpp>
13
Derek Lamberti0790dce2019-04-15 18:37:35 +010014
Nikhil Raj8599a412018-11-19 14:51:07 +000015
16namespace armnn
17{
18using namespace armcomputetensorutils;
19
Derek Lamberti0790dce2019-04-15 18:37:35 +010020namespace
21{
Jim Flynne242f2d2019-05-22 14:24:13 +010022size_t CalcAxis(const armnn::OriginsDescriptor& desc)
Derek Lamberti0790dce2019-04-15 18:37:35 +010023{
24 return (desc.GetNumDimensions() - desc.GetConcatAxis()) - 1;
25}
26} //namespace
27
Jim Flynn39d487d2019-05-17 15:44:36 +010028arm_compute::Status NeonConcatWorkloadValidate(const std::vector<const TensorInfo*>& inputs,
Nikhil Raj8599a412018-11-19 14:51:07 +000029 const TensorInfo& output,
Jim Flynne242f2d2019-05-22 14:24:13 +010030 const OriginsDescriptor& descriptor)
Nikhil Raj8599a412018-11-19 14:51:07 +000031
32{
33 std::vector<arm_compute::TensorInfo> aclInputs;
34 for (const TensorInfo* input : inputs)
35 {
Derek Lamberti0790dce2019-04-15 18:37:35 +010036 arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(*input, armnn::DataLayout::NCHW);
37 aclInputs.emplace_back(aclInputInfo);
Nikhil Raj8599a412018-11-19 14:51:07 +000038 }
39 const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
Nikhil Raj8599a412018-11-19 14:51:07 +000040 std::vector<arm_compute::ITensorInfo*> aclInputPtrs;
41 for (arm_compute::ITensorInfo& input : aclInputs)
42 {
43 aclInputPtrs.emplace_back(&input);
44 }
45
Derek Lamberti0790dce2019-04-15 18:37:35 +010046 size_t aclAxis = CalcAxis(descriptor);
Nikhil Raj8599a412018-11-19 14:51:07 +000047 return arm_compute::NEConcatenateLayer::validate(aclInputPtrs, &aclOutputInfo, aclAxis);
Nikhil Raj8599a412018-11-19 14:51:07 +000048}
49
Jim Flynn39d487d2019-05-17 15:44:36 +010050NeonConcatWorkload::NeonConcatWorkload(
Jim Flynne242f2d2019-05-22 14:24:13 +010051const ConcatQueueDescriptor& descriptor, const WorkloadInfo& info)
52 : BaseWorkload<ConcatQueueDescriptor>(descriptor, info)
Nikhil Raj8599a412018-11-19 14:51:07 +000053{
Derek Lamberti0790dce2019-04-15 18:37:35 +010054 bool allInputsAreSubtensors = true;
Nikhil Raj8599a412018-11-19 14:51:07 +000055
Derek Lamberti0790dce2019-04-15 18:37:35 +010056 // Check that all inputs are sub-tensors
57 for (auto input : descriptor.m_Inputs)
Nikhil Raj8599a412018-11-19 14:51:07 +000058 {
Derek Lamberti0790dce2019-04-15 18:37:35 +010059 if (!input->GetParent())
60 {
Jim Flynne242f2d2019-05-22 14:24:13 +010061 // Non sub-tensor input found so we need to execute the concat function
Derek Lamberti0790dce2019-04-15 18:37:35 +010062 allInputsAreSubtensors = false;
63 break;
64 }
65 }
66
67 if (allInputsAreSubtensors)
68 {
Jim Flynne242f2d2019-05-22 14:24:13 +010069 // Can skip configuring the concat function since it's not executed
Nikhil Raj8599a412018-11-19 14:51:07 +000070 return;
71 }
72
73 std::vector<arm_compute::ITensor *> aclInputs;
Nikhil Raj8599a412018-11-19 14:51:07 +000074 for (auto input : m_Data.m_Inputs)
75 {
Derek Lambertic81855f2019-06-13 17:34:19 +010076 arm_compute::ITensor& aclInput = boost::polymorphic_pointer_downcast<IAclTensorHandle>(input)->GetTensor();
Nikhil Raj8599a412018-11-19 14:51:07 +000077 aclInputs.emplace_back(&aclInput);
78 }
Derek Lambertic81855f2019-06-13 17:34:19 +010079 arm_compute::ITensor& output = boost::polymorphic_pointer_downcast<IAclTensorHandle>(
Derek Lamberti0790dce2019-04-15 18:37:35 +010080 m_Data.m_Outputs[0])->GetTensor();
Nikhil Raj8599a412018-11-19 14:51:07 +000081
Derek Lamberti0790dce2019-04-15 18:37:35 +010082 // Create the layer function
83 m_Layer.reset(new arm_compute::NEConcatenateLayer());
Nikhil Raj8599a412018-11-19 14:51:07 +000084
Derek Lamberti0790dce2019-04-15 18:37:35 +010085 // Configure input and output tensors
86 size_t aclAxis = CalcAxis(descriptor.m_Parameters);
87 m_Layer->configure(aclInputs, &output, aclAxis);
Nikhil Raj8599a412018-11-19 14:51:07 +000088
Derek Lamberti0790dce2019-04-15 18:37:35 +010089 // Prepare
Matthew Benthamd80a7122019-01-08 17:52:37 +000090 m_Layer->prepare();
Nikhil Raj8599a412018-11-19 14:51:07 +000091}
92
Jim Flynn39d487d2019-05-17 15:44:36 +010093void NeonConcatWorkload::Execute() const
Nikhil Raj8599a412018-11-19 14:51:07 +000094{
Derek Lamberti0790dce2019-04-15 18:37:35 +010095 if (m_Layer)
Nikhil Raj8599a412018-11-19 14:51:07 +000096 {
Jim Flynn39d487d2019-05-17 15:44:36 +010097 ARMNN_SCOPED_PROFILING_EVENT_NEON("NeonConcatWorkload_Execute");
Matthew Benthamd80a7122019-01-08 17:52:37 +000098 m_Layer->run();
Nikhil Raj8599a412018-11-19 14:51:07 +000099 }
100}
101
102} //namespace armnn
103