Laurent Carlier | 749294b | 2020-06-01 09:03:17 +0100 | [diff] [blame] | 1 | // |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 2 | // Copyright © 2017 Arm Ltd. All rights reserved. |
David Beck | ecb56cd | 2018-09-05 12:52:57 +0100 | [diff] [blame] | 3 | // SPDX-License-Identifier: MIT |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 4 | // |
| 5 | #pragma once |
| 6 | |
| 7 | #include "TensorFwd.hpp" |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 8 | #include "Exceptions.hpp" |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 9 | |
| 10 | namespace armnn |
| 11 | { |
| 12 | |
| 13 | struct LstmInputParams |
| 14 | { |
| 15 | LstmInputParams() |
| 16 | : m_InputToInputWeights(nullptr) |
| 17 | , m_InputToForgetWeights(nullptr) |
| 18 | , m_InputToCellWeights(nullptr) |
| 19 | , m_InputToOutputWeights(nullptr) |
| 20 | , m_RecurrentToInputWeights(nullptr) |
| 21 | , m_RecurrentToForgetWeights(nullptr) |
| 22 | , m_RecurrentToCellWeights(nullptr) |
| 23 | , m_RecurrentToOutputWeights(nullptr) |
| 24 | , m_CellToInputWeights(nullptr) |
| 25 | , m_CellToForgetWeights(nullptr) |
| 26 | , m_CellToOutputWeights(nullptr) |
| 27 | , m_InputGateBias(nullptr) |
| 28 | , m_ForgetGateBias(nullptr) |
| 29 | , m_CellBias(nullptr) |
| 30 | , m_OutputGateBias(nullptr) |
| 31 | , m_ProjectionWeights(nullptr) |
| 32 | , m_ProjectionBias(nullptr) |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 33 | , m_InputLayerNormWeights(nullptr) |
| 34 | , m_ForgetLayerNormWeights(nullptr) |
| 35 | , m_CellLayerNormWeights(nullptr) |
| 36 | , m_OutputLayerNormWeights(nullptr) |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 37 | { |
| 38 | } |
| 39 | |
| 40 | const ConstTensor* m_InputToInputWeights; |
| 41 | const ConstTensor* m_InputToForgetWeights; |
| 42 | const ConstTensor* m_InputToCellWeights; |
| 43 | const ConstTensor* m_InputToOutputWeights; |
| 44 | const ConstTensor* m_RecurrentToInputWeights; |
| 45 | const ConstTensor* m_RecurrentToForgetWeights; |
| 46 | const ConstTensor* m_RecurrentToCellWeights; |
| 47 | const ConstTensor* m_RecurrentToOutputWeights; |
| 48 | const ConstTensor* m_CellToInputWeights; |
| 49 | const ConstTensor* m_CellToForgetWeights; |
| 50 | const ConstTensor* m_CellToOutputWeights; |
| 51 | const ConstTensor* m_InputGateBias; |
| 52 | const ConstTensor* m_ForgetGateBias; |
| 53 | const ConstTensor* m_CellBias; |
| 54 | const ConstTensor* m_OutputGateBias; |
| 55 | const ConstTensor* m_ProjectionWeights; |
| 56 | const ConstTensor* m_ProjectionBias; |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 57 | const ConstTensor* m_InputLayerNormWeights; |
| 58 | const ConstTensor* m_ForgetLayerNormWeights; |
| 59 | const ConstTensor* m_CellLayerNormWeights; |
| 60 | const ConstTensor* m_OutputLayerNormWeights; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 61 | }; |
| 62 | |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 63 | struct LstmInputParamsInfo |
| 64 | { |
| 65 | LstmInputParamsInfo() |
| 66 | : m_InputToInputWeights(nullptr) |
| 67 | , m_InputToForgetWeights(nullptr) |
| 68 | , m_InputToCellWeights(nullptr) |
| 69 | , m_InputToOutputWeights(nullptr) |
| 70 | , m_RecurrentToInputWeights(nullptr) |
| 71 | , m_RecurrentToForgetWeights(nullptr) |
| 72 | , m_RecurrentToCellWeights(nullptr) |
| 73 | , m_RecurrentToOutputWeights(nullptr) |
| 74 | , m_CellToInputWeights(nullptr) |
| 75 | , m_CellToForgetWeights(nullptr) |
| 76 | , m_CellToOutputWeights(nullptr) |
| 77 | , m_InputGateBias(nullptr) |
| 78 | , m_ForgetGateBias(nullptr) |
| 79 | , m_CellBias(nullptr) |
| 80 | , m_OutputGateBias(nullptr) |
| 81 | , m_ProjectionWeights(nullptr) |
| 82 | , m_ProjectionBias(nullptr) |
| 83 | , m_InputLayerNormWeights(nullptr) |
| 84 | , m_ForgetLayerNormWeights(nullptr) |
| 85 | , m_CellLayerNormWeights(nullptr) |
| 86 | , m_OutputLayerNormWeights(nullptr) |
| 87 | { |
| 88 | } |
| 89 | const TensorInfo* m_InputToInputWeights; |
| 90 | const TensorInfo* m_InputToForgetWeights; |
| 91 | const TensorInfo* m_InputToCellWeights; |
| 92 | const TensorInfo* m_InputToOutputWeights; |
| 93 | const TensorInfo* m_RecurrentToInputWeights; |
| 94 | const TensorInfo* m_RecurrentToForgetWeights; |
| 95 | const TensorInfo* m_RecurrentToCellWeights; |
| 96 | const TensorInfo* m_RecurrentToOutputWeights; |
| 97 | const TensorInfo* m_CellToInputWeights; |
| 98 | const TensorInfo* m_CellToForgetWeights; |
| 99 | const TensorInfo* m_CellToOutputWeights; |
| 100 | const TensorInfo* m_InputGateBias; |
| 101 | const TensorInfo* m_ForgetGateBias; |
| 102 | const TensorInfo* m_CellBias; |
| 103 | const TensorInfo* m_OutputGateBias; |
| 104 | const TensorInfo* m_ProjectionWeights; |
| 105 | const TensorInfo* m_ProjectionBias; |
| 106 | const TensorInfo* m_InputLayerNormWeights; |
| 107 | const TensorInfo* m_ForgetLayerNormWeights; |
| 108 | const TensorInfo* m_CellLayerNormWeights; |
| 109 | const TensorInfo* m_OutputLayerNormWeights; |
| 110 | |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 111 | const TensorInfo& Deref(const TensorInfo* tensorInfo) const |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 112 | { |
| 113 | if (tensorInfo != nullptr) |
| 114 | { |
| 115 | const TensorInfo &temp = *tensorInfo; |
| 116 | return temp; |
| 117 | } |
| 118 | throw InvalidArgumentException("Can't dereference a null pointer"); |
| 119 | } |
| 120 | |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 121 | const TensorInfo& GetInputToInputWeights() const |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 122 | { |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 123 | return Deref(m_InputToInputWeights); |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 124 | } |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 125 | const TensorInfo& GetInputToForgetWeights() const |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 126 | { |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 127 | return Deref(m_InputToForgetWeights); |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 128 | } |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 129 | const TensorInfo& GetInputToCellWeights() const |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 130 | { |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 131 | return Deref(m_InputToCellWeights); |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 132 | } |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 133 | const TensorInfo& GetInputToOutputWeights() const |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 134 | { |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 135 | return Deref(m_InputToOutputWeights); |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 136 | } |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 137 | const TensorInfo& GetRecurrentToInputWeights() const |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 138 | { |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 139 | return Deref(m_RecurrentToInputWeights); |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 140 | } |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 141 | const TensorInfo& GetRecurrentToForgetWeights() const |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 142 | { |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 143 | return Deref(m_RecurrentToForgetWeights); |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 144 | } |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 145 | const TensorInfo& GetRecurrentToCellWeights() const |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 146 | { |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 147 | return Deref(m_RecurrentToCellWeights); |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 148 | } |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 149 | const TensorInfo& GetRecurrentToOutputWeights() const |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 150 | { |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 151 | return Deref(m_RecurrentToOutputWeights); |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 152 | } |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 153 | const TensorInfo& GetCellToInputWeights() const |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 154 | { |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 155 | return Deref(m_CellToInputWeights); |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 156 | } |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 157 | const TensorInfo& GetCellToForgetWeights() const |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 158 | { |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 159 | return Deref(m_CellToForgetWeights); |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 160 | } |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 161 | const TensorInfo& GetCellToOutputWeights() const |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 162 | { |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 163 | return Deref(m_CellToOutputWeights); |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 164 | } |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 165 | const TensorInfo& GetInputGateBias() const |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 166 | { |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 167 | return Deref(m_InputGateBias); |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 168 | } |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 169 | const TensorInfo& GetForgetGateBias() const |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 170 | { |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 171 | return Deref(m_ForgetGateBias); |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 172 | } |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 173 | const TensorInfo& GetCellBias() const |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 174 | { |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 175 | return Deref(m_CellBias); |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 176 | } |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 177 | const TensorInfo& GetOutputGateBias() const |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 178 | { |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 179 | return Deref(m_OutputGateBias); |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 180 | } |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 181 | const TensorInfo& GetProjectionWeights() const |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 182 | { |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 183 | return Deref(m_ProjectionWeights); |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 184 | } |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 185 | const TensorInfo& GetProjectionBias() const |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 186 | { |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 187 | return Deref(m_ProjectionBias); |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 188 | } |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 189 | const TensorInfo& GetInputLayerNormWeights() const |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 190 | { |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 191 | return Deref(m_InputLayerNormWeights); |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 192 | } |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 193 | const TensorInfo& GetForgetLayerNormWeights() const |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 194 | { |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 195 | return Deref(m_ForgetLayerNormWeights); |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 196 | } |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 197 | const TensorInfo& GetCellLayerNormWeights() const |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 198 | { |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 199 | return Deref(m_CellLayerNormWeights); |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 200 | } |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 201 | const TensorInfo& GetOutputLayerNormWeights() const |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 202 | { |
Francis Murtagh | bb590b4 | 2019-08-14 09:51:36 +0100 | [diff] [blame] | 203 | return Deref(m_OutputLayerNormWeights); |
Jan Eilers | d01a83c | 2019-07-03 18:20:40 +0100 | [diff] [blame] | 204 | } |
| 205 | }; |
| 206 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 207 | } // namespace armnn |
| 208 | |