blob: f7e83f528f1bed4a7bd55ef38455717346b735cb [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18
19#include "Classifier.hpp"
20#include "InputFiles.hpp"
21#include "MobileNetModel.hpp"
22#include "UseCaseCommonUtils.hpp"
23#include "hal.h"
24
25using ImgClassClassifier = arm::app::Classifier;
26
27namespace arm {
28namespace app {
29
30 /**
31 * @brief Helper function to load the current image into the input
32 * tensor.
33 * @param[in] imIdx Image index (from the pool of images available
34 * to the application).
35 * @param[out] inputTensor Pointer to the input tensor to be populated.
36 * @return true if tensor is loaded, false otherwise.
37 **/
38 static bool _LoadImageIntoTensor(uint32_t imIdx, TfLiteTensor* inputTensor);
39
40 /**
41 * @brief Helper function to increment current image index.
42 * @param[in,out] ctx Pointer to the application context object.
43 **/
44 static void _IncrementAppCtxImageIdx(ApplicationContext& ctx);
45
46 /**
47 * @brief Helper function to set the image index.
48 * @param[in,out] ctx Pointer to the application context object.
49 * @param[in] idx Value to be set.
50 * @return true if index is set, false otherwise.
51 **/
52 static bool _SetAppCtxImageIdx(ApplicationContext& ctx, uint32_t idx);
53
54 /**
55 * @brief Presents inference results using the data presentation
56 * object.
57 * @param[in] platform Reference to the hal platform object.
58 * @param[in] results Vector of classification results to be displayed.
59 * @param[in] infTimeMs Inference time in milliseconds, if available
60 * otherwise, this can be passed in as 0.
61 * @return true if successful, false otherwise.
62 **/
63 static bool _PresentInferenceResult(hal_platform& platform,
64 const std::vector<ClassificationResult>& results);
65
66 /**
67 * @brief Helper function to convert a UINT8 image to INT8 format.
68 * @param[in,out] data Pointer to the data start.
69 * @param[in] kMaxImageSize Total number of pixels in the image.
70 **/
71 static void ConvertImgToInt8(void* data, size_t kMaxImageSize);
72
73 /* Image inference classification handler. */
74 bool ClassifyImageHandler(ApplicationContext& ctx, uint32_t imgIndex, bool runAll)
75 {
76 auto& platform = ctx.Get<hal_platform&>("platform");
Isabella Gottardi8df12f32021-04-07 17:15:31 +010077 auto& profiler = ctx.Get<Profiler&>("profiler");
alexander3c798932021-03-26 21:42:19 +000078
79 constexpr uint32_t dataPsnImgDownscaleFactor = 2;
80 constexpr uint32_t dataPsnImgStartX = 10;
81 constexpr uint32_t dataPsnImgStartY = 35;
82
83 constexpr uint32_t dataPsnTxtInfStartX = 150;
84 constexpr uint32_t dataPsnTxtInfStartY = 40;
85
86 platform.data_psn->clear(COLOR_BLACK);
87
88 auto& model = ctx.Get<Model&>("model");
89
90 /* If the request has a valid size, set the image index. */
91 if (imgIndex < NUMBER_OF_FILES) {
92 if (!_SetAppCtxImageIdx(ctx, imgIndex)) {
93 return false;
94 }
95 }
96 if (!model.IsInited()) {
97 printf_err("Model is not initialised! Terminating processing.\n");
98 return false;
99 }
100
101 auto curImIdx = ctx.Get<uint32_t>("imgIndex");
102
103 TfLiteTensor* outputTensor = model.GetOutputTensor(0);
104 TfLiteTensor* inputTensor = model.GetInputTensor(0);
105
106 if (!inputTensor->dims) {
107 printf_err("Invalid input tensor dims\n");
108 return false;
109 } else if (inputTensor->dims->size < 3) {
110 printf_err("Input tensor dimension should be >= 3\n");
111 return false;
112 }
113
114 TfLiteIntArray* inputShape = model.GetInputShape(0);
115
116 const uint32_t nCols = inputShape->data[arm::app::MobileNetModel::ms_inputColsIdx];
117 const uint32_t nRows = inputShape->data[arm::app::MobileNetModel::ms_inputRowsIdx];
118 const uint32_t nChannels = inputShape->data[arm::app::MobileNetModel::ms_inputChannelsIdx];
119
120 std::vector<ClassificationResult> results;
121
122 do {
123 /* Strings for presentation/logging. */
124 std::string str_inf{"Running inference... "};
125
126 /* Copy over the data. */
127 _LoadImageIntoTensor(ctx.Get<uint32_t>("imgIndex"), inputTensor);
128
129 /* Display this image on the LCD. */
130 platform.data_psn->present_data_image(
131 (uint8_t*) inputTensor->data.data,
132 nCols, nRows, nChannels,
133 dataPsnImgStartX, dataPsnImgStartY, dataPsnImgDownscaleFactor);
134
135 /* If the data is signed. */
136 if (model.IsDataSigned()) {
137 ConvertImgToInt8(inputTensor->data.data, inputTensor->bytes);
138 }
139
140 /* Display message on the LCD - inference running. */
141 platform.data_psn->present_data_text(str_inf.c_str(), str_inf.size(),
142 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
143
144 /* Run inference over this image. */
145 info("Running inference on image %u => %s\n", ctx.Get<uint32_t>("imgIndex"),
146 get_filename(ctx.Get<uint32_t>("imgIndex")));
147
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100148 RunInference(model, profiler);
alexander3c798932021-03-26 21:42:19 +0000149
150 /* Erase. */
151 str_inf = std::string(str_inf.size(), ' ');
152 platform.data_psn->present_data_text(str_inf.c_str(), str_inf.size(),
153 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
154
155 auto& classifier = ctx.Get<ImgClassClassifier&>("classifier");
156 classifier.GetClassificationResults(outputTensor, results,
157 ctx.Get<std::vector <std::string>&>("labels"),
158 5);
159
160 /* Add results to context for access outside handler. */
161 ctx.Set<std::vector<ClassificationResult>>("results", results);
162
163#if VERIFY_TEST_OUTPUT
164 arm::app::DumpTensor(outputTensor);
165#endif /* VERIFY_TEST_OUTPUT */
166
167 if (!_PresentInferenceResult(platform, results)) {
168 return false;
169 }
170
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100171 profiler.PrintProfilingResult();
172
alexander3c798932021-03-26 21:42:19 +0000173 _IncrementAppCtxImageIdx(ctx);
174
175 } while (runAll && ctx.Get<uint32_t>("imgIndex") != curImIdx);
176
177 return true;
178 }
179
180 static bool _LoadImageIntoTensor(const uint32_t imIdx, TfLiteTensor* inputTensor)
181 {
182 const size_t copySz = inputTensor->bytes < IMAGE_DATA_SIZE ?
183 inputTensor->bytes : IMAGE_DATA_SIZE;
184 const uint8_t* imgSrc = get_img_array(imIdx);
185 if (nullptr == imgSrc) {
186 printf_err("Failed to get image index %u (max: %u)\n", imIdx,
187 NUMBER_OF_FILES - 1);
188 return false;
189 }
190
191 memcpy(inputTensor->data.data, imgSrc, copySz);
192 debug("Image %u loaded\n", imIdx);
193 return true;
194 }
195
196 static void _IncrementAppCtxImageIdx(ApplicationContext& ctx)
197 {
198 auto curImIdx = ctx.Get<uint32_t>("imgIndex");
199
200 if (curImIdx + 1 >= NUMBER_OF_FILES) {
201 ctx.Set<uint32_t>("imgIndex", 0);
202 return;
203 }
204 ++curImIdx;
205 ctx.Set<uint32_t>("imgIndex", curImIdx);
206 }
207
208 static bool _SetAppCtxImageIdx(ApplicationContext& ctx, const uint32_t idx)
209 {
210 if (idx >= NUMBER_OF_FILES) {
211 printf_err("Invalid idx %u (expected less than %u)\n",
212 idx, NUMBER_OF_FILES);
213 return false;
214 }
215 ctx.Set<uint32_t>("imgIndex", idx);
216 return true;
217 }
218
219 static bool _PresentInferenceResult(hal_platform& platform,
220 const std::vector<ClassificationResult>& results)
221 {
222 constexpr uint32_t dataPsnTxtStartX1 = 150;
223 constexpr uint32_t dataPsnTxtStartY1 = 30;
224
225 constexpr uint32_t dataPsnTxtStartX2 = 10;
226 constexpr uint32_t dataPsnTxtStartY2 = 150;
227
228 constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */
229
230 platform.data_psn->set_text_color(COLOR_GREEN);
231
232 /* Display each result. */
233 uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
234 uint32_t rowIdx2 = dataPsnTxtStartY2;
235
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100236 info("Final results:\n");
237 info("Total number of inferences: 1\n");
alexander3c798932021-03-26 21:42:19 +0000238 for (uint32_t i = 0; i < results.size(); ++i) {
239 std::string resultStr =
240 std::to_string(i + 1) + ") " +
241 std::to_string(results[i].m_labelIdx) +
242 " (" + std::to_string(results[i].m_normalisedVal) + ")";
243
244 platform.data_psn->present_data_text(
245 resultStr.c_str(), resultStr.size(),
246 dataPsnTxtStartX1, rowIdx1, 0);
247 rowIdx1 += dataPsnTxtYIncr;
248
249 resultStr = std::to_string(i + 1) + ") " + results[i].m_label;
250 platform.data_psn->present_data_text(
251 resultStr.c_str(), resultStr.size(),
252 dataPsnTxtStartX2, rowIdx2, 0);
253 rowIdx2 += dataPsnTxtYIncr;
254
255 info("%u) %u (%f) -> %s\n", i, results[i].m_labelIdx,
256 results[i].m_normalisedVal, results[i].m_label.c_str());
257 }
258
259 return true;
260 }
261
262 static void ConvertImgToInt8(void* data, const size_t kMaxImageSize)
263 {
264 auto* tmp_req_data = (uint8_t*) data;
265 auto* tmp_signed_req_data = (int8_t*) data;
266
267 for (size_t i = 0; i < kMaxImageSize; i++) {
268 tmp_signed_req_data[i] = (int8_t) (
269 (int32_t) (tmp_req_data[i]) - 128);
270 }
271 }
272
273} /* namespace app */
274} /* namespace arm */