blob: 6b0ad5c593829affae54567f67c2b7cce95ebe2b [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa01c577f2c2018-08-31 09:22:23 +01002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5#pragma once
6
7#include "TensorFwd.hpp"
Jan Eilersd01a83c2019-07-03 18:20:40 +01008#include "Exceptions.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +01009
10namespace armnn
11{
12
13struct LstmInputParams
14{
15 LstmInputParams()
16 : m_InputToInputWeights(nullptr)
17 , m_InputToForgetWeights(nullptr)
18 , m_InputToCellWeights(nullptr)
19 , m_InputToOutputWeights(nullptr)
20 , m_RecurrentToInputWeights(nullptr)
21 , m_RecurrentToForgetWeights(nullptr)
22 , m_RecurrentToCellWeights(nullptr)
23 , m_RecurrentToOutputWeights(nullptr)
24 , m_CellToInputWeights(nullptr)
25 , m_CellToForgetWeights(nullptr)
26 , m_CellToOutputWeights(nullptr)
27 , m_InputGateBias(nullptr)
28 , m_ForgetGateBias(nullptr)
29 , m_CellBias(nullptr)
30 , m_OutputGateBias(nullptr)
31 , m_ProjectionWeights(nullptr)
32 , m_ProjectionBias(nullptr)
Jan Eilers38e05bd2019-06-26 13:10:09 +010033 , m_InputLayerNormWeights(nullptr)
34 , m_ForgetLayerNormWeights(nullptr)
35 , m_CellLayerNormWeights(nullptr)
36 , m_OutputLayerNormWeights(nullptr)
telsoa01c577f2c2018-08-31 09:22:23 +010037 {
38 }
39
40 const ConstTensor* m_InputToInputWeights;
41 const ConstTensor* m_InputToForgetWeights;
42 const ConstTensor* m_InputToCellWeights;
43 const ConstTensor* m_InputToOutputWeights;
44 const ConstTensor* m_RecurrentToInputWeights;
45 const ConstTensor* m_RecurrentToForgetWeights;
46 const ConstTensor* m_RecurrentToCellWeights;
47 const ConstTensor* m_RecurrentToOutputWeights;
48 const ConstTensor* m_CellToInputWeights;
49 const ConstTensor* m_CellToForgetWeights;
50 const ConstTensor* m_CellToOutputWeights;
51 const ConstTensor* m_InputGateBias;
52 const ConstTensor* m_ForgetGateBias;
53 const ConstTensor* m_CellBias;
54 const ConstTensor* m_OutputGateBias;
55 const ConstTensor* m_ProjectionWeights;
56 const ConstTensor* m_ProjectionBias;
Jan Eilers38e05bd2019-06-26 13:10:09 +010057 const ConstTensor* m_InputLayerNormWeights;
58 const ConstTensor* m_ForgetLayerNormWeights;
59 const ConstTensor* m_CellLayerNormWeights;
60 const ConstTensor* m_OutputLayerNormWeights;
telsoa01c577f2c2018-08-31 09:22:23 +010061};
62
Jan Eilersd01a83c2019-07-03 18:20:40 +010063struct LstmInputParamsInfo
64{
65 LstmInputParamsInfo()
66 : m_InputToInputWeights(nullptr)
67 , m_InputToForgetWeights(nullptr)
68 , m_InputToCellWeights(nullptr)
69 , m_InputToOutputWeights(nullptr)
70 , m_RecurrentToInputWeights(nullptr)
71 , m_RecurrentToForgetWeights(nullptr)
72 , m_RecurrentToCellWeights(nullptr)
73 , m_RecurrentToOutputWeights(nullptr)
74 , m_CellToInputWeights(nullptr)
75 , m_CellToForgetWeights(nullptr)
76 , m_CellToOutputWeights(nullptr)
77 , m_InputGateBias(nullptr)
78 , m_ForgetGateBias(nullptr)
79 , m_CellBias(nullptr)
80 , m_OutputGateBias(nullptr)
81 , m_ProjectionWeights(nullptr)
82 , m_ProjectionBias(nullptr)
83 , m_InputLayerNormWeights(nullptr)
84 , m_ForgetLayerNormWeights(nullptr)
85 , m_CellLayerNormWeights(nullptr)
86 , m_OutputLayerNormWeights(nullptr)
87 {
88 }
89 const TensorInfo* m_InputToInputWeights;
90 const TensorInfo* m_InputToForgetWeights;
91 const TensorInfo* m_InputToCellWeights;
92 const TensorInfo* m_InputToOutputWeights;
93 const TensorInfo* m_RecurrentToInputWeights;
94 const TensorInfo* m_RecurrentToForgetWeights;
95 const TensorInfo* m_RecurrentToCellWeights;
96 const TensorInfo* m_RecurrentToOutputWeights;
97 const TensorInfo* m_CellToInputWeights;
98 const TensorInfo* m_CellToForgetWeights;
99 const TensorInfo* m_CellToOutputWeights;
100 const TensorInfo* m_InputGateBias;
101 const TensorInfo* m_ForgetGateBias;
102 const TensorInfo* m_CellBias;
103 const TensorInfo* m_OutputGateBias;
104 const TensorInfo* m_ProjectionWeights;
105 const TensorInfo* m_ProjectionBias;
106 const TensorInfo* m_InputLayerNormWeights;
107 const TensorInfo* m_ForgetLayerNormWeights;
108 const TensorInfo* m_CellLayerNormWeights;
109 const TensorInfo* m_OutputLayerNormWeights;
110
Francis Murtaghbb590b42019-08-14 09:51:36 +0100111 const TensorInfo& Deref(const TensorInfo* tensorInfo) const
Jan Eilersd01a83c2019-07-03 18:20:40 +0100112 {
113 if (tensorInfo != nullptr)
114 {
115 const TensorInfo &temp = *tensorInfo;
116 return temp;
117 }
118 throw InvalidArgumentException("Can't dereference a null pointer");
119 }
120
Francis Murtaghbb590b42019-08-14 09:51:36 +0100121 const TensorInfo& GetInputToInputWeights() const
Jan Eilersd01a83c2019-07-03 18:20:40 +0100122 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100123 return Deref(m_InputToInputWeights);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100124 }
Francis Murtaghbb590b42019-08-14 09:51:36 +0100125 const TensorInfo& GetInputToForgetWeights() const
Jan Eilersd01a83c2019-07-03 18:20:40 +0100126 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100127 return Deref(m_InputToForgetWeights);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100128 }
Francis Murtaghbb590b42019-08-14 09:51:36 +0100129 const TensorInfo& GetInputToCellWeights() const
Jan Eilersd01a83c2019-07-03 18:20:40 +0100130 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100131 return Deref(m_InputToCellWeights);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100132 }
Francis Murtaghbb590b42019-08-14 09:51:36 +0100133 const TensorInfo& GetInputToOutputWeights() const
Jan Eilersd01a83c2019-07-03 18:20:40 +0100134 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100135 return Deref(m_InputToOutputWeights);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100136 }
Francis Murtaghbb590b42019-08-14 09:51:36 +0100137 const TensorInfo& GetRecurrentToInputWeights() const
Jan Eilersd01a83c2019-07-03 18:20:40 +0100138 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100139 return Deref(m_RecurrentToInputWeights);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100140 }
Francis Murtaghbb590b42019-08-14 09:51:36 +0100141 const TensorInfo& GetRecurrentToForgetWeights() const
Jan Eilersd01a83c2019-07-03 18:20:40 +0100142 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100143 return Deref(m_RecurrentToForgetWeights);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100144 }
Francis Murtaghbb590b42019-08-14 09:51:36 +0100145 const TensorInfo& GetRecurrentToCellWeights() const
Jan Eilersd01a83c2019-07-03 18:20:40 +0100146 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100147 return Deref(m_RecurrentToCellWeights);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100148 }
Francis Murtaghbb590b42019-08-14 09:51:36 +0100149 const TensorInfo& GetRecurrentToOutputWeights() const
Jan Eilersd01a83c2019-07-03 18:20:40 +0100150 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100151 return Deref(m_RecurrentToOutputWeights);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100152 }
Francis Murtaghbb590b42019-08-14 09:51:36 +0100153 const TensorInfo& GetCellToInputWeights() const
Jan Eilersd01a83c2019-07-03 18:20:40 +0100154 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100155 return Deref(m_CellToInputWeights);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100156 }
Francis Murtaghbb590b42019-08-14 09:51:36 +0100157 const TensorInfo& GetCellToForgetWeights() const
Jan Eilersd01a83c2019-07-03 18:20:40 +0100158 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100159 return Deref(m_CellToForgetWeights);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100160 }
Francis Murtaghbb590b42019-08-14 09:51:36 +0100161 const TensorInfo& GetCellToOutputWeights() const
Jan Eilersd01a83c2019-07-03 18:20:40 +0100162 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100163 return Deref(m_CellToOutputWeights);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100164 }
Francis Murtaghbb590b42019-08-14 09:51:36 +0100165 const TensorInfo& GetInputGateBias() const
Jan Eilersd01a83c2019-07-03 18:20:40 +0100166 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100167 return Deref(m_InputGateBias);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100168 }
Francis Murtaghbb590b42019-08-14 09:51:36 +0100169 const TensorInfo& GetForgetGateBias() const
Jan Eilersd01a83c2019-07-03 18:20:40 +0100170 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100171 return Deref(m_ForgetGateBias);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100172 }
Francis Murtaghbb590b42019-08-14 09:51:36 +0100173 const TensorInfo& GetCellBias() const
Jan Eilersd01a83c2019-07-03 18:20:40 +0100174 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100175 return Deref(m_CellBias);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100176 }
Francis Murtaghbb590b42019-08-14 09:51:36 +0100177 const TensorInfo& GetOutputGateBias() const
Jan Eilersd01a83c2019-07-03 18:20:40 +0100178 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100179 return Deref(m_OutputGateBias);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100180 }
Francis Murtaghbb590b42019-08-14 09:51:36 +0100181 const TensorInfo& GetProjectionWeights() const
Jan Eilersd01a83c2019-07-03 18:20:40 +0100182 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100183 return Deref(m_ProjectionWeights);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100184 }
Francis Murtaghbb590b42019-08-14 09:51:36 +0100185 const TensorInfo& GetProjectionBias() const
Jan Eilersd01a83c2019-07-03 18:20:40 +0100186 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100187 return Deref(m_ProjectionBias);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100188 }
Francis Murtaghbb590b42019-08-14 09:51:36 +0100189 const TensorInfo& GetInputLayerNormWeights() const
Jan Eilersd01a83c2019-07-03 18:20:40 +0100190 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100191 return Deref(m_InputLayerNormWeights);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100192 }
Francis Murtaghbb590b42019-08-14 09:51:36 +0100193 const TensorInfo& GetForgetLayerNormWeights() const
Jan Eilersd01a83c2019-07-03 18:20:40 +0100194 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100195 return Deref(m_ForgetLayerNormWeights);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100196 }
Francis Murtaghbb590b42019-08-14 09:51:36 +0100197 const TensorInfo& GetCellLayerNormWeights() const
Jan Eilersd01a83c2019-07-03 18:20:40 +0100198 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100199 return Deref(m_CellLayerNormWeights);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100200 }
Francis Murtaghbb590b42019-08-14 09:51:36 +0100201 const TensorInfo& GetOutputLayerNormWeights() const
Jan Eilersd01a83c2019-07-03 18:20:40 +0100202 {
Francis Murtaghbb590b42019-08-14 09:51:36 +0100203 return Deref(m_OutputLayerNormWeights);
Jan Eilersd01a83c2019-07-03 18:20:40 +0100204 }
205};
206
telsoa01c577f2c2018-08-31 09:22:23 +0100207} // namespace armnn
208