blob: bf5b4708a3dad689522228ceb3458d86549181dc [file] [log] [blame]
Nikhil Raj68c2c902019-09-19 11:21:11 +01001//
Mike Kelly7cbe7812023-07-25 17:37:33 +01002// Copyright © 2019-2023 Arm Ltd and Contributors. All rights reserved.
Nikhil Raj68c2c902019-09-19 11:21:11 +01003// SPDX-License-Identifier: MIT
4//
5
6#include "RefArgMinMaxWorkload.hpp"
7
8#include "ArgMinMax.hpp"
9#include "RefWorkloadUtils.hpp"
10#include "Decoders.hpp"
11#include "Encoders.hpp"
12#include "Profiling.hpp"
13
14namespace armnn
15{
16RefArgMinMaxWorkload::RefArgMinMaxWorkload(
17 const ArgMinMaxQueueDescriptor& descriptor,
18 const WorkloadInfo& info)
Finn Williams73c547d2022-02-15 20:47:34 +000019 : RefBaseWorkload<ArgMinMaxQueueDescriptor>(descriptor, info) {}
Nikhil Raj68c2c902019-09-19 11:21:11 +010020
Finn Williamsb8181f72021-04-07 10:23:21 +010021
Nikhil Raj68c2c902019-09-19 11:21:11 +010022void RefArgMinMaxWorkload::Execute() const
23{
Finn Williamsb8181f72021-04-07 10:23:21 +010024 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
25}
26
Matthew Sloyan2d213a72022-06-30 17:13:04 +010027void RefArgMinMaxWorkload::ExecuteAsync(ExecutionData& executionData)
Finn Williamsb8181f72021-04-07 10:23:21 +010028{
Matthew Sloyan2d213a72022-06-30 17:13:04 +010029 WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
30 Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
Finn Williamsb8181f72021-04-07 10:23:21 +010031}
32
33void RefArgMinMaxWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
34{
Mike Kelly7cbe7812023-07-25 17:37:33 +010035 ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID("RefArgMinMaxWorkload_Execute");
Nikhil Raj68c2c902019-09-19 11:21:11 +010036
Finn Williamsb8181f72021-04-07 10:23:21 +010037 const TensorInfo &inputTensorInfo = GetTensorInfo(inputs[0]);
Nikhil Raj68c2c902019-09-19 11:21:11 +010038
Finn Williamsb8181f72021-04-07 10:23:21 +010039 std::unique_ptr<Decoder<float>> decoderPtr = MakeDecoder<float>(inputTensorInfo, inputs[0]->Map());
Nikhil Raj68c2c902019-09-19 11:21:11 +010040 Decoder<float> &decoder = *decoderPtr;
41
Finn Williamsb8181f72021-04-07 10:23:21 +010042 const TensorInfo &outputTensorInfo = GetTensorInfo(outputs[0]);
Nikhil Raj68c2c902019-09-19 11:21:11 +010043
Mike Kelly1f140f72021-04-06 12:25:55 +010044 if (outputTensorInfo.GetDataType() == armnn::DataType::Signed32) {
Finn Williams01097942021-04-26 12:06:34 +010045 int32_t *output = GetOutputTensorData<int32_t>(outputs[0]);
Inki Daed4619e22020-09-10 15:33:54 +090046 ArgMinMax(decoder, output, inputTensorInfo, outputTensorInfo, m_Data.m_Parameters.m_Function,
47 m_Data.m_Parameters.m_Axis);
48 } else {
Finn Williams01097942021-04-26 12:06:34 +010049 int64_t *output = GetOutputTensorData<int64_t>(outputs[0]);
Inki Daed4619e22020-09-10 15:33:54 +090050 ArgMinMax(decoder, output, inputTensorInfo, outputTensorInfo, m_Data.m_Parameters.m_Function,
51 m_Data.m_Parameters.m_Axis);
52 }
Nikhil Raj68c2c902019-09-19 11:21:11 +010053}
54
55} //namespace armnn