blob: 152d19cc044418e6ec56f8512b844b4dde625f9d [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
arovir019e53a352018-08-31 15:26:35 +01006#include "NeonSoftmaxFloatWorkload.hpp"
telsoa014fcda012018-03-09 14:13:49 +00007
Matthew Benthamd80a7122019-01-08 17:52:37 +00008#include "NeonWorkloadUtils.hpp"
9
Narumol Prangnawarat65d30962019-03-14 11:55:03 +000010#include <aclCommon/ArmComputeUtils.hpp>
Matthew Benthamd80a7122019-01-08 17:52:37 +000011#include <arm_compute/runtime/NEON/functions/NESoftmaxLayer.h>
12
telsoa014fcda012018-03-09 14:13:49 +000013namespace armnn
14{
surmeh013537c2c2018-05-18 16:31:43 +010015
arovir019e53a352018-08-31 15:26:35 +010016NeonSoftmaxFloatWorkload::NeonSoftmaxFloatWorkload(const SoftmaxQueueDescriptor& descriptor,
surmeh013537c2c2018-05-18 16:31:43 +010017 const WorkloadInfo& info, std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager)
telsoa01c577f2c2018-08-31 09:22:23 +010018 : FloatWorkload<SoftmaxQueueDescriptor>(descriptor, info)
telsoa014fcda012018-03-09 14:13:49 +000019{
arovir019e53a352018-08-31 15:26:35 +010020 m_Data.ValidateInputsOutputs("NeonSoftmaxFloatWorkload", 1, 1);
telsoa014fcda012018-03-09 14:13:49 +000021
telsoa01c577f2c2018-08-31 09:22:23 +010022 // The ArmCompute softmax layer uses 2D input/output tensors, so flatten the first three dimensions.
Derek Lambertic81855f2019-06-13 17:34:19 +010023 arm_compute::ITensor& input = boost::polymorphic_downcast<IAclTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
24 arm_compute::ITensor& output = boost::polymorphic_downcast<IAclTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
telsoa014fcda012018-03-09 14:13:49 +000025
Matthew Benthamd80a7122019-01-08 17:52:37 +000026 auto layer = std::make_unique<arm_compute::NESoftmaxLayer>(memoryManager);
Colm Donelanc3c5fc22019-08-15 16:03:17 +010027 unsigned int aclAxis = ComputeSoftmaxAclAxis(m_Data.m_Parameters, info.m_InputTensorInfos[0]);
Narumol Prangnawarat65d30962019-03-14 11:55:03 +000028 layer->configure(&input, &output, m_Data.m_Parameters.m_Beta, aclAxis);
Matthew Benthamd80a7122019-01-08 17:52:37 +000029 m_SoftmaxLayer.reset(layer.release());
telsoa014fcda012018-03-09 14:13:49 +000030}
31
arovir019e53a352018-08-31 15:26:35 +010032void NeonSoftmaxFloatWorkload::Execute() const
telsoa014fcda012018-03-09 14:13:49 +000033{
arovir019e53a352018-08-31 15:26:35 +010034 ARMNN_SCOPED_PROFILING_EVENT_NEON("NeonSoftmaxFloatWorkload_Execute");
Matthew Benthamd80a7122019-01-08 17:52:37 +000035 m_SoftmaxLayer->run();
telsoa014fcda012018-03-09 14:13:49 +000036}
surmeh013537c2c2018-05-18 16:31:43 +010037
telsoa014fcda012018-03-09 14:13:49 +000038} //namespace armnn
39