James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 1 | // |
Mike Kelly | 7cbe781 | 2023-07-25 17:37:33 +0100 | [diff] [blame^] | 2 | // Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved. |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 3 | // SPDX-License-Identifier: MIT |
| 4 | // |
| 5 | |
| 6 | #include "RefQLstmWorkload.hpp" |
| 7 | #include "Activation.hpp" |
| 8 | #include "Encoders.hpp" |
| 9 | #include "Decoders.hpp" |
| 10 | #include "LstmUtils.hpp" |
| 11 | #include "RefWorkloadUtils.hpp" |
| 12 | |
| 13 | namespace armnn |
| 14 | { |
| 15 | |
| 16 | RefQLstmWorkload::RefQLstmWorkload(const QLstmQueueDescriptor &descriptor, const WorkloadInfo &info) |
Finn Williams | 73c547d | 2022-02-15 20:47:34 +0000 | [diff] [blame] | 17 | : RefBaseWorkload<QLstmQueueDescriptor>(descriptor, info) |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 18 | , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights)) |
| 19 | , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights)) |
| 20 | , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights)) |
| 21 | , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights)) |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 22 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 23 | , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights)) |
| 24 | , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights)) |
| 25 | , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights)) |
| 26 | , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights)) |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 27 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 28 | , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights)) |
| 29 | , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights)) |
| 30 | , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights)) |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 31 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 32 | , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias)) |
| 33 | , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias)) |
| 34 | , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias)) |
| 35 | , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias)) |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 36 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 37 | , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights)) |
| 38 | , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias)) |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 39 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 40 | , m_InputLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights)) |
| 41 | , m_ForgetLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights)) |
| 42 | , m_CellLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights)) |
| 43 | , m_OutputLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights)) |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 44 | {} |
| 45 | |
| 46 | void RefQLstmWorkload::Execute() const |
| 47 | { |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 48 | Execute(m_Data.m_Inputs, m_Data.m_Outputs); |
| 49 | } |
| 50 | |
Matthew Sloyan | 2d213a7 | 2022-06-30 17:13:04 +0100 | [diff] [blame] | 51 | void RefQLstmWorkload::ExecuteAsync(ExecutionData& executionData) |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 52 | { |
Matthew Sloyan | 2d213a7 | 2022-06-30 17:13:04 +0100 | [diff] [blame] | 53 | WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data); |
| 54 | Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs); |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 55 | } |
| 56 | |
| 57 | void RefQLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const |
| 58 | { |
Mike Kelly | 7cbe781 | 2023-07-25 17:37:33 +0100 | [diff] [blame^] | 59 | ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID("RefQLstmWorkload_Execute"); |
| 60 | |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 61 | // This is a porting of the QLSTM::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) |
| 62 | // method in the Android code base |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 63 | // Note: this implementation wraps the arithmetic functions of the LSTM cell in Quantize/Dequantize ops, so all |
| 64 | // computation is done in the floating point domain. Arithmetic functions are found in LstmUtils.cpp. |
| 65 | // Refer to: android/frameworks/ml/nn/common/operations/QLSTM.cpp |
| 66 | const DataType& internalType = armnn::DataType::QSymmS16; |
| 67 | |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 68 | const TensorInfo& inputInfo = GetTensorInfo(inputs[0]); |
| 69 | const TensorInfo& outputStateInInfo = GetTensorInfo(inputs[1]); |
| 70 | const TensorInfo& cellStateInInfo = GetTensorInfo(inputs[2]); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 71 | |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 72 | const TensorInfo& outputStateOutInfo = GetTensorInfo(outputs[0]); |
| 73 | const TensorInfo& cellStateOutInfo = GetTensorInfo(outputs[1]); |
| 74 | const TensorInfo& outputInfo = GetTensorInfo(outputs[2]); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 75 | |
| 76 | const TensorShape& inputShape = inputInfo.GetShape(); |
| 77 | const TensorShape& outputStateInShape = outputStateInInfo.GetShape(); |
| 78 | const TensorShape& cellStateInShape = cellStateInInfo.GetShape(); |
| 79 | |
| 80 | // Infer numBatches, inputSize, outputSize and numUnits |
| 81 | const uint32_t numBatches = inputShape[0]; |
| 82 | const uint32_t inputSize = inputShape[1]; |
| 83 | const uint32_t outputSize = outputStateInShape[1]; |
| 84 | const uint32_t numUnits = cellStateInShape[1]; |
| 85 | |
| 86 | // Optional param settings |
| 87 | const bool cifgEnabled = m_Data.m_Parameters.m_CifgEnabled; |
| 88 | const bool peepholeEnabled = m_Data.m_Parameters.m_PeepholeEnabled; |
| 89 | const bool projectionEnabled = m_Data.m_Parameters.m_ProjectionEnabled; |
| 90 | const bool layerNormEnabled = m_Data.m_Parameters.m_LayerNormEnabled; |
| 91 | |
| 92 | // Input decoders |
| 93 | std::unique_ptr<Decoder<float>> inputDecoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 94 | MakeDecoder<float>(inputInfo, inputs[0]->Map()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 95 | std::unique_ptr<Decoder<float>> outputStateInDecoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 96 | MakeDecoder<float>(outputStateInInfo, inputs[1]->Map()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 97 | std::unique_ptr<Decoder<float>> cellStateInDecoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 98 | MakeDecoder<float>(cellStateInInfo, inputs[2]->Map()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 99 | |
| 100 | // Output decoders |
| 101 | std::unique_ptr<Decoder<float>> outputStateOutDecoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 102 | MakeDecoder<float>(outputStateOutInfo, outputs[0]->Map()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 103 | std::unique_ptr<Decoder<float>> cellStateOutDecoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 104 | MakeDecoder<float>(cellStateOutInfo, outputs[1]->Map()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 105 | std::unique_ptr<Decoder<float>> outputDecoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 106 | MakeDecoder<float>(outputInfo, outputs[2]->Map()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 107 | |
| 108 | // Output encoders |
| 109 | std::unique_ptr<Encoder<float>> outputStateOutEncoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 110 | MakeEncoder<float>(outputStateOutInfo, outputs[0]->Map()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 111 | std::unique_ptr<Encoder<float>> cellStateOutEncoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 112 | MakeEncoder<float>(cellStateOutInfo, outputs[1]->Map()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 113 | std::unique_ptr<Encoder<float>> outputEncoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 114 | MakeEncoder<float>(outputInfo, outputs[2]->Map()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 115 | |
| 116 | // Weights decoders |
| 117 | std::unique_ptr<Decoder<float>> inputToForgetWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 118 | m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 119 | std::unique_ptr<Decoder<float>> inputToCellWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 120 | m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 121 | std::unique_ptr<Decoder<float>> inputToOutputWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 122 | m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 123 | |
| 124 | std::unique_ptr<Decoder<float>> recurrentToForgetWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 125 | m_RecurrentToForgetWeightsTensor->GetTensorInfo(), |
| 126 | m_RecurrentToForgetWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 127 | std::unique_ptr<Decoder<float>> recurrentToCellWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 128 | m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 129 | std::unique_ptr<Decoder<float>> recurrentToOutputWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 130 | m_RecurrentToOutputWeightsTensor->GetTensorInfo(), |
| 131 | m_RecurrentToOutputWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 132 | |
| 133 | // Optional CIFG params |
| 134 | std::unique_ptr<Decoder<float>> inputToInputWeightsDecoder; |
| 135 | std::unique_ptr<Decoder<float>> recurrentToInputWeightsDecoder; |
| 136 | std::unique_ptr<Decoder<float>> inputGateBiasDecoder; |
| 137 | |
| 138 | // Optional Peephole params |
| 139 | std::unique_ptr<Decoder<float>> cellToInputWeightsDecoder; |
| 140 | std::unique_ptr<Decoder<float>> cellToForgetWeightsDecoder; |
| 141 | std::unique_ptr<Decoder<float>> cellToOutputWeightsDecoder; |
| 142 | |
| 143 | // Optional Projection params |
| 144 | std::unique_ptr<Decoder<float>> projectionWeightsDecoder; |
| 145 | std::unique_ptr<Decoder<float>> projectionBiasDecoder; |
| 146 | |
| 147 | // Optional Layer Norm params |
| 148 | std::unique_ptr<Decoder<float>> inputLayerNormWeightsDecoder; |
| 149 | std::unique_ptr<Decoder<float>> forgetLayerNormWeightsDecoder; |
| 150 | std::unique_ptr<Decoder<float>> cellLayerNormWeightsDecoder; |
| 151 | std::unique_ptr<Decoder<float>> outputLayerNormWeightsDecoder; |
| 152 | |
| 153 | // Biases are only used when Layer Norm is enabled. Scale is defined as (XLayerNormWeights Scale / 1024) |
| 154 | std::unique_ptr<Decoder<float>> forgetGateBiasDecoder; |
| 155 | std::unique_ptr<Decoder<float>> cellGateBiasDecoder; |
| 156 | std::unique_ptr<Decoder<float>> outputGateBiasDecoder; |
| 157 | |
| 158 | // Int16 vectors for internal state data (to be decoded/encoded) |
| 159 | const uint32_t stateTensorSize = numBatches * numUnits; |
| 160 | std::vector<int16_t> inputGateData(stateTensorSize); |
| 161 | std::vector<int16_t> cellGateData(stateTensorSize); |
| 162 | std::vector<int16_t> forgetGateData(stateTensorSize); |
| 163 | std::vector<int16_t> outputGateData(stateTensorSize); |
| 164 | std::vector<int32_t> hiddenStateData(stateTensorSize); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 165 | std::vector<int16_t> outputInt16Data(numBatches * outputSize); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 166 | |
| 167 | armnn::TensorInfo inputGateInfo( |
| 168 | {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_InputIntermediateScale, 0); |
| 169 | armnn::TensorInfo cellGateInfo( |
| 170 | {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_CellIntermediateScale, 0); |
| 171 | armnn::TensorInfo forgetGateInfo( |
| 172 | {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_ForgetIntermediateScale, 0); |
| 173 | armnn::TensorInfo outputGateInfo( |
| 174 | {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_OutputIntermediateScale, 0); |
| 175 | armnn::TensorInfo hiddenStateInfo({numBatches, numUnits}, |
| 176 | armnn::DataType::QAsymmS8, |
| 177 | m_Data.m_Parameters.m_HiddenStateScale, |
| 178 | m_Data.m_Parameters.m_HiddenStateZeroPoint); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 179 | armnn::TensorInfo outputInt16Info({numBatches , outputSize}, |
| 180 | armnn::DataType::QSymmS16, |
| 181 | outputInfo.GetQuantizationScale(), |
| 182 | outputInfo.GetQuantizationOffset()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 183 | |
| 184 | // Decoders/Encoders for internal states |
| 185 | std::unique_ptr<Decoder<float>> inputGateDecoder = |
| 186 | MakeDecoder<float>(inputGateInfo, inputGateData.data()); |
| 187 | std::unique_ptr<Decoder<float>> cellGateDecoder = |
| 188 | MakeDecoder<float>(cellGateInfo, cellGateData.data()); |
| 189 | std::unique_ptr<Decoder<float>> forgetGateDecoder = |
| 190 | MakeDecoder<float>(forgetGateInfo, forgetGateData.data()); |
| 191 | std::unique_ptr<Decoder<float>> outputGateDecoder = |
| 192 | MakeDecoder<float>(outputGateInfo, outputGateData.data()); |
| 193 | std::unique_ptr<Decoder<float>> hiddenStateDecoder = |
| 194 | MakeDecoder<float>(hiddenStateInfo, hiddenStateData.data()); |
| 195 | |
| 196 | std::unique_ptr<Encoder<float>> inputGateEncoder = |
| 197 | MakeEncoder<float>(inputGateInfo, inputGateData.data()); |
| 198 | std::unique_ptr<Encoder<float>> cellGateEncoder = |
| 199 | MakeEncoder<float>(cellGateInfo, cellGateData.data()); |
| 200 | std::unique_ptr<Encoder<float>> forgetGateEncoder = |
| 201 | MakeEncoder<float>(forgetGateInfo, forgetGateData.data()); |
| 202 | std::unique_ptr<Encoder<float>> outputGateEncoder = |
| 203 | MakeEncoder<float>(outputGateInfo, outputGateData.data()); |
| 204 | std::unique_ptr<Encoder<float>> hiddenStateEncoder = |
| 205 | MakeEncoder<float>(hiddenStateInfo, hiddenStateData.data()); |
| 206 | |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 207 | // Int16 used to accumulate output to prevent overflowing (after Projection MatMul) |
| 208 | std::unique_ptr<Decoder<float>> outputInt16Decoder = |
| 209 | MakeDecoder<float>(outputInt16Info, outputInt16Data.data()); |
| 210 | std::unique_ptr<Encoder<float>> outputInt16Encoder = |
| 211 | MakeEncoder<float>(outputInt16Info, outputInt16Data.data()); |
| 212 | |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 213 | // Create decoders for optional params if they are enabled |
| 214 | if (!cifgEnabled) |
| 215 | { |
| 216 | inputToInputWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 217 | m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>()); |
| 218 | recurrentToInputWeightsDecoder = MakeDecoder<float>(m_RecurrentToInputWeightsTensor->GetTensorInfo(), |
| 219 | m_RecurrentToInputWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 220 | } |
| 221 | |
| 222 | if (peepholeEnabled) |
| 223 | { |
| 224 | if (!cifgEnabled) |
| 225 | { |
| 226 | cellToInputWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 227 | m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 228 | } |
| 229 | cellToForgetWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 230 | m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 231 | cellToOutputWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 232 | m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 233 | } |
| 234 | |
| 235 | if (projectionEnabled) |
| 236 | { |
| 237 | projectionWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 238 | m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 239 | if (m_ProjectionBiasTensor) |
| 240 | { |
| 241 | projectionBiasDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 242 | m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 243 | } |
| 244 | } |
| 245 | |
| 246 | if (layerNormEnabled) |
| 247 | { |
| 248 | if (!cifgEnabled) |
| 249 | { |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 250 | inputLayerNormWeightsDecoder = MakeDecoder<float>(m_InputLayerNormWeightsTensor->GetTensorInfo(), |
| 251 | m_InputLayerNormWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 252 | |
| 253 | // Bias only used if layer norm enabled |
| 254 | armnn::TensorInfo inputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32, |
| 255 | m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0); |
| 256 | inputGateBiasDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 257 | inputGateBiasTensorInfo, m_InputGateBiasTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 258 | } |
| 259 | |
| 260 | forgetLayerNormWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 261 | m_ForgetLayerNormWeightsTensor->GetTensorInfo(), |
| 262 | m_ForgetLayerNormWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 263 | cellLayerNormWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 264 | m_CellLayerNormWeightsTensor->GetTensorInfo(), m_CellLayerNormWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 265 | outputLayerNormWeightsDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 266 | m_OutputLayerNormWeightsTensor->GetTensorInfo(), |
| 267 | m_OutputLayerNormWeightsTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 268 | |
| 269 | // Bias only used if layer norm enabled |
| 270 | armnn::TensorInfo forgetGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32, |
| 271 | m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0); |
| 272 | forgetGateBiasDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 273 | forgetGateBiasTensorInfo, m_ForgetGateBiasTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 274 | |
| 275 | armnn::TensorInfo cellGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32, |
| 276 | m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0); |
| 277 | cellGateBiasDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 278 | cellGateBiasTensorInfo, m_CellBiasTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 279 | |
| 280 | armnn::TensorInfo outputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32, |
| 281 | m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0); |
| 282 | outputGateBiasDecoder = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 283 | outputGateBiasTensorInfo, m_OutputGateBiasTensor->GetConstTensor<void>()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 284 | } |
| 285 | |
| 286 | // Initialize internal state tensors with zeroes. |
| 287 | if (!cifgEnabled) |
| 288 | { |
| 289 | ZeroVector(*inputGateEncoder, stateTensorSize); |
| 290 | } |
| 291 | ZeroVector(*forgetGateEncoder, stateTensorSize); |
| 292 | ZeroVector(*cellGateEncoder, stateTensorSize); |
| 293 | ZeroVector(*outputGateEncoder, stateTensorSize); |
| 294 | ZeroVector(*hiddenStateEncoder, stateTensorSize); |
| 295 | |
| 296 | // Input weights * Input |
| 297 | if (!cifgEnabled) |
| 298 | { |
| 299 | MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsDecoder, |
| 300 | numUnits, inputSize, *inputDecoder, numBatches, *inputGateEncoder); |
| 301 | } |
| 302 | |
| 303 | MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsDecoder, |
| 304 | numUnits, inputSize, *inputDecoder, numBatches, *forgetGateEncoder); |
| 305 | |
| 306 | MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsDecoder, |
| 307 | numUnits, inputSize, *inputDecoder, numBatches, *cellGateEncoder); |
| 308 | |
| 309 | MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsDecoder, |
| 310 | numUnits, inputSize, *inputDecoder, numBatches, *outputGateEncoder); |
| 311 | |
| 312 | // Recurrent weights * OutputStateIn |
| 313 | if (!cifgEnabled) |
| 314 | { |
| 315 | MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsDecoder, |
| 316 | numUnits, outputSize, *outputStateInDecoder, numBatches, *inputGateEncoder); |
| 317 | } |
| 318 | |
| 319 | MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsDecoder, |
| 320 | numUnits, outputSize, *outputStateInDecoder, numBatches, *forgetGateEncoder); |
| 321 | |
| 322 | MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsDecoder, |
| 323 | numUnits, outputSize, *outputStateInDecoder, numBatches, *cellGateEncoder); |
| 324 | |
| 325 | MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsDecoder, |
| 326 | numUnits, outputSize, *outputStateInDecoder, numBatches, *outputGateEncoder); |
| 327 | |
| 328 | // Input gate. |
| 329 | if (!cifgEnabled) |
| 330 | { |
| 331 | if (peepholeEnabled) |
| 332 | { |
| 333 | VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsDecoder, |
| 334 | numUnits, *cellStateInDecoder, numBatches, *inputGateEncoder); |
| 335 | } |
| 336 | |
| 337 | if (layerNormEnabled) |
| 338 | { |
| 339 | inputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() * |
| 340 | m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() * |
| 341 | 1024); |
| 342 | inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data()); |
| 343 | |
| 344 | MeanStddevNormalization(*inputGateDecoder, |
| 345 | *inputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon); |
| 346 | |
| 347 | inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data()); |
| 348 | |
| 349 | VectorBatchVectorCwiseProduct(*inputLayerNormWeightsDecoder, |
| 350 | numUnits, *inputGateDecoder, numBatches, *inputGateEncoder); |
| 351 | |
| 352 | inputGateInfo.SetQuantizationScale(1.f / 4096); |
| 353 | inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data()); |
| 354 | |
| 355 | VectorBatchVectorAdd(*inputGateBiasDecoder, |
| 356 | numUnits, *inputGateDecoder, numBatches, *inputGateEncoder); |
| 357 | |
| 358 | inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data()); |
| 359 | } |
| 360 | |
| 361 | inputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale()); |
| 362 | inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data()); |
| 363 | |
| 364 | // Input gate sigmoid |
| 365 | Activation(*inputGateDecoder, *inputGateEncoder, |
| 366 | TensorInfo({numUnits, numBatches}, internalType), |
| 367 | ActivationFunction::Sigmoid, 0, 0); |
| 368 | |
| 369 | inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data()); |
| 370 | } |
| 371 | |
| 372 | // Forget gate |
| 373 | if (peepholeEnabled) |
| 374 | { |
| 375 | VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsDecoder, numUnits, |
| 376 | *cellStateInDecoder, numBatches, *forgetGateEncoder); |
| 377 | } |
| 378 | |
| 379 | if (layerNormEnabled) |
| 380 | { |
| 381 | // Quantize layer norm output to Input Scale * m_ForgetLayerNormWeightsTensor * 1024 |
| 382 | forgetGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() * |
| 383 | m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() * |
| 384 | 1024); |
| 385 | forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data()); |
| 386 | |
| 387 | |
| 388 | |
| 389 | MeanStddevNormalization(*forgetGateDecoder, |
| 390 | *forgetGateEncoder, numUnits, numBatches, m_LayerNormEpsilon); |
| 391 | |
| 392 | |
| 393 | forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data()); |
| 394 | |
| 395 | VectorBatchVectorCwiseProduct(*forgetLayerNormWeightsDecoder, |
| 396 | numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder); |
| 397 | |
| 398 | |
| 399 | // Dequantize layer norm output to (1 / 4096) |
| 400 | forgetGateInfo.SetQuantizationScale(1.f / 4096); |
| 401 | forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data()); |
| 402 | |
| 403 | VectorBatchVectorAdd(*forgetGateBiasDecoder, |
| 404 | numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder); |
| 405 | |
| 406 | |
| 407 | forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data()); |
| 408 | } |
| 409 | |
| 410 | forgetGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale()); |
| 411 | forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data()); |
| 412 | |
| 413 | // Forget gate sigmoid |
| 414 | Activation(*forgetGateDecoder, *forgetGateEncoder, |
| 415 | TensorInfo({numUnits, numBatches}, internalType), |
| 416 | ActivationFunction::Sigmoid, 0, 0); |
| 417 | |
| 418 | forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data()); |
| 419 | |
| 420 | // Cell (Modulation) gate |
| 421 | if (layerNormEnabled) |
| 422 | { |
| 423 | cellGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() * |
| 424 | m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() * |
| 425 | 1024); |
| 426 | cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data()); |
| 427 | |
| 428 | MeanStddevNormalization(*cellGateDecoder, *cellGateEncoder, numUnits, numBatches, m_LayerNormEpsilon); |
| 429 | |
| 430 | cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data()); |
| 431 | |
| 432 | VectorBatchVectorCwiseProduct(*cellLayerNormWeightsDecoder, |
| 433 | numUnits, *cellGateDecoder, numBatches, *cellGateEncoder); |
| 434 | |
| 435 | cellGateInfo.SetQuantizationScale(1.f / 4096); |
| 436 | cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data()); |
| 437 | |
| 438 | VectorBatchVectorAdd(*cellGateBiasDecoder, |
| 439 | numUnits, *cellGateDecoder, numBatches, *cellGateEncoder); |
| 440 | |
| 441 | cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data()); |
| 442 | } |
| 443 | |
| 444 | cellGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale()); |
| 445 | cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data()); |
| 446 | |
| 447 | // Cell (Modulation) gate tanH |
| 448 | Activation(*cellGateDecoder, *cellGateEncoder, |
| 449 | TensorInfo({numUnits, numBatches}, internalType), |
| 450 | ActivationFunction::TanH, 1.0f, 1.0f); |
| 451 | |
| 452 | cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data()); |
| 453 | |
| 454 | VectorVectorCwiseProduct(*forgetGateDecoder, *cellStateInDecoder, stateTensorSize, *cellStateOutEncoder); |
| 455 | |
| 456 | if (cifgEnabled) |
| 457 | { |
| 458 | Sub1Vector(*forgetGateDecoder, stateTensorSize, *forgetGateEncoder); |
| 459 | VectorVectorCwiseProductAccumulate( |
| 460 | *cellGateDecoder, *forgetGateDecoder, stateTensorSize, *cellStateOutEncoder); |
| 461 | } |
| 462 | else |
| 463 | { |
| 464 | VectorVectorCwiseProductAccumulate( |
| 465 | *cellGateDecoder, *inputGateDecoder, stateTensorSize, *cellStateOutEncoder); |
| 466 | } |
| 467 | |
| 468 | // Final cell state out calculated here |
| 469 | if (m_Data.m_Parameters.m_CellClip > 0.0) |
| 470 | { |
| 471 | ClipVector(*cellStateOutDecoder, stateTensorSize, m_Data.m_Parameters.m_CellClip, *cellStateOutEncoder); |
| 472 | } |
| 473 | |
| 474 | // Output gate. |
| 475 | if (peepholeEnabled) |
| 476 | { |
| 477 | VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsDecoder, |
| 478 | numUnits, *cellStateOutDecoder, numBatches, *outputGateEncoder); |
| 479 | } |
| 480 | |
| 481 | if (layerNormEnabled) |
| 482 | { |
| 483 | outputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() * |
| 484 | m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() * |
| 485 | 1024); |
| 486 | outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data()); |
| 487 | |
| 488 | MeanStddevNormalization(*outputGateDecoder, *outputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon); |
| 489 | |
| 490 | outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data()); |
| 491 | |
| 492 | VectorBatchVectorCwiseProduct(*outputLayerNormWeightsDecoder, numUnits, *outputGateDecoder, |
| 493 | numBatches, *outputGateEncoder); |
| 494 | |
| 495 | outputGateInfo.SetQuantizationScale(1.f / 4096); |
| 496 | outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data()); |
| 497 | |
| 498 | VectorBatchVectorAdd(*outputGateBiasDecoder, numUnits, *outputGateDecoder, numBatches, *outputGateEncoder); |
| 499 | |
| 500 | outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data()); |
| 501 | } |
| 502 | |
| 503 | outputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale()); |
| 504 | outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data()); |
| 505 | |
| 506 | // Output gate sigmoid |
| 507 | Activation(*outputGateDecoder, *outputGateEncoder, |
| 508 | TensorInfo({numUnits, numBatches}, internalType), |
| 509 | ActivationFunction::Sigmoid, 0, 0); |
| 510 | |
| 511 | outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data()); |
| 512 | |
| 513 | // Hidden state tanH |
| 514 | Activation(*cellStateOutDecoder, *cellGateEncoder, |
| 515 | TensorInfo({numUnits, numBatches}, internalType), |
| 516 | ActivationFunction::TanH, 1.0f, 1.0f); |
| 517 | |
| 518 | // Final hidden state output |
| 519 | VectorVectorCwiseProduct(*outputGateDecoder, *cellGateDecoder, stateTensorSize, *hiddenStateEncoder); |
| 520 | |
| 521 | // Projection |
| 522 | if (m_Data.m_Parameters.m_ProjectionEnabled) |
| 523 | { |
| 524 | if (m_ProjectionBiasTensor) |
| 525 | { |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 526 | VectorBatchVectorAssign(*projectionBiasDecoder, outputSize, numBatches, *outputInt16Encoder); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 527 | } |
| 528 | |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 529 | MatrixBatchVectorMultiplyAccumulate(*projectionWeightsDecoder, outputSize, numUnits, *hiddenStateDecoder, |
| 530 | numBatches, *outputInt16Encoder); |
| 531 | |
| 532 | CopyVector(*outputInt16Decoder, numBatches * outputSize, *outputEncoder); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 533 | |
| 534 | if (m_Data.m_Parameters.m_ProjectionClip > 0.0) |
| 535 | { |
| 536 | ClipVector(*outputDecoder, numBatches * outputSize, m_Data.m_Parameters.m_ProjectionClip, *outputEncoder); |
| 537 | } |
| 538 | } |
| 539 | else |
| 540 | { |
| 541 | // Output has same quantization scale as hidden state if projection is disabled |
| 542 | CopyVector(*hiddenStateDecoder, numBatches * outputSize, *outputEncoder); |
| 543 | } |
| 544 | |
| 545 | // output == outputStateOut |
| 546 | CopyVector(*outputDecoder, numBatches * outputSize, *outputStateOutEncoder); |
| 547 | } |
| 548 | |
| 549 | } //namespace armnn |