James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2020 Arm Ltd. All rights reserved. |
| 3 | // SPDX-License-Identifier: MIT |
| 4 | // |
| 5 | |
| 6 | #include "RefQLstmWorkload.hpp" |
| 7 | #include "Activation.hpp" |
| 8 | #include "Encoders.hpp" |
| 9 | #include "Decoders.hpp" |
| 10 | #include "LstmUtils.hpp" |
| 11 | #include "RefWorkloadUtils.hpp" |
| 12 | |
| 13 | namespace armnn |
| 14 | { |
| 15 | |
| 16 | RefQLstmWorkload::RefQLstmWorkload(const QLstmQueueDescriptor &descriptor, const WorkloadInfo &info) |
| 17 | : BaseWorkload<QLstmQueueDescriptor>(descriptor, info) |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 18 | , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights)) |
| 19 | , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights)) |
| 20 | , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights)) |
| 21 | , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights)) |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 22 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 23 | , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights)) |
| 24 | , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights)) |
| 25 | , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights)) |
| 26 | , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights)) |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 27 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 28 | , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights)) |
| 29 | , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights)) |
| 30 | , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights)) |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 31 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 32 | , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias)) |
| 33 | , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias)) |
| 34 | , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias)) |
| 35 | , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias)) |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 36 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 37 | , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights)) |
| 38 | , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias)) |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 39 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 40 | , m_InputLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights)) |
| 41 | , m_ForgetLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights)) |
| 42 | , m_CellLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights)) |
| 43 | , m_OutputLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights)) |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 44 | {} |
| 45 | |
| 46 | void RefQLstmWorkload::Execute() const |
| 47 | { |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 48 | Execute(m_Data.m_Inputs, m_Data.m_Outputs); |
| 49 | } |
| 50 | |
| 51 | void RefQLstmWorkload::ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor) |
| 52 | { |
| 53 | Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs); |
| 54 | } |
| 55 | |
| 56 | void RefQLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const |
| 57 | { |
| 58 | // This is a porting of the QLSTM::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) |
| 59 | // method in the Android code base |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 60 | // Note: this implementation wraps the arithmetic functions of the LSTM cell in Quantize/Dequantize ops, so all |
| 61 | // computation is done in the floating point domain. Arithmetic functions are found in LstmUtils.cpp. |
| 62 | // Refer to: android/frameworks/ml/nn/common/operations/QLSTM.cpp |
| 63 | const DataType& internalType = armnn::DataType::QSymmS16; |
| 64 | |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 65 | const TensorInfo& inputInfo = GetTensorInfo(inputs[0]); |
| 66 | const TensorInfo& outputStateInInfo = GetTensorInfo(inputs[1]); |
| 67 | const TensorInfo& cellStateInInfo = GetTensorInfo(inputs[2]); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 68 | |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 69 | const TensorInfo& outputStateOutInfo = GetTensorInfo(outputs[0]); |
| 70 | const TensorInfo& cellStateOutInfo = GetTensorInfo(outputs[1]); |
| 71 | const TensorInfo& outputInfo = GetTensorInfo(outputs[2]); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 72 | |
| 73 | const TensorShape& inputShape = inputInfo.GetShape(); |
| 74 | const TensorShape& outputStateInShape = outputStateInInfo.GetShape(); |
| 75 | const TensorShape& cellStateInShape = cellStateInInfo.GetShape(); |
| 76 | |
| 77 | // Infer numBatches, inputSize, outputSize and numUnits |
| 78 | const uint32_t numBatches = inputShape[0]; |
| 79 | const uint32_t inputSize = inputShape[1]; |
| 80 | const uint32_t outputSize = outputStateInShape[1]; |
| 81 | const uint32_t numUnits = cellStateInShape[1]; |
| 82 | |
| 83 | // Optional param settings |
| 84 | const bool cifgEnabled = m_Data.m_Parameters.m_CifgEnabled; |
| 85 | const bool peepholeEnabled = m_Data.m_Parameters.m_PeepholeEnabled; |
| 86 | const bool projectionEnabled = m_Data.m_Parameters.m_ProjectionEnabled; |
| 87 | const bool layerNormEnabled = m_Data.m_Parameters.m_LayerNormEnabled; |
| 88 | |
| 89 | // Input decoders |
| 90 | std::unique_ptr<Decoder<float>> inputDecoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 91 | MakeDecoder<float>(inputInfo, inputs[0]->Map()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 92 | std::unique_ptr<Decoder<float>> outputStateInDecoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 93 | MakeDecoder<float>(outputStateInInfo, inputs[1]->Map()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 94 | std::unique_ptr<Decoder<float>> cellStateInDecoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 95 | MakeDecoder<float>(cellStateInInfo, inputs[2]->Map()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 96 | |
| 97 | // Output decoders |
| 98 | std::unique_ptr<Decoder<float>> outputStateOutDecoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 99 | MakeDecoder<float>(outputStateOutInfo, outputs[0]->Map()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 100 | std::unique_ptr<Decoder<float>> cellStateOutDecoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 101 | MakeDecoder<float>(cellStateOutInfo, outputs[1]->Map()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 102 | std::unique_ptr<Decoder<float>> outputDecoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 103 | MakeDecoder<float>(outputInfo, outputs[2]->Map()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 104 | |
| 105 | // Output encoders |
| 106 | std::unique_ptr<Encoder<float>> outputStateOutEncoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 107 | MakeEncoder<float>(outputStateOutInfo, outputs[0]->Map()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 108 | std::unique_ptr<Encoder<float>> cellStateOutEncoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 109 | MakeEncoder<float>(cellStateOutInfo, outputs[1]->Map()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 110 | std::unique_ptr<Encoder<float>> outputEncoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 111 | MakeEncoder<float>(outputInfo, outputs[2]->Map()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 112 | |
| 113 | // Weights decoders |
| 114 | std::unique_ptr<Decoder<float>> inputToForgetWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 115 | m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 116 | std::unique_ptr<Decoder<float>> inputToCellWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 117 | m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 118 | std::unique_ptr<Decoder<float>> inputToOutputWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 119 | m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 120 | |
| 121 | std::unique_ptr<Decoder<float>> recurrentToForgetWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 122 | m_RecurrentToForgetWeightsTensor->GetTensorInfo(), |
| 123 | m_RecurrentToForgetWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 124 | std::unique_ptr<Decoder<float>> recurrentToCellWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 125 | m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 126 | std::unique_ptr<Decoder<float>> recurrentToOutputWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 127 | m_RecurrentToOutputWeightsTensor->GetTensorInfo(), |
| 128 | m_RecurrentToOutputWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 129 | |
| 130 | // Optional CIFG params |
| 131 | std::unique_ptr<Decoder<float>> inputToInputWeightsDecoder; |
| 132 | std::unique_ptr<Decoder<float>> recurrentToInputWeightsDecoder; |
| 133 | std::unique_ptr<Decoder<float>> inputGateBiasDecoder; |
| 134 | |
| 135 | // Optional Peephole params |
| 136 | std::unique_ptr<Decoder<float>> cellToInputWeightsDecoder; |
| 137 | std::unique_ptr<Decoder<float>> cellToForgetWeightsDecoder; |
| 138 | std::unique_ptr<Decoder<float>> cellToOutputWeightsDecoder; |
| 139 | |
| 140 | // Optional Projection params |
| 141 | std::unique_ptr<Decoder<float>> projectionWeightsDecoder; |
| 142 | std::unique_ptr<Decoder<float>> projectionBiasDecoder; |
| 143 | |
| 144 | // Optional Layer Norm params |
| 145 | std::unique_ptr<Decoder<float>> inputLayerNormWeightsDecoder; |
| 146 | std::unique_ptr<Decoder<float>> forgetLayerNormWeightsDecoder; |
| 147 | std::unique_ptr<Decoder<float>> cellLayerNormWeightsDecoder; |
| 148 | std::unique_ptr<Decoder<float>> outputLayerNormWeightsDecoder; |
| 149 | |
| 150 | // Biases are only used when Layer Norm is enabled. Scale is defined as (XLayerNormWeights Scale / 1024) |
| 151 | std::unique_ptr<Decoder<float>> forgetGateBiasDecoder; |
| 152 | std::unique_ptr<Decoder<float>> cellGateBiasDecoder; |
| 153 | std::unique_ptr<Decoder<float>> outputGateBiasDecoder; |
| 154 | |
| 155 | // Int16 vectors for internal state data (to be decoded/encoded) |
| 156 | const uint32_t stateTensorSize = numBatches * numUnits; |
| 157 | std::vector<int16_t> inputGateData(stateTensorSize); |
| 158 | std::vector<int16_t> cellGateData(stateTensorSize); |
| 159 | std::vector<int16_t> forgetGateData(stateTensorSize); |
| 160 | std::vector<int16_t> outputGateData(stateTensorSize); |
| 161 | std::vector<int32_t> hiddenStateData(stateTensorSize); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 162 | std::vector<int16_t> outputInt16Data(numBatches * outputSize); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 163 | |
| 164 | armnn::TensorInfo inputGateInfo( |
| 165 | {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_InputIntermediateScale, 0); |
| 166 | armnn::TensorInfo cellGateInfo( |
| 167 | {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_CellIntermediateScale, 0); |
| 168 | armnn::TensorInfo forgetGateInfo( |
| 169 | {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_ForgetIntermediateScale, 0); |
| 170 | armnn::TensorInfo outputGateInfo( |
| 171 | {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_OutputIntermediateScale, 0); |
| 172 | armnn::TensorInfo hiddenStateInfo({numBatches, numUnits}, |
| 173 | armnn::DataType::QAsymmS8, |
| 174 | m_Data.m_Parameters.m_HiddenStateScale, |
| 175 | m_Data.m_Parameters.m_HiddenStateZeroPoint); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 176 | armnn::TensorInfo outputInt16Info({numBatches , outputSize}, |
| 177 | armnn::DataType::QSymmS16, |
| 178 | outputInfo.GetQuantizationScale(), |
| 179 | outputInfo.GetQuantizationOffset()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 180 | |
| 181 | // Decoders/Encoders for internal states |
| 182 | std::unique_ptr<Decoder<float>> inputGateDecoder = |
| 183 | MakeDecoder<float>(inputGateInfo, inputGateData.data()); |
| 184 | std::unique_ptr<Decoder<float>> cellGateDecoder = |
| 185 | MakeDecoder<float>(cellGateInfo, cellGateData.data()); |
| 186 | std::unique_ptr<Decoder<float>> forgetGateDecoder = |
| 187 | MakeDecoder<float>(forgetGateInfo, forgetGateData.data()); |
| 188 | std::unique_ptr<Decoder<float>> outputGateDecoder = |
| 189 | MakeDecoder<float>(outputGateInfo, outputGateData.data()); |
| 190 | std::unique_ptr<Decoder<float>> hiddenStateDecoder = |
| 191 | MakeDecoder<float>(hiddenStateInfo, hiddenStateData.data()); |
| 192 | |
| 193 | std::unique_ptr<Encoder<float>> inputGateEncoder = |
| 194 | MakeEncoder<float>(inputGateInfo, inputGateData.data()); |
| 195 | std::unique_ptr<Encoder<float>> cellGateEncoder = |
| 196 | MakeEncoder<float>(cellGateInfo, cellGateData.data()); |
| 197 | std::unique_ptr<Encoder<float>> forgetGateEncoder = |
| 198 | MakeEncoder<float>(forgetGateInfo, forgetGateData.data()); |
| 199 | std::unique_ptr<Encoder<float>> outputGateEncoder = |
| 200 | MakeEncoder<float>(outputGateInfo, outputGateData.data()); |
| 201 | std::unique_ptr<Encoder<float>> hiddenStateEncoder = |
| 202 | MakeEncoder<float>(hiddenStateInfo, hiddenStateData.data()); |
| 203 | |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 204 | // Int16 used to accumulate output to prevent overflowing (after Projection MatMul) |
| 205 | std::unique_ptr<Decoder<float>> outputInt16Decoder = |
| 206 | MakeDecoder<float>(outputInt16Info, outputInt16Data.data()); |
| 207 | std::unique_ptr<Encoder<float>> outputInt16Encoder = |
| 208 | MakeEncoder<float>(outputInt16Info, outputInt16Data.data()); |
| 209 | |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 210 | // Create decoders for optional params if they are enabled |
| 211 | if (!cifgEnabled) |
| 212 | { |
| 213 | inputToInputWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 214 | m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>()); |
| 215 | recurrentToInputWeightsDecoder = MakeDecoder<float>(m_RecurrentToInputWeightsTensor->GetTensorInfo(), |
| 216 | m_RecurrentToInputWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 217 | } |
| 218 | |
| 219 | if (peepholeEnabled) |
| 220 | { |
| 221 | if (!cifgEnabled) |
| 222 | { |
| 223 | cellToInputWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 224 | m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 225 | } |
| 226 | cellToForgetWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 227 | m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 228 | cellToOutputWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 229 | m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 230 | } |
| 231 | |
| 232 | if (projectionEnabled) |
| 233 | { |
| 234 | projectionWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 235 | m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 236 | if (m_ProjectionBiasTensor) |
| 237 | { |
| 238 | projectionBiasDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 239 | m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 240 | } |
| 241 | } |
| 242 | |
| 243 | if (layerNormEnabled) |
| 244 | { |
| 245 | if (!cifgEnabled) |
| 246 | { |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 247 | inputLayerNormWeightsDecoder = MakeDecoder<float>(m_InputLayerNormWeightsTensor->GetTensorInfo(), |
| 248 | m_InputLayerNormWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 249 | |
| 250 | // Bias only used if layer norm enabled |
| 251 | armnn::TensorInfo inputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32, |
| 252 | m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0); |
| 253 | inputGateBiasDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 254 | inputGateBiasTensorInfo, m_InputGateBiasTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 255 | } |
| 256 | |
| 257 | forgetLayerNormWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 258 | m_ForgetLayerNormWeightsTensor->GetTensorInfo(), |
| 259 | m_ForgetLayerNormWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 260 | cellLayerNormWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 261 | m_CellLayerNormWeightsTensor->GetTensorInfo(), m_CellLayerNormWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 262 | outputLayerNormWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 263 | m_OutputLayerNormWeightsTensor->GetTensorInfo(), |
| 264 | m_OutputLayerNormWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 265 | |
| 266 | // Bias only used if layer norm enabled |
| 267 | armnn::TensorInfo forgetGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32, |
| 268 | m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0); |
| 269 | forgetGateBiasDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 270 | forgetGateBiasTensorInfo, m_ForgetGateBiasTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 271 | |
| 272 | armnn::TensorInfo cellGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32, |
| 273 | m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0); |
| 274 | cellGateBiasDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 275 | cellGateBiasTensorInfo, m_CellBiasTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 276 | |
| 277 | armnn::TensorInfo outputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32, |
| 278 | m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0); |
| 279 | outputGateBiasDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 280 | outputGateBiasTensorInfo, m_OutputGateBiasTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 281 | } |
| 282 | |
| 283 | // Initialize internal state tensors with zeroes. |
| 284 | if (!cifgEnabled) |
| 285 | { |
| 286 | ZeroVector(*inputGateEncoder, stateTensorSize); |
| 287 | } |
| 288 | ZeroVector(*forgetGateEncoder, stateTensorSize); |
| 289 | ZeroVector(*cellGateEncoder, stateTensorSize); |
| 290 | ZeroVector(*outputGateEncoder, stateTensorSize); |
| 291 | ZeroVector(*hiddenStateEncoder, stateTensorSize); |
| 292 | |
| 293 | // Input weights * Input |
| 294 | if (!cifgEnabled) |
| 295 | { |
| 296 | MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsDecoder, |
| 297 | numUnits, inputSize, *inputDecoder, numBatches, *inputGateEncoder); |
| 298 | } |
| 299 | |
| 300 | MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsDecoder, |
| 301 | numUnits, inputSize, *inputDecoder, numBatches, *forgetGateEncoder); |
| 302 | |
| 303 | MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsDecoder, |
| 304 | numUnits, inputSize, *inputDecoder, numBatches, *cellGateEncoder); |
| 305 | |
| 306 | MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsDecoder, |
| 307 | numUnits, inputSize, *inputDecoder, numBatches, *outputGateEncoder); |
| 308 | |
| 309 | // Recurrent weights * OutputStateIn |
| 310 | if (!cifgEnabled) |
| 311 | { |
| 312 | MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsDecoder, |
| 313 | numUnits, outputSize, *outputStateInDecoder, numBatches, *inputGateEncoder); |
| 314 | } |
| 315 | |
| 316 | MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsDecoder, |
| 317 | numUnits, outputSize, *outputStateInDecoder, numBatches, *forgetGateEncoder); |
| 318 | |
| 319 | MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsDecoder, |
| 320 | numUnits, outputSize, *outputStateInDecoder, numBatches, *cellGateEncoder); |
| 321 | |
| 322 | MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsDecoder, |
| 323 | numUnits, outputSize, *outputStateInDecoder, numBatches, *outputGateEncoder); |
| 324 | |
| 325 | // Input gate. |
| 326 | if (!cifgEnabled) |
| 327 | { |
| 328 | if (peepholeEnabled) |
| 329 | { |
| 330 | VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsDecoder, |
| 331 | numUnits, *cellStateInDecoder, numBatches, *inputGateEncoder); |
| 332 | } |
| 333 | |
| 334 | if (layerNormEnabled) |
| 335 | { |
| 336 | inputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() * |
| 337 | m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() * |
| 338 | 1024); |
| 339 | inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data()); |
| 340 | |
| 341 | MeanStddevNormalization(*inputGateDecoder, |
| 342 | *inputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon); |
| 343 | |
| 344 | inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data()); |
| 345 | |
| 346 | VectorBatchVectorCwiseProduct(*inputLayerNormWeightsDecoder, |
| 347 | numUnits, *inputGateDecoder, numBatches, *inputGateEncoder); |
| 348 | |
| 349 | inputGateInfo.SetQuantizationScale(1.f / 4096); |
| 350 | inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data()); |
| 351 | |
| 352 | VectorBatchVectorAdd(*inputGateBiasDecoder, |
| 353 | numUnits, *inputGateDecoder, numBatches, *inputGateEncoder); |
| 354 | |
| 355 | inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data()); |
| 356 | } |
| 357 | |
| 358 | inputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale()); |
| 359 | inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data()); |
| 360 | |
| 361 | // Input gate sigmoid |
| 362 | Activation(*inputGateDecoder, *inputGateEncoder, |
| 363 | TensorInfo({numUnits, numBatches}, internalType), |
| 364 | ActivationFunction::Sigmoid, 0, 0); |
| 365 | |
| 366 | inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data()); |
| 367 | } |
| 368 | |
| 369 | // Forget gate |
| 370 | if (peepholeEnabled) |
| 371 | { |
| 372 | VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsDecoder, numUnits, |
| 373 | *cellStateInDecoder, numBatches, *forgetGateEncoder); |
| 374 | } |
| 375 | |
| 376 | if (layerNormEnabled) |
| 377 | { |
| 378 | // Quantize layer norm output to Input Scale * m_ForgetLayerNormWeightsTensor * 1024 |
| 379 | forgetGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() * |
| 380 | m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() * |
| 381 | 1024); |
| 382 | forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data()); |
| 383 | |
| 384 | |
| 385 | |
| 386 | MeanStddevNormalization(*forgetGateDecoder, |
| 387 | *forgetGateEncoder, numUnits, numBatches, m_LayerNormEpsilon); |
| 388 | |
| 389 | |
| 390 | forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data()); |
| 391 | |
| 392 | VectorBatchVectorCwiseProduct(*forgetLayerNormWeightsDecoder, |
| 393 | numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder); |
| 394 | |
| 395 | |
| 396 | // Dequantize layer norm output to (1 / 4096) |
| 397 | forgetGateInfo.SetQuantizationScale(1.f / 4096); |
| 398 | forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data()); |
| 399 | |
| 400 | VectorBatchVectorAdd(*forgetGateBiasDecoder, |
| 401 | numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder); |
| 402 | |
| 403 | |
| 404 | forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data()); |
| 405 | } |
| 406 | |
| 407 | forgetGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale()); |
| 408 | forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data()); |
| 409 | |
| 410 | // Forget gate sigmoid |
| 411 | Activation(*forgetGateDecoder, *forgetGateEncoder, |
| 412 | TensorInfo({numUnits, numBatches}, internalType), |
| 413 | ActivationFunction::Sigmoid, 0, 0); |
| 414 | |
| 415 | forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data()); |
| 416 | |
| 417 | // Cell (Modulation) gate |
| 418 | if (layerNormEnabled) |
| 419 | { |
| 420 | cellGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() * |
| 421 | m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() * |
| 422 | 1024); |
| 423 | cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data()); |
| 424 | |
| 425 | MeanStddevNormalization(*cellGateDecoder, *cellGateEncoder, numUnits, numBatches, m_LayerNormEpsilon); |
| 426 | |
| 427 | cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data()); |
| 428 | |
| 429 | VectorBatchVectorCwiseProduct(*cellLayerNormWeightsDecoder, |
| 430 | numUnits, *cellGateDecoder, numBatches, *cellGateEncoder); |
| 431 | |
| 432 | cellGateInfo.SetQuantizationScale(1.f / 4096); |
| 433 | cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data()); |
| 434 | |
| 435 | VectorBatchVectorAdd(*cellGateBiasDecoder, |
| 436 | numUnits, *cellGateDecoder, numBatches, *cellGateEncoder); |
| 437 | |
| 438 | cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data()); |
| 439 | } |
| 440 | |
| 441 | cellGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale()); |
| 442 | cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data()); |
| 443 | |
| 444 | // Cell (Modulation) gate tanH |
| 445 | Activation(*cellGateDecoder, *cellGateEncoder, |
| 446 | TensorInfo({numUnits, numBatches}, internalType), |
| 447 | ActivationFunction::TanH, 1.0f, 1.0f); |
| 448 | |
| 449 | cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data()); |
| 450 | |
| 451 | VectorVectorCwiseProduct(*forgetGateDecoder, *cellStateInDecoder, stateTensorSize, *cellStateOutEncoder); |
| 452 | |
| 453 | if (cifgEnabled) |
| 454 | { |
| 455 | Sub1Vector(*forgetGateDecoder, stateTensorSize, *forgetGateEncoder); |
| 456 | VectorVectorCwiseProductAccumulate( |
| 457 | *cellGateDecoder, *forgetGateDecoder, stateTensorSize, *cellStateOutEncoder); |
| 458 | } |
| 459 | else |
| 460 | { |
| 461 | VectorVectorCwiseProductAccumulate( |
| 462 | *cellGateDecoder, *inputGateDecoder, stateTensorSize, *cellStateOutEncoder); |
| 463 | } |
| 464 | |
| 465 | // Final cell state out calculated here |
| 466 | if (m_Data.m_Parameters.m_CellClip > 0.0) |
| 467 | { |
| 468 | ClipVector(*cellStateOutDecoder, stateTensorSize, m_Data.m_Parameters.m_CellClip, *cellStateOutEncoder); |
| 469 | } |
| 470 | |
| 471 | // Output gate. |
| 472 | if (peepholeEnabled) |
| 473 | { |
| 474 | VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsDecoder, |
| 475 | numUnits, *cellStateOutDecoder, numBatches, *outputGateEncoder); |
| 476 | } |
| 477 | |
| 478 | if (layerNormEnabled) |
| 479 | { |
| 480 | outputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() * |
| 481 | m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() * |
| 482 | 1024); |
| 483 | outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data()); |
| 484 | |
| 485 | MeanStddevNormalization(*outputGateDecoder, *outputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon); |
| 486 | |
| 487 | outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data()); |
| 488 | |
| 489 | VectorBatchVectorCwiseProduct(*outputLayerNormWeightsDecoder, numUnits, *outputGateDecoder, |
| 490 | numBatches, *outputGateEncoder); |
| 491 | |
| 492 | outputGateInfo.SetQuantizationScale(1.f / 4096); |
| 493 | outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data()); |
| 494 | |
| 495 | VectorBatchVectorAdd(*outputGateBiasDecoder, numUnits, *outputGateDecoder, numBatches, *outputGateEncoder); |
| 496 | |
| 497 | outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data()); |
| 498 | } |
| 499 | |
| 500 | outputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale()); |
| 501 | outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data()); |
| 502 | |
| 503 | // Output gate sigmoid |
| 504 | Activation(*outputGateDecoder, *outputGateEncoder, |
| 505 | TensorInfo({numUnits, numBatches}, internalType), |
| 506 | ActivationFunction::Sigmoid, 0, 0); |
| 507 | |
| 508 | outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data()); |
| 509 | |
| 510 | // Hidden state tanH |
| 511 | Activation(*cellStateOutDecoder, *cellGateEncoder, |
| 512 | TensorInfo({numUnits, numBatches}, internalType), |
| 513 | ActivationFunction::TanH, 1.0f, 1.0f); |
| 514 | |
| 515 | // Final hidden state output |
| 516 | VectorVectorCwiseProduct(*outputGateDecoder, *cellGateDecoder, stateTensorSize, *hiddenStateEncoder); |
| 517 | |
| 518 | // Projection |
| 519 | if (m_Data.m_Parameters.m_ProjectionEnabled) |
| 520 | { |
| 521 | if (m_ProjectionBiasTensor) |
| 522 | { |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 523 | VectorBatchVectorAssign(*projectionBiasDecoder, outputSize, numBatches, *outputInt16Encoder); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 524 | } |
| 525 | |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 526 | MatrixBatchVectorMultiplyAccumulate(*projectionWeightsDecoder, outputSize, numUnits, *hiddenStateDecoder, |
| 527 | numBatches, *outputInt16Encoder); |
| 528 | |
| 529 | CopyVector(*outputInt16Decoder, numBatches * outputSize, *outputEncoder); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 530 | |
| 531 | if (m_Data.m_Parameters.m_ProjectionClip > 0.0) |
| 532 | { |
| 533 | ClipVector(*outputDecoder, numBatches * outputSize, m_Data.m_Parameters.m_ProjectionClip, *outputEncoder); |
| 534 | } |
| 535 | } |
| 536 | else |
| 537 | { |
| 538 | // Output has same quantization scale as hidden state if projection is disabled |
| 539 | CopyVector(*hiddenStateDecoder, numBatches * outputSize, *outputEncoder); |
| 540 | } |
| 541 | |
| 542 | // output == outputStateOut |
| 543 | CopyVector(*outputDecoder, numBatches * outputSize, *outputStateOutEncoder); |
| 544 | } |
| 545 | |
| 546 | } //namespace armnn |