| // |
| // Copyright © 2017 Arm Ltd. All rights reserved. |
| // SPDX-License-Identifier: MIT |
| // |
| #pragma once |
| |
| #include "TensorFwd.hpp" |
| #include "Exceptions.hpp" |
| |
| namespace armnn |
| { |
| |
| struct LstmInputParams |
| { |
| LstmInputParams() |
| : m_InputToInputWeights(nullptr) |
| , m_InputToForgetWeights(nullptr) |
| , m_InputToCellWeights(nullptr) |
| , m_InputToOutputWeights(nullptr) |
| , m_RecurrentToInputWeights(nullptr) |
| , m_RecurrentToForgetWeights(nullptr) |
| , m_RecurrentToCellWeights(nullptr) |
| , m_RecurrentToOutputWeights(nullptr) |
| , m_CellToInputWeights(nullptr) |
| , m_CellToForgetWeights(nullptr) |
| , m_CellToOutputWeights(nullptr) |
| , m_InputGateBias(nullptr) |
| , m_ForgetGateBias(nullptr) |
| , m_CellBias(nullptr) |
| , m_OutputGateBias(nullptr) |
| , m_ProjectionWeights(nullptr) |
| , m_ProjectionBias(nullptr) |
| , m_InputLayerNormWeights(nullptr) |
| , m_ForgetLayerNormWeights(nullptr) |
| , m_CellLayerNormWeights(nullptr) |
| , m_OutputLayerNormWeights(nullptr) |
| { |
| } |
| |
| const ConstTensor* m_InputToInputWeights; |
| const ConstTensor* m_InputToForgetWeights; |
| const ConstTensor* m_InputToCellWeights; |
| const ConstTensor* m_InputToOutputWeights; |
| const ConstTensor* m_RecurrentToInputWeights; |
| const ConstTensor* m_RecurrentToForgetWeights; |
| const ConstTensor* m_RecurrentToCellWeights; |
| const ConstTensor* m_RecurrentToOutputWeights; |
| const ConstTensor* m_CellToInputWeights; |
| const ConstTensor* m_CellToForgetWeights; |
| const ConstTensor* m_CellToOutputWeights; |
| const ConstTensor* m_InputGateBias; |
| const ConstTensor* m_ForgetGateBias; |
| const ConstTensor* m_CellBias; |
| const ConstTensor* m_OutputGateBias; |
| const ConstTensor* m_ProjectionWeights; |
| const ConstTensor* m_ProjectionBias; |
| const ConstTensor* m_InputLayerNormWeights; |
| const ConstTensor* m_ForgetLayerNormWeights; |
| const ConstTensor* m_CellLayerNormWeights; |
| const ConstTensor* m_OutputLayerNormWeights; |
| }; |
| |
| struct LstmInputParamsInfo |
| { |
| LstmInputParamsInfo() |
| : m_InputToInputWeights(nullptr) |
| , m_InputToForgetWeights(nullptr) |
| , m_InputToCellWeights(nullptr) |
| , m_InputToOutputWeights(nullptr) |
| , m_RecurrentToInputWeights(nullptr) |
| , m_RecurrentToForgetWeights(nullptr) |
| , m_RecurrentToCellWeights(nullptr) |
| , m_RecurrentToOutputWeights(nullptr) |
| , m_CellToInputWeights(nullptr) |
| , m_CellToForgetWeights(nullptr) |
| , m_CellToOutputWeights(nullptr) |
| , m_InputGateBias(nullptr) |
| , m_ForgetGateBias(nullptr) |
| , m_CellBias(nullptr) |
| , m_OutputGateBias(nullptr) |
| , m_ProjectionWeights(nullptr) |
| , m_ProjectionBias(nullptr) |
| , m_InputLayerNormWeights(nullptr) |
| , m_ForgetLayerNormWeights(nullptr) |
| , m_CellLayerNormWeights(nullptr) |
| , m_OutputLayerNormWeights(nullptr) |
| { |
| } |
| const TensorInfo* m_InputToInputWeights; |
| const TensorInfo* m_InputToForgetWeights; |
| const TensorInfo* m_InputToCellWeights; |
| const TensorInfo* m_InputToOutputWeights; |
| const TensorInfo* m_RecurrentToInputWeights; |
| const TensorInfo* m_RecurrentToForgetWeights; |
| const TensorInfo* m_RecurrentToCellWeights; |
| const TensorInfo* m_RecurrentToOutputWeights; |
| const TensorInfo* m_CellToInputWeights; |
| const TensorInfo* m_CellToForgetWeights; |
| const TensorInfo* m_CellToOutputWeights; |
| const TensorInfo* m_InputGateBias; |
| const TensorInfo* m_ForgetGateBias; |
| const TensorInfo* m_CellBias; |
| const TensorInfo* m_OutputGateBias; |
| const TensorInfo* m_ProjectionWeights; |
| const TensorInfo* m_ProjectionBias; |
| const TensorInfo* m_InputLayerNormWeights; |
| const TensorInfo* m_ForgetLayerNormWeights; |
| const TensorInfo* m_CellLayerNormWeights; |
| const TensorInfo* m_OutputLayerNormWeights; |
| |
| const TensorInfo& Deref(const TensorInfo* tensorInfo) const |
| { |
| if (tensorInfo != nullptr) |
| { |
| const TensorInfo &temp = *tensorInfo; |
| return temp; |
| } |
| throw InvalidArgumentException("Can't dereference a null pointer"); |
| } |
| |
| const TensorInfo& GetInputToInputWeights() const |
| { |
| return Deref(m_InputToInputWeights); |
| } |
| const TensorInfo& GetInputToForgetWeights() const |
| { |
| return Deref(m_InputToForgetWeights); |
| } |
| const TensorInfo& GetInputToCellWeights() const |
| { |
| return Deref(m_InputToCellWeights); |
| } |
| const TensorInfo& GetInputToOutputWeights() const |
| { |
| return Deref(m_InputToOutputWeights); |
| } |
| const TensorInfo& GetRecurrentToInputWeights() const |
| { |
| return Deref(m_RecurrentToInputWeights); |
| } |
| const TensorInfo& GetRecurrentToForgetWeights() const |
| { |
| return Deref(m_RecurrentToForgetWeights); |
| } |
| const TensorInfo& GetRecurrentToCellWeights() const |
| { |
| return Deref(m_RecurrentToCellWeights); |
| } |
| const TensorInfo& GetRecurrentToOutputWeights() const |
| { |
| return Deref(m_RecurrentToOutputWeights); |
| } |
| const TensorInfo& GetCellToInputWeights() const |
| { |
| return Deref(m_CellToInputWeights); |
| } |
| const TensorInfo& GetCellToForgetWeights() const |
| { |
| return Deref(m_CellToForgetWeights); |
| } |
| const TensorInfo& GetCellToOutputWeights() const |
| { |
| return Deref(m_CellToOutputWeights); |
| } |
| const TensorInfo& GetInputGateBias() const |
| { |
| return Deref(m_InputGateBias); |
| } |
| const TensorInfo& GetForgetGateBias() const |
| { |
| return Deref(m_ForgetGateBias); |
| } |
| const TensorInfo& GetCellBias() const |
| { |
| return Deref(m_CellBias); |
| } |
| const TensorInfo& GetOutputGateBias() const |
| { |
| return Deref(m_OutputGateBias); |
| } |
| const TensorInfo& GetProjectionWeights() const |
| { |
| return Deref(m_ProjectionWeights); |
| } |
| const TensorInfo& GetProjectionBias() const |
| { |
| return Deref(m_ProjectionBias); |
| } |
| const TensorInfo& GetInputLayerNormWeights() const |
| { |
| return Deref(m_InputLayerNormWeights); |
| } |
| const TensorInfo& GetForgetLayerNormWeights() const |
| { |
| return Deref(m_ForgetLayerNormWeights); |
| } |
| const TensorInfo& GetCellLayerNormWeights() const |
| { |
| return Deref(m_CellLayerNormWeights); |
| } |
| const TensorInfo& GetOutputLayerNormWeights() const |
| { |
| return Deref(m_OutputLayerNormWeights); |
| } |
| }; |
| |
| } // namespace armnn |
| |