blob: 0f6f837785813cff4a048ceae22447b47dcb9c76 [file] [log] [blame]
nikraj01a121de32019-05-29 10:51:05 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RefSoftmaxWorkload.hpp"
7
8#include "Decoders.hpp"
9#include "Encoders.hpp"
10#include "RefWorkloadUtils.hpp"
11#include "Softmax.hpp"
12
13#include "Profiling.hpp"
14
15#include <vector>
16
17namespace armnn
18{
19
20void RefSoftmaxWorkload::Execute() const
21{
22 ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefSoftmaxWorkload_Execute");
23
24 const TensorInfo &inputTensorInfo = GetTensorInfo(m_Data.m_Inputs[0]);
25
26 std::unique_ptr<Decoder<float>> decoderPtr = MakeDecoder<float>(inputTensorInfo, m_Data.m_Inputs[0]->Map());
27 Decoder<float> &decoder = *decoderPtr;
28
29 const TensorInfo &outputTensorInfo = GetTensorInfo(m_Data.m_Outputs[0]);
30
31 std::unique_ptr<Encoder<float>> encoderPtr = MakeEncoder<float>(outputTensorInfo, m_Data.m_Outputs[0]->Map());
32 Encoder<float> &encoder = *encoderPtr;
33
34 Softmax(decoder,
35 encoder,
36 inputTensorInfo,
Francis Murtagh07f21212019-07-23 09:50:50 +010037 m_Data.m_Parameters.m_Beta,
38 m_Data.m_Parameters.m_Axis);
nikraj01a121de32019-05-29 10:51:05 +010039}
40} //namespace armnn