Matthew Jackson | 81e601c | 2019-07-11 12:07:09 +0100 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2017 Arm Ltd. All rights reserved. |
| 3 | // SPDX-License-Identifier: MIT |
| 4 | // |
| 5 | |
| 6 | #include "RefStackWorkload.hpp" |
| 7 | |
| 8 | #include "RefWorkloadUtils.hpp" |
| 9 | #include "Stack.hpp" |
| 10 | |
| 11 | #include <Profiling.hpp> |
| 12 | |
| 13 | namespace armnn |
| 14 | { |
| 15 | |
| 16 | RefStackWorkload::RefStackWorkload(const StackQueueDescriptor& descriptor, |
| 17 | const WorkloadInfo& info) |
| 18 | : BaseWorkload(descriptor, info) |
| 19 | {} |
| 20 | |
| 21 | void RefStackWorkload::Execute() const |
| 22 | { |
| 23 | ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefStackWorkload_Execute"); |
| 24 | |
| 25 | // Can perform a simple concatenation when axis == 0 |
| 26 | if (!m_Data.m_Parameters.m_Axis) |
| 27 | { |
| 28 | float* output = GetOutputTensorData<float>(0, m_Data); |
Narumol Prangnawarat | ac2770a | 2020-04-01 16:51:23 +0100 | [diff] [blame] | 29 | ARMNN_ASSERT(output != nullptr); |
Matthew Jackson | 81e601c | 2019-07-11 12:07:09 +0100 | [diff] [blame] | 30 | |
| 31 | unsigned int numInputs = m_Data.m_Parameters.m_NumInputs; |
| 32 | unsigned int inputLength = GetTensorInfo(m_Data.m_Inputs[0]).GetNumElements(); |
| 33 | |
| 34 | for (unsigned int inputIdx=0; inputIdx<numInputs; ++inputIdx) |
| 35 | { |
| 36 | const float* input = GetInputTensorData<float>(inputIdx, m_Data); |
| 37 | for (unsigned int elmt=0; elmt<inputLength; ++elmt) |
| 38 | { |
| 39 | output[(inputIdx * inputLength) + elmt] = input[elmt]; |
| 40 | } |
| 41 | } |
| 42 | return; |
| 43 | } |
| 44 | |
| 45 | std::vector<std::unique_ptr<Decoder<float>>> inputDecoders; |
| 46 | for (unsigned int i=0; i<m_Data.m_Inputs.size(); ++i) |
| 47 | { |
| 48 | inputDecoders.push_back(MakeDecoder<float>(GetTensorInfo(m_Data.m_Inputs[i]), |
| 49 | m_Data.m_Inputs[i]->Map())); |
| 50 | } |
| 51 | std::unique_ptr<Encoder<float>> outputEncoder = MakeEncoder<float>(GetTensorInfo(m_Data.m_Outputs[0]), |
| 52 | m_Data.m_Outputs[0]->Map()); |
| 53 | |
| 54 | Stack(m_Data, inputDecoders, *outputEncoder); |
| 55 | } |
| 56 | |
| 57 | } // namespace armnn |