James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 1 | // |
Matthew Sloyan | 2d213a7 | 2022-06-30 17:13:04 +0100 | [diff] [blame] | 2 | // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 3 | // SPDX-License-Identifier: MIT |
| 4 | // |
| 5 | |
| 6 | #include "RefQLstmWorkload.hpp" |
| 7 | #include "Activation.hpp" |
| 8 | #include "Encoders.hpp" |
| 9 | #include "Decoders.hpp" |
| 10 | #include "LstmUtils.hpp" |
| 11 | #include "RefWorkloadUtils.hpp" |
| 12 | |
| 13 | namespace armnn |
| 14 | { |
| 15 | |
| 16 | RefQLstmWorkload::RefQLstmWorkload(const QLstmQueueDescriptor &descriptor, const WorkloadInfo &info) |
Finn Williams | 73c547d | 2022-02-15 20:47:34 +0000 | [diff] [blame] | 17 | : RefBaseWorkload<QLstmQueueDescriptor>(descriptor, info) |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 18 | , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights)) |
| 19 | , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights)) |
| 20 | , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights)) |
| 21 | , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights)) |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 22 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 23 | , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights)) |
| 24 | , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights)) |
| 25 | , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights)) |
| 26 | , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights)) |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 27 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 28 | , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights)) |
| 29 | , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights)) |
| 30 | , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights)) |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 31 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 32 | , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias)) |
| 33 | , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias)) |
| 34 | , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias)) |
| 35 | , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias)) |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 36 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 37 | , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights)) |
| 38 | , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias)) |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 39 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 40 | , m_InputLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights)) |
| 41 | , m_ForgetLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights)) |
| 42 | , m_CellLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights)) |
| 43 | , m_OutputLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights)) |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 44 | {} |
| 45 | |
| 46 | void RefQLstmWorkload::Execute() const |
| 47 | { |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 48 | Execute(m_Data.m_Inputs, m_Data.m_Outputs); |
| 49 | } |
| 50 | |
Matthew Sloyan | 2d213a7 | 2022-06-30 17:13:04 +0100 | [diff] [blame] | 51 | void RefQLstmWorkload::ExecuteAsync(ExecutionData& executionData) |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 52 | { |
Matthew Sloyan | 2d213a7 | 2022-06-30 17:13:04 +0100 | [diff] [blame] | 53 | WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data); |
| 54 | Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs); |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 55 | } |
| 56 | |
| 57 | void RefQLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const |
| 58 | { |
| 59 | // This is a porting of the QLSTM::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) |
| 60 | // method in the Android code base |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 61 | // Note: this implementation wraps the arithmetic functions of the LSTM cell in Quantize/Dequantize ops, so all |
| 62 | // computation is done in the floating point domain. Arithmetic functions are found in LstmUtils.cpp. |
| 63 | // Refer to: android/frameworks/ml/nn/common/operations/QLSTM.cpp |
| 64 | const DataType& internalType = armnn::DataType::QSymmS16; |
| 65 | |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 66 | const TensorInfo& inputInfo = GetTensorInfo(inputs[0]); |
| 67 | const TensorInfo& outputStateInInfo = GetTensorInfo(inputs[1]); |
| 68 | const TensorInfo& cellStateInInfo = GetTensorInfo(inputs[2]); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 69 | |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 70 | const TensorInfo& outputStateOutInfo = GetTensorInfo(outputs[0]); |
| 71 | const TensorInfo& cellStateOutInfo = GetTensorInfo(outputs[1]); |
| 72 | const TensorInfo& outputInfo = GetTensorInfo(outputs[2]); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 73 | |
| 74 | const TensorShape& inputShape = inputInfo.GetShape(); |
| 75 | const TensorShape& outputStateInShape = outputStateInInfo.GetShape(); |
| 76 | const TensorShape& cellStateInShape = cellStateInInfo.GetShape(); |
| 77 | |
| 78 | // Infer numBatches, inputSize, outputSize and numUnits |
| 79 | const uint32_t numBatches = inputShape[0]; |
| 80 | const uint32_t inputSize = inputShape[1]; |
| 81 | const uint32_t outputSize = outputStateInShape[1]; |
| 82 | const uint32_t numUnits = cellStateInShape[1]; |
| 83 | |
| 84 | // Optional param settings |
| 85 | const bool cifgEnabled = m_Data.m_Parameters.m_CifgEnabled; |
| 86 | const bool peepholeEnabled = m_Data.m_Parameters.m_PeepholeEnabled; |
| 87 | const bool projectionEnabled = m_Data.m_Parameters.m_ProjectionEnabled; |
| 88 | const bool layerNormEnabled = m_Data.m_Parameters.m_LayerNormEnabled; |
| 89 | |
| 90 | // Input decoders |
| 91 | std::unique_ptr<Decoder<float>> inputDecoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 92 | MakeDecoder<float>(inputInfo, inputs[0]->Map()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 93 | std::unique_ptr<Decoder<float>> outputStateInDecoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 94 | MakeDecoder<float>(outputStateInInfo, inputs[1]->Map()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 95 | std::unique_ptr<Decoder<float>> cellStateInDecoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 96 | MakeDecoder<float>(cellStateInInfo, inputs[2]->Map()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 97 | |
| 98 | // Output decoders |
| 99 | std::unique_ptr<Decoder<float>> outputStateOutDecoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 100 | MakeDecoder<float>(outputStateOutInfo, outputs[0]->Map()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 101 | std::unique_ptr<Decoder<float>> cellStateOutDecoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 102 | MakeDecoder<float>(cellStateOutInfo, outputs[1]->Map()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 103 | std::unique_ptr<Decoder<float>> outputDecoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 104 | MakeDecoder<float>(outputInfo, outputs[2]->Map()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 105 | |
| 106 | // Output encoders |
| 107 | std::unique_ptr<Encoder<float>> outputStateOutEncoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 108 | MakeEncoder<float>(outputStateOutInfo, outputs[0]->Map()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 109 | std::unique_ptr<Encoder<float>> cellStateOutEncoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 110 | MakeEncoder<float>(cellStateOutInfo, outputs[1]->Map()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 111 | std::unique_ptr<Encoder<float>> outputEncoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 112 | MakeEncoder<float>(outputInfo, outputs[2]->Map()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 113 | |
| 114 | // Weights decoders |
| 115 | std::unique_ptr<Decoder<float>> inputToForgetWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 116 | m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 117 | std::unique_ptr<Decoder<float>> inputToCellWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 118 | m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 119 | std::unique_ptr<Decoder<float>> inputToOutputWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 120 | m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 121 | |
| 122 | std::unique_ptr<Decoder<float>> recurrentToForgetWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 123 | m_RecurrentToForgetWeightsTensor->GetTensorInfo(), |
| 124 | m_RecurrentToForgetWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 125 | std::unique_ptr<Decoder<float>> recurrentToCellWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 126 | m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 127 | std::unique_ptr<Decoder<float>> recurrentToOutputWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 128 | m_RecurrentToOutputWeightsTensor->GetTensorInfo(), |
| 129 | m_RecurrentToOutputWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 130 | |
| 131 | // Optional CIFG params |
| 132 | std::unique_ptr<Decoder<float>> inputToInputWeightsDecoder; |
| 133 | std::unique_ptr<Decoder<float>> recurrentToInputWeightsDecoder; |
| 134 | std::unique_ptr<Decoder<float>> inputGateBiasDecoder; |
| 135 | |
| 136 | // Optional Peephole params |
| 137 | std::unique_ptr<Decoder<float>> cellToInputWeightsDecoder; |
| 138 | std::unique_ptr<Decoder<float>> cellToForgetWeightsDecoder; |
| 139 | std::unique_ptr<Decoder<float>> cellToOutputWeightsDecoder; |
| 140 | |
| 141 | // Optional Projection params |
| 142 | std::unique_ptr<Decoder<float>> projectionWeightsDecoder; |
| 143 | std::unique_ptr<Decoder<float>> projectionBiasDecoder; |
| 144 | |
| 145 | // Optional Layer Norm params |
| 146 | std::unique_ptr<Decoder<float>> inputLayerNormWeightsDecoder; |
| 147 | std::unique_ptr<Decoder<float>> forgetLayerNormWeightsDecoder; |
| 148 | std::unique_ptr<Decoder<float>> cellLayerNormWeightsDecoder; |
| 149 | std::unique_ptr<Decoder<float>> outputLayerNormWeightsDecoder; |
| 150 | |
| 151 | // Biases are only used when Layer Norm is enabled. Scale is defined as (XLayerNormWeights Scale / 1024) |
| 152 | std::unique_ptr<Decoder<float>> forgetGateBiasDecoder; |
| 153 | std::unique_ptr<Decoder<float>> cellGateBiasDecoder; |
| 154 | std::unique_ptr<Decoder<float>> outputGateBiasDecoder; |
| 155 | |
| 156 | // Int16 vectors for internal state data (to be decoded/encoded) |
| 157 | const uint32_t stateTensorSize = numBatches * numUnits; |
| 158 | std::vector<int16_t> inputGateData(stateTensorSize); |
| 159 | std::vector<int16_t> cellGateData(stateTensorSize); |
| 160 | std::vector<int16_t> forgetGateData(stateTensorSize); |
| 161 | std::vector<int16_t> outputGateData(stateTensorSize); |
| 162 | std::vector<int32_t> hiddenStateData(stateTensorSize); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 163 | std::vector<int16_t> outputInt16Data(numBatches * outputSize); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 164 | |
| 165 | armnn::TensorInfo inputGateInfo( |
| 166 | {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_InputIntermediateScale, 0); |
| 167 | armnn::TensorInfo cellGateInfo( |
| 168 | {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_CellIntermediateScale, 0); |
| 169 | armnn::TensorInfo forgetGateInfo( |
| 170 | {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_ForgetIntermediateScale, 0); |
| 171 | armnn::TensorInfo outputGateInfo( |
| 172 | {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_OutputIntermediateScale, 0); |
| 173 | armnn::TensorInfo hiddenStateInfo({numBatches, numUnits}, |
| 174 | armnn::DataType::QAsymmS8, |
| 175 | m_Data.m_Parameters.m_HiddenStateScale, |
| 176 | m_Data.m_Parameters.m_HiddenStateZeroPoint); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 177 | armnn::TensorInfo outputInt16Info({numBatches , outputSize}, |
| 178 | armnn::DataType::QSymmS16, |
| 179 | outputInfo.GetQuantizationScale(), |
| 180 | outputInfo.GetQuantizationOffset()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 181 | |
| 182 | // Decoders/Encoders for internal states |
| 183 | std::unique_ptr<Decoder<float>> inputGateDecoder = |
| 184 | MakeDecoder<float>(inputGateInfo, inputGateData.data()); |
| 185 | std::unique_ptr<Decoder<float>> cellGateDecoder = |
| 186 | MakeDecoder<float>(cellGateInfo, cellGateData.data()); |
| 187 | std::unique_ptr<Decoder<float>> forgetGateDecoder = |
| 188 | MakeDecoder<float>(forgetGateInfo, forgetGateData.data()); |
| 189 | std::unique_ptr<Decoder<float>> outputGateDecoder = |
| 190 | MakeDecoder<float>(outputGateInfo, outputGateData.data()); |
| 191 | std::unique_ptr<Decoder<float>> hiddenStateDecoder = |
| 192 | MakeDecoder<float>(hiddenStateInfo, hiddenStateData.data()); |
| 193 | |
| 194 | std::unique_ptr<Encoder<float>> inputGateEncoder = |
| 195 | MakeEncoder<float>(inputGateInfo, inputGateData.data()); |
| 196 | std::unique_ptr<Encoder<float>> cellGateEncoder = |
| 197 | MakeEncoder<float>(cellGateInfo, cellGateData.data()); |
| 198 | std::unique_ptr<Encoder<float>> forgetGateEncoder = |
| 199 | MakeEncoder<float>(forgetGateInfo, forgetGateData.data()); |
| 200 | std::unique_ptr<Encoder<float>> outputGateEncoder = |
| 201 | MakeEncoder<float>(outputGateInfo, outputGateData.data()); |
| 202 | std::unique_ptr<Encoder<float>> hiddenStateEncoder = |
| 203 | MakeEncoder<float>(hiddenStateInfo, hiddenStateData.data()); |
| 204 | |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 205 | // Int16 used to accumulate output to prevent overflowing (after Projection MatMul) |
| 206 | std::unique_ptr<Decoder<float>> outputInt16Decoder = |
| 207 | MakeDecoder<float>(outputInt16Info, outputInt16Data.data()); |
| 208 | std::unique_ptr<Encoder<float>> outputInt16Encoder = |
| 209 | MakeEncoder<float>(outputInt16Info, outputInt16Data.data()); |
| 210 | |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 211 | // Create decoders for optional params if they are enabled |
| 212 | if (!cifgEnabled) |
| 213 | { |
| 214 | inputToInputWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 215 | m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>()); |
| 216 | recurrentToInputWeightsDecoder = MakeDecoder<float>(m_RecurrentToInputWeightsTensor->GetTensorInfo(), |
| 217 | m_RecurrentToInputWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 218 | } |
| 219 | |
| 220 | if (peepholeEnabled) |
| 221 | { |
| 222 | if (!cifgEnabled) |
| 223 | { |
| 224 | cellToInputWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 225 | m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 226 | } |
| 227 | cellToForgetWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 228 | m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 229 | cellToOutputWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 230 | m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 231 | } |
| 232 | |
| 233 | if (projectionEnabled) |
| 234 | { |
| 235 | projectionWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 236 | m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 237 | if (m_ProjectionBiasTensor) |
| 238 | { |
| 239 | projectionBiasDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 240 | m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 241 | } |
| 242 | } |
| 243 | |
| 244 | if (layerNormEnabled) |
| 245 | { |
| 246 | if (!cifgEnabled) |
| 247 | { |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 248 | inputLayerNormWeightsDecoder = MakeDecoder<float>(m_InputLayerNormWeightsTensor->GetTensorInfo(), |
| 249 | m_InputLayerNormWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 250 | |
| 251 | // Bias only used if layer norm enabled |
| 252 | armnn::TensorInfo inputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32, |
| 253 | m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0); |
| 254 | inputGateBiasDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 255 | inputGateBiasTensorInfo, m_InputGateBiasTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 256 | } |
| 257 | |
| 258 | forgetLayerNormWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 259 | m_ForgetLayerNormWeightsTensor->GetTensorInfo(), |
| 260 | m_ForgetLayerNormWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 261 | cellLayerNormWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 262 | m_CellLayerNormWeightsTensor->GetTensorInfo(), m_CellLayerNormWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 263 | outputLayerNormWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 264 | m_OutputLayerNormWeightsTensor->GetTensorInfo(), |
| 265 | m_OutputLayerNormWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 266 | |
| 267 | // Bias only used if layer norm enabled |
| 268 | armnn::TensorInfo forgetGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32, |
| 269 | m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0); |
| 270 | forgetGateBiasDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 271 | forgetGateBiasTensorInfo, m_ForgetGateBiasTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 272 | |
| 273 | armnn::TensorInfo cellGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32, |
| 274 | m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0); |
| 275 | cellGateBiasDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 276 | cellGateBiasTensorInfo, m_CellBiasTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 277 | |
| 278 | armnn::TensorInfo outputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32, |
| 279 | m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0); |
| 280 | outputGateBiasDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 281 | outputGateBiasTensorInfo, m_OutputGateBiasTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 282 | } |
| 283 | |
| 284 | // Initialize internal state tensors with zeroes. |
| 285 | if (!cifgEnabled) |
| 286 | { |
| 287 | ZeroVector(*inputGateEncoder, stateTensorSize); |
| 288 | } |
| 289 | ZeroVector(*forgetGateEncoder, stateTensorSize); |
| 290 | ZeroVector(*cellGateEncoder, stateTensorSize); |
| 291 | ZeroVector(*outputGateEncoder, stateTensorSize); |
| 292 | ZeroVector(*hiddenStateEncoder, stateTensorSize); |
| 293 | |
| 294 | // Input weights * Input |
| 295 | if (!cifgEnabled) |
| 296 | { |
| 297 | MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsDecoder, |
| 298 | numUnits, inputSize, *inputDecoder, numBatches, *inputGateEncoder); |
| 299 | } |
| 300 | |
| 301 | MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsDecoder, |
| 302 | numUnits, inputSize, *inputDecoder, numBatches, *forgetGateEncoder); |
| 303 | |
| 304 | MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsDecoder, |
| 305 | numUnits, inputSize, *inputDecoder, numBatches, *cellGateEncoder); |
| 306 | |
| 307 | MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsDecoder, |
| 308 | numUnits, inputSize, *inputDecoder, numBatches, *outputGateEncoder); |
| 309 | |
| 310 | // Recurrent weights * OutputStateIn |
| 311 | if (!cifgEnabled) |
| 312 | { |
| 313 | MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsDecoder, |
| 314 | numUnits, outputSize, *outputStateInDecoder, numBatches, *inputGateEncoder); |
| 315 | } |
| 316 | |
| 317 | MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsDecoder, |
| 318 | numUnits, outputSize, *outputStateInDecoder, numBatches, *forgetGateEncoder); |
| 319 | |
| 320 | MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsDecoder, |
| 321 | numUnits, outputSize, *outputStateInDecoder, numBatches, *cellGateEncoder); |
| 322 | |
| 323 | MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsDecoder, |
| 324 | numUnits, outputSize, *outputStateInDecoder, numBatches, *outputGateEncoder); |
| 325 | |
| 326 | // Input gate. |
| 327 | if (!cifgEnabled) |
| 328 | { |
| 329 | if (peepholeEnabled) |
| 330 | { |
| 331 | VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsDecoder, |
| 332 | numUnits, *cellStateInDecoder, numBatches, *inputGateEncoder); |
| 333 | } |
| 334 | |
| 335 | if (layerNormEnabled) |
| 336 | { |
| 337 | inputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() * |
| 338 | m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() * |
| 339 | 1024); |
| 340 | inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data()); |
| 341 | |
| 342 | MeanStddevNormalization(*inputGateDecoder, |
| 343 | *inputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon); |
| 344 | |
| 345 | inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data()); |
| 346 | |
| 347 | VectorBatchVectorCwiseProduct(*inputLayerNormWeightsDecoder, |
| 348 | numUnits, *inputGateDecoder, numBatches, *inputGateEncoder); |
| 349 | |
| 350 | inputGateInfo.SetQuantizationScale(1.f / 4096); |
| 351 | inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data()); |
| 352 | |
| 353 | VectorBatchVectorAdd(*inputGateBiasDecoder, |
| 354 | numUnits, *inputGateDecoder, numBatches, *inputGateEncoder); |
| 355 | |
| 356 | inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data()); |
| 357 | } |
| 358 | |
| 359 | inputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale()); |
| 360 | inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data()); |
| 361 | |
| 362 | // Input gate sigmoid |
| 363 | Activation(*inputGateDecoder, *inputGateEncoder, |
| 364 | TensorInfo({numUnits, numBatches}, internalType), |
| 365 | ActivationFunction::Sigmoid, 0, 0); |
| 366 | |
| 367 | inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data()); |
| 368 | } |
| 369 | |
| 370 | // Forget gate |
| 371 | if (peepholeEnabled) |
| 372 | { |
| 373 | VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsDecoder, numUnits, |
| 374 | *cellStateInDecoder, numBatches, *forgetGateEncoder); |
| 375 | } |
| 376 | |
| 377 | if (layerNormEnabled) |
| 378 | { |
| 379 | // Quantize layer norm output to Input Scale * m_ForgetLayerNormWeightsTensor * 1024 |
| 380 | forgetGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() * |
| 381 | m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() * |
| 382 | 1024); |
| 383 | forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data()); |
| 384 | |
| 385 | |
| 386 | |
| 387 | MeanStddevNormalization(*forgetGateDecoder, |
| 388 | *forgetGateEncoder, numUnits, numBatches, m_LayerNormEpsilon); |
| 389 | |
| 390 | |
| 391 | forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data()); |
| 392 | |
| 393 | VectorBatchVectorCwiseProduct(*forgetLayerNormWeightsDecoder, |
| 394 | numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder); |
| 395 | |
| 396 | |
| 397 | // Dequantize layer norm output to (1 / 4096) |
| 398 | forgetGateInfo.SetQuantizationScale(1.f / 4096); |
| 399 | forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data()); |
| 400 | |
| 401 | VectorBatchVectorAdd(*forgetGateBiasDecoder, |
| 402 | numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder); |
| 403 | |
| 404 | |
| 405 | forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data()); |
| 406 | } |
| 407 | |
| 408 | forgetGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale()); |
| 409 | forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data()); |
| 410 | |
| 411 | // Forget gate sigmoid |
| 412 | Activation(*forgetGateDecoder, *forgetGateEncoder, |
| 413 | TensorInfo({numUnits, numBatches}, internalType), |
| 414 | ActivationFunction::Sigmoid, 0, 0); |
| 415 | |
| 416 | forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data()); |
| 417 | |
| 418 | // Cell (Modulation) gate |
| 419 | if (layerNormEnabled) |
| 420 | { |
| 421 | cellGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() * |
| 422 | m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() * |
| 423 | 1024); |
| 424 | cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data()); |
| 425 | |
| 426 | MeanStddevNormalization(*cellGateDecoder, *cellGateEncoder, numUnits, numBatches, m_LayerNormEpsilon); |
| 427 | |
| 428 | cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data()); |
| 429 | |
| 430 | VectorBatchVectorCwiseProduct(*cellLayerNormWeightsDecoder, |
| 431 | numUnits, *cellGateDecoder, numBatches, *cellGateEncoder); |
| 432 | |
| 433 | cellGateInfo.SetQuantizationScale(1.f / 4096); |
| 434 | cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data()); |
| 435 | |
| 436 | VectorBatchVectorAdd(*cellGateBiasDecoder, |
| 437 | numUnits, *cellGateDecoder, numBatches, *cellGateEncoder); |
| 438 | |
| 439 | cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data()); |
| 440 | } |
| 441 | |
| 442 | cellGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale()); |
| 443 | cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data()); |
| 444 | |
| 445 | // Cell (Modulation) gate tanH |
| 446 | Activation(*cellGateDecoder, *cellGateEncoder, |
| 447 | TensorInfo({numUnits, numBatches}, internalType), |
| 448 | ActivationFunction::TanH, 1.0f, 1.0f); |
| 449 | |
| 450 | cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data()); |
| 451 | |
| 452 | VectorVectorCwiseProduct(*forgetGateDecoder, *cellStateInDecoder, stateTensorSize, *cellStateOutEncoder); |
| 453 | |
| 454 | if (cifgEnabled) |
| 455 | { |
| 456 | Sub1Vector(*forgetGateDecoder, stateTensorSize, *forgetGateEncoder); |
| 457 | VectorVectorCwiseProductAccumulate( |
| 458 | *cellGateDecoder, *forgetGateDecoder, stateTensorSize, *cellStateOutEncoder); |
| 459 | } |
| 460 | else |
| 461 | { |
| 462 | VectorVectorCwiseProductAccumulate( |
| 463 | *cellGateDecoder, *inputGateDecoder, stateTensorSize, *cellStateOutEncoder); |
| 464 | } |
| 465 | |
| 466 | // Final cell state out calculated here |
| 467 | if (m_Data.m_Parameters.m_CellClip > 0.0) |
| 468 | { |
| 469 | ClipVector(*cellStateOutDecoder, stateTensorSize, m_Data.m_Parameters.m_CellClip, *cellStateOutEncoder); |
| 470 | } |
| 471 | |
| 472 | // Output gate. |
| 473 | if (peepholeEnabled) |
| 474 | { |
| 475 | VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsDecoder, |
| 476 | numUnits, *cellStateOutDecoder, numBatches, *outputGateEncoder); |
| 477 | } |
| 478 | |
| 479 | if (layerNormEnabled) |
| 480 | { |
| 481 | outputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() * |
| 482 | m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() * |
| 483 | 1024); |
| 484 | outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data()); |
| 485 | |
| 486 | MeanStddevNormalization(*outputGateDecoder, *outputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon); |
| 487 | |
| 488 | outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data()); |
| 489 | |
| 490 | VectorBatchVectorCwiseProduct(*outputLayerNormWeightsDecoder, numUnits, *outputGateDecoder, |
| 491 | numBatches, *outputGateEncoder); |
| 492 | |
| 493 | outputGateInfo.SetQuantizationScale(1.f / 4096); |
| 494 | outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data()); |
| 495 | |
| 496 | VectorBatchVectorAdd(*outputGateBiasDecoder, numUnits, *outputGateDecoder, numBatches, *outputGateEncoder); |
| 497 | |
| 498 | outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data()); |
| 499 | } |
| 500 | |
| 501 | outputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale()); |
| 502 | outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data()); |
| 503 | |
| 504 | // Output gate sigmoid |
| 505 | Activation(*outputGateDecoder, *outputGateEncoder, |
| 506 | TensorInfo({numUnits, numBatches}, internalType), |
| 507 | ActivationFunction::Sigmoid, 0, 0); |
| 508 | |
| 509 | outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data()); |
| 510 | |
| 511 | // Hidden state tanH |
| 512 | Activation(*cellStateOutDecoder, *cellGateEncoder, |
| 513 | TensorInfo({numUnits, numBatches}, internalType), |
| 514 | ActivationFunction::TanH, 1.0f, 1.0f); |
| 515 | |
| 516 | // Final hidden state output |
| 517 | VectorVectorCwiseProduct(*outputGateDecoder, *cellGateDecoder, stateTensorSize, *hiddenStateEncoder); |
| 518 | |
| 519 | // Projection |
| 520 | if (m_Data.m_Parameters.m_ProjectionEnabled) |
| 521 | { |
| 522 | if (m_ProjectionBiasTensor) |
| 523 | { |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 524 | VectorBatchVectorAssign(*projectionBiasDecoder, outputSize, numBatches, *outputInt16Encoder); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 525 | } |
| 526 | |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 527 | MatrixBatchVectorMultiplyAccumulate(*projectionWeightsDecoder, outputSize, numUnits, *hiddenStateDecoder, |
| 528 | numBatches, *outputInt16Encoder); |
| 529 | |
| 530 | CopyVector(*outputInt16Decoder, numBatches * outputSize, *outputEncoder); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 531 | |
| 532 | if (m_Data.m_Parameters.m_ProjectionClip > 0.0) |
| 533 | { |
| 534 | ClipVector(*outputDecoder, numBatches * outputSize, m_Data.m_Parameters.m_ProjectionClip, *outputEncoder); |
| 535 | } |
| 536 | } |
| 537 | else |
| 538 | { |
| 539 | // Output has same quantization scale as hidden state if projection is disabled |
| 540 | CopyVector(*hiddenStateDecoder, numBatches * outputSize, *outputEncoder); |
| 541 | } |
| 542 | |
| 543 | // output == outputStateOut |
| 544 | CopyVector(*outputDecoder, numBatches * outputSize, *outputStateOutEncoder); |
| 545 | } |
| 546 | |
| 547 | } //namespace armnn |