blob: 2820275ea5a202bb0d2f6982b37a372471a4a5fe [file] [log] [blame]
Kristofer Jonsson641c0912020-08-31 11:34:14 +02001/*
2 * Copyright (c) 2020 Arm Limited. All rights reserved.
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19#include <message_process.hpp>
20
21#include <cstddef>
22#include <cstdio>
Kristofer Jonsson25480142020-09-03 12:35:21 +020023#include <cstring>
Kristofer Jonsson641c0912020-08-31 11:34:14 +020024
25namespace MessageProcess {
26
27QueueImpl::QueueImpl(ethosu_core_queue &queue) : queue(queue) {}
28
29bool QueueImpl::empty() const {
30 return queue.header.read == queue.header.write;
31}
32
33size_t QueueImpl::available() const {
34 size_t avail = queue.header.write - queue.header.read;
35
36 if (queue.header.read > queue.header.write) {
37 avail += queue.header.size;
38 }
39
40 return avail;
41}
42
43size_t QueueImpl::capacity() const {
44 return queue.header.size - available();
45}
46
47bool QueueImpl::read(uint8_t *dst, uint32_t length) {
48 const uint8_t *end = dst + length;
49 uint32_t rpos = queue.header.read;
50
51 if (length > available()) {
52 return false;
53 }
54
55 while (dst < end) {
56 *dst++ = queue.data[rpos];
57 rpos = (rpos + 1) % queue.header.size;
58 }
59
60 queue.header.read = rpos;
61
62#if defined(__DCACHE_PRESENT) && (__DCACHE_PRESENT == 1U)
63 SCB_CleanDCache();
64#endif
65
66 return true;
67}
68
69bool QueueImpl::write(const Vec *vec, size_t length) {
70 size_t total = 0;
71
72 for (size_t i = 0; i < length; i++) {
73 total += vec[i].length;
74 }
75
76 if (total > capacity()) {
77 return false;
78 }
79
80 uint32_t wpos = queue.header.write;
81
82 for (size_t i = 0; i < length; i++) {
83 const uint8_t *src = reinterpret_cast<const uint8_t *>(vec[i].base);
84 const uint8_t *end = src + vec[i].length;
85
86 while (src < end) {
87 queue.data[wpos] = *src++;
88 wpos = (wpos + 1) % queue.header.size;
89 }
90 }
91
92 // Update the write position last
93 queue.header.write = wpos;
94
95#if defined(__DCACHE_PRESENT) && (__DCACHE_PRESENT == 1U)
96 SCB_CleanDCache();
97#endif
98
99 // TODO replace with mailbox driver APIs
100 volatile uint32_t *set = reinterpret_cast<volatile uint32_t *>(0x41A00014);
101 *set = 0x1;
102
103 return true;
104}
105
106bool QueueImpl::write(const uint32_t type, const void *src, uint32_t length) {
107 ethosu_core_msg msg = {type, length};
108 Vec vec[2] = {{&msg, sizeof(msg)}, {src, length}};
109
110 return write(vec, 2);
111}
112
113MessageProcess::MessageProcess(ethosu_core_queue &in,
114 ethosu_core_queue &out,
115 InferenceProcess::InferenceProcess &inferenceProcess) :
116 queueIn(in),
117 queueOut(out), inferenceProcess(inferenceProcess) {}
118
119void MessageProcess::run() {
120 while (true) {
121 // Handle all messages in queue
122 while (handleMessage())
123 ;
124
125 // Wait for event
126 __WFE();
127 }
128}
129
130void MessageProcess::handleIrq() {
131 __SEV();
132}
133
134bool MessageProcess::handleMessage() {
135 ethosu_core_msg msg;
136 union {
137 ethosu_core_inference_req inferenceReq;
138 uint8_t data[1000];
139 } data;
140
141#if defined(__DCACHE_PRESENT) && (__DCACHE_PRESENT == 1U)
142 SCB_InvalidateDCache();
143#endif
144
145 // Read msg header
146 if (!queueIn.read(msg)) {
147 return false;
148 }
149
150 printf("Message. type=%u, length=%u\n", msg.type, msg.length);
151
152 // Read payload
153 if (!queueIn.read(data.data, msg.length)) {
154 printf("Failed to read payload.\n");
155 return false;
156 }
157
158 switch (msg.type) {
159 case ETHOSU_CORE_MSG_PING:
160 printf("Ping\n");
161 sendPong();
162 break;
163 case ETHOSU_CORE_MSG_INFERENCE_REQ: {
164 std::memcpy(&data.inferenceReq, data.data, sizeof(data.data));
165
166 ethosu_core_inference_req &req = data.inferenceReq;
167
168 printf("InferenceReq. network={0x%x, %u}, ifm={0x%x, %u}, ofm={0x%x, %u}\n",
169 req.network.ptr,
170 req.network.size,
171 req.ifm.ptr,
172 req.ifm.size,
173 req.ofm.ptr,
174 req.ofm.size,
175 req.user_arg);
176
177 InferenceProcess::DataPtr networkModel(reinterpret_cast<void *>(req.network.ptr), req.network.size);
178 InferenceProcess::DataPtr ifm(reinterpret_cast<void *>(req.ifm.ptr), req.ifm.size);
179 InferenceProcess::DataPtr ofm(reinterpret_cast<void *>(req.ofm.ptr), req.ofm.size);
180 InferenceProcess::DataPtr expectedOutput;
181 InferenceProcess::InferenceJob job("job", networkModel, ifm, ofm, expectedOutput, -1);
182
183 bool failed = inferenceProcess.runJob(job);
184
185 sendInferenceRsp(data.inferenceReq.user_arg, job.output.size, failed);
186 break;
187 }
188 default:
189 break;
190 }
191
192 return true;
193}
194
195void MessageProcess::sendPong() {
196 if (!queueOut.write(ETHOSU_CORE_MSG_PONG)) {
197 printf("Failed to write pong.\n");
198 }
199}
200
201void MessageProcess::sendInferenceRsp(uint64_t userArg, size_t ofmSize, bool failed) {
202 ethosu_core_inference_rsp rsp;
203
204 rsp.user_arg = userArg;
205 rsp.ofm_size = ofmSize;
206 rsp.status = failed ? ETHOSU_CORE_STATUS_ERROR : ETHOSU_CORE_STATUS_OK;
207
208 printf(
209 "Sending inference response. userArg=0x%llx, ofm_size=%u, status=%u\n", rsp.user_arg, rsp.ofm_size, rsp.status);
210
211 if (!queueOut.write(ETHOSU_CORE_MSG_INFERENCE_RSP, rsp)) {
212 printf("Failed to write inference.\n");
213 }
214}
215} // namespace MessageProcess