blob: 1fa9f0c8bf362aaaecbe44790e027dc0ea9fcc4d [file] [log] [blame]
Mike Kelly8ae17b32021-02-17 13:45:50 +00001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "LstmTestHelper.hpp"
7
8#include <armnn_delegate.hpp>
9
10#include <flatbuffers/flatbuffers.h>
11#include <tensorflow/lite/schema/schema_generated.h>
12#include <doctest/doctest.h>
13
14namespace armnnDelegate
15{
16
17void LstmTest(std::vector<armnn::BackendId>& backends)
18{
19 int32_t batchSize = 2;
20 int32_t inputSize = 2;
21 int32_t outputSize = 4;
22 // cellSize and outputSize have the same size when there is no projection.
23 int32_t numUnits = outputSize;
24
25 std::vector<int32_t> inputShape {batchSize , inputSize};
26 std::vector<int32_t> cellStateInTensorInfo {batchSize , numUnits};
27 std::vector<int32_t> outputStateInTensorInfo {batchSize , outputSize};
28
29 std::vector<int32_t> scratchBufferTensorInfo {batchSize, numUnits * 4};
30 std::vector<int32_t> cellStateOutTensorInfo {batchSize, numUnits};
31 std::vector<int32_t> outputStateOutTensorInfo {batchSize, outputSize};
32 std::vector<int32_t> outputTensorInfo {batchSize, outputSize};
33
34 std::vector<int32_t> tensorInfo4 {numUnits};
35 std::vector<int32_t> tensorInfo8 {numUnits, 2};
36 std::vector<int32_t> tensorInfo16 {numUnits, 4};
37
38 //tensorInfo8,
39 bool hasInputToInputWeights = true;
40 std::vector<float> inputToInputWeights {-0.45018822f, -0.02338299f, -0.0870589f,
41 -0.34550029f, 0.04266912f, -0.15680569f,
42 -0.34856534f, 0.43890524f};
43
44 std::vector<float> inputToForgetWeights {0.09701663f, 0.20334584f, -0.50592935f,
45 -0.31343272f, -0.40032279f, 0.44781327f,
46 0.01387155f, -0.35593212f};
47
48 std::vector<float> inputToCellWeights {-0.50013041f, 0.1370284f, 0.11810488f, 0.2013163f,
49 -0.20583314f, 0.44344562f, 0.22077113f,
50 -0.29909778f};
51
52 std::vector<float> inputToOutputWeights {-0.25065863f, -0.28290087f, 0.04613829f,
53 0.40525138f, 0.44272184f, 0.03897077f,
54 -0.1556896f, 0.19487578f};
55
56 //tensorInfo16,
57 bool hasRecurrentToInputWeights = true;
58 std::vector<float> recurrentToInputWeights {-0.0063535f, -0.2042388f, 0.31454784f,
59 -0.35746509f, 0.28902304f, 0.08183324f,
60 -0.16555229f, 0.02286911f, -0.13566875f,
61 0.03034258f, 0.48091322f, -0.12528998f,
62 0.24077177f, -0.51332325f, -0.33502164f,
63 0.10629296f};
64
65 std::vector<float> recurrentToForgetWeights {-0.48684245f, -0.06655136f, 0.42224967f,
66 0.2112639f, 0.27654213f, 0.20864892f,
67 -0.07646349f, 0.45877004f, 0.00141793f,
68 -0.14609534f, 0.36447752f, 0.09196436f,
69 0.28053468f, 0.01560611f, -0.20127171f,
70 -0.01140004f};
71
72 std::vector<float> recurrentToCellWeights {-0.3407414f, 0.24443203f, -0.2078532f,
73 0.26320225f, 0.05695659f, -0.00123841f,
74 -0.4744786f, -0.35869038f, -0.06418842f,
75 -0.13502428f, -0.501764f, 0.22830659f,
76 -0.46367589f, 0.26016325f, -0.03894562f,
77 -0.16368064f};
78
79 std::vector<float> recurrentToOutputWeights {0.43385774f, -0.17194885f, 0.2718237f,
80 0.09215671f, 0.24107647f, -0.39835793f,
81 0.18212086f, 0.01301402f, 0.48572797f,
82 -0.50656658f, 0.20047462f, -0.20607421f,
83 -0.51818722f, -0.15390486f, 0.0468148f,
84 0.39922136f};
85 // tensorInfo4
86 bool hasCellToInputWeights = false;
87 std::vector<float> cellToInputWeights {};
88 bool hasCellToForgetWeights = false;
89 std::vector<float> cellToForgetWeights {};
90 bool hasCellToOutputWeights = false;
91 std::vector<float> cellToOutputWeights {};
92
93 bool hasInputGateBias = true;
94 std::vector<float> inputGateBias {0., 0., 0., 0.};
95 std::vector<float> forgetGateBias {1., 1., 1., 1.};
96 std::vector<float> cellBias {0., 0., 0., 0.};
97 std::vector<float> outputGateBias {0., 0., 0., 0.};
98
99 bool hasProjectionWeights = false;
100 std::vector<float> projectionWeights;
101 bool hasProjectionBias = false;
102 std::vector<float> projectionBias;
103
104 bool hasInputLayerNormWeights = false;
105 std::vector<float> inputLayerNormWeights;
106 bool hasForgetLayerNormWeights = false;
107 std::vector<float> forgetLayerNormWeights;
108 bool hasCellLayerNormWeights = false;
109 std::vector<float> cellLayerNormWeights;
110 bool hasOutputLayerNormWeights = false;
111 std::vector<float> outputLayerNormWeights;
112
113 std::vector<float> inputValues {2., 3., 3., 4.};
114 std::vector<float> expectedOutputValues {-0.02973187f, 0.1229473f, 0.20885126f, -0.15358765f,
115 -0.0185422f, 0.11281417f, 0.24466537f, -0.1826292f};
116
117 tflite::ActivationFunctionType activationFunction = tflite::ActivationFunctionType_TANH;
118 float clippingThresCell = 0.f;
119 float clippingThresProj = 0.f;
120
121 LstmTestImpl<float>(backends,
122 ::tflite::TensorType_FLOAT32,
123 batchSize,
124 inputSize,
125 outputSize,
126 numUnits,
127 hasInputToInputWeights,
128 inputToInputWeights,
129 inputToForgetWeights,
130 inputToCellWeights,
131 inputToOutputWeights,
132 hasRecurrentToInputWeights,
133 recurrentToInputWeights,
134 recurrentToForgetWeights,
135 recurrentToCellWeights,
136 recurrentToOutputWeights,
137 hasCellToInputWeights,
138 cellToInputWeights,
139 hasCellToForgetWeights,
140 cellToForgetWeights,
141 hasCellToOutputWeights,
142 cellToOutputWeights,
143 hasInputGateBias,
144 inputGateBias,
145 forgetGateBias,
146 cellBias,
147 outputGateBias,
148 hasProjectionWeights,
149 projectionWeights,
150 hasProjectionBias,
151 projectionBias,
152 hasInputLayerNormWeights,
153 inputLayerNormWeights,
154 hasForgetLayerNormWeights,
155 forgetLayerNormWeights,
156 hasCellLayerNormWeights,
157 cellLayerNormWeights,
158 hasOutputLayerNormWeights,
159 outputLayerNormWeights,
160 inputValues,
161 expectedOutputValues,
162 activationFunction,
163 clippingThresCell,
164 clippingThresProj);
165}
166
167TEST_SUITE("LstmTest_CpuRefTests")
168{
169
170TEST_CASE ("LstmTest_CpuRef_Test")
171{
172 std::vector <armnn::BackendId> backends = {armnn::Compute::CpuRef};
173 LstmTest(backends);
174}
175
176} //End of TEST_SUITE("Convolution2dTest_CpuRef")
177
178TEST_SUITE("LstmTest_CpuAccTests")
179{
180
181TEST_CASE ("LstmTest_CpuAcc_Test")
182{
183 std::vector <armnn::BackendId> backends = {armnn::Compute::CpuAcc};
184 LstmTest(backends);
185}
186
187} //End of TEST_SUITE("Convolution2dTest_CpuAcc")
188
189} // namespace armnnDelegate