blob: 3809ea875fdb0fd8cd47bd90829eca5854c71a35 [file] [log] [blame]
Narumol Prangnawarat8ed39ae2021-07-15 16:16:25 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "LayerWithParameters.hpp"
8
9namespace armnn
10{
11
12class ScopedTensorHandle;
13
14struct LstmOptLayerNormParameters
15{
16 /// A unique pointer to represent 1D weights tensor with dimensions [num_units].
17 std::shared_ptr<ConstTensorHandle> m_InputLayerNormWeights;
18 /// A unique pointer to represent 1D weights tensor with dimensions [num_units].
19 std::shared_ptr<ConstTensorHandle> m_ForgetLayerNormWeights;
20 /// A unique pointer to represent 1D weights tensor with dimensions [num_units].
21 std::shared_ptr<ConstTensorHandle> m_CellLayerNormWeights;
22 /// A unique pointer to represent 1D weights tensor with dimensions [num_units].
23 std::shared_ptr<ConstTensorHandle> m_OutputLayerNormWeights;
24};
25
26struct LstmOptCifgParameters
27{
28 /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units].
29 std::shared_ptr<ConstTensorHandle> m_InputToInputWeights;
30 /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units].
31 std::shared_ptr<ConstTensorHandle> m_RecurrentToInputWeights;
32 /// A unique pointer to represent 1D weights tensor with dimensions [num_units].
33 std::shared_ptr<ConstTensorHandle> m_InputGateBias;
34};
35
36struct LstmOptProjectionParameters
37{
38 /// A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units].
39 std::shared_ptr<ConstTensorHandle> m_ProjectionWeights;
40 /// A unique pointer to represent 1D weights tensor with dimensions [output_size].
41 std::shared_ptr<ConstTensorHandle> m_ProjectionBias;
42};
43
44struct LstmOptPeepholeParameters
45{
46 /// A unique pointer to represent 1D weights tensor with dimensions [num_units].
47 std::shared_ptr<ConstTensorHandle> m_CellToInputWeights;
48 /// A unique pointer to represent 1D weights tensor with dimensions [num_units].
49 std::shared_ptr<ConstTensorHandle> m_CellToForgetWeights;
50 /// A unique pointer to represent 1D weights tensor with dimensions [num_units].
51 std::shared_ptr<ConstTensorHandle> m_CellToOutputWeights;
52};
53
54struct LstmBasicParameters
55{
56 /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units].
57 std::shared_ptr<ConstTensorHandle> m_InputToForgetWeights;
58 /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units].
59 std::shared_ptr<ConstTensorHandle> m_InputToCellWeights;
60 /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units].
61 std::shared_ptr<ConstTensorHandle> m_InputToOutputWeights;
62 /// A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units].
63 std::shared_ptr<ConstTensorHandle> m_RecurrentToForgetWeights;
64 /// A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units].
65 std::shared_ptr<ConstTensorHandle> m_RecurrentToCellWeights;
66 /// A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units].
67 std::shared_ptr<ConstTensorHandle> m_RecurrentToOutputWeights;
68 /// A unique pointer to represent 1D weights tensor with dimensions [num_units].
69 std::shared_ptr<ConstTensorHandle> m_ForgetGateBias;
70 /// A unique pointer to represent 1D weights tensor with dimensions [num_units].
71 std::shared_ptr<ConstTensorHandle> m_CellBias;
72 /// A unique pointer to represent 1D weights tensor with dimensions [num_units].
73 std::shared_ptr<ConstTensorHandle> m_OutputGateBias;
74};
75
76} // namespace