blob: a412fec6f260ef919b9e814bd6603fb37154abf0 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18
19#include "Classifier.hpp"
20#include "InputFiles.hpp"
21#include "MobileNetModel.hpp"
22#include "UseCaseCommonUtils.hpp"
23#include "hal.h"
24
25using ImgClassClassifier = arm::app::Classifier;
26
27namespace arm {
28namespace app {
29
30 /**
31 * @brief Helper function to load the current image into the input
32 * tensor.
33 * @param[in] imIdx Image index (from the pool of images available
34 * to the application).
35 * @param[out] inputTensor Pointer to the input tensor to be populated.
36 * @return true if tensor is loaded, false otherwise.
37 **/
38 static bool _LoadImageIntoTensor(uint32_t imIdx, TfLiteTensor* inputTensor);
39
40 /**
41 * @brief Helper function to increment current image index.
42 * @param[in,out] ctx Pointer to the application context object.
43 **/
44 static void _IncrementAppCtxImageIdx(ApplicationContext& ctx);
45
46 /**
47 * @brief Helper function to set the image index.
48 * @param[in,out] ctx Pointer to the application context object.
49 * @param[in] idx Value to be set.
50 * @return true if index is set, false otherwise.
51 **/
52 static bool _SetAppCtxImageIdx(ApplicationContext& ctx, uint32_t idx);
53
54 /**
55 * @brief Presents inference results using the data presentation
56 * object.
57 * @param[in] platform Reference to the hal platform object.
58 * @param[in] results Vector of classification results to be displayed.
59 * @param[in] infTimeMs Inference time in milliseconds, if available
60 * otherwise, this can be passed in as 0.
61 * @return true if successful, false otherwise.
62 **/
63 static bool _PresentInferenceResult(hal_platform& platform,
64 const std::vector<ClassificationResult>& results);
65
66 /**
67 * @brief Helper function to convert a UINT8 image to INT8 format.
68 * @param[in,out] data Pointer to the data start.
69 * @param[in] kMaxImageSize Total number of pixels in the image.
70 **/
71 static void ConvertImgToInt8(void* data, size_t kMaxImageSize);
72
73 /* Image inference classification handler. */
74 bool ClassifyImageHandler(ApplicationContext& ctx, uint32_t imgIndex, bool runAll)
75 {
76 auto& platform = ctx.Get<hal_platform&>("platform");
77
78 constexpr uint32_t dataPsnImgDownscaleFactor = 2;
79 constexpr uint32_t dataPsnImgStartX = 10;
80 constexpr uint32_t dataPsnImgStartY = 35;
81
82 constexpr uint32_t dataPsnTxtInfStartX = 150;
83 constexpr uint32_t dataPsnTxtInfStartY = 40;
84
85 platform.data_psn->clear(COLOR_BLACK);
86
87 auto& model = ctx.Get<Model&>("model");
88
89 /* If the request has a valid size, set the image index. */
90 if (imgIndex < NUMBER_OF_FILES) {
91 if (!_SetAppCtxImageIdx(ctx, imgIndex)) {
92 return false;
93 }
94 }
95 if (!model.IsInited()) {
96 printf_err("Model is not initialised! Terminating processing.\n");
97 return false;
98 }
99
100 auto curImIdx = ctx.Get<uint32_t>("imgIndex");
101
102 TfLiteTensor* outputTensor = model.GetOutputTensor(0);
103 TfLiteTensor* inputTensor = model.GetInputTensor(0);
104
105 if (!inputTensor->dims) {
106 printf_err("Invalid input tensor dims\n");
107 return false;
108 } else if (inputTensor->dims->size < 3) {
109 printf_err("Input tensor dimension should be >= 3\n");
110 return false;
111 }
112
113 TfLiteIntArray* inputShape = model.GetInputShape(0);
114
115 const uint32_t nCols = inputShape->data[arm::app::MobileNetModel::ms_inputColsIdx];
116 const uint32_t nRows = inputShape->data[arm::app::MobileNetModel::ms_inputRowsIdx];
117 const uint32_t nChannels = inputShape->data[arm::app::MobileNetModel::ms_inputChannelsIdx];
118
119 std::vector<ClassificationResult> results;
120
121 do {
122 /* Strings for presentation/logging. */
123 std::string str_inf{"Running inference... "};
124
125 /* Copy over the data. */
126 _LoadImageIntoTensor(ctx.Get<uint32_t>("imgIndex"), inputTensor);
127
128 /* Display this image on the LCD. */
129 platform.data_psn->present_data_image(
130 (uint8_t*) inputTensor->data.data,
131 nCols, nRows, nChannels,
132 dataPsnImgStartX, dataPsnImgStartY, dataPsnImgDownscaleFactor);
133
134 /* If the data is signed. */
135 if (model.IsDataSigned()) {
136 ConvertImgToInt8(inputTensor->data.data, inputTensor->bytes);
137 }
138
139 /* Display message on the LCD - inference running. */
140 platform.data_psn->present_data_text(str_inf.c_str(), str_inf.size(),
141 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
142
143 /* Run inference over this image. */
144 info("Running inference on image %u => %s\n", ctx.Get<uint32_t>("imgIndex"),
145 get_filename(ctx.Get<uint32_t>("imgIndex")));
146
147 RunInference(platform, model);
148
149 /* Erase. */
150 str_inf = std::string(str_inf.size(), ' ');
151 platform.data_psn->present_data_text(str_inf.c_str(), str_inf.size(),
152 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
153
154 auto& classifier = ctx.Get<ImgClassClassifier&>("classifier");
155 classifier.GetClassificationResults(outputTensor, results,
156 ctx.Get<std::vector <std::string>&>("labels"),
157 5);
158
159 /* Add results to context for access outside handler. */
160 ctx.Set<std::vector<ClassificationResult>>("results", results);
161
162#if VERIFY_TEST_OUTPUT
163 arm::app::DumpTensor(outputTensor);
164#endif /* VERIFY_TEST_OUTPUT */
165
166 if (!_PresentInferenceResult(platform, results)) {
167 return false;
168 }
169
170 _IncrementAppCtxImageIdx(ctx);
171
172 } while (runAll && ctx.Get<uint32_t>("imgIndex") != curImIdx);
173
174 return true;
175 }
176
177 static bool _LoadImageIntoTensor(const uint32_t imIdx, TfLiteTensor* inputTensor)
178 {
179 const size_t copySz = inputTensor->bytes < IMAGE_DATA_SIZE ?
180 inputTensor->bytes : IMAGE_DATA_SIZE;
181 const uint8_t* imgSrc = get_img_array(imIdx);
182 if (nullptr == imgSrc) {
183 printf_err("Failed to get image index %u (max: %u)\n", imIdx,
184 NUMBER_OF_FILES - 1);
185 return false;
186 }
187
188 memcpy(inputTensor->data.data, imgSrc, copySz);
189 debug("Image %u loaded\n", imIdx);
190 return true;
191 }
192
193 static void _IncrementAppCtxImageIdx(ApplicationContext& ctx)
194 {
195 auto curImIdx = ctx.Get<uint32_t>("imgIndex");
196
197 if (curImIdx + 1 >= NUMBER_OF_FILES) {
198 ctx.Set<uint32_t>("imgIndex", 0);
199 return;
200 }
201 ++curImIdx;
202 ctx.Set<uint32_t>("imgIndex", curImIdx);
203 }
204
205 static bool _SetAppCtxImageIdx(ApplicationContext& ctx, const uint32_t idx)
206 {
207 if (idx >= NUMBER_OF_FILES) {
208 printf_err("Invalid idx %u (expected less than %u)\n",
209 idx, NUMBER_OF_FILES);
210 return false;
211 }
212 ctx.Set<uint32_t>("imgIndex", idx);
213 return true;
214 }
215
216 static bool _PresentInferenceResult(hal_platform& platform,
217 const std::vector<ClassificationResult>& results)
218 {
219 constexpr uint32_t dataPsnTxtStartX1 = 150;
220 constexpr uint32_t dataPsnTxtStartY1 = 30;
221
222 constexpr uint32_t dataPsnTxtStartX2 = 10;
223 constexpr uint32_t dataPsnTxtStartY2 = 150;
224
225 constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */
226
227 platform.data_psn->set_text_color(COLOR_GREEN);
228
229 /* Display each result. */
230 uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
231 uint32_t rowIdx2 = dataPsnTxtStartY2;
232
233 for (uint32_t i = 0; i < results.size(); ++i) {
234 std::string resultStr =
235 std::to_string(i + 1) + ") " +
236 std::to_string(results[i].m_labelIdx) +
237 " (" + std::to_string(results[i].m_normalisedVal) + ")";
238
239 platform.data_psn->present_data_text(
240 resultStr.c_str(), resultStr.size(),
241 dataPsnTxtStartX1, rowIdx1, 0);
242 rowIdx1 += dataPsnTxtYIncr;
243
244 resultStr = std::to_string(i + 1) + ") " + results[i].m_label;
245 platform.data_psn->present_data_text(
246 resultStr.c_str(), resultStr.size(),
247 dataPsnTxtStartX2, rowIdx2, 0);
248 rowIdx2 += dataPsnTxtYIncr;
249
250 info("%u) %u (%f) -> %s\n", i, results[i].m_labelIdx,
251 results[i].m_normalisedVal, results[i].m_label.c_str());
252 }
253
254 return true;
255 }
256
257 static void ConvertImgToInt8(void* data, const size_t kMaxImageSize)
258 {
259 auto* tmp_req_data = (uint8_t*) data;
260 auto* tmp_signed_req_data = (int8_t*) data;
261
262 for (size_t i = 0; i < kMaxImageSize; i++) {
263 tmp_signed_req_data[i] = (int8_t) (
264 (int32_t) (tmp_req_data[i]) - 128);
265 }
266 }
267
268} /* namespace app */
269} /* namespace arm */