Mike Kelly | 8ae17b3 | 2021-02-17 13:45:50 +0000 | [diff] [blame] | 1 | // |
Colm Donelan | 7bcae3c | 2024-01-22 10:07:14 +0000 | [diff] [blame] | 2 | // Copyright © 2021, 2023-2024 Arm Ltd and Contributors. All rights reserved. |
Mike Kelly | 8ae17b3 | 2021-02-17 13:45:50 +0000 | [diff] [blame] | 3 | // SPDX-License-Identifier: MIT |
| 4 | // |
| 5 | |
| 6 | #include "LstmTestHelper.hpp" |
| 7 | |
Mike Kelly | 8ae17b3 | 2021-02-17 13:45:50 +0000 | [diff] [blame] | 8 | #include <doctest/doctest.h> |
| 9 | |
| 10 | namespace armnnDelegate |
| 11 | { |
| 12 | |
| 13 | void LstmTest(std::vector<armnn::BackendId>& backends) |
| 14 | { |
| 15 | int32_t batchSize = 2; |
| 16 | int32_t inputSize = 2; |
| 17 | int32_t outputSize = 4; |
| 18 | // cellSize and outputSize have the same size when there is no projection. |
| 19 | int32_t numUnits = outputSize; |
| 20 | |
| 21 | std::vector<int32_t> inputShape {batchSize , inputSize}; |
| 22 | std::vector<int32_t> cellStateInTensorInfo {batchSize , numUnits}; |
| 23 | std::vector<int32_t> outputStateInTensorInfo {batchSize , outputSize}; |
| 24 | |
| 25 | std::vector<int32_t> scratchBufferTensorInfo {batchSize, numUnits * 4}; |
| 26 | std::vector<int32_t> cellStateOutTensorInfo {batchSize, numUnits}; |
| 27 | std::vector<int32_t> outputStateOutTensorInfo {batchSize, outputSize}; |
| 28 | std::vector<int32_t> outputTensorInfo {batchSize, outputSize}; |
| 29 | |
| 30 | std::vector<int32_t> tensorInfo4 {numUnits}; |
| 31 | std::vector<int32_t> tensorInfo8 {numUnits, 2}; |
| 32 | std::vector<int32_t> tensorInfo16 {numUnits, 4}; |
| 33 | |
| 34 | //tensorInfo8, |
| 35 | bool hasInputToInputWeights = true; |
| 36 | std::vector<float> inputToInputWeights {-0.45018822f, -0.02338299f, -0.0870589f, |
| 37 | -0.34550029f, 0.04266912f, -0.15680569f, |
| 38 | -0.34856534f, 0.43890524f}; |
| 39 | |
| 40 | std::vector<float> inputToForgetWeights {0.09701663f, 0.20334584f, -0.50592935f, |
| 41 | -0.31343272f, -0.40032279f, 0.44781327f, |
| 42 | 0.01387155f, -0.35593212f}; |
| 43 | |
| 44 | std::vector<float> inputToCellWeights {-0.50013041f, 0.1370284f, 0.11810488f, 0.2013163f, |
| 45 | -0.20583314f, 0.44344562f, 0.22077113f, |
| 46 | -0.29909778f}; |
| 47 | |
| 48 | std::vector<float> inputToOutputWeights {-0.25065863f, -0.28290087f, 0.04613829f, |
| 49 | 0.40525138f, 0.44272184f, 0.03897077f, |
| 50 | -0.1556896f, 0.19487578f}; |
| 51 | |
| 52 | //tensorInfo16, |
| 53 | bool hasRecurrentToInputWeights = true; |
| 54 | std::vector<float> recurrentToInputWeights {-0.0063535f, -0.2042388f, 0.31454784f, |
| 55 | -0.35746509f, 0.28902304f, 0.08183324f, |
| 56 | -0.16555229f, 0.02286911f, -0.13566875f, |
| 57 | 0.03034258f, 0.48091322f, -0.12528998f, |
| 58 | 0.24077177f, -0.51332325f, -0.33502164f, |
| 59 | 0.10629296f}; |
| 60 | |
| 61 | std::vector<float> recurrentToForgetWeights {-0.48684245f, -0.06655136f, 0.42224967f, |
| 62 | 0.2112639f, 0.27654213f, 0.20864892f, |
| 63 | -0.07646349f, 0.45877004f, 0.00141793f, |
| 64 | -0.14609534f, 0.36447752f, 0.09196436f, |
| 65 | 0.28053468f, 0.01560611f, -0.20127171f, |
| 66 | -0.01140004f}; |
| 67 | |
| 68 | std::vector<float> recurrentToCellWeights {-0.3407414f, 0.24443203f, -0.2078532f, |
| 69 | 0.26320225f, 0.05695659f, -0.00123841f, |
| 70 | -0.4744786f, -0.35869038f, -0.06418842f, |
| 71 | -0.13502428f, -0.501764f, 0.22830659f, |
| 72 | -0.46367589f, 0.26016325f, -0.03894562f, |
| 73 | -0.16368064f}; |
| 74 | |
| 75 | std::vector<float> recurrentToOutputWeights {0.43385774f, -0.17194885f, 0.2718237f, |
| 76 | 0.09215671f, 0.24107647f, -0.39835793f, |
| 77 | 0.18212086f, 0.01301402f, 0.48572797f, |
| 78 | -0.50656658f, 0.20047462f, -0.20607421f, |
| 79 | -0.51818722f, -0.15390486f, 0.0468148f, |
| 80 | 0.39922136f}; |
| 81 | // tensorInfo4 |
| 82 | bool hasCellToInputWeights = false; |
| 83 | std::vector<float> cellToInputWeights {}; |
| 84 | bool hasCellToForgetWeights = false; |
| 85 | std::vector<float> cellToForgetWeights {}; |
| 86 | bool hasCellToOutputWeights = false; |
| 87 | std::vector<float> cellToOutputWeights {}; |
| 88 | |
| 89 | bool hasInputGateBias = true; |
| 90 | std::vector<float> inputGateBias {0., 0., 0., 0.}; |
| 91 | std::vector<float> forgetGateBias {1., 1., 1., 1.}; |
| 92 | std::vector<float> cellBias {0., 0., 0., 0.}; |
| 93 | std::vector<float> outputGateBias {0., 0., 0., 0.}; |
| 94 | |
| 95 | bool hasProjectionWeights = false; |
| 96 | std::vector<float> projectionWeights; |
| 97 | bool hasProjectionBias = false; |
| 98 | std::vector<float> projectionBias; |
| 99 | |
| 100 | bool hasInputLayerNormWeights = false; |
| 101 | std::vector<float> inputLayerNormWeights; |
| 102 | bool hasForgetLayerNormWeights = false; |
| 103 | std::vector<float> forgetLayerNormWeights; |
| 104 | bool hasCellLayerNormWeights = false; |
| 105 | std::vector<float> cellLayerNormWeights; |
| 106 | bool hasOutputLayerNormWeights = false; |
| 107 | std::vector<float> outputLayerNormWeights; |
| 108 | |
| 109 | std::vector<float> inputValues {2., 3., 3., 4.}; |
| 110 | std::vector<float> expectedOutputValues {-0.02973187f, 0.1229473f, 0.20885126f, -0.15358765f, |
| 111 | -0.0185422f, 0.11281417f, 0.24466537f, -0.1826292f}; |
| 112 | |
| 113 | tflite::ActivationFunctionType activationFunction = tflite::ActivationFunctionType_TANH; |
| 114 | float clippingThresCell = 0.f; |
| 115 | float clippingThresProj = 0.f; |
| 116 | |
Colm Donelan | 7bcae3c | 2024-01-22 10:07:14 +0000 | [diff] [blame] | 117 | LstmTestImpl<float>(::tflite::TensorType_FLOAT32, |
Mike Kelly | 8ae17b3 | 2021-02-17 13:45:50 +0000 | [diff] [blame] | 118 | batchSize, |
| 119 | inputSize, |
| 120 | outputSize, |
| 121 | numUnits, |
| 122 | hasInputToInputWeights, |
| 123 | inputToInputWeights, |
| 124 | inputToForgetWeights, |
| 125 | inputToCellWeights, |
| 126 | inputToOutputWeights, |
| 127 | hasRecurrentToInputWeights, |
| 128 | recurrentToInputWeights, |
| 129 | recurrentToForgetWeights, |
| 130 | recurrentToCellWeights, |
| 131 | recurrentToOutputWeights, |
| 132 | hasCellToInputWeights, |
| 133 | cellToInputWeights, |
| 134 | hasCellToForgetWeights, |
| 135 | cellToForgetWeights, |
| 136 | hasCellToOutputWeights, |
| 137 | cellToOutputWeights, |
| 138 | hasInputGateBias, |
| 139 | inputGateBias, |
| 140 | forgetGateBias, |
| 141 | cellBias, |
| 142 | outputGateBias, |
| 143 | hasProjectionWeights, |
| 144 | projectionWeights, |
| 145 | hasProjectionBias, |
| 146 | projectionBias, |
| 147 | hasInputLayerNormWeights, |
| 148 | inputLayerNormWeights, |
| 149 | hasForgetLayerNormWeights, |
| 150 | forgetLayerNormWeights, |
| 151 | hasCellLayerNormWeights, |
| 152 | cellLayerNormWeights, |
| 153 | hasOutputLayerNormWeights, |
| 154 | outputLayerNormWeights, |
| 155 | inputValues, |
| 156 | expectedOutputValues, |
| 157 | activationFunction, |
| 158 | clippingThresCell, |
Colm Donelan | 7bcae3c | 2024-01-22 10:07:14 +0000 | [diff] [blame] | 159 | clippingThresProj, |
| 160 | backends); |
Mike Kelly | 8ae17b3 | 2021-02-17 13:45:50 +0000 | [diff] [blame] | 161 | } |
| 162 | |
Colm Donelan | 7bcae3c | 2024-01-22 10:07:14 +0000 | [diff] [blame] | 163 | TEST_SUITE("LstmTest_Tests") |
Mike Kelly | 8ae17b3 | 2021-02-17 13:45:50 +0000 | [diff] [blame] | 164 | { |
| 165 | |
Colm Donelan | 7bcae3c | 2024-01-22 10:07:14 +0000 | [diff] [blame] | 166 | TEST_CASE ("LstmTest_Test") |
Mike Kelly | 8ae17b3 | 2021-02-17 13:45:50 +0000 | [diff] [blame] | 167 | { |
Colm Donelan | 7bcae3c | 2024-01-22 10:07:14 +0000 | [diff] [blame] | 168 | std::vector <armnn::BackendId> backends = {armnn::Compute::CpuRef, armnn::Compute::CpuAcc}; |
Mike Kelly | 8ae17b3 | 2021-02-17 13:45:50 +0000 | [diff] [blame] | 169 | LstmTest(backends); |
| 170 | } |
| 171 | |
Mike Kelly | 8ae17b3 | 2021-02-17 13:45:50 +0000 | [diff] [blame] | 172 | } |
Mike Kelly | 8ae17b3 | 2021-02-17 13:45:50 +0000 | [diff] [blame] | 173 | } // namespace armnnDelegate |