blob: dd050592126e1c3d7440f930ed5d984d188d83b5 [file] [log] [blame]
Yulia Garbovichf61ea352021-11-11 14:16:57 +02001/*
Kristofer Jonssonac535f02022-03-10 11:08:39 +01002 * Copyright (c) 2020-2022 Arm Limited.
Yulia Garbovichf61ea352021-11-11 14:16:57 +02003 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19#ifndef MESSAGE_HANDLER_H
20#define MESSAGE_HANDLER_H
21
22#include "FreeRTOS.h"
23#include "queue.h"
24#include "semphr.h"
25
26#include "message_queue.hpp"
Davide Grohmann160001c2022-03-24 15:38:27 +010027#include <ethosu_core_interface.h>
Davide Grohmannadc908c2022-02-16 13:13:27 +010028#if defined(ETHOSU)
Kristofer Jonsson5410db12022-01-27 17:39:06 +010029#include <ethosu_driver.h>
Davide Grohmannadc908c2022-02-16 13:13:27 +010030#endif
Davide Grohmann65520052022-04-07 15:01:34 +020031#include <inference_parser.hpp>
Yulia Garbovichf61ea352021-11-11 14:16:57 +020032#include <inference_process.hpp>
33#include <mailbox.hpp>
34
Davide Grohmann160001c2022-03-24 15:38:27 +010035#include <algorithm>
Yulia Garbovichf61ea352021-11-11 14:16:57 +020036#include <cstddef>
37#include <cstdio>
Davide Grohmann160001c2022-03-24 15:38:27 +010038#include <list>
Yulia Garbovichf61ea352021-11-11 14:16:57 +020039#include <vector>
40
41namespace MessageHandler {
42
Davide Grohmann160001c2022-03-24 15:38:27 +010043template <typename T, size_t capacity = 10>
44class Queue {
45public:
46 using Predicate = std::function<bool(const T &data)>;
47
48 Queue() {
49 mutex = xSemaphoreCreateMutex();
50 size = xSemaphoreCreateCounting(capacity, 0u);
51
52 if (mutex == nullptr || size == nullptr) {
53 printf("Error: failed to allocate memory for inference queue\n");
54 }
55 }
56
57 ~Queue() {
58 vSemaphoreDelete(mutex);
59 vSemaphoreDelete(size);
60 }
61
62 bool push(const T &data) {
63 xSemaphoreTake(mutex, portMAX_DELAY);
64 if (list.size() >= capacity) {
65 xSemaphoreGive(mutex);
66 return false;
67 }
68
69 list.push_back(data);
70 xSemaphoreGive(mutex);
71
72 // increase number of available inferences to pop
73 xSemaphoreGive(size);
74 return true;
75 }
76
77 void pop(T &data) {
78 // decrease the number of available inferences to pop
79 xSemaphoreTake(size, portMAX_DELAY);
80
81 xSemaphoreTake(mutex, portMAX_DELAY);
82 data = list.front();
83 list.pop_front();
84 xSemaphoreGive(mutex);
85 }
86
87 bool erase(Predicate pred) {
88 // let's optimistically assume we are removing an inference, so decrease pop
89 if (pdFALSE == xSemaphoreTake(size, 0)) {
90 // if there are no inferences return immediately
91 return false;
92 }
93
94 xSemaphoreTake(mutex, portMAX_DELAY);
95 auto found = std::find_if(list.begin(), list.end(), pred);
96 bool erased = found != list.end();
97 if (erased) {
98 list.erase(found);
99 }
100 xSemaphoreGive(mutex);
101
102 if (!erased) {
103 // no inference erased, so let's put the size count back
104 xSemaphoreGive(size);
105 }
106
107 return erased;
108 }
109
110private:
111 std::list<T> list;
112
113 SemaphoreHandle_t mutex;
114 SemaphoreHandle_t size;
115};
116
Yulia Garbovichf61ea352021-11-11 14:16:57 +0200117class IncomingMessageHandler {
118public:
Davide Grohmann134c39e2022-04-25 12:21:12 +0200119 IncomingMessageHandler(EthosU::ethosu_core_queue &inputMessageQueue,
120 EthosU::ethosu_core_queue &outputMessageQueue,
Yulia Garbovichf61ea352021-11-11 14:16:57 +0200121 Mailbox::Mailbox &mailbox,
Davide Grohmann160001c2022-03-24 15:38:27 +0100122 std::shared_ptr<Queue<EthosU::ethosu_core_inference_req>> inferenceInputQueue,
Davide Grohmann134c39e2022-04-25 12:21:12 +0200123 QueueHandle_t inferenceOutputQueue,
124 SemaphoreHandle_t messageNotify);
Davide Grohmann160001c2022-03-24 15:38:27 +0100125
Yulia Garbovichf61ea352021-11-11 14:16:57 +0200126 void run();
127
128private:
129 bool handleMessage();
Davide Grohmann134c39e2022-04-25 12:21:12 +0200130 bool handleInferenceOutput();
Yulia Garbovichf61ea352021-11-11 14:16:57 +0200131 static void handleIrq(void *userArg);
132
Davide Grohmann134c39e2022-04-25 12:21:12 +0200133 void sendPong();
134 void sendErrorAndResetQueue(EthosU::ethosu_core_msg_err_type type, const char *message);
135 void sendVersionRsp();
136 void sendCapabilitiesRsp(uint64_t userArg);
137 void sendNetworkInfoRsp(uint64_t userArg, EthosU::ethosu_core_network_buffer &network);
138 void sendInferenceRsp(EthosU::ethosu_core_inference_rsp &inference);
139 void sendFailedInferenceRsp(uint64_t userArg, uint32_t status);
Davide Grohmann00de9ee2022-03-23 14:59:56 +0100140 void sendCancelInferenceRsp(uint64_t userArg, uint32_t status);
Davide Grohmann134c39e2022-04-25 12:21:12 +0200141 void readCapabilties(EthosU::ethosu_core_msg_capabilities_rsp &rsp);
142
143 MessageQueue::QueueImpl inputMessageQueue;
144 MessageQueue::QueueImpl outputMessageQueue;
Yulia Garbovichf61ea352021-11-11 14:16:57 +0200145 Mailbox::Mailbox &mailbox;
Davide Grohmann65520052022-04-07 15:01:34 +0200146 InferenceProcess::InferenceParser parser;
Davide Grohmann160001c2022-03-24 15:38:27 +0100147 std::shared_ptr<Queue<EthosU::ethosu_core_inference_req>> inferenceInputQueue;
Davide Grohmann134c39e2022-04-25 12:21:12 +0200148 QueueHandle_t inferenceOutputQueue;
149 SemaphoreHandle_t messageNotify;
150 EthosU::ethosu_core_msg_capabilities_rsp capabilities;
Yulia Garbovichf61ea352021-11-11 14:16:57 +0200151};
152
153class InferenceHandler {
154public:
Davide Grohmann134c39e2022-04-25 12:21:12 +0200155 InferenceHandler(uint8_t *tensorArena,
156 size_t arenaSize,
Davide Grohmann160001c2022-03-24 15:38:27 +0100157 std::shared_ptr<Queue<EthosU::ethosu_core_inference_req>> inferenceInputQueue,
Davide Grohmann134c39e2022-04-25 12:21:12 +0200158 QueueHandle_t inferenceOutputQueue,
159 SemaphoreHandle_t messageNotify);
Yulia Garbovichf61ea352021-11-11 14:16:57 +0200160
161 void run();
162
163private:
164 void runInference(EthosU::ethosu_core_inference_req &req, EthosU::ethosu_core_inference_rsp &rsp);
Kristofer Jonsson585ce692022-03-08 13:28:05 +0100165 bool getInferenceJob(const EthosU::ethosu_core_inference_req &req, InferenceProcess::InferenceJob &job);
166
Davide Grohmannadc908c2022-02-16 13:13:27 +0100167#if defined(ETHOSU)
Kristofer Jonsson5410db12022-01-27 17:39:06 +0100168 friend void ::ethosu_inference_begin(struct ethosu_driver *drv, void *userArg);
169 friend void ::ethosu_inference_end(struct ethosu_driver *drv, void *userArg);
Davide Grohmannadc908c2022-02-16 13:13:27 +0100170#endif
Davide Grohmann160001c2022-03-24 15:38:27 +0100171 std::shared_ptr<Queue<EthosU::ethosu_core_inference_req>> inferenceInputQueue;
Davide Grohmann134c39e2022-04-25 12:21:12 +0200172 QueueHandle_t inferenceOutputQueue;
173 SemaphoreHandle_t messageNotify;
Yulia Garbovichf61ea352021-11-11 14:16:57 +0200174 InferenceProcess::InferenceProcess inference;
Kristofer Jonsson5410db12022-01-27 17:39:06 +0100175 EthosU::ethosu_core_inference_req *currentReq;
176 EthosU::ethosu_core_inference_rsp *currentRsp;
Yulia Garbovichf61ea352021-11-11 14:16:57 +0200177};
178
Yulia Garbovichf61ea352021-11-11 14:16:57 +0200179} // namespace MessageHandler
180
181#endif