blob: 10c2ecbd19a6e2c27cd9bfdc1b68f1a23a309816 [file] [log] [blame]
Cathal Corbettfd5bec42022-03-03 15:13:23 +00001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/Descriptors.hpp>
9#include <armnn/LstmParams.hpp>
10#include <armnn/backends/Workload.hpp>
11#include <armnn/backends/WorkloadData.hpp>
12
13#include "arm_compute/graph/Tensor.h"
14#include "arm_compute/runtime/NEON/functions/NELSTMLayer.h"
15#include "arm_compute/runtime/NEON/functions/NEPermute.h"
16#include "arm_compute/runtime/NEON/functions/NESplit.h"
17#include "arm_compute/runtime/NEON/functions/NEConcatenateLayer.h"
18
19namespace armnn
20{
21
22class NeonUnidirectionalSequenceLstmFloatWorkload : public FloatWorkload<UnidirectionalSequenceLstmQueueDescriptor>
23{
24public:
25 NeonUnidirectionalSequenceLstmFloatWorkload(const UnidirectionalSequenceLstmQueueDescriptor& descriptor,
26 const WorkloadInfo& info);
27 virtual void Execute() const override;
28
29private:
30
31 //
32 // ACL layers required to fully form a Unidirectional Sequence LSTM layer.
33 //
34 mutable std::unique_ptr<arm_compute::NEPermute> m_Permute1;
35 mutable std::unique_ptr<arm_compute::IFunction> m_Splitter;
36 mutable std::vector<std::unique_ptr<arm_compute::NELSTMLayer>> m_Layers;
37 mutable std::unique_ptr<arm_compute::NEConcatenateLayer> m_Concat;
38 mutable std::unique_ptr<arm_compute::NEPermute> m_Permute2;
39
40 //
41 // ACL LSTM arm_compute::Tensors.
42 //
43 std::unique_ptr<arm_compute::Tensor> m_InputToInputWeightsTensor;
44 std::unique_ptr<arm_compute::Tensor> m_InputToForgetWeightsTensor;
45 std::unique_ptr<arm_compute::Tensor> m_InputToCellWeightsTensor;
46 std::unique_ptr<arm_compute::Tensor> m_InputToOutputWeightsTensor;
47 std::unique_ptr<arm_compute::Tensor> m_RecurrentToInputWeightsTensor;
48 std::unique_ptr<arm_compute::Tensor> m_RecurrentToForgetWeightsTensor;
49 std::unique_ptr<arm_compute::Tensor> m_RecurrentToCellWeightsTensor;
50 std::unique_ptr<arm_compute::Tensor> m_RecurrentToOutputWeightsTensor;
51 std::unique_ptr<arm_compute::Tensor> m_CellToInputWeightsTensor;
52 std::unique_ptr<arm_compute::Tensor> m_CellToForgetWeightsTensor;
53 std::unique_ptr<arm_compute::Tensor> m_CellToOutputWeightsTensor;
54 std::unique_ptr<arm_compute::Tensor> m_InputGateBiasTensor;
55 std::unique_ptr<arm_compute::Tensor> m_ForgetGateBiasTensor;
56 std::unique_ptr<arm_compute::Tensor> m_CellBiasTensor;
57 std::unique_ptr<arm_compute::Tensor> m_OutputGateBiasTensor;
58 std::unique_ptr<arm_compute::Tensor> m_ProjectionWeightsTensor;
59 std::unique_ptr<arm_compute::Tensor> m_ProjectionBiasTensor;
60
61 std::unique_ptr<arm_compute::Tensor> m_ScratchBuffer;
62
63 std::unique_ptr<arm_compute::Tensor> m_InputLayerNormWeightsTensor;
64 std::unique_ptr<arm_compute::Tensor> m_ForgetLayerNormWeightsTensor;
65 std::unique_ptr<arm_compute::Tensor> m_CellLayerNormWeightsTensor;
66 std::unique_ptr<arm_compute::Tensor> m_OutputLayerNormWeightsTensor;
67
68 //
69 // Additional ACL arm_compute::Tensors and std::vector<arm_compute::Tensor>.
70 // Required to perform splitting, concatenation and permutations.
71 //
72 arm_compute::Tensor m_PermuteFirstOut;
73 std::vector<arm_compute::Tensor> m_SplitterOutputsTensors;
74 std::vector<arm_compute::Tensor> m_ConcatInputsTensors;
75 std::vector<arm_compute::ITensor*> m_SplitterOutputs;
76 std::vector<const arm_compute::ITensor*> m_ConcatInputs;
77 arm_compute::Tensor concat_out;
78
79 void FreeUnusedTensors();
80};
81
82arm_compute::Status
83NeonUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input,
84 const TensorInfo& outputStateIn,
85 const TensorInfo& cellStateIn,
86 const TensorInfo& output,
87 const Optional<TensorInfo>& hiddenStateOutput,
88 const Optional<TensorInfo>& cellStateOutput,
89 const UnidirectionalSequenceLstmDescriptor& descriptor,
90 const LstmInputParamsInfo& paramsInfo);
91
92} //namespace armnn