blob: 48cf7dc7e4ea2df92889f8f08995ccf2a36f89e0 [file] [log] [blame]
Cathal Corbettfd5bec42022-03-03 15:13:23 +00001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/Descriptors.hpp>
9#include <armnn/LstmParams.hpp>
10#include <armnn/backends/Workload.hpp>
11#include <armnn/backends/WorkloadData.hpp>
12
Cathal Corbettfd5bec42022-03-03 15:13:23 +000013#include "arm_compute/runtime/NEON/functions/NELSTMLayer.h"
14#include "arm_compute/runtime/NEON/functions/NEPermute.h"
15#include "arm_compute/runtime/NEON/functions/NESplit.h"
16#include "arm_compute/runtime/NEON/functions/NEConcatenateLayer.h"
17
18namespace armnn
19{
20
21class NeonUnidirectionalSequenceLstmFloatWorkload : public FloatWorkload<UnidirectionalSequenceLstmQueueDescriptor>
22{
23public:
24 NeonUnidirectionalSequenceLstmFloatWorkload(const UnidirectionalSequenceLstmQueueDescriptor& descriptor,
25 const WorkloadInfo& info);
26 virtual void Execute() const override;
27
28private:
29
30 //
31 // ACL layers required to fully form a Unidirectional Sequence LSTM layer.
32 //
Cathal Corbett4952a3e2022-03-03 15:14:18 +000033
34 // permutation for input (only used when input is batch major)
Cathal Corbettfd5bec42022-03-03 15:13:23 +000035 mutable std::unique_ptr<arm_compute::NEPermute> m_Permute1;
36 mutable std::unique_ptr<arm_compute::IFunction> m_Splitter;
37 mutable std::vector<std::unique_ptr<arm_compute::NELSTMLayer>> m_Layers;
38 mutable std::unique_ptr<arm_compute::NEConcatenateLayer> m_Concat;
Cathal Corbett4952a3e2022-03-03 15:14:18 +000039 // permutation for output (only used when input is batch major)
Cathal Corbettfd5bec42022-03-03 15:13:23 +000040 mutable std::unique_ptr<arm_compute::NEPermute> m_Permute2;
41
42 //
43 // ACL LSTM arm_compute::Tensors.
44 //
45 std::unique_ptr<arm_compute::Tensor> m_InputToInputWeightsTensor;
46 std::unique_ptr<arm_compute::Tensor> m_InputToForgetWeightsTensor;
47 std::unique_ptr<arm_compute::Tensor> m_InputToCellWeightsTensor;
48 std::unique_ptr<arm_compute::Tensor> m_InputToOutputWeightsTensor;
49 std::unique_ptr<arm_compute::Tensor> m_RecurrentToInputWeightsTensor;
50 std::unique_ptr<arm_compute::Tensor> m_RecurrentToForgetWeightsTensor;
51 std::unique_ptr<arm_compute::Tensor> m_RecurrentToCellWeightsTensor;
52 std::unique_ptr<arm_compute::Tensor> m_RecurrentToOutputWeightsTensor;
53 std::unique_ptr<arm_compute::Tensor> m_CellToInputWeightsTensor;
54 std::unique_ptr<arm_compute::Tensor> m_CellToForgetWeightsTensor;
55 std::unique_ptr<arm_compute::Tensor> m_CellToOutputWeightsTensor;
56 std::unique_ptr<arm_compute::Tensor> m_InputGateBiasTensor;
57 std::unique_ptr<arm_compute::Tensor> m_ForgetGateBiasTensor;
58 std::unique_ptr<arm_compute::Tensor> m_CellBiasTensor;
59 std::unique_ptr<arm_compute::Tensor> m_OutputGateBiasTensor;
60 std::unique_ptr<arm_compute::Tensor> m_ProjectionWeightsTensor;
61 std::unique_ptr<arm_compute::Tensor> m_ProjectionBiasTensor;
62
63 std::unique_ptr<arm_compute::Tensor> m_ScratchBuffer;
64
65 std::unique_ptr<arm_compute::Tensor> m_InputLayerNormWeightsTensor;
66 std::unique_ptr<arm_compute::Tensor> m_ForgetLayerNormWeightsTensor;
67 std::unique_ptr<arm_compute::Tensor> m_CellLayerNormWeightsTensor;
68 std::unique_ptr<arm_compute::Tensor> m_OutputLayerNormWeightsTensor;
69
70 //
71 // Additional ACL arm_compute::Tensors and std::vector<arm_compute::Tensor>.
72 // Required to perform splitting, concatenation and permutations.
73 //
74 arm_compute::Tensor m_PermuteFirstOut;
75 std::vector<arm_compute::Tensor> m_SplitterOutputsTensors;
76 std::vector<arm_compute::Tensor> m_ConcatInputsTensors;
77 std::vector<arm_compute::ITensor*> m_SplitterOutputs;
78 std::vector<const arm_compute::ITensor*> m_ConcatInputs;
79 arm_compute::Tensor concat_out;
80
81 void FreeUnusedTensors();
82};
83
84arm_compute::Status
85NeonUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input,
86 const TensorInfo& outputStateIn,
87 const TensorInfo& cellStateIn,
Mike Kelly12994962022-04-21 11:57:09 +010088 const TensorInfo& outputStateOut,
89 const TensorInfo& cellStateOut,
Cathal Corbettfd5bec42022-03-03 15:13:23 +000090 const TensorInfo& output,
Cathal Corbettfd5bec42022-03-03 15:13:23 +000091 const UnidirectionalSequenceLstmDescriptor& descriptor,
92 const LstmInputParamsInfo& paramsInfo);
93
94} //namespace armnn