blob: 776afd3965e44b419ee505f6767ce2755a5310c7 [file] [log] [blame]
Cathal Corbettfd5bec42022-03-03 15:13:23 +00001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/Descriptors.hpp>
9#include <armnn/LstmParams.hpp>
10#include <armnn/backends/Workload.hpp>
11#include <armnn/backends/WorkloadData.hpp>
12
13#include "arm_compute/graph/Tensor.h"
14#include "arm_compute/runtime/NEON/functions/NELSTMLayer.h"
15#include "arm_compute/runtime/NEON/functions/NEPermute.h"
16#include "arm_compute/runtime/NEON/functions/NESplit.h"
17#include "arm_compute/runtime/NEON/functions/NEConcatenateLayer.h"
18
19namespace armnn
20{
21
22class NeonUnidirectionalSequenceLstmFloatWorkload : public FloatWorkload<UnidirectionalSequenceLstmQueueDescriptor>
23{
24public:
25 NeonUnidirectionalSequenceLstmFloatWorkload(const UnidirectionalSequenceLstmQueueDescriptor& descriptor,
26 const WorkloadInfo& info);
27 virtual void Execute() const override;
28
29private:
30
31 //
32 // ACL layers required to fully form a Unidirectional Sequence LSTM layer.
33 //
Cathal Corbett4952a3e2022-03-03 15:14:18 +000034
35 // permutation for input (only used when input is batch major)
Cathal Corbettfd5bec42022-03-03 15:13:23 +000036 mutable std::unique_ptr<arm_compute::NEPermute> m_Permute1;
37 mutable std::unique_ptr<arm_compute::IFunction> m_Splitter;
38 mutable std::vector<std::unique_ptr<arm_compute::NELSTMLayer>> m_Layers;
39 mutable std::unique_ptr<arm_compute::NEConcatenateLayer> m_Concat;
Cathal Corbett4952a3e2022-03-03 15:14:18 +000040 // permutation for output (only used when input is batch major)
Cathal Corbettfd5bec42022-03-03 15:13:23 +000041 mutable std::unique_ptr<arm_compute::NEPermute> m_Permute2;
42
43 //
44 // ACL LSTM arm_compute::Tensors.
45 //
46 std::unique_ptr<arm_compute::Tensor> m_InputToInputWeightsTensor;
47 std::unique_ptr<arm_compute::Tensor> m_InputToForgetWeightsTensor;
48 std::unique_ptr<arm_compute::Tensor> m_InputToCellWeightsTensor;
49 std::unique_ptr<arm_compute::Tensor> m_InputToOutputWeightsTensor;
50 std::unique_ptr<arm_compute::Tensor> m_RecurrentToInputWeightsTensor;
51 std::unique_ptr<arm_compute::Tensor> m_RecurrentToForgetWeightsTensor;
52 std::unique_ptr<arm_compute::Tensor> m_RecurrentToCellWeightsTensor;
53 std::unique_ptr<arm_compute::Tensor> m_RecurrentToOutputWeightsTensor;
54 std::unique_ptr<arm_compute::Tensor> m_CellToInputWeightsTensor;
55 std::unique_ptr<arm_compute::Tensor> m_CellToForgetWeightsTensor;
56 std::unique_ptr<arm_compute::Tensor> m_CellToOutputWeightsTensor;
57 std::unique_ptr<arm_compute::Tensor> m_InputGateBiasTensor;
58 std::unique_ptr<arm_compute::Tensor> m_ForgetGateBiasTensor;
59 std::unique_ptr<arm_compute::Tensor> m_CellBiasTensor;
60 std::unique_ptr<arm_compute::Tensor> m_OutputGateBiasTensor;
61 std::unique_ptr<arm_compute::Tensor> m_ProjectionWeightsTensor;
62 std::unique_ptr<arm_compute::Tensor> m_ProjectionBiasTensor;
63
64 std::unique_ptr<arm_compute::Tensor> m_ScratchBuffer;
65
66 std::unique_ptr<arm_compute::Tensor> m_InputLayerNormWeightsTensor;
67 std::unique_ptr<arm_compute::Tensor> m_ForgetLayerNormWeightsTensor;
68 std::unique_ptr<arm_compute::Tensor> m_CellLayerNormWeightsTensor;
69 std::unique_ptr<arm_compute::Tensor> m_OutputLayerNormWeightsTensor;
70
71 //
72 // Additional ACL arm_compute::Tensors and std::vector<arm_compute::Tensor>.
73 // Required to perform splitting, concatenation and permutations.
74 //
75 arm_compute::Tensor m_PermuteFirstOut;
76 std::vector<arm_compute::Tensor> m_SplitterOutputsTensors;
77 std::vector<arm_compute::Tensor> m_ConcatInputsTensors;
78 std::vector<arm_compute::ITensor*> m_SplitterOutputs;
79 std::vector<const arm_compute::ITensor*> m_ConcatInputs;
80 arm_compute::Tensor concat_out;
81
82 void FreeUnusedTensors();
83};
84
85arm_compute::Status
86NeonUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input,
87 const TensorInfo& outputStateIn,
88 const TensorInfo& cellStateIn,
89 const TensorInfo& output,
90 const Optional<TensorInfo>& hiddenStateOutput,
91 const Optional<TensorInfo>& cellStateOutput,
92 const UnidirectionalSequenceLstmDescriptor& descriptor,
93 const LstmInputParamsInfo& paramsInfo);
94
95} //namespace armnn