blob: fa77512be120eec8e1e159297746374647286848 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18
19#include "Classifier.hpp"
20#include "InputFiles.hpp"
21#include "MobileNetModel.hpp"
22#include "UseCaseCommonUtils.hpp"
23#include "hal.h"
24
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +010025#include <inttypes.h>
26
alexander3c798932021-03-26 21:42:19 +000027using ImgClassClassifier = arm::app::Classifier;
28
29namespace arm {
30namespace app {
31
32 /**
33 * @brief Helper function to load the current image into the input
34 * tensor.
35 * @param[in] imIdx Image index (from the pool of images available
36 * to the application).
37 * @param[out] inputTensor Pointer to the input tensor to be populated.
38 * @return true if tensor is loaded, false otherwise.
39 **/
alexanderc350cdc2021-04-29 20:36:09 +010040 static bool LoadImageIntoTensor(uint32_t imIdx, TfLiteTensor* inputTensor);
alexander3c798932021-03-26 21:42:19 +000041
42 /**
43 * @brief Helper function to increment current image index.
44 * @param[in,out] ctx Pointer to the application context object.
45 **/
alexanderc350cdc2021-04-29 20:36:09 +010046 static void IncrementAppCtxImageIdx(ApplicationContext& ctx);
alexander3c798932021-03-26 21:42:19 +000047
48 /**
49 * @brief Helper function to set the image index.
50 * @param[in,out] ctx Pointer to the application context object.
51 * @param[in] idx Value to be set.
52 * @return true if index is set, false otherwise.
53 **/
alexanderc350cdc2021-04-29 20:36:09 +010054 static bool SetAppCtxImageIdx(ApplicationContext& ctx, uint32_t idx);
alexander3c798932021-03-26 21:42:19 +000055
56 /**
57 * @brief Presents inference results using the data presentation
58 * object.
59 * @param[in] platform Reference to the hal platform object.
60 * @param[in] results Vector of classification results to be displayed.
61 * @param[in] infTimeMs Inference time in milliseconds, if available
62 * otherwise, this can be passed in as 0.
63 * @return true if successful, false otherwise.
64 **/
alexanderc350cdc2021-04-29 20:36:09 +010065 static bool PresentInferenceResult(hal_platform& platform,
66 const std::vector<ClassificationResult>& results);
alexander3c798932021-03-26 21:42:19 +000067
68 /**
69 * @brief Helper function to convert a UINT8 image to INT8 format.
70 * @param[in,out] data Pointer to the data start.
71 * @param[in] kMaxImageSize Total number of pixels in the image.
72 **/
73 static void ConvertImgToInt8(void* data, size_t kMaxImageSize);
74
75 /* Image inference classification handler. */
76 bool ClassifyImageHandler(ApplicationContext& ctx, uint32_t imgIndex, bool runAll)
77 {
78 auto& platform = ctx.Get<hal_platform&>("platform");
Isabella Gottardi8df12f32021-04-07 17:15:31 +010079 auto& profiler = ctx.Get<Profiler&>("profiler");
alexander3c798932021-03-26 21:42:19 +000080
81 constexpr uint32_t dataPsnImgDownscaleFactor = 2;
82 constexpr uint32_t dataPsnImgStartX = 10;
83 constexpr uint32_t dataPsnImgStartY = 35;
84
85 constexpr uint32_t dataPsnTxtInfStartX = 150;
86 constexpr uint32_t dataPsnTxtInfStartY = 40;
87
88 platform.data_psn->clear(COLOR_BLACK);
89
90 auto& model = ctx.Get<Model&>("model");
91
92 /* If the request has a valid size, set the image index. */
93 if (imgIndex < NUMBER_OF_FILES) {
alexanderc350cdc2021-04-29 20:36:09 +010094 if (!SetAppCtxImageIdx(ctx, imgIndex)) {
alexander3c798932021-03-26 21:42:19 +000095 return false;
96 }
97 }
98 if (!model.IsInited()) {
99 printf_err("Model is not initialised! Terminating processing.\n");
100 return false;
101 }
102
103 auto curImIdx = ctx.Get<uint32_t>("imgIndex");
104
105 TfLiteTensor* outputTensor = model.GetOutputTensor(0);
106 TfLiteTensor* inputTensor = model.GetInputTensor(0);
107
108 if (!inputTensor->dims) {
109 printf_err("Invalid input tensor dims\n");
110 return false;
111 } else if (inputTensor->dims->size < 3) {
112 printf_err("Input tensor dimension should be >= 3\n");
113 return false;
114 }
115
116 TfLiteIntArray* inputShape = model.GetInputShape(0);
117
118 const uint32_t nCols = inputShape->data[arm::app::MobileNetModel::ms_inputColsIdx];
119 const uint32_t nRows = inputShape->data[arm::app::MobileNetModel::ms_inputRowsIdx];
120 const uint32_t nChannels = inputShape->data[arm::app::MobileNetModel::ms_inputChannelsIdx];
121
122 std::vector<ClassificationResult> results;
123
124 do {
125 /* Strings for presentation/logging. */
126 std::string str_inf{"Running inference... "};
127
128 /* Copy over the data. */
alexanderc350cdc2021-04-29 20:36:09 +0100129 LoadImageIntoTensor(ctx.Get<uint32_t>("imgIndex"), inputTensor);
alexander3c798932021-03-26 21:42:19 +0000130
131 /* Display this image on the LCD. */
132 platform.data_psn->present_data_image(
133 (uint8_t*) inputTensor->data.data,
134 nCols, nRows, nChannels,
135 dataPsnImgStartX, dataPsnImgStartY, dataPsnImgDownscaleFactor);
136
137 /* If the data is signed. */
138 if (model.IsDataSigned()) {
139 ConvertImgToInt8(inputTensor->data.data, inputTensor->bytes);
140 }
141
142 /* Display message on the LCD - inference running. */
143 platform.data_psn->present_data_text(str_inf.c_str(), str_inf.size(),
144 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
145
146 /* Run inference over this image. */
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100147 info("Running inference on image %" PRIu32 " => %s\n", ctx.Get<uint32_t>("imgIndex"),
alexander3c798932021-03-26 21:42:19 +0000148 get_filename(ctx.Get<uint32_t>("imgIndex")));
149
alexander27b62d92021-05-04 20:46:08 +0100150 if (!RunInference(model, profiler)) {
151 return false;
152 }
alexander3c798932021-03-26 21:42:19 +0000153
154 /* Erase. */
155 str_inf = std::string(str_inf.size(), ' ');
156 platform.data_psn->present_data_text(str_inf.c_str(), str_inf.size(),
157 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
158
159 auto& classifier = ctx.Get<ImgClassClassifier&>("classifier");
160 classifier.GetClassificationResults(outputTensor, results,
161 ctx.Get<std::vector <std::string>&>("labels"),
162 5);
163
164 /* Add results to context for access outside handler. */
165 ctx.Set<std::vector<ClassificationResult>>("results", results);
166
167#if VERIFY_TEST_OUTPUT
168 arm::app::DumpTensor(outputTensor);
169#endif /* VERIFY_TEST_OUTPUT */
170
alexanderc350cdc2021-04-29 20:36:09 +0100171 if (!PresentInferenceResult(platform, results)) {
alexander3c798932021-03-26 21:42:19 +0000172 return false;
173 }
174
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100175 profiler.PrintProfilingResult();
176
alexanderc350cdc2021-04-29 20:36:09 +0100177 IncrementAppCtxImageIdx(ctx);
alexander3c798932021-03-26 21:42:19 +0000178
179 } while (runAll && ctx.Get<uint32_t>("imgIndex") != curImIdx);
180
181 return true;
182 }
183
alexanderc350cdc2021-04-29 20:36:09 +0100184 static bool LoadImageIntoTensor(uint32_t imIdx, TfLiteTensor* inputTensor)
alexander3c798932021-03-26 21:42:19 +0000185 {
186 const size_t copySz = inputTensor->bytes < IMAGE_DATA_SIZE ?
187 inputTensor->bytes : IMAGE_DATA_SIZE;
188 const uint8_t* imgSrc = get_img_array(imIdx);
189 if (nullptr == imgSrc) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100190 printf_err("Failed to get image index %" PRIu32 " (max: %u)\n", imIdx,
alexander3c798932021-03-26 21:42:19 +0000191 NUMBER_OF_FILES - 1);
192 return false;
193 }
194
195 memcpy(inputTensor->data.data, imgSrc, copySz);
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100196 debug("Image %" PRIu32 " loaded\n", imIdx);
alexander3c798932021-03-26 21:42:19 +0000197 return true;
198 }
199
alexanderc350cdc2021-04-29 20:36:09 +0100200 static void IncrementAppCtxImageIdx(ApplicationContext& ctx)
alexander3c798932021-03-26 21:42:19 +0000201 {
202 auto curImIdx = ctx.Get<uint32_t>("imgIndex");
203
204 if (curImIdx + 1 >= NUMBER_OF_FILES) {
205 ctx.Set<uint32_t>("imgIndex", 0);
206 return;
207 }
208 ++curImIdx;
209 ctx.Set<uint32_t>("imgIndex", curImIdx);
210 }
211
alexanderc350cdc2021-04-29 20:36:09 +0100212 static bool SetAppCtxImageIdx(ApplicationContext& ctx, uint32_t idx)
alexander3c798932021-03-26 21:42:19 +0000213 {
214 if (idx >= NUMBER_OF_FILES) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100215 printf_err("Invalid idx %" PRIu32 " (expected less than %u)\n",
alexander3c798932021-03-26 21:42:19 +0000216 idx, NUMBER_OF_FILES);
217 return false;
218 }
219 ctx.Set<uint32_t>("imgIndex", idx);
220 return true;
221 }
222
alexanderc350cdc2021-04-29 20:36:09 +0100223 static bool PresentInferenceResult(hal_platform& platform,
224 const std::vector<ClassificationResult>& results)
alexander3c798932021-03-26 21:42:19 +0000225 {
226 constexpr uint32_t dataPsnTxtStartX1 = 150;
227 constexpr uint32_t dataPsnTxtStartY1 = 30;
228
229 constexpr uint32_t dataPsnTxtStartX2 = 10;
230 constexpr uint32_t dataPsnTxtStartY2 = 150;
231
232 constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */
233
234 platform.data_psn->set_text_color(COLOR_GREEN);
235
236 /* Display each result. */
237 uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
238 uint32_t rowIdx2 = dataPsnTxtStartY2;
239
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100240 info("Final results:\n");
241 info("Total number of inferences: 1\n");
alexander3c798932021-03-26 21:42:19 +0000242 for (uint32_t i = 0; i < results.size(); ++i) {
243 std::string resultStr =
244 std::to_string(i + 1) + ") " +
245 std::to_string(results[i].m_labelIdx) +
246 " (" + std::to_string(results[i].m_normalisedVal) + ")";
247
248 platform.data_psn->present_data_text(
249 resultStr.c_str(), resultStr.size(),
250 dataPsnTxtStartX1, rowIdx1, 0);
251 rowIdx1 += dataPsnTxtYIncr;
252
253 resultStr = std::to_string(i + 1) + ") " + results[i].m_label;
254 platform.data_psn->present_data_text(
255 resultStr.c_str(), resultStr.size(),
256 dataPsnTxtStartX2, rowIdx2, 0);
257 rowIdx2 += dataPsnTxtYIncr;
258
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100259 info("%" PRIu32 ") %" PRIu32 " (%f) -> %s\n", i,
260 results[i].m_labelIdx, results[i].m_normalisedVal,
261 results[i].m_label.c_str());
alexander3c798932021-03-26 21:42:19 +0000262 }
263
264 return true;
265 }
266
267 static void ConvertImgToInt8(void* data, const size_t kMaxImageSize)
268 {
269 auto* tmp_req_data = (uint8_t*) data;
270 auto* tmp_signed_req_data = (int8_t*) data;
271
272 for (size_t i = 0; i < kMaxImageSize; i++) {
273 tmp_signed_req_data[i] = (int8_t) (
274 (int32_t) (tmp_req_data[i]) - 128);
275 }
276 }
277
278} /* namespace app */
279} /* namespace arm */