blob: e61c6488a405d1e7f642071a5a7bc58dc7bce84d [file] [log] [blame]
Kristofer Jonsson641c0912020-08-31 11:34:14 +02001/*
2 * Copyright (c) 2020 Arm Limited. All rights reserved.
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19#include <message_process.hpp>
20
21#include <cstddef>
22#include <cstdio>
23
24namespace MessageProcess {
25
26QueueImpl::QueueImpl(ethosu_core_queue &queue) : queue(queue) {}
27
28bool QueueImpl::empty() const {
29 return queue.header.read == queue.header.write;
30}
31
32size_t QueueImpl::available() const {
33 size_t avail = queue.header.write - queue.header.read;
34
35 if (queue.header.read > queue.header.write) {
36 avail += queue.header.size;
37 }
38
39 return avail;
40}
41
42size_t QueueImpl::capacity() const {
43 return queue.header.size - available();
44}
45
46bool QueueImpl::read(uint8_t *dst, uint32_t length) {
47 const uint8_t *end = dst + length;
48 uint32_t rpos = queue.header.read;
49
50 if (length > available()) {
51 return false;
52 }
53
54 while (dst < end) {
55 *dst++ = queue.data[rpos];
56 rpos = (rpos + 1) % queue.header.size;
57 }
58
59 queue.header.read = rpos;
60
61#if defined(__DCACHE_PRESENT) && (__DCACHE_PRESENT == 1U)
62 SCB_CleanDCache();
63#endif
64
65 return true;
66}
67
68bool QueueImpl::write(const Vec *vec, size_t length) {
69 size_t total = 0;
70
71 for (size_t i = 0; i < length; i++) {
72 total += vec[i].length;
73 }
74
75 if (total > capacity()) {
76 return false;
77 }
78
79 uint32_t wpos = queue.header.write;
80
81 for (size_t i = 0; i < length; i++) {
82 const uint8_t *src = reinterpret_cast<const uint8_t *>(vec[i].base);
83 const uint8_t *end = src + vec[i].length;
84
85 while (src < end) {
86 queue.data[wpos] = *src++;
87 wpos = (wpos + 1) % queue.header.size;
88 }
89 }
90
91 // Update the write position last
92 queue.header.write = wpos;
93
94#if defined(__DCACHE_PRESENT) && (__DCACHE_PRESENT == 1U)
95 SCB_CleanDCache();
96#endif
97
98 // TODO replace with mailbox driver APIs
99 volatile uint32_t *set = reinterpret_cast<volatile uint32_t *>(0x41A00014);
100 *set = 0x1;
101
102 return true;
103}
104
105bool QueueImpl::write(const uint32_t type, const void *src, uint32_t length) {
106 ethosu_core_msg msg = {type, length};
107 Vec vec[2] = {{&msg, sizeof(msg)}, {src, length}};
108
109 return write(vec, 2);
110}
111
112MessageProcess::MessageProcess(ethosu_core_queue &in,
113 ethosu_core_queue &out,
114 InferenceProcess::InferenceProcess &inferenceProcess) :
115 queueIn(in),
116 queueOut(out), inferenceProcess(inferenceProcess) {}
117
118void MessageProcess::run() {
119 while (true) {
120 // Handle all messages in queue
121 while (handleMessage())
122 ;
123
124 // Wait for event
125 __WFE();
126 }
127}
128
129void MessageProcess::handleIrq() {
130 __SEV();
131}
132
133bool MessageProcess::handleMessage() {
134 ethosu_core_msg msg;
135 union {
136 ethosu_core_inference_req inferenceReq;
137 uint8_t data[1000];
138 } data;
139
140#if defined(__DCACHE_PRESENT) && (__DCACHE_PRESENT == 1U)
141 SCB_InvalidateDCache();
142#endif
143
144 // Read msg header
145 if (!queueIn.read(msg)) {
146 return false;
147 }
148
149 printf("Message. type=%u, length=%u\n", msg.type, msg.length);
150
151 // Read payload
152 if (!queueIn.read(data.data, msg.length)) {
153 printf("Failed to read payload.\n");
154 return false;
155 }
156
157 switch (msg.type) {
158 case ETHOSU_CORE_MSG_PING:
159 printf("Ping\n");
160 sendPong();
161 break;
162 case ETHOSU_CORE_MSG_INFERENCE_REQ: {
163 std::memcpy(&data.inferenceReq, data.data, sizeof(data.data));
164
165 ethosu_core_inference_req &req = data.inferenceReq;
166
167 printf("InferenceReq. network={0x%x, %u}, ifm={0x%x, %u}, ofm={0x%x, %u}\n",
168 req.network.ptr,
169 req.network.size,
170 req.ifm.ptr,
171 req.ifm.size,
172 req.ofm.ptr,
173 req.ofm.size,
174 req.user_arg);
175
176 InferenceProcess::DataPtr networkModel(reinterpret_cast<void *>(req.network.ptr), req.network.size);
177 InferenceProcess::DataPtr ifm(reinterpret_cast<void *>(req.ifm.ptr), req.ifm.size);
178 InferenceProcess::DataPtr ofm(reinterpret_cast<void *>(req.ofm.ptr), req.ofm.size);
179 InferenceProcess::DataPtr expectedOutput;
180 InferenceProcess::InferenceJob job("job", networkModel, ifm, ofm, expectedOutput, -1);
181
182 bool failed = inferenceProcess.runJob(job);
183
184 sendInferenceRsp(data.inferenceReq.user_arg, job.output.size, failed);
185 break;
186 }
187 default:
188 break;
189 }
190
191 return true;
192}
193
194void MessageProcess::sendPong() {
195 if (!queueOut.write(ETHOSU_CORE_MSG_PONG)) {
196 printf("Failed to write pong.\n");
197 }
198}
199
200void MessageProcess::sendInferenceRsp(uint64_t userArg, size_t ofmSize, bool failed) {
201 ethosu_core_inference_rsp rsp;
202
203 rsp.user_arg = userArg;
204 rsp.ofm_size = ofmSize;
205 rsp.status = failed ? ETHOSU_CORE_STATUS_ERROR : ETHOSU_CORE_STATUS_OK;
206
207 printf(
208 "Sending inference response. userArg=0x%llx, ofm_size=%u, status=%u\n", rsp.user_arg, rsp.ofm_size, rsp.status);
209
210 if (!queueOut.write(ETHOSU_CORE_MSG_INFERENCE_RSP, rsp)) {
211 printf("Failed to write inference.\n");
212 }
213}
214} // namespace MessageProcess