blob: 22e6ba070e33983dc482ad7817ffc4ba66357757 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18
19#include "Classifier.hpp"
20#include "InputFiles.hpp"
21#include "MobileNetModel.hpp"
22#include "UseCaseCommonUtils.hpp"
23#include "hal.h"
24
25using ImgClassClassifier = arm::app::Classifier;
26
27namespace arm {
28namespace app {
29
30 /**
31 * @brief Helper function to load the current image into the input
32 * tensor.
33 * @param[in] imIdx Image index (from the pool of images available
34 * to the application).
35 * @param[out] inputTensor Pointer to the input tensor to be populated.
36 * @return true if tensor is loaded, false otherwise.
37 **/
alexanderc350cdc2021-04-29 20:36:09 +010038 static bool LoadImageIntoTensor(uint32_t imIdx, TfLiteTensor* inputTensor);
alexander3c798932021-03-26 21:42:19 +000039
40 /**
41 * @brief Helper function to increment current image index.
42 * @param[in,out] ctx Pointer to the application context object.
43 **/
alexanderc350cdc2021-04-29 20:36:09 +010044 static void IncrementAppCtxImageIdx(ApplicationContext& ctx);
alexander3c798932021-03-26 21:42:19 +000045
46 /**
47 * @brief Helper function to set the image index.
48 * @param[in,out] ctx Pointer to the application context object.
49 * @param[in] idx Value to be set.
50 * @return true if index is set, false otherwise.
51 **/
alexanderc350cdc2021-04-29 20:36:09 +010052 static bool SetAppCtxImageIdx(ApplicationContext& ctx, uint32_t idx);
alexander3c798932021-03-26 21:42:19 +000053
54 /**
55 * @brief Presents inference results using the data presentation
56 * object.
57 * @param[in] platform Reference to the hal platform object.
58 * @param[in] results Vector of classification results to be displayed.
59 * @param[in] infTimeMs Inference time in milliseconds, if available
60 * otherwise, this can be passed in as 0.
61 * @return true if successful, false otherwise.
62 **/
alexanderc350cdc2021-04-29 20:36:09 +010063 static bool PresentInferenceResult(hal_platform& platform,
64 const std::vector<ClassificationResult>& results);
alexander3c798932021-03-26 21:42:19 +000065
66 /**
67 * @brief Helper function to convert a UINT8 image to INT8 format.
68 * @param[in,out] data Pointer to the data start.
69 * @param[in] kMaxImageSize Total number of pixels in the image.
70 **/
71 static void ConvertImgToInt8(void* data, size_t kMaxImageSize);
72
73 /* Image inference classification handler. */
74 bool ClassifyImageHandler(ApplicationContext& ctx, uint32_t imgIndex, bool runAll)
75 {
76 auto& platform = ctx.Get<hal_platform&>("platform");
Isabella Gottardi8df12f32021-04-07 17:15:31 +010077 auto& profiler = ctx.Get<Profiler&>("profiler");
alexander3c798932021-03-26 21:42:19 +000078
79 constexpr uint32_t dataPsnImgDownscaleFactor = 2;
80 constexpr uint32_t dataPsnImgStartX = 10;
81 constexpr uint32_t dataPsnImgStartY = 35;
82
83 constexpr uint32_t dataPsnTxtInfStartX = 150;
84 constexpr uint32_t dataPsnTxtInfStartY = 40;
85
86 platform.data_psn->clear(COLOR_BLACK);
87
88 auto& model = ctx.Get<Model&>("model");
89
90 /* If the request has a valid size, set the image index. */
91 if (imgIndex < NUMBER_OF_FILES) {
alexanderc350cdc2021-04-29 20:36:09 +010092 if (!SetAppCtxImageIdx(ctx, imgIndex)) {
alexander3c798932021-03-26 21:42:19 +000093 return false;
94 }
95 }
96 if (!model.IsInited()) {
97 printf_err("Model is not initialised! Terminating processing.\n");
98 return false;
99 }
100
101 auto curImIdx = ctx.Get<uint32_t>("imgIndex");
102
103 TfLiteTensor* outputTensor = model.GetOutputTensor(0);
104 TfLiteTensor* inputTensor = model.GetInputTensor(0);
105
106 if (!inputTensor->dims) {
107 printf_err("Invalid input tensor dims\n");
108 return false;
109 } else if (inputTensor->dims->size < 3) {
110 printf_err("Input tensor dimension should be >= 3\n");
111 return false;
112 }
113
114 TfLiteIntArray* inputShape = model.GetInputShape(0);
115
116 const uint32_t nCols = inputShape->data[arm::app::MobileNetModel::ms_inputColsIdx];
117 const uint32_t nRows = inputShape->data[arm::app::MobileNetModel::ms_inputRowsIdx];
118 const uint32_t nChannels = inputShape->data[arm::app::MobileNetModel::ms_inputChannelsIdx];
119
120 std::vector<ClassificationResult> results;
121
122 do {
123 /* Strings for presentation/logging. */
124 std::string str_inf{"Running inference... "};
125
126 /* Copy over the data. */
alexanderc350cdc2021-04-29 20:36:09 +0100127 LoadImageIntoTensor(ctx.Get<uint32_t>("imgIndex"), inputTensor);
alexander3c798932021-03-26 21:42:19 +0000128
129 /* Display this image on the LCD. */
130 platform.data_psn->present_data_image(
131 (uint8_t*) inputTensor->data.data,
132 nCols, nRows, nChannels,
133 dataPsnImgStartX, dataPsnImgStartY, dataPsnImgDownscaleFactor);
134
135 /* If the data is signed. */
136 if (model.IsDataSigned()) {
137 ConvertImgToInt8(inputTensor->data.data, inputTensor->bytes);
138 }
139
140 /* Display message on the LCD - inference running. */
141 platform.data_psn->present_data_text(str_inf.c_str(), str_inf.size(),
142 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
143
144 /* Run inference over this image. */
145 info("Running inference on image %u => %s\n", ctx.Get<uint32_t>("imgIndex"),
146 get_filename(ctx.Get<uint32_t>("imgIndex")));
147
alexander27b62d92021-05-04 20:46:08 +0100148 if (!RunInference(model, profiler)) {
149 return false;
150 }
alexander3c798932021-03-26 21:42:19 +0000151
152 /* Erase. */
153 str_inf = std::string(str_inf.size(), ' ');
154 platform.data_psn->present_data_text(str_inf.c_str(), str_inf.size(),
155 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
156
157 auto& classifier = ctx.Get<ImgClassClassifier&>("classifier");
158 classifier.GetClassificationResults(outputTensor, results,
159 ctx.Get<std::vector <std::string>&>("labels"),
160 5);
161
162 /* Add results to context for access outside handler. */
163 ctx.Set<std::vector<ClassificationResult>>("results", results);
164
165#if VERIFY_TEST_OUTPUT
166 arm::app::DumpTensor(outputTensor);
167#endif /* VERIFY_TEST_OUTPUT */
168
alexanderc350cdc2021-04-29 20:36:09 +0100169 if (!PresentInferenceResult(platform, results)) {
alexander3c798932021-03-26 21:42:19 +0000170 return false;
171 }
172
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100173 profiler.PrintProfilingResult();
174
alexanderc350cdc2021-04-29 20:36:09 +0100175 IncrementAppCtxImageIdx(ctx);
alexander3c798932021-03-26 21:42:19 +0000176
177 } while (runAll && ctx.Get<uint32_t>("imgIndex") != curImIdx);
178
179 return true;
180 }
181
alexanderc350cdc2021-04-29 20:36:09 +0100182 static bool LoadImageIntoTensor(uint32_t imIdx, TfLiteTensor* inputTensor)
alexander3c798932021-03-26 21:42:19 +0000183 {
184 const size_t copySz = inputTensor->bytes < IMAGE_DATA_SIZE ?
185 inputTensor->bytes : IMAGE_DATA_SIZE;
186 const uint8_t* imgSrc = get_img_array(imIdx);
187 if (nullptr == imgSrc) {
188 printf_err("Failed to get image index %u (max: %u)\n", imIdx,
189 NUMBER_OF_FILES - 1);
190 return false;
191 }
192
193 memcpy(inputTensor->data.data, imgSrc, copySz);
194 debug("Image %u loaded\n", imIdx);
195 return true;
196 }
197
alexanderc350cdc2021-04-29 20:36:09 +0100198 static void IncrementAppCtxImageIdx(ApplicationContext& ctx)
alexander3c798932021-03-26 21:42:19 +0000199 {
200 auto curImIdx = ctx.Get<uint32_t>("imgIndex");
201
202 if (curImIdx + 1 >= NUMBER_OF_FILES) {
203 ctx.Set<uint32_t>("imgIndex", 0);
204 return;
205 }
206 ++curImIdx;
207 ctx.Set<uint32_t>("imgIndex", curImIdx);
208 }
209
alexanderc350cdc2021-04-29 20:36:09 +0100210 static bool SetAppCtxImageIdx(ApplicationContext& ctx, uint32_t idx)
alexander3c798932021-03-26 21:42:19 +0000211 {
212 if (idx >= NUMBER_OF_FILES) {
213 printf_err("Invalid idx %u (expected less than %u)\n",
214 idx, NUMBER_OF_FILES);
215 return false;
216 }
217 ctx.Set<uint32_t>("imgIndex", idx);
218 return true;
219 }
220
alexanderc350cdc2021-04-29 20:36:09 +0100221 static bool PresentInferenceResult(hal_platform& platform,
222 const std::vector<ClassificationResult>& results)
alexander3c798932021-03-26 21:42:19 +0000223 {
224 constexpr uint32_t dataPsnTxtStartX1 = 150;
225 constexpr uint32_t dataPsnTxtStartY1 = 30;
226
227 constexpr uint32_t dataPsnTxtStartX2 = 10;
228 constexpr uint32_t dataPsnTxtStartY2 = 150;
229
230 constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */
231
232 platform.data_psn->set_text_color(COLOR_GREEN);
233
234 /* Display each result. */
235 uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
236 uint32_t rowIdx2 = dataPsnTxtStartY2;
237
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100238 info("Final results:\n");
239 info("Total number of inferences: 1\n");
alexander3c798932021-03-26 21:42:19 +0000240 for (uint32_t i = 0; i < results.size(); ++i) {
241 std::string resultStr =
242 std::to_string(i + 1) + ") " +
243 std::to_string(results[i].m_labelIdx) +
244 " (" + std::to_string(results[i].m_normalisedVal) + ")";
245
246 platform.data_psn->present_data_text(
247 resultStr.c_str(), resultStr.size(),
248 dataPsnTxtStartX1, rowIdx1, 0);
249 rowIdx1 += dataPsnTxtYIncr;
250
251 resultStr = std::to_string(i + 1) + ") " + results[i].m_label;
252 platform.data_psn->present_data_text(
253 resultStr.c_str(), resultStr.size(),
254 dataPsnTxtStartX2, rowIdx2, 0);
255 rowIdx2 += dataPsnTxtYIncr;
256
257 info("%u) %u (%f) -> %s\n", i, results[i].m_labelIdx,
258 results[i].m_normalisedVal, results[i].m_label.c_str());
259 }
260
261 return true;
262 }
263
264 static void ConvertImgToInt8(void* data, const size_t kMaxImageSize)
265 {
266 auto* tmp_req_data = (uint8_t*) data;
267 auto* tmp_signed_req_data = (int8_t*) data;
268
269 for (size_t i = 0; i < kMaxImageSize; i++) {
270 tmp_signed_req_data[i] = (int8_t) (
271 (int32_t) (tmp_req_data[i]) - 128);
272 }
273 }
274
275} /* namespace app */
276} /* namespace arm */