blob: 740154675762654f7189cc34f35db70cb419f134 [file] [log] [blame]
Yulia Garbovichf61ea352021-11-11 14:16:57 +02001/*
2 * Copyright (c) 2020-2021 Arm Limited. All rights reserved.
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19#include "message_handler.hpp"
20
21#include "cmsis_compiler.h"
22
23#ifdef ETHOSU
24#include <ethosu_driver.h>
25#endif
26
27#include "FreeRTOS.h"
28#include "queue.h"
29#include "semphr.h"
30
31#include <cstring>
32#include <inttypes.h>
33
34using namespace EthosU;
35using namespace MessageQueue;
36
37namespace MessageHandler {
38
39/****************************************************************************
40 * IncomingMessageHandler
41 ****************************************************************************/
42
43IncomingMessageHandler::IncomingMessageHandler(ethosu_core_queue &_messageQueue,
44 Mailbox::Mailbox &_mailbox,
45 QueueHandle_t _inferenceQueue,
46 QueueHandle_t _outputQueue) :
47 messageQueue(_messageQueue),
48 mailbox(_mailbox), inferenceQueue(_inferenceQueue), outputQueue(_outputQueue) {
49 mailbox.registerCallback(handleIrq, reinterpret_cast<void *>(this));
50 semaphore = xSemaphoreCreateBinary();
51}
52
53void IncomingMessageHandler::run() {
54 while (true) {
55 // Wait for event
56 xSemaphoreTake(semaphore, portMAX_DELAY);
57
58 // Handle all messages in queue
59 while (handleMessage()) {}
60 }
61}
62
63void IncomingMessageHandler::handleIrq(void *userArg) {
64 IncomingMessageHandler *_this = reinterpret_cast<IncomingMessageHandler *>(userArg);
65 xSemaphoreGive(_this->semaphore);
66}
67
68void IncomingMessageHandler::queueErrorAndResetQueue(ethosu_core_msg_err_type type, const char *message) {
69 OutputMessage msg(ETHOSU_CORE_MSG_ERR);
70 msg.data.error.type = type;
71
72 for (size_t i = 0; i < sizeof(msg.data.error.msg) && message[i]; i++) {
73 msg.data.error.msg[i] = message[i];
74 }
75
76 xQueueSend(outputQueue, &msg, portMAX_DELAY);
77 messageQueue.reset();
78}
79
80bool IncomingMessageHandler::handleMessage() {
81 struct ethosu_core_msg msg;
82
83 if (messageQueue.available() == 0) {
84 return false;
85 }
86
87 // Read msg header
88 // Only process a complete message header, else send error message
89 // and reset queue
90 if (!messageQueue.read(msg)) {
91 queueErrorAndResetQueue(ETHOSU_CORE_MSG_ERR_INVALID_SIZE, "Failed to read a complete header");
92 return false;
93 }
94
95 printf("Msg: header magic=%" PRIX32 ", type=%" PRIu32 ", length=%" PRIu32 "\n", msg.magic, msg.type, msg.length);
96
97 if (msg.magic != ETHOSU_CORE_MSG_MAGIC) {
98 printf("Msg: Invalid Magic\n");
99 queueErrorAndResetQueue(ETHOSU_CORE_MSG_ERR_INVALID_MAGIC, "Invalid magic");
100 return false;
101 }
102
103 switch (msg.type) {
104 case ETHOSU_CORE_MSG_PING: {
105 printf("Msg: Ping\n");
106
107 OutputMessage message(ETHOSU_CORE_MSG_PONG);
108 xQueueSend(outputQueue, &message, portMAX_DELAY);
109 break;
110 }
111 case ETHOSU_CORE_MSG_ERR: {
112 ethosu_core_msg_err error;
113
114 if (!messageQueue.read(error)) {
115 printf("ERROR: Msg: Failed to receive error message\n");
116 } else {
117 printf("Msg: Received an error response, type=%" PRIu32 ", msg=\"%s\"\n", error.type, error.msg);
118 }
119
120 messageQueue.reset();
121 return false;
122 }
123 case ETHOSU_CORE_MSG_VERSION_REQ: {
124 printf("Msg: Version request\n");
125
126 OutputMessage message(ETHOSU_CORE_MSG_VERSION_RSP);
127 xQueueSend(outputQueue, &message, portMAX_DELAY);
128 break;
129 }
130 case ETHOSU_CORE_MSG_CAPABILITIES_REQ: {
131 ethosu_core_capabilities_req capabilities;
132
133 if (!messageQueue.read(capabilities)) {
134 queueErrorAndResetQueue(ETHOSU_CORE_MSG_ERR_INVALID_PAYLOAD, "CapabilitiesReq. Failed to read payload");
135 break;
136 }
137
138 printf("Msg: Capabilities request.user_arg=0x%" PRIx64 "\n", capabilities.user_arg);
139
140 OutputMessage message(ETHOSU_CORE_MSG_CAPABILITIES_RSP);
141 message.data.userArg = capabilities.user_arg;
142 xQueueSend(outputQueue, &message, portMAX_DELAY);
143 break;
144 }
145 case ETHOSU_CORE_MSG_INFERENCE_REQ: {
146 ethosu_core_inference_req inference;
147
148 if (!messageQueue.read(inference)) {
149 queueErrorAndResetQueue(ETHOSU_CORE_MSG_ERR_INVALID_PAYLOAD, "InferenceReq. Failed to read payload");
150 break;
151 }
152
153 printf("Msg: InferenceReq. user_arg=0x%" PRIx64 ", network={0x%" PRIx32 ", %" PRIu32 "}\n",
154 inference.user_arg,
155 inference.network.ptr,
156 inference.network.size);
157
158 printf(", ifm_count=%" PRIu32 ", ifm=[", inference.ifm_count);
159 for (uint32_t i = 0; i < inference.ifm_count; ++i) {
160 if (i > 0) {
161 printf(", ");
162 }
163
164 printf("{0x%" PRIx32 ", %" PRIu32 "}", inference.ifm[i].ptr, inference.ifm[i].size);
165 }
166 printf("]");
167
168 printf(", ofm_count=%" PRIu32 ", ofm=[", inference.ofm_count);
169 for (uint32_t i = 0; i < inference.ofm_count; ++i) {
170 if (i > 0) {
171 printf(", ");
172 }
173
174 printf("{0x%" PRIx32 ", %" PRIu32 "}", inference.ofm[i].ptr, inference.ofm[i].size);
175 }
176 printf("]\n");
177
178 xQueueSend(inferenceQueue, &inference, portMAX_DELAY);
179 break;
180 }
181 default: {
182 char errMsg[128];
183
184 snprintf(&errMsg[0],
185 sizeof(errMsg),
186 "Msg: Unknown type: %" PRIu32 " with payload length %" PRIu32 " bytes\n",
187 msg.type,
188 msg.length);
189
190 queueErrorAndResetQueue(ETHOSU_CORE_MSG_ERR_UNSUPPORTED_TYPE, errMsg);
191
192 return false;
193 }
194 }
195
196 return true;
197}
198
199/****************************************************************************
200 * InferenceHandler
201 ****************************************************************************/
202
203InferenceHandler::InferenceHandler(uint8_t *tensorArena,
204 size_t arenaSize,
205 QueueHandle_t _inferenceQueue,
206 QueueHandle_t _outputQueue) :
207 inferenceQueue(_inferenceQueue),
208 outputQueue(_outputQueue), inference(tensorArena, arenaSize) {}
209
210void InferenceHandler::run() {
211 while (true) {
212 ethosu_core_inference_req req;
213
214 if (pdTRUE != xQueueReceive(inferenceQueue, &req, portMAX_DELAY)) {
215 continue;
216 }
217
218 OutputMessage msg(ETHOSU_CORE_MSG_INFERENCE_RSP);
219 runInference(req, msg.data.inference);
220
221 xQueueSend(outputQueue, &msg, portMAX_DELAY);
222 }
223}
224
225void InferenceHandler::runInference(ethosu_core_inference_req &req, ethosu_core_inference_rsp &rsp) {
226 /*
227 * Setup inference job
228 */
229
230 InferenceProcess::DataPtr networkModel(reinterpret_cast<void *>(req.network.ptr), req.network.size);
231
232 std::vector<InferenceProcess::DataPtr> ifm;
233 for (uint32_t i = 0; i < req.ifm_count; ++i) {
234 ifm.push_back(InferenceProcess::DataPtr(reinterpret_cast<void *>(req.ifm[i].ptr), req.ifm[i].size));
235 }
236
237 std::vector<InferenceProcess::DataPtr> ofm;
238 for (uint32_t i = 0; i < req.ofm_count; ++i) {
239 ofm.push_back(InferenceProcess::DataPtr(reinterpret_cast<void *>(req.ofm[i].ptr), req.ofm[i].size));
240 }
241
242 std::vector<InferenceProcess::DataPtr> expectedOutput;
243
244 std::vector<uint8_t> pmuEventConfig(ETHOSU_CORE_PMU_MAX);
245 for (uint32_t i = 0; i < ETHOSU_CORE_PMU_MAX; i++) {
246 pmuEventConfig[i] = req.pmu_event_config[i];
247 }
248
249 InferenceProcess::InferenceJob job(
250 "job", networkModel, ifm, ofm, expectedOutput, -1, pmuEventConfig, req.pmu_cycle_counter_enable);
251
252 /*
253 * Run inference
254 */
255
256 job.invalidate();
257 bool failed = inference.runJob(job);
258 job.clean();
259
260 /*
261 * Send inference response
262 */
263
264 rsp.user_arg = req.user_arg;
265 rsp.ofm_count = job.output.size();
266 rsp.status = failed ? ETHOSU_CORE_STATUS_ERROR : ETHOSU_CORE_STATUS_OK;
267
268 for (size_t i = 0; i < job.output.size(); ++i) {
269 rsp.ofm_size[i] = job.output[i].size;
270 }
271
272 for (size_t i = 0; i < job.pmuEventConfig.size(); i++) {
273 rsp.pmu_event_config[i] = job.pmuEventConfig[i];
274 }
275
276 for (size_t i = 0; i < job.pmuEventCount.size(); i++) {
277 rsp.pmu_event_count[i] = job.pmuEventCount[i];
278 }
279
280 rsp.pmu_cycle_counter_enable = job.pmuCycleCounterEnable;
281 rsp.pmu_cycle_counter_count = job.pmuCycleCounterCount;
282}
283
284/****************************************************************************
285 * OutgoingMessageHandler
286 ****************************************************************************/
287
288OutgoingMessageHandler::OutgoingMessageHandler(ethosu_core_queue &_messageQueue,
289 Mailbox::Mailbox &_mailbox,
290 QueueHandle_t _outputQueue) :
291 messageQueue(_messageQueue),
292 mailbox(_mailbox), outputQueue(_outputQueue) {
293 readCapabilties(capabilities);
294}
295
296void OutgoingMessageHandler::run() {
297 while (true) {
298 OutputMessage message;
299 if (pdTRUE != xQueueReceive(outputQueue, &message, portMAX_DELAY)) {
300 continue;
301 }
302
303 switch (message.type) {
304 case ETHOSU_CORE_MSG_INFERENCE_RSP:
305 sendInferenceRsp(message.data.inference);
306 break;
307 case ETHOSU_CORE_MSG_CAPABILITIES_RSP:
308 sendCapabilitiesRsp(message.data.userArg);
309 break;
310 case ETHOSU_CORE_MSG_VERSION_RSP:
311 sendVersionRsp();
312 break;
313 case ETHOSU_CORE_MSG_PONG:
314 sendPong();
315 break;
316 case ETHOSU_CORE_MSG_ERR:
317 sendErrorRsp(message.data.error);
318 break;
319 default:
320 printf("Dropping unknown outcome of type %d\n", message.type);
321 break;
322 }
323 }
324}
325
326void OutgoingMessageHandler::sendPong() {
327 if (!messageQueue.write(ETHOSU_CORE_MSG_PONG)) {
328 printf("ERROR: Msg: Failed to write pong response. No mailbox message sent\n");
329 } else {
330 mailbox.sendMessage();
331 }
332}
333
334void OutgoingMessageHandler::sendVersionRsp() {
335 ethosu_core_msg_version version = {
336 ETHOSU_CORE_MSG_VERSION_MAJOR,
337 ETHOSU_CORE_MSG_VERSION_MINOR,
338 ETHOSU_CORE_MSG_VERSION_PATCH,
339 0,
340 };
341
342 if (!messageQueue.write(ETHOSU_CORE_MSG_VERSION_RSP, version)) {
343 printf("ERROR: Failed to write version response. No mailbox message sent\n");
344 } else {
345 mailbox.sendMessage();
346 }
347}
348
349void OutgoingMessageHandler::sendCapabilitiesRsp(uint64_t userArg) {
350 capabilities.user_arg = userArg;
351
352 if (!messageQueue.write(ETHOSU_CORE_MSG_CAPABILITIES_RSP, capabilities)) {
353 printf("ERROR: Failed to write capabilities response. No mailbox message sent\n");
354 } else {
355 mailbox.sendMessage();
356 }
357}
358
359void OutgoingMessageHandler::sendInferenceRsp(ethosu_core_inference_rsp &inference) {
360 if (!messageQueue.write(ETHOSU_CORE_MSG_INFERENCE_RSP, inference)) {
361 printf("ERROR: Msg: Failed to write inference response. No mailbox message sent\n");
362 } else {
363 mailbox.sendMessage();
364 }
365}
366
367void OutgoingMessageHandler::sendErrorRsp(ethosu_core_msg_err &error) {
368 printf("ERROR: Msg: \"%s\"\n", error.msg);
369
370 if (!messageQueue.write(ETHOSU_CORE_MSG_ERR, error)) {
371 printf("ERROR: Msg: Failed to write error response. No mailbox message sent\n");
372 } else {
373 mailbox.sendMessage();
374 }
375}
376
377void OutgoingMessageHandler::readCapabilties(ethosu_core_msg_capabilities_rsp &rsp) {
378 rsp = {0};
379
380#ifdef ETHOSU
381 struct ethosu_driver_version version;
382 ethosu_get_driver_version(&version);
383
384 struct ethosu_hw_info info;
385 struct ethosu_driver *drv = ethosu_reserve_driver();
386 ethosu_get_hw_info(drv, &info);
387 ethosu_release_driver(drv);
388
389 rsp.user_arg = 0;
390 rsp.version_status = info.version.version_status;
391 rsp.version_minor = info.version.version_minor;
392 rsp.version_major = info.version.version_major;
393 rsp.product_major = info.version.product_major;
394 rsp.arch_patch_rev = info.version.arch_patch_rev;
395 rsp.arch_minor_rev = info.version.arch_minor_rev;
396 rsp.arch_major_rev = info.version.arch_major_rev;
397 rsp.driver_patch_rev = version.patch;
398 rsp.driver_minor_rev = version.minor;
399 rsp.driver_major_rev = version.major;
400 rsp.macs_per_cc = info.cfg.macs_per_cc;
401 rsp.cmd_stream_version = info.cfg.cmd_stream_version;
402 rsp.custom_dma = info.cfg.custom_dma;
403#endif
404}
405
406} // namespace MessageHandler