Yulia Garbovich | f61ea35 | 2021-11-11 14:16:57 +0200 | [diff] [blame] | 1 | /* |
Kristofer Jonsson | ac535f0 | 2022-03-10 11:08:39 +0100 | [diff] [blame] | 2 | * Copyright (c) 2020-2022 Arm Limited. |
Yulia Garbovich | f61ea35 | 2021-11-11 14:16:57 +0200 | [diff] [blame] | 3 | * |
| 4 | * SPDX-License-Identifier: Apache-2.0 |
| 5 | * |
| 6 | * Licensed under the Apache License, Version 2.0 (the License); you may |
| 7 | * not use this file except in compliance with the License. |
| 8 | * You may obtain a copy of the License at |
| 9 | * |
| 10 | * www.apache.org/licenses/LICENSE-2.0 |
| 11 | * |
| 12 | * Unless required by applicable law or agreed to in writing, software |
| 13 | * distributed under the License is distributed on an AS IS BASIS, WITHOUT |
| 14 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 15 | * See the License for the specific language governing permissions and |
| 16 | * limitations under the License. |
| 17 | */ |
| 18 | |
| 19 | #ifndef MESSAGE_HANDLER_H |
| 20 | #define MESSAGE_HANDLER_H |
| 21 | |
| 22 | #include "FreeRTOS.h" |
| 23 | #include "queue.h" |
| 24 | #include "semphr.h" |
| 25 | |
| 26 | #include "message_queue.hpp" |
Davide Grohmann | 144b2d2 | 2022-05-31 15:24:02 +0200 | [diff] [blame^] | 27 | #include "networks.hpp" |
Davide Grohmann | 160001c | 2022-03-24 15:38:27 +0100 | [diff] [blame] | 28 | #include <ethosu_core_interface.h> |
Davide Grohmann | adc908c | 2022-02-16 13:13:27 +0100 | [diff] [blame] | 29 | #if defined(ETHOSU) |
Kristofer Jonsson | 5410db1 | 2022-01-27 17:39:06 +0100 | [diff] [blame] | 30 | #include <ethosu_driver.h> |
Davide Grohmann | adc908c | 2022-02-16 13:13:27 +0100 | [diff] [blame] | 31 | #endif |
Davide Grohmann | 6552005 | 2022-04-07 15:01:34 +0200 | [diff] [blame] | 32 | #include <inference_parser.hpp> |
Yulia Garbovich | f61ea35 | 2021-11-11 14:16:57 +0200 | [diff] [blame] | 33 | #include <inference_process.hpp> |
| 34 | #include <mailbox.hpp> |
| 35 | |
Davide Grohmann | 160001c | 2022-03-24 15:38:27 +0100 | [diff] [blame] | 36 | #include <algorithm> |
Yulia Garbovich | f61ea35 | 2021-11-11 14:16:57 +0200 | [diff] [blame] | 37 | #include <cstddef> |
| 38 | #include <cstdio> |
Davide Grohmann | 144b2d2 | 2022-05-31 15:24:02 +0200 | [diff] [blame^] | 39 | #include <inttypes.h> |
Davide Grohmann | 160001c | 2022-03-24 15:38:27 +0100 | [diff] [blame] | 40 | #include <list> |
Yulia Garbovich | f61ea35 | 2021-11-11 14:16:57 +0200 | [diff] [blame] | 41 | #include <vector> |
| 42 | |
| 43 | namespace MessageHandler { |
| 44 | |
Davide Grohmann | 160001c | 2022-03-24 15:38:27 +0100 | [diff] [blame] | 45 | template <typename T, size_t capacity = 10> |
| 46 | class Queue { |
| 47 | public: |
| 48 | using Predicate = std::function<bool(const T &data)>; |
| 49 | |
| 50 | Queue() { |
| 51 | mutex = xSemaphoreCreateMutex(); |
| 52 | size = xSemaphoreCreateCounting(capacity, 0u); |
| 53 | |
| 54 | if (mutex == nullptr || size == nullptr) { |
| 55 | printf("Error: failed to allocate memory for inference queue\n"); |
| 56 | } |
| 57 | } |
| 58 | |
| 59 | ~Queue() { |
| 60 | vSemaphoreDelete(mutex); |
| 61 | vSemaphoreDelete(size); |
| 62 | } |
| 63 | |
| 64 | bool push(const T &data) { |
| 65 | xSemaphoreTake(mutex, portMAX_DELAY); |
| 66 | if (list.size() >= capacity) { |
| 67 | xSemaphoreGive(mutex); |
| 68 | return false; |
| 69 | } |
| 70 | |
| 71 | list.push_back(data); |
| 72 | xSemaphoreGive(mutex); |
| 73 | |
| 74 | // increase number of available inferences to pop |
| 75 | xSemaphoreGive(size); |
| 76 | return true; |
| 77 | } |
| 78 | |
| 79 | void pop(T &data) { |
| 80 | // decrease the number of available inferences to pop |
| 81 | xSemaphoreTake(size, portMAX_DELAY); |
| 82 | |
| 83 | xSemaphoreTake(mutex, portMAX_DELAY); |
| 84 | data = list.front(); |
| 85 | list.pop_front(); |
| 86 | xSemaphoreGive(mutex); |
| 87 | } |
| 88 | |
| 89 | bool erase(Predicate pred) { |
| 90 | // let's optimistically assume we are removing an inference, so decrease pop |
| 91 | if (pdFALSE == xSemaphoreTake(size, 0)) { |
| 92 | // if there are no inferences return immediately |
| 93 | return false; |
| 94 | } |
| 95 | |
| 96 | xSemaphoreTake(mutex, portMAX_DELAY); |
| 97 | auto found = std::find_if(list.begin(), list.end(), pred); |
| 98 | bool erased = found != list.end(); |
| 99 | if (erased) { |
| 100 | list.erase(found); |
| 101 | } |
| 102 | xSemaphoreGive(mutex); |
| 103 | |
| 104 | if (!erased) { |
| 105 | // no inference erased, so let's put the size count back |
| 106 | xSemaphoreGive(size); |
| 107 | } |
| 108 | |
| 109 | return erased; |
| 110 | } |
| 111 | |
| 112 | private: |
| 113 | std::list<T> list; |
| 114 | |
| 115 | SemaphoreHandle_t mutex; |
| 116 | SemaphoreHandle_t size; |
| 117 | }; |
| 118 | |
Yulia Garbovich | f61ea35 | 2021-11-11 14:16:57 +0200 | [diff] [blame] | 119 | class IncomingMessageHandler { |
| 120 | public: |
Davide Grohmann | 134c39e | 2022-04-25 12:21:12 +0200 | [diff] [blame] | 121 | IncomingMessageHandler(EthosU::ethosu_core_queue &inputMessageQueue, |
| 122 | EthosU::ethosu_core_queue &outputMessageQueue, |
Yulia Garbovich | f61ea35 | 2021-11-11 14:16:57 +0200 | [diff] [blame] | 123 | Mailbox::Mailbox &mailbox, |
Davide Grohmann | 160001c | 2022-03-24 15:38:27 +0100 | [diff] [blame] | 124 | std::shared_ptr<Queue<EthosU::ethosu_core_inference_req>> inferenceInputQueue, |
Davide Grohmann | 134c39e | 2022-04-25 12:21:12 +0200 | [diff] [blame] | 125 | QueueHandle_t inferenceOutputQueue, |
Davide Grohmann | 144b2d2 | 2022-05-31 15:24:02 +0200 | [diff] [blame^] | 126 | SemaphoreHandle_t messageNotify, |
| 127 | std::shared_ptr<Networks> networks); |
Davide Grohmann | 160001c | 2022-03-24 15:38:27 +0100 | [diff] [blame] | 128 | |
Yulia Garbovich | f61ea35 | 2021-11-11 14:16:57 +0200 | [diff] [blame] | 129 | void run(); |
| 130 | |
| 131 | private: |
| 132 | bool handleMessage(); |
Davide Grohmann | 134c39e | 2022-04-25 12:21:12 +0200 | [diff] [blame] | 133 | bool handleInferenceOutput(); |
Yulia Garbovich | f61ea35 | 2021-11-11 14:16:57 +0200 | [diff] [blame] | 134 | static void handleIrq(void *userArg); |
| 135 | |
Davide Grohmann | 134c39e | 2022-04-25 12:21:12 +0200 | [diff] [blame] | 136 | void sendPong(); |
| 137 | void sendErrorAndResetQueue(EthosU::ethosu_core_msg_err_type type, const char *message); |
| 138 | void sendVersionRsp(); |
| 139 | void sendCapabilitiesRsp(uint64_t userArg); |
| 140 | void sendNetworkInfoRsp(uint64_t userArg, EthosU::ethosu_core_network_buffer &network); |
| 141 | void sendInferenceRsp(EthosU::ethosu_core_inference_rsp &inference); |
| 142 | void sendFailedInferenceRsp(uint64_t userArg, uint32_t status); |
Davide Grohmann | 00de9ee | 2022-03-23 14:59:56 +0100 | [diff] [blame] | 143 | void sendCancelInferenceRsp(uint64_t userArg, uint32_t status); |
Davide Grohmann | 134c39e | 2022-04-25 12:21:12 +0200 | [diff] [blame] | 144 | void readCapabilties(EthosU::ethosu_core_msg_capabilities_rsp &rsp); |
| 145 | |
| 146 | MessageQueue::QueueImpl inputMessageQueue; |
| 147 | MessageQueue::QueueImpl outputMessageQueue; |
Yulia Garbovich | f61ea35 | 2021-11-11 14:16:57 +0200 | [diff] [blame] | 148 | Mailbox::Mailbox &mailbox; |
Davide Grohmann | 6552005 | 2022-04-07 15:01:34 +0200 | [diff] [blame] | 149 | InferenceProcess::InferenceParser parser; |
Davide Grohmann | 160001c | 2022-03-24 15:38:27 +0100 | [diff] [blame] | 150 | std::shared_ptr<Queue<EthosU::ethosu_core_inference_req>> inferenceInputQueue; |
Davide Grohmann | 134c39e | 2022-04-25 12:21:12 +0200 | [diff] [blame] | 151 | QueueHandle_t inferenceOutputQueue; |
| 152 | SemaphoreHandle_t messageNotify; |
| 153 | EthosU::ethosu_core_msg_capabilities_rsp capabilities; |
Davide Grohmann | 144b2d2 | 2022-05-31 15:24:02 +0200 | [diff] [blame^] | 154 | std::shared_ptr<Networks> networks; |
Yulia Garbovich | f61ea35 | 2021-11-11 14:16:57 +0200 | [diff] [blame] | 155 | }; |
| 156 | |
| 157 | class InferenceHandler { |
| 158 | public: |
Davide Grohmann | 134c39e | 2022-04-25 12:21:12 +0200 | [diff] [blame] | 159 | InferenceHandler(uint8_t *tensorArena, |
| 160 | size_t arenaSize, |
Davide Grohmann | 160001c | 2022-03-24 15:38:27 +0100 | [diff] [blame] | 161 | std::shared_ptr<Queue<EthosU::ethosu_core_inference_req>> inferenceInputQueue, |
Davide Grohmann | 134c39e | 2022-04-25 12:21:12 +0200 | [diff] [blame] | 162 | QueueHandle_t inferenceOutputQueue, |
Davide Grohmann | 144b2d2 | 2022-05-31 15:24:02 +0200 | [diff] [blame^] | 163 | SemaphoreHandle_t messageNotify, |
| 164 | std::shared_ptr<Networks> networks); |
Yulia Garbovich | f61ea35 | 2021-11-11 14:16:57 +0200 | [diff] [blame] | 165 | |
| 166 | void run(); |
| 167 | |
| 168 | private: |
| 169 | void runInference(EthosU::ethosu_core_inference_req &req, EthosU::ethosu_core_inference_rsp &rsp); |
Kristofer Jonsson | 585ce69 | 2022-03-08 13:28:05 +0100 | [diff] [blame] | 170 | bool getInferenceJob(const EthosU::ethosu_core_inference_req &req, InferenceProcess::InferenceJob &job); |
| 171 | |
Davide Grohmann | adc908c | 2022-02-16 13:13:27 +0100 | [diff] [blame] | 172 | #if defined(ETHOSU) |
Kristofer Jonsson | 5410db1 | 2022-01-27 17:39:06 +0100 | [diff] [blame] | 173 | friend void ::ethosu_inference_begin(struct ethosu_driver *drv, void *userArg); |
| 174 | friend void ::ethosu_inference_end(struct ethosu_driver *drv, void *userArg); |
Davide Grohmann | adc908c | 2022-02-16 13:13:27 +0100 | [diff] [blame] | 175 | #endif |
Davide Grohmann | 160001c | 2022-03-24 15:38:27 +0100 | [diff] [blame] | 176 | std::shared_ptr<Queue<EthosU::ethosu_core_inference_req>> inferenceInputQueue; |
Davide Grohmann | 134c39e | 2022-04-25 12:21:12 +0200 | [diff] [blame] | 177 | QueueHandle_t inferenceOutputQueue; |
| 178 | SemaphoreHandle_t messageNotify; |
Yulia Garbovich | f61ea35 | 2021-11-11 14:16:57 +0200 | [diff] [blame] | 179 | InferenceProcess::InferenceProcess inference; |
Kristofer Jonsson | 5410db1 | 2022-01-27 17:39:06 +0100 | [diff] [blame] | 180 | EthosU::ethosu_core_inference_req *currentReq; |
| 181 | EthosU::ethosu_core_inference_rsp *currentRsp; |
Davide Grohmann | 144b2d2 | 2022-05-31 15:24:02 +0200 | [diff] [blame^] | 182 | std::shared_ptr<Networks> networks; |
Yulia Garbovich | f61ea35 | 2021-11-11 14:16:57 +0200 | [diff] [blame] | 183 | }; |
| 184 | |
Yulia Garbovich | f61ea35 | 2021-11-11 14:16:57 +0200 | [diff] [blame] | 185 | } // namespace MessageHandler |
| 186 | |
| 187 | #endif |