blob: b56ae3ff52908fbe06566bfe88c1bc0cf551188f [file] [log] [blame]
James Conroyee18dc82019-07-17 11:27:46 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#include "QuantizedLstmLayer.hpp"
6
7#include "LayerCloneBase.hpp"
8
Matthew Bentham39ef3e52020-01-20 10:09:09 +00009#include <armnn/QuantizedLstmParams.hpp>
James Conroyee18dc82019-07-17 11:27:46 +010010#include <armnn/TypesUtils.hpp>
11#include <backendsCommon/CpuTensorHandle.hpp>
12#include <backendsCommon/WorkloadFactory.hpp>
13
14namespace armnn
15{
16
17QuantizedLstmLayer::QuantizedLstmLayer(const char* name)
18 : Layer(3, 2, LayerType::QuantizedLstm, name)
19{
20}
21
Derek Lamberti94a88d22019-12-10 21:12:59 +000022std::unique_ptr<IWorkload> QuantizedLstmLayer::CreateWorkload(const IWorkloadFactory& factory) const
James Conroyee18dc82019-07-17 11:27:46 +010023{
24 QuantizedLstmQueueDescriptor descriptor;
25
26 // QuantizedLstmLayer parameters - there are no optional params
27 descriptor.m_InputToInputWeights = m_QuantizedLstmParameters.m_InputToInputWeights.get();
28 descriptor.m_InputToForgetWeights = m_QuantizedLstmParameters.m_InputToForgetWeights.get();
29 descriptor.m_InputToCellWeights = m_QuantizedLstmParameters.m_InputToCellWeights.get();
30 descriptor.m_InputToOutputWeights = m_QuantizedLstmParameters.m_InputToOutputWeights.get();
31
32 descriptor.m_RecurrentToInputWeights = m_QuantizedLstmParameters.m_RecurrentToInputWeights.get();
33 descriptor.m_RecurrentToForgetWeights = m_QuantizedLstmParameters.m_RecurrentToForgetWeights.get();
34 descriptor.m_RecurrentToCellWeights = m_QuantizedLstmParameters.m_RecurrentToCellWeights.get();
35 descriptor.m_RecurrentToOutputWeights = m_QuantizedLstmParameters.m_RecurrentToOutputWeights.get();
36
37 descriptor.m_InputGateBias = m_QuantizedLstmParameters.m_InputGateBias.get();
38 descriptor.m_ForgetGateBias = m_QuantizedLstmParameters.m_ForgetGateBias.get();
39 descriptor.m_CellBias = m_QuantizedLstmParameters.m_CellBias.get();
40 descriptor.m_OutputGateBias = m_QuantizedLstmParameters.m_OutputGateBias.get();
41
Derek Lamberti94a88d22019-12-10 21:12:59 +000042 return factory.CreateQuantizedLstm(descriptor, PrepInfoAndDesc(descriptor));
James Conroyee18dc82019-07-17 11:27:46 +010043}
44
45QuantizedLstmLayer* QuantizedLstmLayer::Clone(Graph& graph) const
46{
47 auto layer = CloneBase<QuantizedLstmLayer>(graph, GetName());
48
49 layer->m_QuantizedLstmParameters.m_InputToInputWeights = m_QuantizedLstmParameters.m_InputToInputWeights ?
50 std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_InputToInputWeights) : nullptr;
51 layer->m_QuantizedLstmParameters.m_InputToForgetWeights = m_QuantizedLstmParameters.m_InputToForgetWeights ?
52 std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_InputToForgetWeights) : nullptr;
53 layer->m_QuantizedLstmParameters.m_InputToCellWeights = m_QuantizedLstmParameters.m_InputToCellWeights ?
54 std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_InputToCellWeights) : nullptr;
55 layer->m_QuantizedLstmParameters.m_InputToOutputWeights = m_QuantizedLstmParameters.m_InputToOutputWeights ?
56 std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_InputToOutputWeights) : nullptr;
57
58 layer->m_QuantizedLstmParameters.m_RecurrentToInputWeights = m_QuantizedLstmParameters.m_RecurrentToInputWeights ?
59 std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_RecurrentToInputWeights) : nullptr;
60 layer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights = m_QuantizedLstmParameters.m_RecurrentToForgetWeights
61 ? std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_RecurrentToForgetWeights) : nullptr;
62 layer->m_QuantizedLstmParameters.m_RecurrentToCellWeights = m_QuantizedLstmParameters.m_RecurrentToCellWeights ?
63 std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_RecurrentToCellWeights) : nullptr;
64 layer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights = m_QuantizedLstmParameters.m_RecurrentToOutputWeights
65 ? std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_RecurrentToOutputWeights) : nullptr;
66
67 layer->m_QuantizedLstmParameters.m_InputGateBias = m_QuantizedLstmParameters.m_InputGateBias ?
68 std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_InputGateBias) : nullptr;
69 layer->m_QuantizedLstmParameters.m_ForgetGateBias = m_QuantizedLstmParameters.m_ForgetGateBias ?
70 std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_ForgetGateBias) : nullptr;
71 layer->m_QuantizedLstmParameters.m_CellBias = m_QuantizedLstmParameters.m_CellBias ?
72 std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_CellBias) : nullptr;
73 layer->m_QuantizedLstmParameters.m_OutputGateBias = m_QuantizedLstmParameters.m_OutputGateBias ?
74 std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_OutputGateBias) : nullptr;
75
76 return std::move(layer);
77}
78
79std::vector<TensorShape> QuantizedLstmLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
80{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010081 ARMNN_ASSERT(inputShapes.size() == 3);
James Conroyee18dc82019-07-17 11:27:46 +010082
83 // Get input values for validation
84 unsigned int numBatches = inputShapes[0][0];
85 unsigned int outputSize = inputShapes[1][1];
86
87 std::vector<TensorShape> outShapes;
88 outShapes.push_back(TensorShape({numBatches, outputSize})); // cellStateOut
89 outShapes.push_back(TensorShape({numBatches, outputSize})); // output
90
91 return outShapes;
92}
93
94void QuantizedLstmLayer::ValidateTensorShapesFromInputs()
95{
96 VerifyLayerConnections(3, CHECK_LOCATION());
97
98 auto inferredShapes = InferOutputShapes(
99 {
100 GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), // input
101 GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape(), // previousCellStateIn
102 GetInputSlot(2).GetConnection()->GetTensorInfo().GetShape() // previousOutputIn
103 });
104
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100105 ARMNN_ASSERT(inferredShapes.size() == 2);
James Conroyee18dc82019-07-17 11:27:46 +0100106
107 // Check weights and bias for nullptr
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100108 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_InputToInputWeights != nullptr,
James Conroyee18dc82019-07-17 11:27:46 +0100109 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputToInputWeights should not be null.");
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100110 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_InputToForgetWeights != nullptr,
James Conroyee18dc82019-07-17 11:27:46 +0100111 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputToForgetWeights should not be null.");
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100112 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_InputToCellWeights != nullptr,
James Conroyee18dc82019-07-17 11:27:46 +0100113 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputToCellWeights should not be null.");
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100114 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_InputToOutputWeights != nullptr,
James Conroyee18dc82019-07-17 11:27:46 +0100115 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputToOutputWeights should not be null.");
116
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100117 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_RecurrentToInputWeights != nullptr,
James Conroyee18dc82019-07-17 11:27:46 +0100118 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_RecurrentToInputWeights should not be null.");
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100119 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_RecurrentToForgetWeights != nullptr,
James Conroyee18dc82019-07-17 11:27:46 +0100120 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_RecurrentToForgetWeights should not be null.");
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100121 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_RecurrentToCellWeights != nullptr,
James Conroyee18dc82019-07-17 11:27:46 +0100122 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_RecurrentToCellWeights should not be null.");
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100123 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_RecurrentToOutputWeights != nullptr,
James Conroyee18dc82019-07-17 11:27:46 +0100124 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_RecurrentToOutputWeights should not be null.");
125
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100126 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_InputGateBias != nullptr,
James Conroyee18dc82019-07-17 11:27:46 +0100127 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputGateBias should not be null.");
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100128 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_ForgetGateBias != nullptr,
James Conroyee18dc82019-07-17 11:27:46 +0100129 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_ForgetGateBias should not be null.");
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100130 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_CellBias != nullptr,
James Conroyee18dc82019-07-17 11:27:46 +0100131 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_CellBias should not be null.");
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100132 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_OutputGateBias != nullptr,
James Conroyee18dc82019-07-17 11:27:46 +0100133 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_OutputGateBias should not be null.");
134
135 // Check output TensorShape(s) match inferred shape
136 ConditionalThrowIfNotEqual<LayerValidationException>(
137 "QuantizedLstmLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
138 GetOutputSlot(0).GetTensorInfo().GetShape(),
139 inferredShapes[0]);
140
141 ConditionalThrowIfNotEqual<LayerValidationException>(
142 "QuantizedLstmLayer: TensorShape set on OutputSlot[1] does not match the inferred shape.",
143 GetOutputSlot(1).GetTensorInfo().GetShape(),
144 inferredShapes[1]);
145}
146
147Layer::ConstantTensors QuantizedLstmLayer::GetConstantTensorsByRef()
148{
149 return
150 {
151 m_QuantizedLstmParameters.m_InputToInputWeights,
152 m_QuantizedLstmParameters.m_InputToForgetWeights,
153 m_QuantizedLstmParameters.m_InputToCellWeights,
154 m_QuantizedLstmParameters.m_InputToOutputWeights,
155
156 m_QuantizedLstmParameters.m_RecurrentToInputWeights,
157 m_QuantizedLstmParameters.m_RecurrentToForgetWeights,
158 m_QuantizedLstmParameters.m_RecurrentToCellWeights,
159 m_QuantizedLstmParameters.m_RecurrentToOutputWeights,
160
161 m_QuantizedLstmParameters.m_InputGateBias,
162 m_QuantizedLstmParameters.m_ForgetGateBias,
163 m_QuantizedLstmParameters.m_CellBias,
164 m_QuantizedLstmParameters.m_OutputGateBias
165 };
166}
167
168void QuantizedLstmLayer::Accept(ILayerVisitor& visitor) const
169{
170 QuantizedLstmInputParams inputParams;
171
172 // InputToX weight tensors
173 ConstTensor inputToInputWeightsTensor;
174 if (m_QuantizedLstmParameters.m_InputToInputWeights != nullptr)
175 {
176 ConstTensor inputToInputWeightsTensorCopy(m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo(),
177 m_QuantizedLstmParameters.m_InputToInputWeights->Map(true));
178 inputToInputWeightsTensor = inputToInputWeightsTensorCopy;
179 inputParams.m_InputToInputWeights = &inputToInputWeightsTensor;
180 }
181
182 ConstTensor inputToForgetWeightsTensor;
183 if (m_QuantizedLstmParameters.m_InputToForgetWeights != nullptr)
184 {
185 ConstTensor inputToForgetWeightsTensorCopy(m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo(),
186 m_QuantizedLstmParameters.m_InputToForgetWeights->Map(true));
187 inputToForgetWeightsTensor = inputToForgetWeightsTensorCopy;
188 inputParams.m_InputToForgetWeights = &inputToForgetWeightsTensor;
189 }
190
191 ConstTensor inputToCellWeightsTensor;
192 if (m_QuantizedLstmParameters.m_InputToCellWeights != nullptr)
193 {
194 ConstTensor inputToCellWeightsTensorCopy(m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo(),
195 m_QuantizedLstmParameters.m_InputToCellWeights->Map(true));
196 inputToCellWeightsTensor = inputToCellWeightsTensorCopy;
197 inputParams.m_InputToCellWeights = &inputToCellWeightsTensor;
198 }
199
200 ConstTensor inputToOutputWeightsTensor;
201 if (m_QuantizedLstmParameters.m_InputToOutputWeights != nullptr)
202 {
203 ConstTensor inputToOutputWeightsTensorCopy(m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo(),
204 m_QuantizedLstmParameters.m_InputToOutputWeights->Map(true));
205 inputToOutputWeightsTensor = inputToOutputWeightsTensorCopy;
206 inputParams.m_InputToOutputWeights = &inputToOutputWeightsTensor;
207 }
208
209 // RecurrentToX weight tensors
210 ConstTensor recurrentToInputWeightsTensor;
211 if (m_QuantizedLstmParameters.m_RecurrentToInputWeights != nullptr)
212 {
213 ConstTensor recurrentToInputWeightsTensorCopy(
214 m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo(),
215 m_QuantizedLstmParameters.m_RecurrentToInputWeights->Map(true));
216 recurrentToInputWeightsTensor = recurrentToInputWeightsTensorCopy;
217 inputParams.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
218 }
219
220 ConstTensor recurrentToForgetWeightsTensor;
221 if (m_QuantizedLstmParameters.m_RecurrentToForgetWeights != nullptr)
222 {
223 ConstTensor recurrentToForgetWeightsTensorCopy(
224 m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo(),
225 m_QuantizedLstmParameters.m_RecurrentToForgetWeights->Map(true));
226 recurrentToForgetWeightsTensor = recurrentToForgetWeightsTensorCopy;
227 inputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
228 }
229
230 ConstTensor recurrentToCellWeightsTensor;
231 if (m_QuantizedLstmParameters.m_RecurrentToCellWeights != nullptr)
232 {
233 ConstTensor recurrentToCellWeightsTensorCopy(
234 m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo(),
235 m_QuantizedLstmParameters.m_RecurrentToCellWeights->Map(true));
236 recurrentToCellWeightsTensor = recurrentToCellWeightsTensorCopy;
237 inputParams.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
238 }
239
240 ConstTensor recurrentToOutputWeightsTensor;
241 if (m_QuantizedLstmParameters.m_RecurrentToOutputWeights != nullptr)
242 {
243 ConstTensor recurrentToOutputWeightsTensorCopy(
244 m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo(),
245 m_QuantizedLstmParameters.m_RecurrentToOutputWeights->Map(true));
246 recurrentToOutputWeightsTensor = recurrentToOutputWeightsTensorCopy;
247 inputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
248 }
249
250 // Bias tensors
251 ConstTensor inputGateBiasTensor;
252 if (m_QuantizedLstmParameters.m_InputGateBias != nullptr)
253 {
254 ConstTensor inputGateBiasTensorCopy(m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo(),
255 m_QuantizedLstmParameters.m_InputGateBias->Map(true));
256 inputGateBiasTensor = inputGateBiasTensorCopy;
257 inputParams.m_InputGateBias = &inputGateBiasTensor;
258 }
259
260 ConstTensor forgetGateBiasTensor;
261 if (m_QuantizedLstmParameters.m_ForgetGateBias != nullptr)
262 {
263 ConstTensor forgetGateBiasTensorCopy(m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo(),
264 m_QuantizedLstmParameters.m_ForgetGateBias->Map(true));
265 forgetGateBiasTensor = forgetGateBiasTensorCopy;
266 inputParams.m_ForgetGateBias = &forgetGateBiasTensor;
267 }
268
269 ConstTensor cellBiasTensor;
270 if (m_QuantizedLstmParameters.m_CellBias != nullptr)
271 {
272 ConstTensor cellBiasTensorCopy(m_QuantizedLstmParameters.m_CellBias->GetTensorInfo(),
273 m_QuantizedLstmParameters.m_CellBias->Map(true));
274 cellBiasTensor = cellBiasTensorCopy;
275 inputParams.m_CellBias = &cellBiasTensor;
276 }
277
278 ConstTensor outputGateBiasTensor;
279 if (m_QuantizedLstmParameters.m_OutputGateBias != nullptr)
280 {
281 ConstTensor outputGateBiasCopy(m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo(),
282 m_QuantizedLstmParameters.m_OutputGateBias->Map(true));
283 outputGateBiasTensor = outputGateBiasCopy;
284 inputParams.m_OutputGateBias = &outputGateBiasTensor;
285 }
286
287 visitor.VisitQuantizedLstmLayer(this, inputParams, GetName());
288}
289
290} // namespace armnn