blob: a5cc998b402039f123e43d12ba9e113e3a729d1a [file] [log] [blame]
narpra014951d842019-01-18 16:53:53 +00001//
Mike Kelly7cbe7812023-07-25 17:37:33 +01002// Copyright © 2019-2023 Arm Ltd and Contributors. All rights reserved.
narpra014951d842019-01-18 16:53:53 +00003// SPDX-License-Identifier: MIT
4//
5
6#include "RefGatherWorkload.hpp"
7
8#include "Gather.hpp"
9#include "Profiling.hpp"
10#include "RefWorkloadUtils.hpp"
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010011#include <ResolveType.hpp>
Ciara Sookarryabd3c212023-10-11 17:04:04 +010012#include <fmt/format.h>
narpra014951d842019-01-18 16:53:53 +000013
14namespace armnn
15{
16
Ellen Norris-Thompson6858d3f2019-06-21 15:50:00 +010017void RefGatherWorkload::Execute() const
narpra014951d842019-01-18 16:53:53 +000018{
Finn Williamsb8181f72021-04-07 10:23:21 +010019 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
20}
21
Matthew Sloyan2d213a72022-06-30 17:13:04 +010022void RefGatherWorkload::ExecuteAsync(ExecutionData& executionData)
Finn Williamsb8181f72021-04-07 10:23:21 +010023{
Matthew Sloyan2d213a72022-06-30 17:13:04 +010024 WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
25 Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
Finn Williamsb8181f72021-04-07 10:23:21 +010026}
27
28void RefGatherWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
29{
Mike Kelly7cbe7812023-07-25 17:37:33 +010030 ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID("RefGatherWorkload_Execute");
narpra014951d842019-01-18 16:53:53 +000031
Finn Williamsb8181f72021-04-07 10:23:21 +010032 const TensorInfo& inputInfo0 = GetTensorInfo(inputs[0]);
33 const TensorInfo& inputInfo1 = GetTensorInfo(inputs[1]);
34 const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
narpra014951d842019-01-18 16:53:53 +000035
Finn Williamsb8181f72021-04-07 10:23:21 +010036 std::unique_ptr<Decoder<float>> decoderPtr = MakeDecoder<float>(inputInfo0, inputs[0]->Map());
Ellen Norris-Thompson6858d3f2019-06-21 15:50:00 +010037 Decoder<float>& decoder = *decoderPtr;
38
Finn Williams01097942021-04-26 12:06:34 +010039 const int32_t* indicesData = reinterpret_cast<int32_t*>(inputs[1]->Map());
Ciara Sookarryabd3c212023-10-11 17:04:04 +010040 // Check for negative indices, it could not be checked in validate as we do not have access to the values there
41 for (unsigned int i = 0; i < inputInfo1.GetNumElements(); ++i)
42 {
43 if (indicesData[i] < 0)
44 {
45 throw InvalidArgumentException((fmt::format("Gather: indices[{}] < 0", i)));
46 }
47 }
narpra014951d842019-01-18 16:53:53 +000048
Finn Williamsb8181f72021-04-07 10:23:21 +010049 std::unique_ptr<Encoder<float>> encoderPtr = MakeEncoder<float>(outputInfo, outputs[0]->Map());
Ellen Norris-Thompson6858d3f2019-06-21 15:50:00 +010050 Encoder<float>& encoder = *encoderPtr;
51
Teresa Charlin52664732020-06-29 16:27:03 +010052 Gather(inputInfo0, inputInfo1, outputInfo, decoder, indicesData, encoder, m_Data.m_Parameters.m_Axis);
narpra014951d842019-01-18 16:53:53 +000053}
54
narpra014951d842019-01-18 16:53:53 +000055} //namespace armnn