James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2020 Arm Ltd. All rights reserved. |
| 3 | // SPDX-License-Identifier: MIT |
| 4 | // |
| 5 | |
| 6 | #include "RefQLstmWorkload.hpp" |
| 7 | #include "Activation.hpp" |
| 8 | #include "Encoders.hpp" |
| 9 | #include "Decoders.hpp" |
| 10 | #include "LstmUtils.hpp" |
| 11 | #include "RefWorkloadUtils.hpp" |
| 12 | |
| 13 | namespace armnn |
| 14 | { |
| 15 | |
| 16 | RefQLstmWorkload::RefQLstmWorkload(const QLstmQueueDescriptor &descriptor, const WorkloadInfo &info) |
| 17 | : BaseWorkload<QLstmQueueDescriptor>(descriptor, info) |
| 18 | , m_InputToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToInputWeights)) |
| 19 | , m_InputToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToForgetWeights)) |
| 20 | , m_InputToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToCellWeights)) |
| 21 | , m_InputToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToOutputWeights)) |
| 22 | |
| 23 | , m_RecurrentToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToInputWeights)) |
| 24 | , m_RecurrentToForgetWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToForgetWeights)) |
| 25 | , m_RecurrentToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToCellWeights)) |
| 26 | , m_RecurrentToOutputWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToOutputWeights)) |
| 27 | |
| 28 | , m_CellToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToInputWeights)) |
| 29 | , m_CellToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToForgetWeights)) |
| 30 | , m_CellToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToOutputWeights)) |
| 31 | |
| 32 | , m_InputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_InputGateBias)) |
| 33 | , m_ForgetGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ForgetGateBias)) |
| 34 | , m_CellBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_CellBias)) |
| 35 | , m_OutputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_OutputGateBias)) |
| 36 | |
| 37 | , m_ProjectionWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionWeights)) |
| 38 | , m_ProjectionBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionBias)) |
| 39 | |
| 40 | , m_InputLayerNormWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputLayerNormWeights)) |
| 41 | , m_ForgetLayerNormWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_ForgetLayerNormWeights)) |
| 42 | , m_CellLayerNormWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellLayerNormWeights)) |
| 43 | , m_OutputLayerNormWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_OutputLayerNormWeights)) |
| 44 | {} |
| 45 | |
| 46 | void RefQLstmWorkload::Execute() const |
| 47 | { |
| 48 | // This is a porting of the QLSTM::Execute() method in the Android code base |
| 49 | // Note: this implementation wraps the arithmetic functions of the LSTM cell in Quantize/Dequantize ops, so all |
| 50 | // computation is done in the floating point domain. Arithmetic functions are found in LstmUtils.cpp. |
| 51 | // Refer to: android/frameworks/ml/nn/common/operations/QLSTM.cpp |
| 52 | const DataType& internalType = armnn::DataType::QSymmS16; |
| 53 | |
| 54 | const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); |
| 55 | const TensorInfo& outputStateInInfo = GetTensorInfo(m_Data.m_Inputs[1]); |
| 56 | const TensorInfo& cellStateInInfo = GetTensorInfo(m_Data.m_Inputs[2]); |
| 57 | |
| 58 | const TensorInfo& outputStateOutInfo = GetTensorInfo(m_Data.m_Outputs[0]); |
| 59 | const TensorInfo& cellStateOutInfo = GetTensorInfo(m_Data.m_Outputs[1]); |
| 60 | const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[2]); |
| 61 | |
| 62 | const TensorShape& inputShape = inputInfo.GetShape(); |
| 63 | const TensorShape& outputStateInShape = outputStateInInfo.GetShape(); |
| 64 | const TensorShape& cellStateInShape = cellStateInInfo.GetShape(); |
| 65 | |
| 66 | // Infer numBatches, inputSize, outputSize and numUnits |
| 67 | const uint32_t numBatches = inputShape[0]; |
| 68 | const uint32_t inputSize = inputShape[1]; |
| 69 | const uint32_t outputSize = outputStateInShape[1]; |
| 70 | const uint32_t numUnits = cellStateInShape[1]; |
| 71 | |
| 72 | // Optional param settings |
| 73 | const bool cifgEnabled = m_Data.m_Parameters.m_CifgEnabled; |
| 74 | const bool peepholeEnabled = m_Data.m_Parameters.m_PeepholeEnabled; |
| 75 | const bool projectionEnabled = m_Data.m_Parameters.m_ProjectionEnabled; |
| 76 | const bool layerNormEnabled = m_Data.m_Parameters.m_LayerNormEnabled; |
| 77 | |
| 78 | // Input decoders |
| 79 | std::unique_ptr<Decoder<float>> inputDecoder = |
| 80 | MakeDecoder<float>(inputInfo, m_Data.m_Inputs[0]->Map()); |
| 81 | std::unique_ptr<Decoder<float>> outputStateInDecoder = |
| 82 | MakeDecoder<float>(outputStateInInfo, m_Data.m_Inputs[1]->Map()); |
| 83 | std::unique_ptr<Decoder<float>> cellStateInDecoder = |
| 84 | MakeDecoder<float>(cellStateInInfo, m_Data.m_Inputs[2]->Map()); |
| 85 | |
| 86 | // Output decoders |
| 87 | std::unique_ptr<Decoder<float>> outputStateOutDecoder = |
| 88 | MakeDecoder<float>(outputStateOutInfo, m_Data.m_Outputs[0]->Map()); |
| 89 | std::unique_ptr<Decoder<float>> cellStateOutDecoder = |
| 90 | MakeDecoder<float>(cellStateOutInfo, m_Data.m_Outputs[1]->Map()); |
| 91 | std::unique_ptr<Decoder<float>> outputDecoder = |
| 92 | MakeDecoder<float>(outputInfo, m_Data.m_Outputs[2]->Map()); |
| 93 | |
| 94 | // Output encoders |
| 95 | std::unique_ptr<Encoder<float>> outputStateOutEncoder = |
| 96 | MakeEncoder<float>(outputStateOutInfo, m_Data.m_Outputs[0]->Map()); |
| 97 | std::unique_ptr<Encoder<float>> cellStateOutEncoder = |
| 98 | MakeEncoder<float>(cellStateOutInfo, m_Data.m_Outputs[1]->Map()); |
| 99 | std::unique_ptr<Encoder<float>> outputEncoder = |
| 100 | MakeEncoder<float>(outputInfo, m_Data.m_Outputs[2]->Map()); |
| 101 | |
| 102 | // Weights decoders |
| 103 | std::unique_ptr<Decoder<float>> inputToForgetWeightsDecoder = MakeDecoder<float>( |
| 104 | m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetTensor<void>()); |
| 105 | std::unique_ptr<Decoder<float>> inputToCellWeightsDecoder = MakeDecoder<float>( |
| 106 | m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetTensor<void>()); |
| 107 | std::unique_ptr<Decoder<float>> inputToOutputWeightsDecoder = MakeDecoder<float>( |
| 108 | m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetTensor<void>()); |
| 109 | |
| 110 | std::unique_ptr<Decoder<float>> recurrentToForgetWeightsDecoder = MakeDecoder<float>( |
| 111 | m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetTensor<void>()); |
| 112 | std::unique_ptr<Decoder<float>> recurrentToCellWeightsDecoder = MakeDecoder<float>( |
| 113 | m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetTensor<void>()); |
| 114 | std::unique_ptr<Decoder<float>> recurrentToOutputWeightsDecoder = MakeDecoder<float>( |
| 115 | m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetTensor<void>()); |
| 116 | |
| 117 | // Optional CIFG params |
| 118 | std::unique_ptr<Decoder<float>> inputToInputWeightsDecoder; |
| 119 | std::unique_ptr<Decoder<float>> recurrentToInputWeightsDecoder; |
| 120 | std::unique_ptr<Decoder<float>> inputGateBiasDecoder; |
| 121 | |
| 122 | // Optional Peephole params |
| 123 | std::unique_ptr<Decoder<float>> cellToInputWeightsDecoder; |
| 124 | std::unique_ptr<Decoder<float>> cellToForgetWeightsDecoder; |
| 125 | std::unique_ptr<Decoder<float>> cellToOutputWeightsDecoder; |
| 126 | |
| 127 | // Optional Projection params |
| 128 | std::unique_ptr<Decoder<float>> projectionWeightsDecoder; |
| 129 | std::unique_ptr<Decoder<float>> projectionBiasDecoder; |
| 130 | |
| 131 | // Optional Layer Norm params |
| 132 | std::unique_ptr<Decoder<float>> inputLayerNormWeightsDecoder; |
| 133 | std::unique_ptr<Decoder<float>> forgetLayerNormWeightsDecoder; |
| 134 | std::unique_ptr<Decoder<float>> cellLayerNormWeightsDecoder; |
| 135 | std::unique_ptr<Decoder<float>> outputLayerNormWeightsDecoder; |
| 136 | |
| 137 | // Biases are only used when Layer Norm is enabled. Scale is defined as (XLayerNormWeights Scale / 1024) |
| 138 | std::unique_ptr<Decoder<float>> forgetGateBiasDecoder; |
| 139 | std::unique_ptr<Decoder<float>> cellGateBiasDecoder; |
| 140 | std::unique_ptr<Decoder<float>> outputGateBiasDecoder; |
| 141 | |
| 142 | // Int16 vectors for internal state data (to be decoded/encoded) |
| 143 | const uint32_t stateTensorSize = numBatches * numUnits; |
| 144 | std::vector<int16_t> inputGateData(stateTensorSize); |
| 145 | std::vector<int16_t> cellGateData(stateTensorSize); |
| 146 | std::vector<int16_t> forgetGateData(stateTensorSize); |
| 147 | std::vector<int16_t> outputGateData(stateTensorSize); |
| 148 | std::vector<int32_t> hiddenStateData(stateTensorSize); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 149 | std::vector<int16_t> outputInt16Data(numBatches * outputSize); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 150 | |
| 151 | armnn::TensorInfo inputGateInfo( |
| 152 | {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_InputIntermediateScale, 0); |
| 153 | armnn::TensorInfo cellGateInfo( |
| 154 | {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_CellIntermediateScale, 0); |
| 155 | armnn::TensorInfo forgetGateInfo( |
| 156 | {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_ForgetIntermediateScale, 0); |
| 157 | armnn::TensorInfo outputGateInfo( |
| 158 | {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_OutputIntermediateScale, 0); |
| 159 | armnn::TensorInfo hiddenStateInfo({numBatches, numUnits}, |
| 160 | armnn::DataType::QAsymmS8, |
| 161 | m_Data.m_Parameters.m_HiddenStateScale, |
| 162 | m_Data.m_Parameters.m_HiddenStateZeroPoint); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 163 | armnn::TensorInfo outputInt16Info({numBatches , outputSize}, |
| 164 | armnn::DataType::QSymmS16, |
| 165 | outputInfo.GetQuantizationScale(), |
| 166 | outputInfo.GetQuantizationOffset()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 167 | |
| 168 | // Decoders/Encoders for internal states |
| 169 | std::unique_ptr<Decoder<float>> inputGateDecoder = |
| 170 | MakeDecoder<float>(inputGateInfo, inputGateData.data()); |
| 171 | std::unique_ptr<Decoder<float>> cellGateDecoder = |
| 172 | MakeDecoder<float>(cellGateInfo, cellGateData.data()); |
| 173 | std::unique_ptr<Decoder<float>> forgetGateDecoder = |
| 174 | MakeDecoder<float>(forgetGateInfo, forgetGateData.data()); |
| 175 | std::unique_ptr<Decoder<float>> outputGateDecoder = |
| 176 | MakeDecoder<float>(outputGateInfo, outputGateData.data()); |
| 177 | std::unique_ptr<Decoder<float>> hiddenStateDecoder = |
| 178 | MakeDecoder<float>(hiddenStateInfo, hiddenStateData.data()); |
| 179 | |
| 180 | std::unique_ptr<Encoder<float>> inputGateEncoder = |
| 181 | MakeEncoder<float>(inputGateInfo, inputGateData.data()); |
| 182 | std::unique_ptr<Encoder<float>> cellGateEncoder = |
| 183 | MakeEncoder<float>(cellGateInfo, cellGateData.data()); |
| 184 | std::unique_ptr<Encoder<float>> forgetGateEncoder = |
| 185 | MakeEncoder<float>(forgetGateInfo, forgetGateData.data()); |
| 186 | std::unique_ptr<Encoder<float>> outputGateEncoder = |
| 187 | MakeEncoder<float>(outputGateInfo, outputGateData.data()); |
| 188 | std::unique_ptr<Encoder<float>> hiddenStateEncoder = |
| 189 | MakeEncoder<float>(hiddenStateInfo, hiddenStateData.data()); |
| 190 | |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 191 | // Int16 used to accumulate output to prevent overflowing (after Projection MatMul) |
| 192 | std::unique_ptr<Decoder<float>> outputInt16Decoder = |
| 193 | MakeDecoder<float>(outputInt16Info, outputInt16Data.data()); |
| 194 | std::unique_ptr<Encoder<float>> outputInt16Encoder = |
| 195 | MakeEncoder<float>(outputInt16Info, outputInt16Data.data()); |
| 196 | |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 197 | // Create decoders for optional params if they are enabled |
| 198 | if (!cifgEnabled) |
| 199 | { |
| 200 | inputToInputWeightsDecoder = MakeDecoder<float>( |
| 201 | m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetTensor<void>()); |
| 202 | recurrentToInputWeightsDecoder = MakeDecoder<float>( |
| 203 | m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetTensor<void>()); |
| 204 | } |
| 205 | |
| 206 | if (peepholeEnabled) |
| 207 | { |
| 208 | if (!cifgEnabled) |
| 209 | { |
| 210 | cellToInputWeightsDecoder = MakeDecoder<float>( |
| 211 | m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetTensor<void>()); |
| 212 | } |
| 213 | cellToForgetWeightsDecoder = MakeDecoder<float>( |
| 214 | m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetTensor<void>()); |
| 215 | cellToOutputWeightsDecoder = MakeDecoder<float>( |
| 216 | m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetTensor<void>()); |
| 217 | } |
| 218 | |
| 219 | if (projectionEnabled) |
| 220 | { |
| 221 | projectionWeightsDecoder = MakeDecoder<float>( |
| 222 | m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetTensor<void>()); |
| 223 | if (m_ProjectionBiasTensor) |
| 224 | { |
| 225 | projectionBiasDecoder = MakeDecoder<float>( |
| 226 | m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetTensor<void>()); |
| 227 | } |
| 228 | } |
| 229 | |
| 230 | if (layerNormEnabled) |
| 231 | { |
| 232 | if (!cifgEnabled) |
| 233 | { |
| 234 | inputLayerNormWeightsDecoder = MakeDecoder<float>( |
| 235 | m_InputLayerNormWeightsTensor->GetTensorInfo(), m_InputLayerNormWeightsTensor->GetTensor<void>()); |
| 236 | |
| 237 | // Bias only used if layer norm enabled |
| 238 | armnn::TensorInfo inputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32, |
| 239 | m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0); |
| 240 | inputGateBiasDecoder = MakeDecoder<float>( |
| 241 | inputGateBiasTensorInfo, m_InputGateBiasTensor->GetTensor<void>()); |
| 242 | } |
| 243 | |
| 244 | forgetLayerNormWeightsDecoder = MakeDecoder<float>( |
| 245 | m_ForgetLayerNormWeightsTensor->GetTensorInfo(), m_ForgetLayerNormWeightsTensor->GetTensor<void>()); |
| 246 | cellLayerNormWeightsDecoder = MakeDecoder<float>( |
| 247 | m_CellLayerNormWeightsTensor->GetTensorInfo(), m_CellLayerNormWeightsTensor->GetTensor<void>()); |
| 248 | outputLayerNormWeightsDecoder = MakeDecoder<float>( |
| 249 | m_OutputLayerNormWeightsTensor->GetTensorInfo(), m_OutputLayerNormWeightsTensor->GetTensor<void>()); |
| 250 | |
| 251 | // Bias only used if layer norm enabled |
| 252 | armnn::TensorInfo forgetGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32, |
| 253 | m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0); |
| 254 | forgetGateBiasDecoder = MakeDecoder<float>( |
| 255 | forgetGateBiasTensorInfo, m_ForgetGateBiasTensor->GetTensor<void>()); |
| 256 | |
| 257 | armnn::TensorInfo cellGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32, |
| 258 | m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0); |
| 259 | cellGateBiasDecoder = MakeDecoder<float>( |
| 260 | cellGateBiasTensorInfo, m_CellBiasTensor->GetTensor<void>()); |
| 261 | |
| 262 | armnn::TensorInfo outputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32, |
| 263 | m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0); |
| 264 | outputGateBiasDecoder = MakeDecoder<float>( |
| 265 | outputGateBiasTensorInfo, m_OutputGateBiasTensor->GetTensor<void>()); |
| 266 | } |
| 267 | |
| 268 | // Initialize internal state tensors with zeroes. |
| 269 | if (!cifgEnabled) |
| 270 | { |
| 271 | ZeroVector(*inputGateEncoder, stateTensorSize); |
| 272 | } |
| 273 | ZeroVector(*forgetGateEncoder, stateTensorSize); |
| 274 | ZeroVector(*cellGateEncoder, stateTensorSize); |
| 275 | ZeroVector(*outputGateEncoder, stateTensorSize); |
| 276 | ZeroVector(*hiddenStateEncoder, stateTensorSize); |
| 277 | |
| 278 | // Input weights * Input |
| 279 | if (!cifgEnabled) |
| 280 | { |
| 281 | MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsDecoder, |
| 282 | numUnits, inputSize, *inputDecoder, numBatches, *inputGateEncoder); |
| 283 | } |
| 284 | |
| 285 | MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsDecoder, |
| 286 | numUnits, inputSize, *inputDecoder, numBatches, *forgetGateEncoder); |
| 287 | |
| 288 | MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsDecoder, |
| 289 | numUnits, inputSize, *inputDecoder, numBatches, *cellGateEncoder); |
| 290 | |
| 291 | MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsDecoder, |
| 292 | numUnits, inputSize, *inputDecoder, numBatches, *outputGateEncoder); |
| 293 | |
| 294 | // Recurrent weights * OutputStateIn |
| 295 | if (!cifgEnabled) |
| 296 | { |
| 297 | MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsDecoder, |
| 298 | numUnits, outputSize, *outputStateInDecoder, numBatches, *inputGateEncoder); |
| 299 | } |
| 300 | |
| 301 | MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsDecoder, |
| 302 | numUnits, outputSize, *outputStateInDecoder, numBatches, *forgetGateEncoder); |
| 303 | |
| 304 | MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsDecoder, |
| 305 | numUnits, outputSize, *outputStateInDecoder, numBatches, *cellGateEncoder); |
| 306 | |
| 307 | MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsDecoder, |
| 308 | numUnits, outputSize, *outputStateInDecoder, numBatches, *outputGateEncoder); |
| 309 | |
| 310 | // Input gate. |
| 311 | if (!cifgEnabled) |
| 312 | { |
| 313 | if (peepholeEnabled) |
| 314 | { |
| 315 | VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsDecoder, |
| 316 | numUnits, *cellStateInDecoder, numBatches, *inputGateEncoder); |
| 317 | } |
| 318 | |
| 319 | if (layerNormEnabled) |
| 320 | { |
| 321 | inputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() * |
| 322 | m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() * |
| 323 | 1024); |
| 324 | inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data()); |
| 325 | |
| 326 | MeanStddevNormalization(*inputGateDecoder, |
| 327 | *inputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon); |
| 328 | |
| 329 | inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data()); |
| 330 | |
| 331 | VectorBatchVectorCwiseProduct(*inputLayerNormWeightsDecoder, |
| 332 | numUnits, *inputGateDecoder, numBatches, *inputGateEncoder); |
| 333 | |
| 334 | inputGateInfo.SetQuantizationScale(1.f / 4096); |
| 335 | inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data()); |
| 336 | |
| 337 | VectorBatchVectorAdd(*inputGateBiasDecoder, |
| 338 | numUnits, *inputGateDecoder, numBatches, *inputGateEncoder); |
| 339 | |
| 340 | inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data()); |
| 341 | } |
| 342 | |
| 343 | inputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale()); |
| 344 | inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data()); |
| 345 | |
| 346 | // Input gate sigmoid |
| 347 | Activation(*inputGateDecoder, *inputGateEncoder, |
| 348 | TensorInfo({numUnits, numBatches}, internalType), |
| 349 | ActivationFunction::Sigmoid, 0, 0); |
| 350 | |
| 351 | inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data()); |
| 352 | } |
| 353 | |
| 354 | // Forget gate |
| 355 | if (peepholeEnabled) |
| 356 | { |
| 357 | VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsDecoder, numUnits, |
| 358 | *cellStateInDecoder, numBatches, *forgetGateEncoder); |
| 359 | } |
| 360 | |
| 361 | if (layerNormEnabled) |
| 362 | { |
| 363 | // Quantize layer norm output to Input Scale * m_ForgetLayerNormWeightsTensor * 1024 |
| 364 | forgetGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() * |
| 365 | m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() * |
| 366 | 1024); |
| 367 | forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data()); |
| 368 | |
| 369 | |
| 370 | |
| 371 | MeanStddevNormalization(*forgetGateDecoder, |
| 372 | *forgetGateEncoder, numUnits, numBatches, m_LayerNormEpsilon); |
| 373 | |
| 374 | |
| 375 | forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data()); |
| 376 | |
| 377 | VectorBatchVectorCwiseProduct(*forgetLayerNormWeightsDecoder, |
| 378 | numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder); |
| 379 | |
| 380 | |
| 381 | // Dequantize layer norm output to (1 / 4096) |
| 382 | forgetGateInfo.SetQuantizationScale(1.f / 4096); |
| 383 | forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data()); |
| 384 | |
| 385 | VectorBatchVectorAdd(*forgetGateBiasDecoder, |
| 386 | numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder); |
| 387 | |
| 388 | |
| 389 | forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data()); |
| 390 | } |
| 391 | |
| 392 | forgetGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale()); |
| 393 | forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data()); |
| 394 | |
| 395 | // Forget gate sigmoid |
| 396 | Activation(*forgetGateDecoder, *forgetGateEncoder, |
| 397 | TensorInfo({numUnits, numBatches}, internalType), |
| 398 | ActivationFunction::Sigmoid, 0, 0); |
| 399 | |
| 400 | forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data()); |
| 401 | |
| 402 | // Cell (Modulation) gate |
| 403 | if (layerNormEnabled) |
| 404 | { |
| 405 | cellGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() * |
| 406 | m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() * |
| 407 | 1024); |
| 408 | cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data()); |
| 409 | |
| 410 | MeanStddevNormalization(*cellGateDecoder, *cellGateEncoder, numUnits, numBatches, m_LayerNormEpsilon); |
| 411 | |
| 412 | cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data()); |
| 413 | |
| 414 | VectorBatchVectorCwiseProduct(*cellLayerNormWeightsDecoder, |
| 415 | numUnits, *cellGateDecoder, numBatches, *cellGateEncoder); |
| 416 | |
| 417 | cellGateInfo.SetQuantizationScale(1.f / 4096); |
| 418 | cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data()); |
| 419 | |
| 420 | VectorBatchVectorAdd(*cellGateBiasDecoder, |
| 421 | numUnits, *cellGateDecoder, numBatches, *cellGateEncoder); |
| 422 | |
| 423 | cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data()); |
| 424 | } |
| 425 | |
| 426 | cellGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale()); |
| 427 | cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data()); |
| 428 | |
| 429 | // Cell (Modulation) gate tanH |
| 430 | Activation(*cellGateDecoder, *cellGateEncoder, |
| 431 | TensorInfo({numUnits, numBatches}, internalType), |
| 432 | ActivationFunction::TanH, 1.0f, 1.0f); |
| 433 | |
| 434 | cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data()); |
| 435 | |
| 436 | VectorVectorCwiseProduct(*forgetGateDecoder, *cellStateInDecoder, stateTensorSize, *cellStateOutEncoder); |
| 437 | |
| 438 | if (cifgEnabled) |
| 439 | { |
| 440 | Sub1Vector(*forgetGateDecoder, stateTensorSize, *forgetGateEncoder); |
| 441 | VectorVectorCwiseProductAccumulate( |
| 442 | *cellGateDecoder, *forgetGateDecoder, stateTensorSize, *cellStateOutEncoder); |
| 443 | } |
| 444 | else |
| 445 | { |
| 446 | VectorVectorCwiseProductAccumulate( |
| 447 | *cellGateDecoder, *inputGateDecoder, stateTensorSize, *cellStateOutEncoder); |
| 448 | } |
| 449 | |
| 450 | // Final cell state out calculated here |
| 451 | if (m_Data.m_Parameters.m_CellClip > 0.0) |
| 452 | { |
| 453 | ClipVector(*cellStateOutDecoder, stateTensorSize, m_Data.m_Parameters.m_CellClip, *cellStateOutEncoder); |
| 454 | } |
| 455 | |
| 456 | // Output gate. |
| 457 | if (peepholeEnabled) |
| 458 | { |
| 459 | VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsDecoder, |
| 460 | numUnits, *cellStateOutDecoder, numBatches, *outputGateEncoder); |
| 461 | } |
| 462 | |
| 463 | if (layerNormEnabled) |
| 464 | { |
| 465 | outputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() * |
| 466 | m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() * |
| 467 | 1024); |
| 468 | outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data()); |
| 469 | |
| 470 | MeanStddevNormalization(*outputGateDecoder, *outputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon); |
| 471 | |
| 472 | outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data()); |
| 473 | |
| 474 | VectorBatchVectorCwiseProduct(*outputLayerNormWeightsDecoder, numUnits, *outputGateDecoder, |
| 475 | numBatches, *outputGateEncoder); |
| 476 | |
| 477 | outputGateInfo.SetQuantizationScale(1.f / 4096); |
| 478 | outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data()); |
| 479 | |
| 480 | VectorBatchVectorAdd(*outputGateBiasDecoder, numUnits, *outputGateDecoder, numBatches, *outputGateEncoder); |
| 481 | |
| 482 | outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data()); |
| 483 | } |
| 484 | |
| 485 | outputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale()); |
| 486 | outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data()); |
| 487 | |
| 488 | // Output gate sigmoid |
| 489 | Activation(*outputGateDecoder, *outputGateEncoder, |
| 490 | TensorInfo({numUnits, numBatches}, internalType), |
| 491 | ActivationFunction::Sigmoid, 0, 0); |
| 492 | |
| 493 | outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data()); |
| 494 | |
| 495 | // Hidden state tanH |
| 496 | Activation(*cellStateOutDecoder, *cellGateEncoder, |
| 497 | TensorInfo({numUnits, numBatches}, internalType), |
| 498 | ActivationFunction::TanH, 1.0f, 1.0f); |
| 499 | |
| 500 | // Final hidden state output |
| 501 | VectorVectorCwiseProduct(*outputGateDecoder, *cellGateDecoder, stateTensorSize, *hiddenStateEncoder); |
| 502 | |
| 503 | // Projection |
| 504 | if (m_Data.m_Parameters.m_ProjectionEnabled) |
| 505 | { |
| 506 | if (m_ProjectionBiasTensor) |
| 507 | { |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 508 | VectorBatchVectorAssign(*projectionBiasDecoder, outputSize, numBatches, *outputInt16Encoder); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 509 | } |
| 510 | |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 511 | MatrixBatchVectorMultiplyAccumulate(*projectionWeightsDecoder, outputSize, numUnits, *hiddenStateDecoder, |
| 512 | numBatches, *outputInt16Encoder); |
| 513 | |
| 514 | CopyVector(*outputInt16Decoder, numBatches * outputSize, *outputEncoder); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 515 | |
| 516 | if (m_Data.m_Parameters.m_ProjectionClip > 0.0) |
| 517 | { |
| 518 | ClipVector(*outputDecoder, numBatches * outputSize, m_Data.m_Parameters.m_ProjectionClip, *outputEncoder); |
| 519 | } |
| 520 | } |
| 521 | else |
| 522 | { |
| 523 | // Output has same quantization scale as hidden state if projection is disabled |
| 524 | CopyVector(*hiddenStateDecoder, numBatches * outputSize, *outputEncoder); |
| 525 | } |
| 526 | |
| 527 | // output == outputStateOut |
| 528 | CopyVector(*outputDecoder, numBatches * outputSize, *outputStateOutEncoder); |
| 529 | } |
| 530 | |
| 531 | } //namespace armnn |