blob: 36514f470b42527c3def735a4ddd0db75b3eeee2 [file] [log] [blame]
Davide Grohmann6d2e5b72022-08-24 17:01:40 +02001/*
2 * Copyright (c) 2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19#include <ethosu.hpp>
20#include <uapi/ethosu.h>
21
22#include <cstring>
23#include <iostream>
24#include <list>
25#include <memory>
26#include <sstream>
27#include <stdio.h>
28#include <string>
29#include <unistd.h>
30
31#include "input.h"
32#include "model.h"
33#include "output.h"
34#include "test_assertions.hpp"
35
36using namespace EthosU;
37
38namespace {
39
40int64_t defaultTimeout = 60000000000;
41
42void testCancelInference(const Device &device) {
43 try {
44 auto networkBuffer = std::make_shared<Buffer>(device, sizeof(networkModelData));
45 networkBuffer->resize(sizeof(networkModelData));
46 std::memcpy(networkBuffer->data(), networkModelData, sizeof(networkModelData));
47 auto network = std::make_shared<Network>(device, networkBuffer);
48
49 std::vector<std::shared_ptr<Buffer>> inputBuffers;
50 std::vector<std::shared_ptr<Buffer>> outputBuffers;
51
52 auto inputBuffer = std::make_shared<Buffer>(device, sizeof(inputData));
53 inputBuffer->resize(sizeof(inputData));
54 std::memcpy(inputBuffer->data(), inputData, sizeof(inputData));
55
56 inputBuffers.push_back(inputBuffer);
57 outputBuffers.push_back(std::make_shared<Buffer>(device, sizeof(expectedOutputData)));
58 std::vector<uint8_t> enabledCounters(Inference::getMaxPmuEventCounters());
59
60 auto inference = std::make_shared<Inference>(network,
61 inputBuffers.begin(),
62 inputBuffers.end(),
63 outputBuffers.begin(),
64 outputBuffers.end(),
65 enabledCounters,
66 false);
67
68 InferenceStatus status = inference->status();
69 TEST_ASSERT(status == InferenceStatus::RUNNING);
70
71 bool success = inference->cancel();
72 TEST_ASSERT(success);
73
74 status = inference->status();
75 TEST_ASSERT(status == InferenceStatus::ABORTED);
76
77 bool timedout = inference->wait(defaultTimeout);
78 TEST_ASSERT(!timedout);
79
80 } catch (std::exception &e) { throw TestFailureException("Inference run test: ", e.what()); }
81}
82
83void testRejectInference(const Device &device) {
84 try {
85 auto networkBuffer = std::make_shared<Buffer>(device, sizeof(networkModelData));
86 networkBuffer->resize(sizeof(networkModelData));
87 std::memcpy(networkBuffer->data(), networkModelData, sizeof(networkModelData));
88 auto network = std::make_shared<Network>(device, networkBuffer);
89
90 std::vector<std::shared_ptr<Buffer>> inputBuffers;
91 std::vector<std::shared_ptr<Buffer>> outputBuffers;
92
93 auto inputBuffer = std::make_shared<Buffer>(device, sizeof(inputData));
94 inputBuffer->resize(sizeof(inputData));
95 std::memcpy(inputBuffer->data(), inputData, sizeof(inputData));
96
97 inputBuffers.push_back(inputBuffer);
98 outputBuffers.push_back(std::make_shared<Buffer>(device, sizeof(expectedOutputData)));
99 std::vector<uint8_t> enabledCounters(Inference::getMaxPmuEventCounters());
100
101 std::shared_ptr<Inference> inferences[5];
102
103 for (int i = 0; i < 5; i++) {
104 inferences[i] = std::make_shared<Inference>(network,
105 inputBuffers.begin(),
106 inputBuffers.end(),
107 outputBuffers.begin(),
108 outputBuffers.end(),
109 enabledCounters,
110 false);
111
112 InferenceStatus status = inferences[i]->status();
113 TEST_ASSERT(status == InferenceStatus::RUNNING);
114 }
115
116 auto inference = std::make_shared<Inference>(network,
117 inputBuffers.begin(),
118 inputBuffers.end(),
119 outputBuffers.begin(),
120 outputBuffers.end(),
121 enabledCounters,
122 false);
123
124 bool timedout = inference->wait(defaultTimeout);
125 TEST_ASSERT(!timedout);
126
127 InferenceStatus status = inference->status();
128 TEST_ASSERT(status == InferenceStatus::REJECTED);
129
130 for (int i = 0; i < 5; i++) {
131
132 bool success = inferences[i]->cancel();
133 TEST_ASSERT(success);
134
135 InferenceStatus status = inferences[i]->status();
136 TEST_ASSERT(status == InferenceStatus::ABORTED);
137
138 bool timedout = inference->wait(defaultTimeout);
139 TEST_ASSERT(!timedout);
140 }
141 } catch (std::exception &e) { throw TestFailureException("Inference run test: ", e.what()); }
142}
143
144} // namespace
145
146int main() {
147 Device device;
148
149 try {
150 testCancelInference(device);
151 testRejectInference(device);
152 } catch (TestFailureException &e) {
153 std::cerr << "Test failure: " << e.what() << std::endl;
154 return 1;
155 }
156
157 return 0;
158}