blob: 1353a06d9ff9f7f1ad693ca689f67ab90ec54d1d [file] [log] [blame]
James Conroyee18dc82019-07-17 11:27:46 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include <Layer.hpp>
8
9namespace armnn
10{
11
12class ScopedCpuTensorHandle;
13
14struct QuantizedLstmParameters
15{
16 /// A unique pointer to represent 2D weights tensor with dimensions [outputSize, inputSize] (QAsymm8).
17 std::unique_ptr<ScopedCpuTensorHandle> m_InputToInputWeights;
18 /// A unique pointer to represent 2D weights tensor with dimensions [outputSize, inputSize] (QAsymm8).
19 std::unique_ptr<ScopedCpuTensorHandle> m_InputToForgetWeights;
20 /// A unique pointer to represent 2D weights tensor with dimensions [outputSize, inputSize] (QAsymm8).
21 std::unique_ptr<ScopedCpuTensorHandle> m_InputToCellWeights;
22 /// A unique pointer to represent 2D weights tensor with dimensions [outputSize, inputSize] (QAsymm8).
23 std::unique_ptr<ScopedCpuTensorHandle> m_InputToOutputWeights;
24
25 /// A unique pointer to represent 2D weights tensor with dimensions [outputSize, outputSize] (QAsymm8).
26 std::unique_ptr<ScopedCpuTensorHandle> m_RecurrentToInputWeights;
27 /// A unique pointer to represent 2D weights tensor with dimensions [outputSize, outputSize] (QAsymm8).
28 std::unique_ptr<ScopedCpuTensorHandle> m_RecurrentToForgetWeights;
29 /// A unique pointer to represent 2D weights tensor with dimensions [outputSize, outputSize] (QAsymm8).
30 std::unique_ptr<ScopedCpuTensorHandle> m_RecurrentToCellWeights;
31 /// A unique pointer to represent 2D weights tensor with dimensions [outputSize, outputSize] (QAsymm8).
32 std::unique_ptr<ScopedCpuTensorHandle> m_RecurrentToOutputWeights;
33
34 /// A unique pointer to represent 1D bias tensor with dimensions [outputSize] (int32).
35 std::unique_ptr<ScopedCpuTensorHandle> m_InputGateBias;
36 /// A unique pointer to represent 1D bias tensor with dimensions [outputSize] (int32).
37 std::unique_ptr<ScopedCpuTensorHandle> m_ForgetGateBias;
38 /// A unique pointer to represent 1D bias tensor with dimensions [outputSize] (int32).
39 std::unique_ptr<ScopedCpuTensorHandle> m_CellBias;
40 /// A unique pointer to represent 1D bias tensor with dimensions [outputSize] (int32).
41 std::unique_ptr<ScopedCpuTensorHandle> m_OutputGateBias;
42};
43
44/// This layer represents a QuantizedLstm operation.
45class QuantizedLstmLayer : public Layer
46{
47public:
48
49 QuantizedLstmParameters m_QuantizedLstmParameters;
50
51 /// Makes a workload for the QuantizedLstm type.
52 /// @param [in] graph The graph where this layer can be found.
53 /// @param [in] factory The workload factory which will create the workload.
54 /// @return A pointer to the created workload, or nullptr if not created.
Derek Lamberti94a88d22019-12-10 21:12:59 +000055 virtual std::unique_ptr<IWorkload> CreateWorkload(const IWorkloadFactory& factory) const override;
James Conroyee18dc82019-07-17 11:27:46 +010056
57 /// Creates a dynamically-allocated copy of this layer.
58 /// @param [in] graph The graph into which this layer is being cloned.
59 QuantizedLstmLayer* Clone(Graph& graph) const override;
60
61 /// Check if the input tensor shape(s)
62 /// will lead to a valid configuration of @ref QuantizedLstmLayer.
Teresa Charlincdc01492020-06-09 18:00:20 +010063 /// @param [in] shapeInferenceMethod Indicates if output shape shall be overwritten or just validated.
64 void ValidateTensorShapesFromInputs(
65 ShapeInferenceMethod shapeInferenceMethod = ShapeInferenceMethod::ValidateOnly) override;
James Conroyee18dc82019-07-17 11:27:46 +010066
67 /// By default returns inputShapes if the number of inputs are equal to number of outputs,
68 /// otherwise infers the output shapes from given input shapes and layer properties.
69 /// @param [in] inputShapes The input shapes layer has.
70 /// @return A vector to the inferred output shape.
71 std::vector<TensorShape> InferOutputShapes(const std::vector<TensorShape>& inputShapes) const override;
72
73 void Accept(ILayerVisitor& visitor) const override;
74
75protected:
76 /// Constructor to create a QuantizedLstmLayer.
77 /// @param [in] name Optional name for the layer.
78 QuantizedLstmLayer(const char* name);
79
80 /// Default destructor
81 ~QuantizedLstmLayer() = default;
82
83 /// Retrieve the handles to the constant values stored by the layer.
84 /// @return A vector of the constant tensors stored by this layer.
85 Layer::ConstantTensors GetConstantTensorsByRef() override;
86};
87
88} // namespace armnn