blob: 309f3cd3fe8f285caa0684b09ec72885be15fdc3 [file] [log] [blame]
Mike Kelly8ae17b32021-02-17 13:45:50 +00001//
Colm Donelan7bcae3c2024-01-22 10:07:14 +00002// Copyright © 2021, 2023-2024 Arm Ltd and Contributors. All rights reserved.
Mike Kelly8ae17b32021-02-17 13:45:50 +00003// SPDX-License-Identifier: MIT
4//
5
6#include "LstmTestHelper.hpp"
7
Mike Kelly8ae17b32021-02-17 13:45:50 +00008#include <doctest/doctest.h>
9
10namespace armnnDelegate
11{
12
13void LstmTest(std::vector<armnn::BackendId>& backends)
14{
15 int32_t batchSize = 2;
16 int32_t inputSize = 2;
17 int32_t outputSize = 4;
18 // cellSize and outputSize have the same size when there is no projection.
19 int32_t numUnits = outputSize;
20
21 std::vector<int32_t> inputShape {batchSize , inputSize};
22 std::vector<int32_t> cellStateInTensorInfo {batchSize , numUnits};
23 std::vector<int32_t> outputStateInTensorInfo {batchSize , outputSize};
24
25 std::vector<int32_t> scratchBufferTensorInfo {batchSize, numUnits * 4};
26 std::vector<int32_t> cellStateOutTensorInfo {batchSize, numUnits};
27 std::vector<int32_t> outputStateOutTensorInfo {batchSize, outputSize};
28 std::vector<int32_t> outputTensorInfo {batchSize, outputSize};
29
30 std::vector<int32_t> tensorInfo4 {numUnits};
31 std::vector<int32_t> tensorInfo8 {numUnits, 2};
32 std::vector<int32_t> tensorInfo16 {numUnits, 4};
33
34 //tensorInfo8,
35 bool hasInputToInputWeights = true;
36 std::vector<float> inputToInputWeights {-0.45018822f, -0.02338299f, -0.0870589f,
37 -0.34550029f, 0.04266912f, -0.15680569f,
38 -0.34856534f, 0.43890524f};
39
40 std::vector<float> inputToForgetWeights {0.09701663f, 0.20334584f, -0.50592935f,
41 -0.31343272f, -0.40032279f, 0.44781327f,
42 0.01387155f, -0.35593212f};
43
44 std::vector<float> inputToCellWeights {-0.50013041f, 0.1370284f, 0.11810488f, 0.2013163f,
45 -0.20583314f, 0.44344562f, 0.22077113f,
46 -0.29909778f};
47
48 std::vector<float> inputToOutputWeights {-0.25065863f, -0.28290087f, 0.04613829f,
49 0.40525138f, 0.44272184f, 0.03897077f,
50 -0.1556896f, 0.19487578f};
51
52 //tensorInfo16,
53 bool hasRecurrentToInputWeights = true;
54 std::vector<float> recurrentToInputWeights {-0.0063535f, -0.2042388f, 0.31454784f,
55 -0.35746509f, 0.28902304f, 0.08183324f,
56 -0.16555229f, 0.02286911f, -0.13566875f,
57 0.03034258f, 0.48091322f, -0.12528998f,
58 0.24077177f, -0.51332325f, -0.33502164f,
59 0.10629296f};
60
61 std::vector<float> recurrentToForgetWeights {-0.48684245f, -0.06655136f, 0.42224967f,
62 0.2112639f, 0.27654213f, 0.20864892f,
63 -0.07646349f, 0.45877004f, 0.00141793f,
64 -0.14609534f, 0.36447752f, 0.09196436f,
65 0.28053468f, 0.01560611f, -0.20127171f,
66 -0.01140004f};
67
68 std::vector<float> recurrentToCellWeights {-0.3407414f, 0.24443203f, -0.2078532f,
69 0.26320225f, 0.05695659f, -0.00123841f,
70 -0.4744786f, -0.35869038f, -0.06418842f,
71 -0.13502428f, -0.501764f, 0.22830659f,
72 -0.46367589f, 0.26016325f, -0.03894562f,
73 -0.16368064f};
74
75 std::vector<float> recurrentToOutputWeights {0.43385774f, -0.17194885f, 0.2718237f,
76 0.09215671f, 0.24107647f, -0.39835793f,
77 0.18212086f, 0.01301402f, 0.48572797f,
78 -0.50656658f, 0.20047462f, -0.20607421f,
79 -0.51818722f, -0.15390486f, 0.0468148f,
80 0.39922136f};
81 // tensorInfo4
82 bool hasCellToInputWeights = false;
83 std::vector<float> cellToInputWeights {};
84 bool hasCellToForgetWeights = false;
85 std::vector<float> cellToForgetWeights {};
86 bool hasCellToOutputWeights = false;
87 std::vector<float> cellToOutputWeights {};
88
89 bool hasInputGateBias = true;
90 std::vector<float> inputGateBias {0., 0., 0., 0.};
91 std::vector<float> forgetGateBias {1., 1., 1., 1.};
92 std::vector<float> cellBias {0., 0., 0., 0.};
93 std::vector<float> outputGateBias {0., 0., 0., 0.};
94
95 bool hasProjectionWeights = false;
96 std::vector<float> projectionWeights;
97 bool hasProjectionBias = false;
98 std::vector<float> projectionBias;
99
100 bool hasInputLayerNormWeights = false;
101 std::vector<float> inputLayerNormWeights;
102 bool hasForgetLayerNormWeights = false;
103 std::vector<float> forgetLayerNormWeights;
104 bool hasCellLayerNormWeights = false;
105 std::vector<float> cellLayerNormWeights;
106 bool hasOutputLayerNormWeights = false;
107 std::vector<float> outputLayerNormWeights;
108
109 std::vector<float> inputValues {2., 3., 3., 4.};
110 std::vector<float> expectedOutputValues {-0.02973187f, 0.1229473f, 0.20885126f, -0.15358765f,
111 -0.0185422f, 0.11281417f, 0.24466537f, -0.1826292f};
112
113 tflite::ActivationFunctionType activationFunction = tflite::ActivationFunctionType_TANH;
114 float clippingThresCell = 0.f;
115 float clippingThresProj = 0.f;
116
Colm Donelan7bcae3c2024-01-22 10:07:14 +0000117 LstmTestImpl<float>(::tflite::TensorType_FLOAT32,
Mike Kelly8ae17b32021-02-17 13:45:50 +0000118 batchSize,
119 inputSize,
120 outputSize,
121 numUnits,
122 hasInputToInputWeights,
123 inputToInputWeights,
124 inputToForgetWeights,
125 inputToCellWeights,
126 inputToOutputWeights,
127 hasRecurrentToInputWeights,
128 recurrentToInputWeights,
129 recurrentToForgetWeights,
130 recurrentToCellWeights,
131 recurrentToOutputWeights,
132 hasCellToInputWeights,
133 cellToInputWeights,
134 hasCellToForgetWeights,
135 cellToForgetWeights,
136 hasCellToOutputWeights,
137 cellToOutputWeights,
138 hasInputGateBias,
139 inputGateBias,
140 forgetGateBias,
141 cellBias,
142 outputGateBias,
143 hasProjectionWeights,
144 projectionWeights,
145 hasProjectionBias,
146 projectionBias,
147 hasInputLayerNormWeights,
148 inputLayerNormWeights,
149 hasForgetLayerNormWeights,
150 forgetLayerNormWeights,
151 hasCellLayerNormWeights,
152 cellLayerNormWeights,
153 hasOutputLayerNormWeights,
154 outputLayerNormWeights,
155 inputValues,
156 expectedOutputValues,
157 activationFunction,
158 clippingThresCell,
Colm Donelan7bcae3c2024-01-22 10:07:14 +0000159 clippingThresProj,
160 backends);
Mike Kelly8ae17b32021-02-17 13:45:50 +0000161}
162
Colm Donelan7bcae3c2024-01-22 10:07:14 +0000163TEST_SUITE("LstmTest_Tests")
Mike Kelly8ae17b32021-02-17 13:45:50 +0000164{
165
Colm Donelan7bcae3c2024-01-22 10:07:14 +0000166TEST_CASE ("LstmTest_Test")
Mike Kelly8ae17b32021-02-17 13:45:50 +0000167{
Colm Donelan7bcae3c2024-01-22 10:07:14 +0000168 std::vector <armnn::BackendId> backends = {armnn::Compute::CpuRef, armnn::Compute::CpuAcc};
Mike Kelly8ae17b32021-02-17 13:45:50 +0000169 LstmTest(backends);
170}
171
Mike Kelly8ae17b32021-02-17 13:45:50 +0000172}
Mike Kelly8ae17b32021-02-17 13:45:50 +0000173} // namespace armnnDelegate