blob: 8edf14c8f8c2e316508113bbc99c5c0f409cf3fe [file] [log] [blame]
narpra014951d842019-01-18 16:53:53 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RefGatherWorkload.hpp"
7
8#include "Gather.hpp"
9#include "Profiling.hpp"
10#include "RefWorkloadUtils.hpp"
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010011#include <ResolveType.hpp>
narpra014951d842019-01-18 16:53:53 +000012
13namespace armnn
14{
15
Ellen Norris-Thompson6858d3f2019-06-21 15:50:00 +010016void RefGatherWorkload::Execute() const
narpra014951d842019-01-18 16:53:53 +000017{
narpra014951d842019-01-18 16:53:53 +000018 ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefGatherWorkload_Execute");
19
20 const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]);
21 const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]);
22 const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
23
Ellen Norris-Thompson6858d3f2019-06-21 15:50:00 +010024 std::unique_ptr<Decoder<float>> decoderPtr = MakeDecoder<float>(inputInfo0, m_Data.m_Inputs[0]->Map());
25 Decoder<float>& decoder = *decoderPtr;
26
narpra014951d842019-01-18 16:53:53 +000027 const int32_t* indicesData = GetInputTensorData<int32_t>(1, m_Data);
narpra014951d842019-01-18 16:53:53 +000028
Ellen Norris-Thompson6858d3f2019-06-21 15:50:00 +010029 std::unique_ptr<Encoder<float>> encoderPtr = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
30 Encoder<float>& encoder = *encoderPtr;
31
32 Gather(inputInfo0, inputInfo1, outputInfo, decoder, indicesData, encoder);
narpra014951d842019-01-18 16:53:53 +000033}
34
narpra014951d842019-01-18 16:53:53 +000035} //namespace armnn