blob: 4717357b3ba6aae1a2ddf9b95039a6c87b6c7159 [file] [log] [blame]
surmeh0149b9e102018-05-17 14:11:25 +01001//
Mike Kellye2d611e2021-10-14 12:35:58 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beck93e48982018-09-05 13:05:09 +01003// SPDX-License-Identifier: MIT
surmeh0149b9e102018-05-17 14:11:25 +01004//
Mike Kellye2d611e2021-10-14 12:35:58 +01005
surmeh0149b9e102018-05-17 14:11:25 +01006#include "DriverTestHelpers.hpp"
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +01007
surmeh0149b9e102018-05-17 14:11:25 +01008#include <log/log.h>
9
Mike Kellye2d611e2021-10-14 12:35:58 +010010DOCTEST_TEST_SUITE("FullyConnectedTests")
Sadik Armagan9150bff2021-05-26 15:40:53 +010011{
telsoa01ce3e84a2018-08-31 09:31:35 +010012using namespace android::hardware;
surmeh0149b9e102018-05-17 14:11:25 +010013using namespace driverTestHelpers;
telsoa01ce3e84a2018-08-31 09:31:35 +010014using namespace armnn_driver;
surmeh0149b9e102018-05-17 14:11:25 +010015
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +010016using HalPolicy = hal_1_0::HalPolicy;
17
surmeh0149b9e102018-05-17 14:11:25 +010018// Add our own test here since we fail the fc tests which Google supplies (because of non-const weights)
Mike Kellye2d611e2021-10-14 12:35:58 +010019DOCTEST_TEST_CASE("FullyConnected")
surmeh0149b9e102018-05-17 14:11:25 +010020{
21 // this should ideally replicate fully_connected_float.model.cpp
22 // but that uses slightly weird dimensions which I don't think we need to support for now
23
24 auto driver = std::make_unique<ArmnnDriver>(DriverOptions(armnn::Compute::CpuRef));
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +010025 HalPolicy::Model model = {};
surmeh0149b9e102018-05-17 14:11:25 +010026
27 // add operands
28 int32_t actValue = 0;
29 float weightValue[] = {2, 4, 1};
30 float biasValue[] = {4};
31
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +010032 AddInputOperand<HalPolicy>(model, hidl_vec<uint32_t>{1, 3});
33 AddTensorOperand<HalPolicy>(model, hidl_vec<uint32_t>{1, 3}, weightValue);
34 AddTensorOperand<HalPolicy>(model, hidl_vec<uint32_t>{1}, biasValue);
35 AddIntOperand<HalPolicy>(model, actValue);
36 AddOutputOperand<HalPolicy>(model, hidl_vec<uint32_t>{1, 1});
surmeh0149b9e102018-05-17 14:11:25 +010037
38 // make the fully connected operation
39 model.operations.resize(1);
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +010040 model.operations[0].type = HalPolicy::OperationType::FULLY_CONNECTED;
surmeh0149b9e102018-05-17 14:11:25 +010041 model.operations[0].inputs = hidl_vec<uint32_t>{0, 1, 2, 3};
42 model.operations[0].outputs = hidl_vec<uint32_t>{4};
43
44 // make the prepared model
Sadik Armagane6e54a82019-05-08 10:18:05 +010045 android::sp<V1_0::IPreparedModel> preparedModel = PrepareModel(model, *driver);
surmeh0149b9e102018-05-17 14:11:25 +010046
47 // construct the request
Sadik Armagan188675f2021-02-12 17:16:42 +000048 V1_0::DataLocation inloc = {};
surmeh0149b9e102018-05-17 14:11:25 +010049 inloc.poolIndex = 0;
50 inloc.offset = 0;
51 inloc.length = 3 * sizeof(float);
52 RequestArgument input = {};
53 input.location = inloc;
54 input.dimensions = hidl_vec<uint32_t>{};
55
Sadik Armagan188675f2021-02-12 17:16:42 +000056 V1_0::DataLocation outloc = {};
surmeh0149b9e102018-05-17 14:11:25 +010057 outloc.poolIndex = 1;
58 outloc.offset = 0;
59 outloc.length = 1 * sizeof(float);
60 RequestArgument output = {};
61 output.location = outloc;
62 output.dimensions = hidl_vec<uint32_t>{};
63
Kevin Mayec1e5b82020-02-26 17:00:39 +000064 V1_0::Request request = {};
surmeh0149b9e102018-05-17 14:11:25 +010065 request.inputs = hidl_vec<RequestArgument>{input};
66 request.outputs = hidl_vec<RequestArgument>{output};
67
68 // set the input data (matching source test)
69 float indata[] = {2, 32, 16};
Ellen Norris-Thompson976ad3e2019-08-21 15:21:14 +010070 AddPoolAndSetData<float>(3, request, indata);
surmeh0149b9e102018-05-17 14:11:25 +010071
72 // add memory for the output
Ellen Norris-Thompson976ad3e2019-08-21 15:21:14 +010073 android::sp<IMemory> outMemory = AddPoolAndGetData<float>(1, request);
surmeh0149b9e102018-05-17 14:11:25 +010074 float* outdata = static_cast<float*>(static_cast<void*>(outMemory->getPointer()));
75
76 // run the execution
Sadik Armagand4636872020-04-27 10:15:41 +010077 if (preparedModel.get() != nullptr)
78 {
79 Execute(preparedModel, request);
80 }
surmeh0149b9e102018-05-17 14:11:25 +010081
82 // check the result
Mike Kellye2d611e2021-10-14 12:35:58 +010083 DOCTEST_CHECK(outdata[0] == 152);
surmeh0149b9e102018-05-17 14:11:25 +010084}
85
Mike Kellye2d611e2021-10-14 12:35:58 +010086DOCTEST_TEST_CASE("TestFullyConnected4dInput")
surmeh0149b9e102018-05-17 14:11:25 +010087{
88 auto driver = std::make_unique<ArmnnDriver>(DriverOptions(armnn::Compute::CpuRef));
89
Kevin Mayec1e5b82020-02-26 17:00:39 +000090 V1_0::ErrorStatus error;
surmeh0149b9e102018-05-17 14:11:25 +010091 std::vector<bool> sup;
92
Kevin Mayec1e5b82020-02-26 17:00:39 +000093 ArmnnDriver::getSupportedOperations_cb cb = [&](V1_0::ErrorStatus status, const std::vector<bool>& supported)
surmeh0149b9e102018-05-17 14:11:25 +010094 {
95 error = status;
96 sup = supported;
97 };
98
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +010099 HalPolicy::Model model = {};
surmeh0149b9e102018-05-17 14:11:25 +0100100
101 // operands
102 int32_t actValue = 0;
103 float weightValue[] = {1, 0, 0, 0, 0, 0, 0, 0,
104 0, 1, 0, 0, 0, 0, 0, 0,
105 0, 0, 1, 0, 0, 0, 0, 0,
106 0, 0, 0, 1, 0, 0, 0, 0,
107 0, 0, 0, 0, 1, 0, 0, 0,
108 0, 0, 0, 0, 0, 1, 0, 0,
109 0, 0, 0, 0, 0, 0, 1, 0,
110 0, 0, 0, 0, 0, 0, 0, 1}; //identity
111 float biasValue[] = {0, 0, 0, 0, 0, 0, 0, 0};
112
113 // fully connected operation
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100114 AddInputOperand<HalPolicy>(model, hidl_vec<uint32_t>{1, 1, 1, 8});
115 AddTensorOperand<HalPolicy>(model, hidl_vec<uint32_t>{8, 8}, weightValue);
116 AddTensorOperand<HalPolicy>(model, hidl_vec<uint32_t>{8}, biasValue);
117 AddIntOperand<HalPolicy>(model, actValue);
118 AddOutputOperand<HalPolicy>(model, hidl_vec<uint32_t>{1, 8});
surmeh0149b9e102018-05-17 14:11:25 +0100119
120 model.operations.resize(1);
121
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100122 model.operations[0].type = HalPolicy::OperationType::FULLY_CONNECTED;
surmeh0149b9e102018-05-17 14:11:25 +0100123 model.operations[0].inputs = hidl_vec<uint32_t>{0,1,2,3};
124 model.operations[0].outputs = hidl_vec<uint32_t>{4};
125
126 // make the prepared model
Sadik Armagane6e54a82019-05-08 10:18:05 +0100127 android::sp<V1_0::IPreparedModel> preparedModel = PrepareModel(model, *driver);
surmeh0149b9e102018-05-17 14:11:25 +0100128
surmeh0149b9e102018-05-17 14:11:25 +0100129 // construct the request
Sadik Armagan188675f2021-02-12 17:16:42 +0000130 V1_0::DataLocation inloc = {};
131 inloc.poolIndex = 0;
132 inloc.offset = 0;
133 inloc.length = 8 * sizeof(float);
134 RequestArgument input = {};
135 input.location = inloc;
136 input.dimensions = hidl_vec<uint32_t>{};
surmeh0149b9e102018-05-17 14:11:25 +0100137
Sadik Armagan188675f2021-02-12 17:16:42 +0000138 V1_0::DataLocation outloc = {};
139 outloc.poolIndex = 1;
140 outloc.offset = 0;
141 outloc.length = 8 * sizeof(float);
142 RequestArgument output = {};
143 output.location = outloc;
144 output.dimensions = hidl_vec<uint32_t>{};
surmeh0149b9e102018-05-17 14:11:25 +0100145
Kevin Mayec1e5b82020-02-26 17:00:39 +0000146 V1_0::Request request = {};
surmeh0149b9e102018-05-17 14:11:25 +0100147 request.inputs = hidl_vec<RequestArgument>{input};
148 request.outputs = hidl_vec<RequestArgument>{output};
149
150 // set the input data
151 float indata[] = {1,2,3,4,5,6,7,8};
152 AddPoolAndSetData(8, request, indata);
153
154 // add memory for the output
Ellen Norris-Thompson976ad3e2019-08-21 15:21:14 +0100155 android::sp<IMemory> outMemory = AddPoolAndGetData<float>(8, request);
surmeh0149b9e102018-05-17 14:11:25 +0100156 float* outdata = static_cast<float*>(static_cast<void*>(outMemory->getPointer()));
157
158 // run the execution
Sadik Armagand4636872020-04-27 10:15:41 +0100159 if (preparedModel != nullptr)
160 {
161 Execute(preparedModel, request);
162 }
surmeh0149b9e102018-05-17 14:11:25 +0100163
164 // check the result
Mike Kellye2d611e2021-10-14 12:35:58 +0100165 DOCTEST_CHECK(outdata[0] == 1);
166 DOCTEST_CHECK(outdata[1] == 2);
167 DOCTEST_CHECK(outdata[2] == 3);
168 DOCTEST_CHECK(outdata[3] == 4);
169 DOCTEST_CHECK(outdata[4] == 5);
170 DOCTEST_CHECK(outdata[5] == 6);
171 DOCTEST_CHECK(outdata[6] == 7);
172 DOCTEST_CHECK(outdata[7] == 8);
surmeh0149b9e102018-05-17 14:11:25 +0100173}
174
Mike Kellye2d611e2021-10-14 12:35:58 +0100175DOCTEST_TEST_CASE("TestFullyConnected4dInputReshape")
surmeh0149b9e102018-05-17 14:11:25 +0100176{
177 auto driver = std::make_unique<ArmnnDriver>(DriverOptions(armnn::Compute::CpuRef));
178
Kevin Mayec1e5b82020-02-26 17:00:39 +0000179 V1_0::ErrorStatus error;
surmeh0149b9e102018-05-17 14:11:25 +0100180 std::vector<bool> sup;
181
Kevin Mayec1e5b82020-02-26 17:00:39 +0000182 ArmnnDriver::getSupportedOperations_cb cb = [&](V1_0::ErrorStatus status, const std::vector<bool>& supported)
surmeh0149b9e102018-05-17 14:11:25 +0100183 {
184 error = status;
185 sup = supported;
186 };
187
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100188 HalPolicy::Model model = {};
surmeh0149b9e102018-05-17 14:11:25 +0100189
190 // operands
191 int32_t actValue = 0;
192 float weightValue[] = {1, 0, 0, 0, 0, 0, 0, 0,
193 0, 1, 0, 0, 0, 0, 0, 0,
194 0, 0, 1, 0, 0, 0, 0, 0,
195 0, 0, 0, 1, 0, 0, 0, 0,
196 0, 0, 0, 0, 1, 0, 0, 0,
197 0, 0, 0, 0, 0, 1, 0, 0,
198 0, 0, 0, 0, 0, 0, 1, 0,
199 0, 0, 0, 0, 0, 0, 0, 1}; //identity
200 float biasValue[] = {0, 0, 0, 0, 0, 0, 0, 0};
201
202 // fully connected operation
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100203 AddInputOperand<HalPolicy>(model, hidl_vec<uint32_t>{1, 2, 2, 2});
204 AddTensorOperand<HalPolicy>(model, hidl_vec<uint32_t>{8, 8}, weightValue);
205 AddTensorOperand<HalPolicy>(model, hidl_vec<uint32_t>{8}, biasValue);
206 AddIntOperand<HalPolicy>(model, actValue);
207 AddOutputOperand<HalPolicy>(model, hidl_vec<uint32_t>{1, 8});
surmeh0149b9e102018-05-17 14:11:25 +0100208
209 model.operations.resize(1);
210
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100211 model.operations[0].type = HalPolicy::OperationType::FULLY_CONNECTED;
surmeh0149b9e102018-05-17 14:11:25 +0100212 model.operations[0].inputs = hidl_vec<uint32_t>{0,1,2,3};
213 model.operations[0].outputs = hidl_vec<uint32_t>{4};
214
215 // make the prepared model
Sadik Armagane6e54a82019-05-08 10:18:05 +0100216 android::sp<V1_0::IPreparedModel> preparedModel = PrepareModel(model, *driver);
surmeh0149b9e102018-05-17 14:11:25 +0100217
surmeh0149b9e102018-05-17 14:11:25 +0100218 // construct the request
Sadik Armagan188675f2021-02-12 17:16:42 +0000219 V1_0::DataLocation inloc = {};
220 inloc.poolIndex = 0;
221 inloc.offset = 0;
222 inloc.length = 8 * sizeof(float);
223 RequestArgument input = {};
224 input.location = inloc;
225 input.dimensions = hidl_vec<uint32_t>{};
surmeh0149b9e102018-05-17 14:11:25 +0100226
Sadik Armagan188675f2021-02-12 17:16:42 +0000227 V1_0::DataLocation outloc = {};
228 outloc.poolIndex = 1;
229 outloc.offset = 0;
230 outloc.length = 8 * sizeof(float);
231 RequestArgument output = {};
232 output.location = outloc;
233 output.dimensions = hidl_vec<uint32_t>{};
surmeh0149b9e102018-05-17 14:11:25 +0100234
Kevin Mayec1e5b82020-02-26 17:00:39 +0000235 V1_0::Request request = {};
surmeh0149b9e102018-05-17 14:11:25 +0100236 request.inputs = hidl_vec<RequestArgument>{input};
237 request.outputs = hidl_vec<RequestArgument>{output};
238
239 // set the input data
240 float indata[] = {1,2,3,4,5,6,7,8};
241 AddPoolAndSetData(8, request, indata);
242
243 // add memory for the output
Ellen Norris-Thompson976ad3e2019-08-21 15:21:14 +0100244 android::sp<IMemory> outMemory = AddPoolAndGetData<float>(8, request);
surmeh0149b9e102018-05-17 14:11:25 +0100245 float* outdata = static_cast<float*>(static_cast<void*>(outMemory->getPointer()));
246
247 // run the execution
Sadik Armagand4636872020-04-27 10:15:41 +0100248 if (preparedModel != nullptr)
249 {
250 Execute(preparedModel, request);
251 }
surmeh0149b9e102018-05-17 14:11:25 +0100252
253 // check the result
Mike Kellye2d611e2021-10-14 12:35:58 +0100254 DOCTEST_CHECK(outdata[0] == 1);
255 DOCTEST_CHECK(outdata[1] == 2);
256 DOCTEST_CHECK(outdata[2] == 3);
257 DOCTEST_CHECK(outdata[3] == 4);
258 DOCTEST_CHECK(outdata[4] == 5);
259 DOCTEST_CHECK(outdata[5] == 6);
260 DOCTEST_CHECK(outdata[6] == 7);
261 DOCTEST_CHECK(outdata[7] == 8);
surmeh0149b9e102018-05-17 14:11:25 +0100262}
263
Mike Kellye2d611e2021-10-14 12:35:58 +0100264DOCTEST_TEST_CASE("TestFullyConnectedWeightsAsInput")
Sadik Armagan2e4a24a2021-03-18 13:59:40 +0000265{
266 auto driver = std::make_unique<ArmnnDriver>(DriverOptions(armnn::Compute::CpuRef));
267
268 V1_0::ErrorStatus error;
269 std::vector<bool> sup;
270
271 ArmnnDriver::getSupportedOperations_cb cb = [&](V1_0::ErrorStatus status, const std::vector<bool>& supported)
272 {
273 error = status;
274 sup = supported;
275 };
276
277 HalPolicy::Model model = {};
278
279 // operands
280 int32_t actValue = 0;
281 float weightValue[] = {1, 0, 0, 0, 0, 0, 0, 0,
282 0, 1, 0, 0, 0, 0, 0, 0,
283 0, 0, 1, 0, 0, 0, 0, 0,
284 0, 0, 0, 1, 0, 0, 0, 0,
285 0, 0, 0, 0, 1, 0, 0, 0,
286 0, 0, 0, 0, 0, 1, 0, 0,
287 0, 0, 0, 0, 0, 0, 1, 0,
288 0, 0, 0, 0, 0, 0, 0, 1}; //identity
289 float biasValue[] = {0, 0, 0, 0, 0, 0, 0, 0};
290
291 // fully connected operation
292 AddInputOperand<HalPolicy>(model, hidl_vec<uint32_t>{1, 1, 1, 8});
293 AddInputOperand<HalPolicy>(model, hidl_vec<uint32_t>{8, 8});
294 AddInputOperand<HalPolicy>(model, hidl_vec<uint32_t>{8});
295 AddIntOperand<HalPolicy>(model, actValue);
296 AddOutputOperand<HalPolicy>(model, hidl_vec<uint32_t>{1, 8});
297
298 model.operations.resize(1);
299
300 model.operations[0].type = HalPolicy::OperationType::FULLY_CONNECTED;
301 model.operations[0].inputs = hidl_vec<uint32_t>{0,1,2,3};
302 model.operations[0].outputs = hidl_vec<uint32_t>{4};
303
304 // make the prepared model
305 android::sp<V1_0::IPreparedModel> preparedModel = PrepareModel(model, *driver);
306
307 // construct the request for input
308 V1_0::DataLocation inloc = {};
309 inloc.poolIndex = 0;
310 inloc.offset = 0;
311 inloc.length = 8 * sizeof(float);
312 RequestArgument input = {};
313 input.location = inloc;
314 input.dimensions = hidl_vec<uint32_t>{1, 1, 1, 8};
315
316 // construct the request for weights as input
317 V1_0::DataLocation wloc = {};
318 wloc.poolIndex = 1;
319 wloc.offset = 0;
320 wloc.length = 64 * sizeof(float);
321 RequestArgument weights = {};
322 weights.location = wloc;
323 weights.dimensions = hidl_vec<uint32_t>{8, 8};
324
325 // construct the request for bias as input
326 V1_0::DataLocation bloc = {};
327 bloc.poolIndex = 2;
328 bloc.offset = 0;
329 bloc.length = 8 * sizeof(float);
330 RequestArgument bias = {};
331 bias.location = bloc;
332 bias.dimensions = hidl_vec<uint32_t>{8};
333
334 V1_0::DataLocation outloc = {};
335 outloc.poolIndex = 3;
336 outloc.offset = 0;
337 outloc.length = 8 * sizeof(float);
338 RequestArgument output = {};
339 output.location = outloc;
340 output.dimensions = hidl_vec<uint32_t>{1, 8};
341
342 V1_0::Request request = {};
343 request.inputs = hidl_vec<RequestArgument>{input, weights, bias};
344 request.outputs = hidl_vec<RequestArgument>{output};
345
346 // set the input data
347 float indata[] = {1,2,3,4,5,6,7,8};
348 AddPoolAndSetData(8, request, indata);
349
350 // set the weights data
351 AddPoolAndSetData(64, request, weightValue);
352 // set the bias data
353 AddPoolAndSetData(8, request, biasValue);
354
355 // add memory for the output
356 android::sp<IMemory> outMemory = AddPoolAndGetData<float>(8, request);
357 float* outdata = static_cast<float*>(static_cast<void*>(outMemory->getPointer()));
358
359 // run the execution
360 if (preparedModel != nullptr)
361 {
362 Execute(preparedModel, request);
363 }
364
365 // check the result
Mike Kellye2d611e2021-10-14 12:35:58 +0100366 DOCTEST_CHECK(outdata[0] == 1);
367 DOCTEST_CHECK(outdata[1] == 2);
368 DOCTEST_CHECK(outdata[2] == 3);
369 DOCTEST_CHECK(outdata[3] == 4);
370 DOCTEST_CHECK(outdata[4] == 5);
371 DOCTEST_CHECK(outdata[5] == 6);
372 DOCTEST_CHECK(outdata[6] == 7);
373 DOCTEST_CHECK(outdata[7] == 8);
Sadik Armagan2e4a24a2021-03-18 13:59:40 +0000374}
375
Sadik Armagan9150bff2021-05-26 15:40:53 +0100376}