blob: 56ba2b51caeccd9f819f21918402ca92372b5960 [file] [log] [blame]
Éanna Ó Catháin8f958872021-09-15 09:32:30 +01001/*
Liam Barrye9588502022-01-25 14:31:15 +00002 * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
Éanna Ó Catháin8f958872021-09-15 09:32:30 +01003 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18#include "VisualWakeWordModel.hpp"
19#include "Classifier.hpp"
20#include "InputFiles.hpp"
Richard Burtoned35a6f2022-02-14 11:55:35 +000021#include "ImageUtils.hpp"
Éanna Ó Catháin8f958872021-09-15 09:32:30 +010022#include "UseCaseCommonUtils.hpp"
23#include "hal.h"
alexander31ae9f02022-02-10 16:15:54 +000024#include "log_macros.h"
Éanna Ó Catháin8f958872021-09-15 09:32:30 +010025
Isabella Gottardi79d41542021-10-20 15:52:32 +010026#include <algorithm>
27
Éanna Ó Catháin8f958872021-09-15 09:32:30 +010028namespace arm {
29namespace app {
30
31 /**
32 * @brief Helper function to load the current image into the input
33 * tensor.
34 * @param[in] imIdx Image index (from the pool of images available
35 * to the application).
36 * @param[out] inputTensor Pointer to the input tensor to be populated.
37 * @return true if tensor is loaded, false otherwise.
38 **/
39 static bool LoadImageIntoTensor(uint32_t imIdx,
40 TfLiteTensor *inputTensor);
41
42 /* Image inference classification handler. */
43 bool ClassifyImageHandler(ApplicationContext &ctx, uint32_t imgIndex, bool runAll)
44 {
Éanna Ó Catháin8f958872021-09-15 09:32:30 +010045 auto& profiler = ctx.Get<Profiler&>("profiler");
46
47 constexpr uint32_t dataPsnImgDownscaleFactor = 1;
48 constexpr uint32_t dataPsnImgStartX = 10;
49 constexpr uint32_t dataPsnImgStartY = 35;
50
51 constexpr uint32_t dataPsnTxtInfStartX = 150;
52 constexpr uint32_t dataPsnTxtInfStartY = 70;
53
Éanna Ó Catháin8f958872021-09-15 09:32:30 +010054 auto& model = ctx.Get<Model&>("model");
55
56 /* If the request has a valid size, set the image index. */
57 if (imgIndex < NUMBER_OF_FILES) {
58 if (!SetAppCtxIfmIdx(ctx, imgIndex,"imgIndex")) {
59 return false;
60 }
61 }
62 if (!model.IsInited()) {
63 printf_err("Model is not initialised! Terminating processing.\n");
64 return false;
65 }
66
67 auto curImIdx = ctx.Get<uint32_t>("imgIndex");
68
69 TfLiteTensor *outputTensor = model.GetOutputTensor(0);
70 TfLiteTensor *inputTensor = model.GetInputTensor(0);
71
72 if (!inputTensor->dims) {
73 printf_err("Invalid input tensor dims\n");
74 return false;
75 } else if (inputTensor->dims->size < 3) {
76 printf_err("Input tensor dimension should be >= 3\n");
77 return false;
78 }
79 TfLiteIntArray* inputShape = model.GetInputShape(0);
Isabella Gottardi3107aa22022-01-27 16:39:37 +000080 const uint32_t nCols = inputShape->data[arm::app::VisualWakeWordModel::ms_inputColsIdx];
81 const uint32_t nRows = inputShape->data[arm::app::VisualWakeWordModel::ms_inputRowsIdx];
82 if (arm::app::VisualWakeWordModel::ms_inputChannelsIdx >= static_cast<uint32_t>(inputShape->size)) {
83 printf_err("Invalid channel index.\n");
84 return false;
85 }
86 const uint32_t nChannels = inputShape->data[arm::app::VisualWakeWordModel::ms_inputChannelsIdx];
Éanna Ó Catháin8f958872021-09-15 09:32:30 +010087
88 std::vector<ClassificationResult> results;
89
90 do {
Kshitij Sisodia68fdd112022-04-06 13:03:20 +010091 hal_lcd_clear(COLOR_BLACK);
Éanna Ó Catháin8f958872021-09-15 09:32:30 +010092
93 /* Strings for presentation/logging. */
94 std::string str_inf{"Running inference... "};
95
96 /* Copy over the data. */
97 LoadImageIntoTensor(ctx.Get<uint32_t>("imgIndex"), inputTensor);
98
99 /* Display this image on the LCD. */
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100100 hal_lcd_display_image(
Isabella Gottardi79d41542021-10-20 15:52:32 +0100101 static_cast<uint8_t *>(inputTensor->data.data),
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100102 nCols, nRows, nChannels,
103 dataPsnImgStartX, dataPsnImgStartY, dataPsnImgDownscaleFactor);
104
Isabella Gottardi79d41542021-10-20 15:52:32 +0100105 /* Vww model preprocessing is image conversion from uint8 to [0,1] float values,
106 * then quantize them with input quantization info. */
107 QuantParams inQuantParams = GetTensorQuantParams(inputTensor);
108
109 auto* req_data = static_cast<uint8_t *>(inputTensor->data.data);
110 auto* signed_req_data = static_cast<int8_t *>(inputTensor->data.data);
111 for (size_t i = 0; i < inputTensor->bytes; i++) {
112 auto i_data_int8 = static_cast<int8_t>(((static_cast<float>(req_data[i]) / 255.0f) / inQuantParams.scale) + inQuantParams.offset);
113 signed_req_data[i] = std::min<int8_t>(INT8_MAX, std::max<int8_t>(i_data_int8, INT8_MIN));
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100114 }
115
116 /* Display message on the LCD - inference running. */
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100117 hal_lcd_display_text(
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100118 str_inf.c_str(), str_inf.size(),
119 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
120
121 /* Run inference over this image. */
122 info("Running inference on image %" PRIu32 " => %s\n", ctx.Get<uint32_t>("imgIndex"),
123 get_filename(ctx.Get<uint32_t>("imgIndex")));
124
125 if (!RunInference(model, profiler)) {
126 return false;
127 }
128
129 /* Erase. */
130 str_inf = std::string(str_inf.size(), ' ');
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100131 hal_lcd_display_text(
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100132 str_inf.c_str(), str_inf.size(),
133 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
134
135 auto& classifier = ctx.Get<Classifier&>("classifier");
136 classifier.GetClassificationResults(outputTensor, results,
alexander31ae9f02022-02-10 16:15:54 +0000137 ctx.Get<std::vector <std::string>&>("labels"), 1,
138 false);
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100139
140 /* Add results to context for access outside handler. */
141 ctx.Set<std::vector<ClassificationResult>>("results", results);
142
143#if VERIFY_TEST_OUTPUT
144 arm::app::DumpTensor(outputTensor);
145#endif /* VERIFY_TEST_OUTPUT */
146
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100147 if (!PresentInferenceResult(results)) {
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100148 return false;
149 }
150
151 profiler.PrintProfilingResult();
152 IncrementAppCtxIfmIdx(ctx,"imgIndex");
153
154 } while (runAll && ctx.Get<uint32_t>("imgIndex") != curImIdx);
155
156 return true;
157 }
158
159 static bool LoadImageIntoTensor(const uint32_t imIdx,
160 TfLiteTensor *inputTensor)
161 {
162 const size_t copySz = inputTensor->bytes < IMAGE_DATA_SIZE ?
163 inputTensor->bytes : IMAGE_DATA_SIZE;
164 if (imIdx >= NUMBER_OF_FILES) {
165 printf_err("invalid image index %" PRIu32 " (max: %u)\n", imIdx,
166 NUMBER_OF_FILES - 1);
167 return false;
168 }
169
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000170 if (arm::app::VisualWakeWordModel::ms_inputChannelsIdx >= static_cast<uint32_t>(inputTensor->dims->size)) {
171 printf_err("Invalid channel index.\n");
172 return false;
173 }
174 const uint32_t nChannels = inputTensor->dims->data[arm::app::VisualWakeWordModel::ms_inputChannelsIdx];
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100175
176 const uint8_t* srcPtr = get_img_array(imIdx);
Isabella Gottardi79d41542021-10-20 15:52:32 +0100177 auto* dstPtr = static_cast<uint8_t *>(inputTensor->data.data);
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100178 if (1 == nChannels) {
179 /**
180 * Visual Wake Word model accepts only one channel =>
181 * Convert image to grayscale here
182 **/
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000183 image::RgbToGrayscale(srcPtr, dstPtr, copySz);
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100184 } else {
185 memcpy(inputTensor->data.data, srcPtr, copySz);
186 }
187
188 debug("Image %" PRIu32 " loaded\n", imIdx);
189 return true;
190 }
191
192} /* namespace app */
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000193} /* namespace arm */