blob: 7c8fb76895e78def1a8f788b72296c003c025269 [file] [log] [blame]
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/TypesUtils.hpp>
Colm Donelan0c479742021-12-10 12:43:54 +00009#include <armnn/backends/WorkloadData.hpp>
Narumol Prangnawarate5339e72021-07-28 17:33:28 +010010
11#include "Encoders.hpp"
12#include "Decoders.hpp"
13
14namespace armnn
15{
16
17void LstmImpl(const LstmDescriptor& descriptor,
18 const TensorInfo& inputInfo,
19 const TensorInfo& outputInfo,
20 const TensorShape& inputToOutputWeightsShape,
21 const TensorShape& recurrentToOutputWeightsShape,
22 std::unique_ptr<Decoder<float>>& inputData,
23 std::unique_ptr<Decoder<float>>& outputStateIn,
24 std::unique_ptr<Decoder<float>>& cellStateIn,
25 std::unique_ptr<Encoder<float>>& outputStateOut,
26 std::unique_ptr<Encoder<float>>& cellStateOut,
27 std::unique_ptr<Encoder<float>>& output,
28 std::unique_ptr<Decoder<float>>& cellStateOutDecoder,
29 std::unique_ptr<Decoder<float>>& outputDecoder,
30 std::unique_ptr<Decoder<float>>& inputToInputWeightsTensor,
31 std::unique_ptr<Decoder<float>>& inputToForgetWeightsTensor,
32 std::unique_ptr<Decoder<float>>& inputToCellWeightsTensor,
33 std::unique_ptr<Decoder<float>>& inputToOutputWeightsTensor,
34 std::unique_ptr<Decoder<float>>& recurrentToInputWeightsTensor,
35 std::unique_ptr<Decoder<float>>& recurrentToForgetWeightsTensor,
36 std::unique_ptr<Decoder<float>>& recurrentToCellWeightsTensor,
37 std::unique_ptr<Decoder<float>>& recurrentToOutputWeightsTensor,
38 std::unique_ptr<Decoder<float>>& cellToInputWeightsTensor,
39 std::unique_ptr<Decoder<float>>& cellToForgetWeightsTensor,
40 std::unique_ptr<Decoder<float>>& cellToOutputWeightsTensor,
41 std::unique_ptr<Decoder<float>>& inputGateBiasTensor,
42 std::unique_ptr<Decoder<float>>& forgetGateBiasTensor,
43 std::unique_ptr<Decoder<float>>& cellBiasTensor,
44 std::unique_ptr<Decoder<float>>& outputGateBiasTensor,
45 std::unique_ptr<Decoder<float>>& projectionWeightsTensor,
46 std::unique_ptr<Decoder<float>>& projectionBiasTensor,
47 std::unique_ptr<Decoder<float>>& inputLayerNormWeights,
48 std::unique_ptr<Decoder<float>>& forgetLayerNormWeights,
49 std::unique_ptr<Decoder<float>>& cellLayerNormWeights,
50 std::unique_ptr<Decoder<float>>& outputLayerNormWeights,
51 std::unique_ptr<Encoder<float>>& inputGateScratch,
52 std::unique_ptr<Encoder<float>>& cellScratch,
53 std::unique_ptr<Encoder<float>>& forgetGateScratch,
54 std::unique_ptr<Encoder<float>>& outputGateScratch,
55 std::unique_ptr<Decoder<float>>& inputGateScratchDecoder,
56 std::unique_ptr<Decoder<float>>& cellScratchDecoder,
57 std::unique_ptr<Decoder<float>>& forgetGateScratchDecoder,
58 std::unique_ptr<Decoder<float>>& outputGateScratchDecoder,
59 float layerNormEpsilon);
60
61} //namespace armnn