blob: 3c227be00caab8cc3e043610acadc5cd5e5bfb6a [file] [log] [blame]
Yulia Garbovichf61ea352021-11-11 14:16:57 +02001/*
Kristofer Jonssonac535f02022-03-10 11:08:39 +01002 * Copyright (c) 2020-2022 Arm Limited.
Yulia Garbovichf61ea352021-11-11 14:16:57 +02003 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19#ifndef MESSAGE_HANDLER_H
20#define MESSAGE_HANDLER_H
21
22#include "FreeRTOS.h"
23#include "queue.h"
24#include "semphr.h"
25
26#include "message_queue.hpp"
Davide Grohmann144b2d22022-05-31 15:24:02 +020027#include "networks.hpp"
Davide Grohmann160001c2022-03-24 15:38:27 +010028#include <ethosu_core_interface.h>
Davide Grohmannadc908c2022-02-16 13:13:27 +010029#if defined(ETHOSU)
Kristofer Jonsson5410db12022-01-27 17:39:06 +010030#include <ethosu_driver.h>
Davide Grohmannadc908c2022-02-16 13:13:27 +010031#endif
Davide Grohmann65520052022-04-07 15:01:34 +020032#include <inference_parser.hpp>
Yulia Garbovichf61ea352021-11-11 14:16:57 +020033#include <inference_process.hpp>
34#include <mailbox.hpp>
35
Davide Grohmann160001c2022-03-24 15:38:27 +010036#include <algorithm>
Yulia Garbovichf61ea352021-11-11 14:16:57 +020037#include <cstddef>
38#include <cstdio>
Davide Grohmann144b2d22022-05-31 15:24:02 +020039#include <inttypes.h>
Davide Grohmann160001c2022-03-24 15:38:27 +010040#include <list>
Yulia Garbovichf61ea352021-11-11 14:16:57 +020041#include <vector>
42
43namespace MessageHandler {
44
Davide Grohmann160001c2022-03-24 15:38:27 +010045template <typename T, size_t capacity = 10>
46class Queue {
47public:
48 using Predicate = std::function<bool(const T &data)>;
49
50 Queue() {
51 mutex = xSemaphoreCreateMutex();
52 size = xSemaphoreCreateCounting(capacity, 0u);
53
54 if (mutex == nullptr || size == nullptr) {
55 printf("Error: failed to allocate memory for inference queue\n");
56 }
57 }
58
59 ~Queue() {
60 vSemaphoreDelete(mutex);
61 vSemaphoreDelete(size);
62 }
63
64 bool push(const T &data) {
65 xSemaphoreTake(mutex, portMAX_DELAY);
66 if (list.size() >= capacity) {
67 xSemaphoreGive(mutex);
68 return false;
69 }
70
71 list.push_back(data);
72 xSemaphoreGive(mutex);
73
74 // increase number of available inferences to pop
75 xSemaphoreGive(size);
76 return true;
77 }
78
79 void pop(T &data) {
80 // decrease the number of available inferences to pop
81 xSemaphoreTake(size, portMAX_DELAY);
82
83 xSemaphoreTake(mutex, portMAX_DELAY);
84 data = list.front();
85 list.pop_front();
86 xSemaphoreGive(mutex);
87 }
88
89 bool erase(Predicate pred) {
90 // let's optimistically assume we are removing an inference, so decrease pop
91 if (pdFALSE == xSemaphoreTake(size, 0)) {
92 // if there are no inferences return immediately
93 return false;
94 }
95
96 xSemaphoreTake(mutex, portMAX_DELAY);
97 auto found = std::find_if(list.begin(), list.end(), pred);
98 bool erased = found != list.end();
99 if (erased) {
100 list.erase(found);
101 }
102 xSemaphoreGive(mutex);
103
104 if (!erased) {
105 // no inference erased, so let's put the size count back
106 xSemaphoreGive(size);
107 }
108
109 return erased;
110 }
111
112private:
113 std::list<T> list;
114
115 SemaphoreHandle_t mutex;
116 SemaphoreHandle_t size;
117};
118
Yulia Garbovichf61ea352021-11-11 14:16:57 +0200119class IncomingMessageHandler {
120public:
Davide Grohmann134c39e2022-04-25 12:21:12 +0200121 IncomingMessageHandler(EthosU::ethosu_core_queue &inputMessageQueue,
122 EthosU::ethosu_core_queue &outputMessageQueue,
Yulia Garbovichf61ea352021-11-11 14:16:57 +0200123 Mailbox::Mailbox &mailbox,
Davide Grohmann160001c2022-03-24 15:38:27 +0100124 std::shared_ptr<Queue<EthosU::ethosu_core_inference_req>> inferenceInputQueue,
Davide Grohmann134c39e2022-04-25 12:21:12 +0200125 QueueHandle_t inferenceOutputQueue,
Davide Grohmann144b2d22022-05-31 15:24:02 +0200126 SemaphoreHandle_t messageNotify,
127 std::shared_ptr<Networks> networks);
Davide Grohmann160001c2022-03-24 15:38:27 +0100128
Yulia Garbovichf61ea352021-11-11 14:16:57 +0200129 void run();
130
131private:
132 bool handleMessage();
Davide Grohmann134c39e2022-04-25 12:21:12 +0200133 bool handleInferenceOutput();
Yulia Garbovichf61ea352021-11-11 14:16:57 +0200134 static void handleIrq(void *userArg);
135
Davide Grohmann134c39e2022-04-25 12:21:12 +0200136 void sendPong();
137 void sendErrorAndResetQueue(EthosU::ethosu_core_msg_err_type type, const char *message);
138 void sendVersionRsp();
139 void sendCapabilitiesRsp(uint64_t userArg);
140 void sendNetworkInfoRsp(uint64_t userArg, EthosU::ethosu_core_network_buffer &network);
141 void sendInferenceRsp(EthosU::ethosu_core_inference_rsp &inference);
142 void sendFailedInferenceRsp(uint64_t userArg, uint32_t status);
Davide Grohmann00de9ee2022-03-23 14:59:56 +0100143 void sendCancelInferenceRsp(uint64_t userArg, uint32_t status);
Davide Grohmann134c39e2022-04-25 12:21:12 +0200144 void readCapabilties(EthosU::ethosu_core_msg_capabilities_rsp &rsp);
145
146 MessageQueue::QueueImpl inputMessageQueue;
147 MessageQueue::QueueImpl outputMessageQueue;
Yulia Garbovichf61ea352021-11-11 14:16:57 +0200148 Mailbox::Mailbox &mailbox;
Davide Grohmann65520052022-04-07 15:01:34 +0200149 InferenceProcess::InferenceParser parser;
Davide Grohmann160001c2022-03-24 15:38:27 +0100150 std::shared_ptr<Queue<EthosU::ethosu_core_inference_req>> inferenceInputQueue;
Davide Grohmann134c39e2022-04-25 12:21:12 +0200151 QueueHandle_t inferenceOutputQueue;
152 SemaphoreHandle_t messageNotify;
153 EthosU::ethosu_core_msg_capabilities_rsp capabilities;
Davide Grohmann144b2d22022-05-31 15:24:02 +0200154 std::shared_ptr<Networks> networks;
Yulia Garbovichf61ea352021-11-11 14:16:57 +0200155};
156
157class InferenceHandler {
158public:
Davide Grohmann134c39e2022-04-25 12:21:12 +0200159 InferenceHandler(uint8_t *tensorArena,
160 size_t arenaSize,
Davide Grohmann160001c2022-03-24 15:38:27 +0100161 std::shared_ptr<Queue<EthosU::ethosu_core_inference_req>> inferenceInputQueue,
Davide Grohmann134c39e2022-04-25 12:21:12 +0200162 QueueHandle_t inferenceOutputQueue,
Davide Grohmann144b2d22022-05-31 15:24:02 +0200163 SemaphoreHandle_t messageNotify,
164 std::shared_ptr<Networks> networks);
Yulia Garbovichf61ea352021-11-11 14:16:57 +0200165
166 void run();
167
168private:
169 void runInference(EthosU::ethosu_core_inference_req &req, EthosU::ethosu_core_inference_rsp &rsp);
Kristofer Jonsson585ce692022-03-08 13:28:05 +0100170 bool getInferenceJob(const EthosU::ethosu_core_inference_req &req, InferenceProcess::InferenceJob &job);
171
Davide Grohmannadc908c2022-02-16 13:13:27 +0100172#if defined(ETHOSU)
Kristofer Jonsson5410db12022-01-27 17:39:06 +0100173 friend void ::ethosu_inference_begin(struct ethosu_driver *drv, void *userArg);
174 friend void ::ethosu_inference_end(struct ethosu_driver *drv, void *userArg);
Davide Grohmannadc908c2022-02-16 13:13:27 +0100175#endif
Davide Grohmann160001c2022-03-24 15:38:27 +0100176 std::shared_ptr<Queue<EthosU::ethosu_core_inference_req>> inferenceInputQueue;
Davide Grohmann134c39e2022-04-25 12:21:12 +0200177 QueueHandle_t inferenceOutputQueue;
178 SemaphoreHandle_t messageNotify;
Yulia Garbovichf61ea352021-11-11 14:16:57 +0200179 InferenceProcess::InferenceProcess inference;
Kristofer Jonsson5410db12022-01-27 17:39:06 +0100180 EthosU::ethosu_core_inference_req *currentReq;
181 EthosU::ethosu_core_inference_rsp *currentRsp;
Davide Grohmann144b2d22022-05-31 15:24:02 +0200182 std::shared_ptr<Networks> networks;
Yulia Garbovichf61ea352021-11-11 14:16:57 +0200183};
184
Yulia Garbovichf61ea352021-11-11 14:16:57 +0200185} // namespace MessageHandler
186
187#endif