blob: 55a1a395d8a3cc93008b0c1a6882064649c057d1 [file] [log] [blame]
surmeh0149b9e102018-05-17 14:11:25 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beck93e48982018-09-05 13:05:09 +01003// SPDX-License-Identifier: MIT
surmeh0149b9e102018-05-17 14:11:25 +01004//
5#include "DriverTestHelpers.hpp"
6#include <boost/test/unit_test.hpp>
7#include <log/log.h>
8
9BOOST_AUTO_TEST_SUITE(ConcurrentDriverTests)
10
11using ArmnnDriver = armnn_driver::ArmnnDriver;
12using DriverOptions = armnn_driver::DriverOptions;
13using namespace android::nn;
telsoa01ce3e84a2018-08-31 09:31:35 +010014using namespace android::hardware;
surmeh0149b9e102018-05-17 14:11:25 +010015using namespace driverTestHelpers;
telsoa01ce3e84a2018-08-31 09:31:35 +010016using namespace armnn_driver;
surmeh0149b9e102018-05-17 14:11:25 +010017
18// Add our own test for concurrent execution
19// The main point of this test is to check that multiple requests can be
20// executed without waiting for the callback from previous execution.
21// The operations performed are not significant.
22BOOST_AUTO_TEST_CASE(ConcurrentExecute)
23{
24 ALOGI("ConcurrentExecute: entry");
25
26 auto driver = std::make_unique<ArmnnDriver>(DriverOptions(armnn::Compute::CpuRef));
Matteo Martincigh8b287c22018-09-07 09:25:10 +010027 V1_0::Model model = {};
surmeh0149b9e102018-05-17 14:11:25 +010028
29 // add operands
30 int32_t actValue = 0;
31 float weightValue[] = {2, 4, 1};
32 float biasValue[] = {4};
33
34 AddInputOperand(model, hidl_vec<uint32_t>{1, 3});
35 AddTensorOperand(model, hidl_vec<uint32_t>{1, 3}, weightValue);
36 AddTensorOperand(model, hidl_vec<uint32_t>{1}, biasValue);
37 AddIntOperand(model, actValue);
38 AddOutputOperand(model, hidl_vec<uint32_t>{1, 1});
39
40 // make the fully connected operation
41 model.operations.resize(1);
Matteo Martincigh8b287c22018-09-07 09:25:10 +010042 model.operations[0].type = V1_0::OperationType::FULLY_CONNECTED;
surmeh0149b9e102018-05-17 14:11:25 +010043 model.operations[0].inputs = hidl_vec<uint32_t>{0, 1, 2, 3};
44 model.operations[0].outputs = hidl_vec<uint32_t>{4};
45
46 // make the prepared models
47 const size_t maxRequests = 5;
48 android::sp<IPreparedModel> preparedModels[maxRequests];
49 for (size_t i = 0; i < maxRequests; ++i)
50 {
51 preparedModels[i] = PrepareModel(model, *driver);
52 }
53
54 // construct the request data
55 DataLocation inloc = {};
56 inloc.poolIndex = 0;
57 inloc.offset = 0;
58 inloc.length = 3 * sizeof(float);
59 RequestArgument input = {};
60 input.location = inloc;
61 input.dimensions = hidl_vec<uint32_t>{};
62
63 DataLocation outloc = {};
64 outloc.poolIndex = 1;
65 outloc.offset = 0;
66 outloc.length = 1 * sizeof(float);
67 RequestArgument output = {};
68 output.location = outloc;
69 output.dimensions = hidl_vec<uint32_t>{};
70
71 // build the requests
72 Request requests[maxRequests];
73 android::sp<IMemory> outMemory[maxRequests];
74 float* outdata[maxRequests];
75 for (size_t i = 0; i < maxRequests; ++i)
76 {
77 requests[i].inputs = hidl_vec<RequestArgument>{input};
78 requests[i].outputs = hidl_vec<RequestArgument>{output};
79 // set the input data (matching source test)
80 float indata[] = {2, 32, 16};
81 AddPoolAndSetData(3, requests[i], indata);
82 // add memory for the output
83 outMemory[i] = AddPoolAndGetData(1, requests[i]);
84 outdata[i] = static_cast<float*>(static_cast<void*>(outMemory[i]->getPointer()));
85 }
86
87 // invoke the execution of the requests
88 ALOGI("ConcurrentExecute: executing requests");
89 android::sp<ExecutionCallback> cb[maxRequests];
90 for (size_t i = 0; i < maxRequests; ++i)
91 {
92 cb[i] = ExecuteNoWait(preparedModels[i], requests[i]);
93 }
94
95 // wait for the requests to complete
96 ALOGI("ConcurrentExecute: waiting for callbacks");
97 for (size_t i = 0; i < maxRequests; ++i)
98 {
99 cb[i]->wait();
100 }
101
102 // check the results
103 ALOGI("ConcurrentExecute: validating results");
104 for (size_t i = 0; i < maxRequests; ++i)
105 {
106 BOOST_TEST(outdata[i][0] == 152);
107 }
108 ALOGI("ConcurrentExecute: exit");
109}
110
111BOOST_AUTO_TEST_SUITE_END()