blob: 2f9d8ec1e175aa7f62053a8853f56db66d339212 [file] [log] [blame]
Kristofer Jonsson3f5510f2023-02-08 14:23:00 +01001/*
2 * SPDX-FileCopyrightText: Copyright 2022-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use _this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19/*****************************************************************************
20 * Includes
21 *****************************************************************************/
22
23#include "inference_runner.hpp"
24
25#include <cstdlib>
26
27#include <ethosu_log.h>
28
29#if defined(ETHOSU)
30#include <ethosu_driver.h>
31#include <pmu_ethosu.h>
32#endif
33
34/*****************************************************************************
35 * InferenceRunner
36 *****************************************************************************/
37
38InferenceRunner::InferenceRunner(uint8_t *tensorArena,
39 size_t arenaSize,
40 MessageHandler::InferenceQueue &_inferenceQueue,
41 MessageHandler::ResponseQueue &_responseQueue) :
42 inferenceQueue(_inferenceQueue),
43 responseQueue(_responseQueue), inference(tensorArena, arenaSize) {
44 BaseType_t ret = xTaskCreate(inferenceTask, "inferenceTask", 8 * 1024, this, 4, &taskHandle);
45 if (ret != pdPASS) {
46 LOG_ERR("Failed to create inference task");
47 abort();
48 }
49}
50
51InferenceRunner::~InferenceRunner() {
52 vTaskDelete(taskHandle);
53}
54
55void InferenceRunner::inferenceTask(void *param) {
56 auto _this = static_cast<InferenceRunner *>(param);
57
58 LOG_DEBUG("Starting inference task");
59
60 while (true) {
61 Message *message;
62 auto ret = _this->inferenceQueue.receive(message);
63 if (ret) {
64 abort();
65 }
66
67 auto &rpmsg = message->rpmsg;
68
69 switch (rpmsg.header.type) {
70 case EthosU::ETHOSU_CORE_MSG_INFERENCE_REQ: {
71 _this->handleInferenceRequest(message->src, rpmsg.header.msg_id, rpmsg.inf_req);
72 break;
73 }
74 default: {
75 LOG_WARN("Unsupported message for inference runner. type=%lu", rpmsg.header.type);
76 }
77 }
78
79 delete message;
80 }
81}
82
83void InferenceRunner::handleInferenceRequest(const uint32_t src,
84 const uint64_t msgId,
85 const EthosU::ethosu_core_msg_inference_req &request) {
86 auto message =
87 new Message(src, EthosU::ETHOSU_CORE_MSG_INFERENCE_RSP, msgId, sizeof(EthosU::ethosu_core_msg_inference_rsp));
88 auto &response = message->rpmsg.inf_rsp;
89
90 // Setup PMU configuration
91 response.pmu_cycle_counter_enable = request.pmu_cycle_counter_enable;
92
93 for (int i = 0; i < ETHOSU_CORE_PMU_MAX; i++) {
94 response.pmu_event_config[i] = request.pmu_event_config[i];
95 }
96
97 // Run inference
98 auto job = makeInferenceJob(request, response);
99 auto failed = inference.runJob(job);
100
101 // Send response rpmsg
102 response.ofm_count = job.output.size();
103 response.status = failed ? EthosU::ETHOSU_CORE_STATUS_ERROR : EthosU::ETHOSU_CORE_STATUS_OK;
104
105 for (size_t i = 0; i < job.output.size(); ++i) {
106 response.ofm_size[i] = job.output[i].size;
107 }
108
109 responseQueue.send(message);
110}
111
112InferenceProcess::InferenceJob InferenceRunner::makeInferenceJob(const EthosU::ethosu_core_msg_inference_req &request,
113 EthosU::ethosu_core_msg_inference_rsp &response) {
114 InferenceProcess::InferenceJob job;
115
116 job.networkModel =
117 InferenceProcess::DataPtr(reinterpret_cast<void *>(request.network.buffer.ptr), request.network.buffer.size);
118
119 for (uint32_t i = 0; i < request.ifm_count; ++i) {
120 job.input.push_back(
121 InferenceProcess::DataPtr(reinterpret_cast<void *>(request.ifm[i].ptr), request.ifm[i].size));
122 }
123
124 for (uint32_t i = 0; i < request.ofm_count; ++i) {
125 job.output.push_back(
126 InferenceProcess::DataPtr(reinterpret_cast<void *>(request.ofm[i].ptr), request.ofm[i].size));
127 }
128
129 job.externalContext = &response;
130
131 return job;
132}
133
134#if defined(ETHOSU)
135extern "C" {
136
137void ethosu_inference_begin(ethosu_driver *drv, void *userArg) {
138 LOG_DEBUG("");
139
140 auto response = static_cast<EthosU::ethosu_core_msg_inference_rsp *>(userArg);
141
142 // Calculate maximum number of events
143 const int numEvents = std::min(static_cast<int>(ETHOSU_PMU_Get_NumEventCounters()), ETHOSU_CORE_PMU_MAX);
144
145 // Enable PMU
146 ETHOSU_PMU_Enable(drv);
147
148 // Configure and enable events
149 for (int i = 0; i < numEvents; i++) {
150 ETHOSU_PMU_Set_EVTYPER(drv, i, static_cast<ethosu_pmu_event_type>(response->pmu_event_config[i]));
151 ETHOSU_PMU_CNTR_Enable(drv, 1 << i);
152 }
153
154 // Enable cycle counter
155 if (response->pmu_cycle_counter_enable) {
156 ETHOSU_PMU_PMCCNTR_CFG_Set_Stop_Event(drv, ETHOSU_PMU_NPU_IDLE);
157 ETHOSU_PMU_PMCCNTR_CFG_Set_Start_Event(drv, ETHOSU_PMU_NPU_ACTIVE);
158
159 ETHOSU_PMU_CNTR_Enable(drv, ETHOSU_PMU_CCNT_Msk);
160 ETHOSU_PMU_CYCCNT_Reset(drv);
161 }
162
163 // Reset all counters
164 ETHOSU_PMU_EVCNTR_ALL_Reset(drv);
165}
166
167void ethosu_inference_end(ethosu_driver *drv, void *userArg) {
168 auto response = static_cast<EthosU::ethosu_core_msg_inference_rsp *>(userArg);
169
170 // Get cycle counter
171 if (response->pmu_cycle_counter_enable) {
172 response->pmu_cycle_counter_count = ETHOSU_PMU_Get_CCNTR(drv);
173 }
174
175 // Calculate maximum number of events
176 const int numEvents = std::min(static_cast<int>(ETHOSU_PMU_Get_NumEventCounters()), ETHOSU_CORE_PMU_MAX);
177
178 // Get event counters
179 int i;
180 for (i = 0; i < numEvents; i++) {
181 response->pmu_event_count[i] = ETHOSU_PMU_Get_EVCNTR(drv, i);
182 }
183
184 for (; i < ETHOSU_CORE_PMU_MAX; i++) {
185 response->pmu_event_config[i] = 0;
186 response->pmu_event_count[i] = 0;
187 }
188
189 // Disable PMU
190 ETHOSU_PMU_Disable(drv);
191}
192}
193
194#endif