blob: 9f4f9b472035ecf9bd502589b8db281ca09d38ba [file] [log] [blame]
Davide Grohmannb35f0c62022-06-15 11:23:25 +02001/*
2 * Copyright (c) 2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19/****************************************************************************
20 * Includes
21 ****************************************************************************/
22
23#include "FreeRTOS.h"
24#include "queue.h"
25#include "semphr.h"
26#include "task.h"
27
28#include <inttypes.h>
29#include <stdio.h>
30
31#include "ethosu_core_interface.h"
32#include "indexed_networks.hpp"
33#include "message_client.hpp"
34#include "message_handler.hpp"
35#include "message_queue.hpp"
36#include "networks.hpp"
37#include "test_assertions.hpp"
38#include "test_helpers.hpp"
39
40#include <mailbox.hpp>
41#include <mhu_dummy.hpp>
42
43/* Disable semihosting */
44__asm(".global __use_no_semihosting\n\t");
45
46using namespace EthosU;
47using namespace MessageHandler;
48
49/****************************************************************************
50 * Defines
51 ****************************************************************************/
52
53// TensorArena static initialisation
54constexpr size_t arenaSize = TENSOR_ARENA_SIZE;
55
56__attribute__((section(".bss.tensor_arena"), aligned(16))) uint8_t tensorArena[arenaSize];
57
58// Message queue from remote host
59__attribute__((section("ethosu_core_in_queue"))) MessageQueue::Queue<1000> inputMessageQueue;
60
61// Message queue to remote host
62__attribute__((section("ethosu_core_out_queue"))) MessageQueue::Queue<1000> outputMessageQueue;
63
64namespace {
65Mailbox::MHUDummy mailbox;
66} // namespace
67
68/****************************************************************************
69 * Application
70 ****************************************************************************/
71namespace {
72
73struct TaskParams {
74 TaskParams() :
75 messageNotify(xSemaphoreCreateBinary()),
76 inferenceInputQueue(std::make_shared<Queue<ethosu_core_inference_req>>()),
77 inferenceOutputQueue(xQueueCreate(5, sizeof(ethosu_core_inference_rsp))),
78 networks(std::make_shared<WithIndexedNetworks>()) {}
79
80 SemaphoreHandle_t messageNotify;
81 // Used to pass inference requests to the inference runner task
82 std::shared_ptr<Queue<ethosu_core_inference_req>> inferenceInputQueue;
83 // Queue for message responses to the remote host
84 QueueHandle_t inferenceOutputQueue;
85 // Networks provider
86 std::shared_ptr<Networks> networks;
87};
88
89void messageTask(void *pvParameters) {
90 printf("Starting message task\n");
91 TaskParams *params = reinterpret_cast<TaskParams *>(pvParameters);
92
93 IncomingMessageHandler process(*inputMessageQueue.toQueue(),
94 *outputMessageQueue.toQueue(),
95 mailbox,
96 params->inferenceInputQueue,
97 params->inferenceOutputQueue,
98 params->messageNotify,
99 params->networks);
100 process.run();
101}
102
103void testCancelInference(MessageClient client) {
104 const uint64_t fake_inference_user_arg = 42;
105 const uint32_t network_index = 0;
106 ethosu_core_inference_req inference_req =
107 inferenceIndexedRequest(fake_inference_user_arg, network_index, nullptr, 0, nullptr, 0);
108
109 const uint64_t fake_cancel_inference_user_arg = 55;
110 ethosu_core_cancel_inference_req cancel_req = {fake_cancel_inference_user_arg, fake_inference_user_arg};
111
112 ethosu_core_inference_rsp inference_rsp;
113 ethosu_core_cancel_inference_rsp cancel_rsp;
114
115 TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_INFERENCE_REQ, inference_req));
116 TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_CANCEL_INFERENCE_REQ, cancel_req));
117
118 TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_INFERENCE_RSP, inference_rsp));
119 TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_CANCEL_INFERENCE_RSP, cancel_rsp));
120
121 TEST_ASSERT(inference_req.user_arg == inference_rsp.user_arg);
122 TEST_ASSERT(inference_rsp.status == ETHOSU_CORE_STATUS_ABORTED);
123
124 TEST_ASSERT(cancel_req.user_arg == cancel_rsp.user_arg);
125 TEST_ASSERT(cancel_rsp.status == ETHOSU_CORE_STATUS_OK);
126}
127
128void testCancelNonExistentInference(MessageClient client) {
129 const uint64_t fake_inference_user_arg = 42;
130 const uint64_t fake_cancel_inference_user_arg = 55;
131 ethosu_core_cancel_inference_req cancel_req = {fake_cancel_inference_user_arg, fake_inference_user_arg};
132 ethosu_core_cancel_inference_rsp cancel_rsp;
133
134 TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_CANCEL_INFERENCE_REQ, cancel_req));
135 TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_CANCEL_INFERENCE_RSP, cancel_rsp));
136
137 TEST_ASSERT(cancel_req.user_arg == cancel_rsp.user_arg);
138 TEST_ASSERT(cancel_rsp.status == ETHOSU_CORE_STATUS_ERROR);
139}
140
141void testCannotCancelRunningInference(MessageClient client,
142 std::shared_ptr<Queue<ethosu_core_inference_req>> inferenceInputQueue) {
143 const uint64_t fake_inference_user_arg = 42;
144 const uint32_t network_index = 0;
145 ethosu_core_inference_req inference_req =
146 inferenceIndexedRequest(fake_inference_user_arg, network_index, nullptr, 0, nullptr, 0);
147
148 const uint64_t fake_cancel_inference_user_arg = 55;
149 ethosu_core_cancel_inference_req cancel_req = {fake_cancel_inference_user_arg, fake_inference_user_arg};
150 ethosu_core_cancel_inference_rsp cancel_rsp;
151
152 TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_INFERENCE_REQ, inference_req));
153
154 // fake start of the inference by removing the inference from the queue
155 ethosu_core_inference_req start_req;
156 inferenceInputQueue->pop(start_req);
157 TEST_ASSERT(inference_req.user_arg == start_req.user_arg);
158
159 TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_CANCEL_INFERENCE_REQ, cancel_req));
160 TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_CANCEL_INFERENCE_RSP, cancel_rsp));
161
162 TEST_ASSERT(cancel_req.user_arg == cancel_rsp.user_arg);
163 TEST_ASSERT(cancel_rsp.status == ETHOSU_CORE_STATUS_ERROR);
164}
165
166void testRejectInference(MessageClient client) {
167 int runs = 6;
168 const uint64_t fake_inference_user_arg = 42;
169 const uint32_t network_index = 0;
170 const uint64_t fake_cancel_inference_user_arg = 55;
171 ethosu_core_inference_req req;
172 ethosu_core_inference_rsp rsp;
173
174 for (int i = 0; i < runs; i++) {
175
176 req = inferenceIndexedRequest(fake_inference_user_arg + i, network_index, nullptr, 0, nullptr, 0);
177 TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_INFERENCE_REQ, req));
178 vTaskDelay(150);
179 }
180
181 TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_INFERENCE_RSP, rsp));
182 TEST_ASSERT(uint64_t(fake_inference_user_arg + runs - 1) == rsp.user_arg);
183 TEST_ASSERT(rsp.status == ETHOSU_CORE_STATUS_REJECTED);
184
185 // let's cleanup the queue
186 ethosu_core_cancel_inference_req cancel_req = {0, 0};
187 ethosu_core_cancel_inference_rsp cancel_rsp;
188 ethosu_core_inference_rsp inference_rsp;
189
190 for (int i = 0; i < runs - 1; i++) {
191 cancel_req.user_arg = fake_cancel_inference_user_arg + i;
192 cancel_req.inference_handle = fake_inference_user_arg + i;
193 TEST_ASSERT(client.sendInputMessage(ETHOSU_CORE_MSG_CANCEL_INFERENCE_REQ, cancel_req));
194
195 TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_INFERENCE_RSP, inference_rsp));
196 TEST_ASSERT(inference_rsp.user_arg = cancel_req.inference_handle);
197
198 TEST_ASSERT(client.waitAndReadOutputMessage(ETHOSU_CORE_MSG_CANCEL_INFERENCE_RSP, cancel_rsp));
199 TEST_ASSERT(cancel_req.user_arg == cancel_rsp.user_arg);
200 TEST_ASSERT(cancel_rsp.status == ETHOSU_CORE_STATUS_OK);
201 }
202}
203
204void clientTask(void *pvParameters) {
205 printf("Starting client task\n");
206 TaskParams *params = reinterpret_cast<TaskParams *>(pvParameters);
207
208 MessageClient client(*inputMessageQueue.toQueue(), *outputMessageQueue.toQueue(), mailbox);
209
210 vTaskDelay(50);
211
212 testCancelInference(client);
213 testCancelNonExistentInference(client);
214 testCannotCancelRunningInference(client, params->inferenceInputQueue);
215 testRejectInference(client);
216
217 exit(0);
218}
219
220/*
221 * Keep task parameters as global data as FreeRTOS resets the stack when the
222 * scheduler is started.
223 */
224TaskParams taskParams;
225
226} // namespace
227
228// FreeRTOS application. NOTE: Additional tasks may require increased heap size.
229int main() {
230 BaseType_t ret;
231
232 if (!mailbox.verifyHardware()) {
233 printf("Failed to verify mailbox hardware\n");
234 return 1;
235 }
236
237 // Task for handling incoming /outgoing messages from the remote host
238 ret = xTaskCreate(messageTask, "messageTask", 1024, &taskParams, 2, nullptr);
239 if (ret != pdPASS) {
240 printf("Failed to create 'messageTask'\n");
241 return ret;
242 }
243
244 // Task for handling incoming /outgoing messages from the remote host
245 ret = xTaskCreate(clientTask, "clientTask", 1024, &taskParams, 2, nullptr);
246 if (ret != pdPASS) {
247 printf("Failed to create 'messageTask'\n");
248 return ret;
249 }
250
251 // Start Scheduler
252 vTaskStartScheduler();
253
254 return 1;
255}