blob: d9c78bbd43ea4342c5d5a7a2f360d6f1a01ff782 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
arovir019e53a352018-08-31 15:26:35 +01006#include "NeonSoftmaxFloatWorkload.hpp"
telsoa014fcda012018-03-09 14:13:49 +00007
Matthew Benthamd80a7122019-01-08 17:52:37 +00008#include "NeonWorkloadUtils.hpp"
9
10#include <arm_compute/runtime/NEON/functions/NESoftmaxLayer.h>
11
telsoa014fcda012018-03-09 14:13:49 +000012namespace armnn
13{
surmeh013537c2c2018-05-18 16:31:43 +010014
arovir019e53a352018-08-31 15:26:35 +010015NeonSoftmaxFloatWorkload::NeonSoftmaxFloatWorkload(const SoftmaxQueueDescriptor& descriptor,
surmeh013537c2c2018-05-18 16:31:43 +010016 const WorkloadInfo& info, std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager)
telsoa01c577f2c2018-08-31 09:22:23 +010017 : FloatWorkload<SoftmaxQueueDescriptor>(descriptor, info)
telsoa014fcda012018-03-09 14:13:49 +000018{
arovir019e53a352018-08-31 15:26:35 +010019 m_Data.ValidateInputsOutputs("NeonSoftmaxFloatWorkload", 1, 1);
telsoa014fcda012018-03-09 14:13:49 +000020
telsoa01c577f2c2018-08-31 09:22:23 +010021 // The ArmCompute softmax layer uses 2D input/output tensors, so flatten the first three dimensions.
telsoa014fcda012018-03-09 14:13:49 +000022 arm_compute::ITensor& input = boost::polymorphic_downcast<INeonTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
23 arm_compute::ITensor& output = boost::polymorphic_downcast<INeonTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
24
Matthew Benthamd80a7122019-01-08 17:52:37 +000025 auto layer = std::make_unique<arm_compute::NESoftmaxLayer>(memoryManager);
26 layer->configure(&input, &output, m_Data.m_Parameters.m_Beta);
27 m_SoftmaxLayer.reset(layer.release());
telsoa014fcda012018-03-09 14:13:49 +000028}
29
arovir019e53a352018-08-31 15:26:35 +010030void NeonSoftmaxFloatWorkload::Execute() const
telsoa014fcda012018-03-09 14:13:49 +000031{
arovir019e53a352018-08-31 15:26:35 +010032 ARMNN_SCOPED_PROFILING_EVENT_NEON("NeonSoftmaxFloatWorkload_Execute");
Matthew Benthamd80a7122019-01-08 17:52:37 +000033 m_SoftmaxLayer->run();
telsoa014fcda012018-03-09 14:13:49 +000034}
surmeh013537c2c2018-05-18 16:31:43 +010035
telsoa014fcda012018-03-09 14:13:49 +000036} //namespace armnn
37