blob: c1fb2bf4aa50bffe11d20d657080b4a826785e8f [file] [log] [blame]
//
// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include "Activation.hpp"
#include "Lstm.hpp"
#include "LstmUtils.hpp"
namespace armnn
{
void LstmImpl(const LstmDescriptor& descriptor,
const TensorInfo& inputInfo,
const TensorInfo& outputInfo,
const TensorShape& inputToOutputWeightsShape,
const TensorShape& recurrentToOutputWeightsShape,
std::unique_ptr<Decoder<float>>& inputData,
std::unique_ptr<Decoder<float>>& outputStateIn,
std::unique_ptr<Decoder<float>>& cellStateIn,
std::unique_ptr<Encoder<float>>& outputStateOut,
std::unique_ptr<Encoder<float>>& cellStateOut,
std::unique_ptr<Encoder<float>>& output,
std::unique_ptr<Decoder<float>>& cellStateOutDecoder,
std::unique_ptr<Decoder<float>>& outputDecoder,
std::unique_ptr<Decoder<float>>& inputToInputWeightsTensor,
std::unique_ptr<Decoder<float>>& inputToForgetWeightsTensor,
std::unique_ptr<Decoder<float>>& inputToCellWeightsTensor,
std::unique_ptr<Decoder<float>>& inputToOutputWeightsTensor,
std::unique_ptr<Decoder<float>>& recurrentToInputWeightsTensor,
std::unique_ptr<Decoder<float>>& recurrentToForgetWeightsTensor,
std::unique_ptr<Decoder<float>>& recurrentToCellWeightsTensor,
std::unique_ptr<Decoder<float>>& recurrentToOutputWeightsTensor,
std::unique_ptr<Decoder<float>>& cellToInputWeightsTensor,
std::unique_ptr<Decoder<float>>& cellToForgetWeightsTensor,
std::unique_ptr<Decoder<float>>& cellToOutputWeightsTensor,
std::unique_ptr<Decoder<float>>& inputGateBiasTensor,
std::unique_ptr<Decoder<float>>& forgetGateBiasTensor,
std::unique_ptr<Decoder<float>>& cellBiasTensor,
std::unique_ptr<Decoder<float>>& outputGateBiasTensor,
std::unique_ptr<Decoder<float>>& projectionWeightsTensor,
std::unique_ptr<Decoder<float>>& projectionBiasTensor,
std::unique_ptr<Decoder<float>>& inputLayerNormWeights,
std::unique_ptr<Decoder<float>>& forgetLayerNormWeights,
std::unique_ptr<Decoder<float>>& cellLayerNormWeights,
std::unique_ptr<Decoder<float>>& outputLayerNormWeights,
std::unique_ptr<Encoder<float>>& inputGateScratch,
std::unique_ptr<Encoder<float>>& cellScratch,
std::unique_ptr<Encoder<float>>& forgetGateScratch,
std::unique_ptr<Encoder<float>>& outputGateScratch,
std::unique_ptr<Decoder<float>>& inputGateScratchDecoder,
std::unique_ptr<Decoder<float>>& cellScratchDecoder,
std::unique_ptr<Decoder<float>>& forgetGateScratchDecoder,
std::unique_ptr<Decoder<float>>& outputGateScratchDecoder,
float layerNormEpsilon)
{
// This is a porting of the LSTM::Eval() method in the Android code base
// Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp
const TensorShape& inputShape = inputInfo.GetShape();
const DataType& outputType = outputInfo.GetDataType();
const uint32_t nBatch = inputShape[0];
const uint32_t nInput = inputShape[1];
const uint32_t nCell = inputToOutputWeightsShape[0];
const uint32_t nOutput = recurrentToOutputWeightsShape[1];
const bool useCifg = descriptor.m_CifgEnabled;
const bool usePeephole = descriptor.m_PeepholeEnabled;
const bool useLayerNorm = descriptor.m_LayerNormEnabled;
if (!useLayerNorm)
{
// Initialize scratch buffers with bias.
if (!useCifg)
{
VectorBatchVectorAssign(*inputGateBiasTensor,
nCell, nBatch, *inputGateScratch);
}
VectorBatchVectorAssign(*forgetGateBiasTensor,
nCell, nBatch, *forgetGateScratch);
VectorBatchVectorAssign(*cellBiasTensor,
nCell, nBatch, *cellScratch);
VectorBatchVectorAssign(*outputGateBiasTensor,
nCell, nBatch, *outputGateScratch);
}
else
{
// Initialize scratch buffers with zeroes.
if (!useCifg)
{
ZeroVector(*inputGateScratch, nCell * nBatch);
}
ZeroVector(*forgetGateScratch, nCell * nBatch);
ZeroVector(*cellScratch , nCell * nBatch);
ZeroVector(*outputGateScratch, nCell * nBatch);
}
// For each batch and cell: compute input_weight * input.
if (!useCifg)
{
MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsTensor,
nCell, nInput, *inputData, nBatch, *inputGateScratch);
}
MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsTensor,
nCell, nInput, *inputData, nBatch, *forgetGateScratch);
MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsTensor,
nCell, nInput, *inputData, nBatch, *cellScratch);
MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsTensor,
nCell, nInput, *inputData, nBatch, *outputGateScratch);
// For each batch and cell: compute recurrent_weight * output_state.
if (!useCifg)
{
MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsTensor,
nCell, nOutput, *outputStateIn, nBatch, *inputGateScratch);
}
MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsTensor,
nCell, nOutput, *outputStateIn, nBatch, *forgetGateScratch);
MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsTensor,
nCell, nOutput, *outputStateIn, nBatch, *cellScratch);
MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsTensor,
nCell, nOutput, *outputStateIn, nBatch, *outputGateScratch);
// For each batch and cell: update input gate.
if (!useCifg)
{
if (usePeephole)
{
VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsTensor,
nCell, *cellStateIn, nBatch, *inputGateScratch);
}
if (useLayerNorm)
{
MeanStddevNormalization(*inputGateScratchDecoder,
*inputGateScratch, nCell, nBatch, layerNormEpsilon);
VectorBatchVectorCwiseProduct(*inputLayerNormWeights,
nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch);
VectorBatchVectorAdd(*inputGateBiasTensor,
nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch);
}
Activation(*inputGateScratchDecoder, *inputGateScratch,
TensorInfo({nCell, nBatch}, outputType),
ActivationFunction::Sigmoid, 0, 0);
}
// For each batch and cell: update forget gate.
if (usePeephole)
{
VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsTensor, nCell,
*cellStateIn, nBatch, *forgetGateScratch);
}
if (useLayerNorm)
{
MeanStddevNormalization(*forgetGateScratchDecoder,
*forgetGateScratch, nCell, nBatch, layerNormEpsilon);
VectorBatchVectorCwiseProduct(*forgetLayerNormWeights,
nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch);
VectorBatchVectorAdd(*forgetGateBiasTensor,
nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch);
}
Activation(*forgetGateScratchDecoder, *forgetGateScratch,
TensorInfo({nCell, nBatch}, outputType),
ActivationFunction::Sigmoid, 0, 0);
// For each batch and cell: update the cell.
if (useLayerNorm)
{
MeanStddevNormalization(*cellScratchDecoder,
*cellScratch, nCell, nBatch, layerNormEpsilon);
VectorBatchVectorCwiseProduct(*cellLayerNormWeights,
nCell, *cellScratchDecoder, nBatch, *cellScratch);
VectorBatchVectorAdd(*cellBiasTensor,
nCell, *cellScratchDecoder, nBatch, *cellScratch);
}
VectorVectorCwiseProduct(*forgetGateScratchDecoder, *cellStateIn, nBatch * nCell, *cellStateOut);
ActivationFunction armnnActivationFunc = ActivationFunction::Sigmoid;
float a = 0;
float b = 0;
SetActivationParameters(descriptor.m_ActivationFunc, armnnActivationFunc, a, b);
if (descriptor.m_ActivationFunc > 0)
{
Activation(*cellScratchDecoder, *cellScratch,
TensorInfo({nCell, nBatch}, outputType),
armnnActivationFunc, a, b);
}
if (useCifg)
{
Sub1Vector(*forgetGateScratchDecoder, nBatch * nCell, *forgetGateScratch);
VectorVectorCwiseProductAccumulate(
*cellScratchDecoder, *forgetGateScratchDecoder, nBatch * nCell, *cellStateOut);
}
else
{
VectorVectorCwiseProductAccumulate(
*cellScratchDecoder, *inputGateScratchDecoder, nBatch * nCell, *cellStateOut);
}
if (descriptor.m_ClippingThresCell > 0.0)
{
ClipVector(*cellStateOutDecoder, nBatch * nCell, descriptor.m_ClippingThresCell, *cellStateOut);
}
// For each batch and cell: update the output gate.
if (usePeephole)
{
VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsTensor,
nCell, *cellStateOutDecoder, nBatch, *outputGateScratch);
}
if (useLayerNorm)
{
MeanStddevNormalization(*outputGateScratchDecoder,
*outputGateScratch, nCell, nBatch, layerNormEpsilon);
VectorBatchVectorCwiseProduct(*outputLayerNormWeights,
nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch);
VectorBatchVectorAdd(*outputGateBiasTensor,
nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch);
}
Activation(*outputGateScratchDecoder, *outputGateScratch,
TensorInfo({nCell, nBatch}, outputType),
ActivationFunction::Sigmoid, 0, 0);
if (descriptor.m_ActivationFunc > 0)
{
Activation(*cellStateOutDecoder, *cellScratch,
TensorInfo({nCell, nBatch}, outputType),
armnnActivationFunc, a, b);
}
VectorVectorCwiseProduct(*outputGateScratchDecoder, *cellScratchDecoder, nBatch * nCell, *outputGateScratch);
// For each batch: update the projection and output_state.
if (descriptor.m_ProjectionEnabled)
{
if (projectionBiasTensor)
{
VectorBatchVectorAssign(*projectionBiasTensor,
nOutput, nBatch, *output);
}
MatrixBatchVectorMultiplyAccumulate(*projectionWeightsTensor,
nOutput, nCell, *outputGateScratchDecoder, nBatch, *output);
if (descriptor.m_ClippingThresProj > 0.0)
{
ClipVector(*outputDecoder, nBatch * nOutput, descriptor.m_ClippingThresProj, *output);
}
}
else
{
CopyVector(*outputGateScratchDecoder, nBatch * nOutput, *output);
}
CopyVector(*outputDecoder, nBatch * nOutput, *outputStateOut);
}
} //namespace armnn