blob: 4bd611deaf4d3c217d07a763fe54fea83e86c856 [file] [log] [blame]
Kristofer Jonsson3f5510f2023-02-08 14:23:00 +01001/*
2 * SPDX-FileCopyrightText: Copyright 2022-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use _this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19/*****************************************************************************
20 * Includes
21 *****************************************************************************/
22
23#include "message_handler.hpp"
24
25#include <cinttypes>
26#include <cstdlib>
27
28#include <ethosu_log.h>
29#include <inference_parser.hpp>
30
31#ifdef ETHOSU
32#include <ethosu_driver.h>
33#endif
34
35/*****************************************************************************
36 * Networks
37 *****************************************************************************/
38
39namespace {
40#if defined(__has_include)
41
42#if defined(MODEL_0)
43namespace Model0 {
44#include STRINGIFY(MODEL_0)
45}
46#endif
47
48#if defined(MODEL_1)
49namespace Model1 {
50#include STRINGIFY(MODEL_1)
51}
52#endif
53
54#if defined(MODEL_2)
55namespace Model2 {
56#include STRINGIFY(MODEL_2)
57}
58#endif
59
60#if defined(MODEL_3)
61namespace Model3 {
62#include STRINGIFY(MODEL_3)
63}
64#endif
65
66#endif
67
68bool getIndexedNetwork(const uint32_t index, void *&data, size_t &size) {
69 switch (index) {
70#if defined(MODEL_0)
71 case 0:
72 data = reinterpret_cast<void *>(Model0::networkModelData);
73 size = sizeof(Model0::networkModelData);
74 break;
75#endif
76
77#if defined(MODEL_1)
78 case 1:
79 data = reinterpret_cast<void *>(Model1::networkModelData);
80 size = sizeof(Model1::networkModelData);
81 break;
82#endif
83
84#if defined(MODEL_2)
85 case 2:
86 data = reinterpret_cast<void *>(Model2::networkModelData);
87 size = sizeof(Model2::networkModelData);
88 break;
89#endif
90
91#if defined(MODEL_3)
92 case 3:
93 data = reinterpret_cast<void *>(Model3::networkModelData);
94 size = sizeof(Model3::networkModelData);
95 break;
96#endif
97
98 default:
99 LOG_WARN("Network model index out of range. index=%" PRIu32, index);
100 return true;
101 }
102
103 return false;
104}
105
106} // namespace
107
108/*****************************************************************************
109 * MessageHandler
110 *****************************************************************************/
111
112MessageHandler::MessageHandler(RProc &_rproc, const char *const _name) :
113 Rpmsg(_rproc, _name), capabilities(getCapabilities()) {
114 BaseType_t ret = xTaskCreate(responseTask, "responseTask", 1024, this, 3, &taskHandle);
115 if (ret != pdPASS) {
116 LOG_ERR("Failed to create response task");
117 abort();
118 }
119}
120
121MessageHandler::~MessageHandler() {
122 vTaskDelete(taskHandle);
123}
124
125int MessageHandler::handleMessage(void *data, size_t len, uint32_t src) {
126 auto rpmsg = static_cast<EthosU::ethosu_core_rpmsg *>(data);
127
128 LOG_DEBUG("Msg: src=%" PRIX32 ", len=%zu, magic=%" PRIX32 ", type=%" PRIu32,
129 src,
130 len,
131 rpmsg->header.magic,
132 rpmsg->header.type);
133
134 if (rpmsg->header.magic != ETHOSU_CORE_MSG_MAGIC) {
135 LOG_WARN("Msg: Invalid Magic");
136 sendError(src, EthosU::ETHOSU_CORE_MSG_ERR_INVALID_MAGIC, "Invalid magic");
137 return 0;
138 }
139
140 switch (rpmsg->header.type) {
141 case EthosU::ETHOSU_CORE_MSG_PING: {
142 LOG_INFO("Msg: Ping");
143 sendPong(src, rpmsg->header.msg_id);
144 break;
145 }
146 case EthosU::ETHOSU_CORE_MSG_VERSION_REQ: {
147 LOG_INFO("Msg: Version request");
148 sendVersionRsp(src, rpmsg->header.msg_id);
149 break;
150 }
151 case EthosU::ETHOSU_CORE_MSG_CAPABILITIES_REQ: {
152 if (len != sizeof(rpmsg->header)) {
153 sendError(
154 src, EthosU::ETHOSU_CORE_MSG_ERR_INVALID_PAYLOAD, "Incorrect capabilities request payload length.");
155 break;
156 }
157
158 LOG_INFO("Msg: Capabilities request");
159
160 sendCapabilitiesRsp(src, rpmsg->header.msg_id);
161 break;
162 }
163 case EthosU::ETHOSU_CORE_MSG_INFERENCE_REQ: {
164 if (len != sizeof(rpmsg->header) + sizeof(rpmsg->inf_req)) {
165 sendError(src, EthosU::ETHOSU_CORE_MSG_ERR_INVALID_PAYLOAD, "Incorrect inference request payload length.");
166 break;
167 }
168
169 forwardInferenceReq(src, rpmsg->header.msg_id, rpmsg->inf_req);
170 break;
171 }
172 case EthosU::ETHOSU_CORE_MSG_CANCEL_INFERENCE_REQ: {
173 if (len != sizeof(rpmsg->header) + sizeof(rpmsg->cancel_req)) {
174 sendError(
175 src, EthosU::ETHOSU_CORE_MSG_ERR_INVALID_PAYLOAD, "Incorrect cancel inference request payload length.");
176 break;
177 }
178
179 auto &request = rpmsg->cancel_req;
180 bool found = false;
181 inferenceQueue.erase([request, &found](auto &message) {
182 if (message->rpmsg.header.msg_id == request.inference_handle) {
183 found = true;
184 delete message;
185 return true;
186 }
187
188 return false;
189 });
190
191 if (found) {
192 sendInferenceRsp(src, request.inference_handle, EthosU::ETHOSU_CORE_STATUS_ABORTED);
193 }
194
195 sendCancelInferenceRsp(
196 src, rpmsg->header.msg_id, found ? EthosU::ETHOSU_CORE_STATUS_OK : EthosU::ETHOSU_CORE_STATUS_ERROR);
197 break;
198 }
199 case EthosU::ETHOSU_CORE_MSG_NETWORK_INFO_REQ: {
200 if (len != sizeof(rpmsg->header) + sizeof(rpmsg->net_info_req)) {
201 sendError(
202 src, EthosU::ETHOSU_CORE_MSG_ERR_INVALID_PAYLOAD, "Incorrect network info request payload length.");
203 break;
204 }
205
206 LOG_INFO("Msg: NetworkInfoReq. network={ type=%" PRIu32 ", index=%" PRIu32 ", buffer={ ptr=0x%" PRIX32
207 ", size=%" PRIu32 " } }",
208 rpmsg->net_info_req.network.type,
209 rpmsg->net_info_req.network.index,
210 rpmsg->net_info_req.network.buffer.ptr,
211 rpmsg->net_info_req.network.buffer.size);
212
213 sendNetworkInfoRsp(src, rpmsg->header.msg_id, rpmsg->net_info_req.network);
214 break;
215 }
216 default: {
217 LOG_WARN("Msg: Unsupported message. type=%" PRIu32, rpmsg->header.type);
218
219 char errMsg[128];
220 snprintf(
221 &errMsg[0], sizeof(errMsg), "Msg: Unknown message. type=%" PRIu32 ", length=%zu", rpmsg->header.type, len);
222
223 sendError(src, EthosU::ETHOSU_CORE_MSG_ERR_UNSUPPORTED_TYPE, errMsg);
224 }
225 }
226
227 return 0;
228}
229
230void MessageHandler::sendError(const uint32_t src, const EthosU::ethosu_core_err_type type, const char *msg) {
231 auto message = new Message(src, EthosU::ETHOSU_CORE_MSG_ERR, 0, sizeof(EthosU::ethosu_core_msg_err));
232
233 message->rpmsg.error.type = type;
234
235 for (size_t i = 0; i < sizeof(message->rpmsg.error.msg) && msg[i]; i++) {
236 message->rpmsg.error.msg[i] = msg[i];
237 }
238
239 responseQueue.send(message);
240}
241
242void MessageHandler::sendPong(const uint32_t src, const uint64_t msgId) {
243 auto message = new Message(src, EthosU::ETHOSU_CORE_MSG_PONG, msgId);
244
245 responseQueue.send(message);
246}
247
248void MessageHandler::sendVersionRsp(const uint32_t src, const uint64_t msgId) {
249 auto message =
250 new Message(src, EthosU::ETHOSU_CORE_MSG_VERSION_RSP, msgId, sizeof(EthosU::ethosu_core_msg_version_rsp));
251
252 message->rpmsg.version_rsp = {
253 ETHOSU_CORE_MSG_VERSION_MAJOR,
254 ETHOSU_CORE_MSG_VERSION_MINOR,
255 ETHOSU_CORE_MSG_VERSION_PATCH,
256 0,
257 };
258
259 responseQueue.send(message);
260}
261
262void MessageHandler::sendCapabilitiesRsp(const uint32_t src, const uint64_t msgId) {
263 auto message = new Message(
264 src, EthosU::ETHOSU_CORE_MSG_CAPABILITIES_RSP, msgId, sizeof(EthosU::ethosu_core_msg_capabilities_rsp));
265
266 message->rpmsg.cap_rsp = capabilities;
267
268 responseQueue.send(message);
269}
270
271EthosU::ethosu_core_msg_capabilities_rsp MessageHandler::getCapabilities() const {
272 EthosU::ethosu_core_msg_capabilities_rsp cap = {};
273
274#ifdef ETHOSU
275 ethosu_driver_version version;
276 ethosu_get_driver_version(&version);
277
278 ethosu_hw_info info;
279 ethosu_driver *drv = ethosu_reserve_driver();
280 ethosu_get_hw_info(drv, &info);
281 ethosu_release_driver(drv);
282
283 cap.version_status = info.version.version_status;
284 cap.version_minor = info.version.version_minor;
285 cap.version_major = info.version.version_major;
286 cap.product_major = info.version.product_major;
287 cap.arch_patch_rev = info.version.arch_patch_rev;
288 cap.arch_minor_rev = info.version.arch_minor_rev;
289 cap.arch_major_rev = info.version.arch_major_rev;
290 cap.driver_patch_rev = version.patch;
291 cap.driver_minor_rev = version.minor;
292 cap.driver_major_rev = version.major;
293 cap.macs_per_cc = info.cfg.macs_per_cc;
294 cap.cmd_stream_version = info.cfg.cmd_stream_version;
295 cap.custom_dma = info.cfg.custom_dma;
296#endif
297
298 return cap;
299}
300
301void MessageHandler::sendNetworkInfoRsp(const uint32_t src,
302 const uint64_t msgId,
303 EthosU::ethosu_core_network_buffer &network) {
304 auto message = new Message(
305 src, EthosU::ETHOSU_CORE_MSG_NETWORK_INFO_RSP, msgId, sizeof(EthosU::ethosu_core_msg_network_info_rsp));
306 auto &rsp = message->rpmsg.net_info_rsp;
307
308 rsp.ifm_count = 0;
309 rsp.ofm_count = 0;
310
311 bool failed = networkToVirtual(network);
312
313 if (!failed) {
314 InferenceProcess::InferenceParser parser;
315
316 failed = parser.parseModel(reinterpret_cast<void *>(network.buffer.ptr),
317 network.buffer.size,
318 rsp.desc,
319 InferenceProcess::makeArray(rsp.ifm_size, rsp.ifm_count, ETHOSU_CORE_BUFFER_MAX),
320 InferenceProcess::makeArray(rsp.ofm_size, rsp.ofm_count, ETHOSU_CORE_BUFFER_MAX));
321 }
322
323 rsp.status = failed ? EthosU::ETHOSU_CORE_STATUS_ERROR : EthosU::ETHOSU_CORE_STATUS_OK;
324
325 responseQueue.send(message);
326}
327
328void MessageHandler::forwardInferenceReq(const uint32_t src,
329 const uint64_t msgId,
330 const EthosU::ethosu_core_msg_inference_req &inference) {
331 auto message = new Message(src, EthosU::ETHOSU_CORE_MSG_INFERENCE_REQ, msgId);
332 auto &req = message->rpmsg.inf_req;
333
334 req = inference;
335
336 for (uint32_t i = 0; i < req.ifm_count; i++) {
337 bufferToVirtual(req.ifm[i]);
338 }
339
340 for (uint32_t i = 0; i < req.ofm_count; i++) {
341 bufferToVirtual(req.ofm[i]);
342 }
343
344 networkToVirtual(req.network);
345
346 inferenceQueue.send(message);
347}
348
349void MessageHandler::sendInferenceRsp(const uint32_t src,
350 const uint64_t msgId,
351 const EthosU::ethosu_core_status status) {
352 auto message =
353 new Message(src, EthosU::ETHOSU_CORE_MSG_INFERENCE_RSP, msgId, sizeof(EthosU::ethosu_core_msg_inference_rsp));
354
355 message->rpmsg.inf_rsp.status = status;
356
357 responseQueue.send(message);
358}
359
360void MessageHandler::sendCancelInferenceRsp(const uint32_t src,
361 const uint64_t msgId,
362 const EthosU::ethosu_core_status status) {
363 auto message = new Message(
364 src, EthosU::ETHOSU_CORE_MSG_CANCEL_INFERENCE_RSP, msgId, sizeof(EthosU::ethosu_core_msg_cancel_inference_rsp));
365
366 message->rpmsg.cancel_rsp.status = status;
367
368 responseQueue.send(message);
369}
370
371bool MessageHandler::getNetwork(const EthosU::ethosu_core_network_buffer &buffer, void *&data, size_t &size) {
372 switch (buffer.type) {
373 case EthosU::ETHOSU_CORE_NETWORK_BUFFER:
374 data = physicalToVirtual(buffer.buffer.ptr);
375 size = buffer.buffer.size;
376 return false;
377 case EthosU::ETHOSU_CORE_NETWORK_INDEX:
378 return getIndexedNetwork(buffer.index, data, size);
379 default:
380 LOG_WARN("Unsupported network model type. type=%" PRIu32, buffer.type);
381 return true;
382 }
383}
384
385bool MessageHandler::bufferToVirtual(EthosU::ethosu_core_buffer &buffer) {
386 void *ptr = physicalToVirtual(buffer.ptr);
387 if (ptr == nullptr) {
388 return true;
389 }
390
391 buffer.ptr = reinterpret_cast<uint32_t>(ptr);
392
393 return false;
394}
395
396bool MessageHandler::networkToVirtual(EthosU::ethosu_core_network_buffer &buffer) {
397 switch (buffer.type) {
398 case EthosU::ETHOSU_CORE_NETWORK_BUFFER:
399 return bufferToVirtual(buffer.buffer);
400 case EthosU::ETHOSU_CORE_NETWORK_INDEX: {
401 void *ptr;
402 size_t size;
403 if (getIndexedNetwork(buffer.index, ptr, size)) {
404 return true;
405 }
406
407 buffer.type = EthosU::ETHOSU_CORE_NETWORK_BUFFER;
408 buffer.buffer.ptr = reinterpret_cast<uint32_t>(ptr);
409 buffer.buffer.size = size;
410
411 return false;
412 }
413 default:
414 LOG_WARN("Unsupported network model type. type=%" PRIu32, buffer.type);
415 return true;
416 }
417}
418
419void MessageHandler::responseTask(void *param) {
420 auto _this = static_cast<MessageHandler *>(param);
421
422 LOG_DEBUG("Starting message response task");
423
424 while (true) {
425 Message *message;
426 auto ret = _this->responseQueue.receive(message);
427 if (ret) {
428 abort();
429 }
430
431 LOG_DEBUG("Sending message. type=%" PRIu32, message->rpmsg.header.type);
432
433 _this->send(&message->rpmsg, sizeof(message->rpmsg.header) + message->length, message->src);
434
435 delete message;
436 }
437}