blob: f0122589a480789152c6e0e6cdce0fa643f06cc3 [file] [log] [blame]
Mike Kelly12994962022-04-21 11:57:09 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/Descriptors.hpp>
9#include <armnn/LstmParams.hpp>
10#include <armnn/backends/Workload.hpp>
11#include <armnn/backends/WorkloadData.hpp>
12#include "NeonBaseWorkload.hpp"
13
14#include "arm_compute/runtime/NEON/functions/NEQLSTMLayer.h"
15#include "arm_compute/runtime/NEON/functions/NEPermute.h"
16#include "arm_compute/runtime/NEON/functions/NESplit.h"
17#include "arm_compute/runtime/NEON/functions/NEConcatenateLayer.h"
18
19namespace armnn
20{
21
22class NeonUnidirectionalSequenceLstmWorkload : public NeonBaseWorkload<UnidirectionalSequenceLstmQueueDescriptor>
23{
24public:
25 NeonUnidirectionalSequenceLstmWorkload(const UnidirectionalSequenceLstmQueueDescriptor& descriptor,
26 const WorkloadInfo& info);
27 virtual void Execute() const override;
28
29private:
30
31 //
32 // ACL layers required to fully form a Unidirectional Sequence LSTM layer.
33 //
34 mutable std::unique_ptr<arm_compute::NEPermute> m_Permute1;
35 mutable std::unique_ptr<arm_compute::IFunction> m_Splitter;
36 mutable std::vector<std::unique_ptr<arm_compute::NEQLSTMLayer>> m_Layers;
37 mutable std::unique_ptr<arm_compute::NEConcatenateLayer> m_Concat;
38 mutable std::unique_ptr<arm_compute::NEPermute> m_Permute2;
39
40 //
41 // ACL LSTM arm_compute::Tensors.
42 //
43 std::unique_ptr<arm_compute::Tensor> m_InputToInputWeightsTensor;
44 std::unique_ptr<arm_compute::Tensor> m_InputToForgetWeightsTensor;
45 std::unique_ptr<arm_compute::Tensor> m_InputToCellWeightsTensor;
46 std::unique_ptr<arm_compute::Tensor> m_InputToOutputWeightsTensor;
47 std::unique_ptr<arm_compute::Tensor> m_RecurrentToInputWeightsTensor;
48 std::unique_ptr<arm_compute::Tensor> m_RecurrentToForgetWeightsTensor;
49 std::unique_ptr<arm_compute::Tensor> m_RecurrentToCellWeightsTensor;
50 std::unique_ptr<arm_compute::Tensor> m_RecurrentToOutputWeightsTensor;
51 std::unique_ptr<arm_compute::Tensor> m_CellToInputWeightsTensor;
52 std::unique_ptr<arm_compute::Tensor> m_CellToForgetWeightsTensor;
53 std::unique_ptr<arm_compute::Tensor> m_CellToOutputWeightsTensor;
54 std::unique_ptr<arm_compute::Tensor> m_InputGateBiasTensor;
55 std::unique_ptr<arm_compute::Tensor> m_ForgetGateBiasTensor;
56 std::unique_ptr<arm_compute::Tensor> m_CellBiasTensor;
57 std::unique_ptr<arm_compute::Tensor> m_OutputGateBiasTensor;
58 std::unique_ptr<arm_compute::Tensor> m_ProjectionWeightsTensor;
59 std::unique_ptr<arm_compute::Tensor> m_ProjectionBiasTensor;
60
61 std::unique_ptr<arm_compute::Tensor> m_InputLayerNormWeightsTensor;
62 std::unique_ptr<arm_compute::Tensor> m_ForgetLayerNormWeightsTensor;
63 std::unique_ptr<arm_compute::Tensor> m_CellLayerNormWeightsTensor;
64 std::unique_ptr<arm_compute::Tensor> m_OutputLayerNormWeightsTensor;
65
66 //
67 // Additional ACL arm_compute::Tensors and std::vector<arm_compute::Tensor>.
68 // Required to perform splitting, concatenation and permutations.
69 //
70 arm_compute::Tensor m_PermuteFirstOut;
71 std::vector<arm_compute::Tensor> m_SplitterOutputsTensors;
72 std::vector<arm_compute::Tensor> m_ConcatInputsTensors;
73 std::vector<arm_compute::ITensor*> m_SplitterOutputs;
74 std::vector<const arm_compute::ITensor*> m_ConcatInputs;
75 arm_compute::Tensor concat_out;
76
77 void FreeUnusedTensors();
78};
79
80arm_compute::Status
81NeonUnidirectionalSequenceLstmWorkloadValidate(const TensorInfo& input,
82 const TensorInfo& outputStateIn,
83 const TensorInfo& cellStateIn,
84 const TensorInfo& outputStateOut,
85 const TensorInfo& cellStateOut,
86 const TensorInfo& output,
87 const UnidirectionalSequenceLstmDescriptor& descriptor,
88 const LstmInputParamsInfo& paramsInfo);
89
90} //namespace armnn