blob: dbfe92b6949c4e637fff0ec0b238880e6ce0d60e [file] [log] [blame]
Éanna Ó Catháin8f958872021-09-15 09:32:30 +01001/*
Liam Barrye9588502022-01-25 14:31:15 +00002 * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
Éanna Ó Catháin8f958872021-09-15 09:32:30 +01003 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18#include "VisualWakeWordModel.hpp"
19#include "Classifier.hpp"
20#include "InputFiles.hpp"
21#include "UseCaseCommonUtils.hpp"
22#include "hal.h"
23
Isabella Gottardi79d41542021-10-20 15:52:32 +010024#include <algorithm>
25
Éanna Ó Catháin8f958872021-09-15 09:32:30 +010026namespace arm {
27namespace app {
28
29 /**
30 * @brief Helper function to load the current image into the input
31 * tensor.
32 * @param[in] imIdx Image index (from the pool of images available
33 * to the application).
34 * @param[out] inputTensor Pointer to the input tensor to be populated.
35 * @return true if tensor is loaded, false otherwise.
36 **/
37 static bool LoadImageIntoTensor(uint32_t imIdx,
38 TfLiteTensor *inputTensor);
39
40 /* Image inference classification handler. */
41 bool ClassifyImageHandler(ApplicationContext &ctx, uint32_t imgIndex, bool runAll)
42 {
43 auto& platform = ctx.Get<hal_platform &>("platform");
44 auto& profiler = ctx.Get<Profiler&>("profiler");
45
46 constexpr uint32_t dataPsnImgDownscaleFactor = 1;
47 constexpr uint32_t dataPsnImgStartX = 10;
48 constexpr uint32_t dataPsnImgStartY = 35;
49
50 constexpr uint32_t dataPsnTxtInfStartX = 150;
51 constexpr uint32_t dataPsnTxtInfStartY = 70;
52
Éanna Ó Catháin8f958872021-09-15 09:32:30 +010053 time_t infTimeMs = 0;
54
55 auto& model = ctx.Get<Model&>("model");
56
57 /* If the request has a valid size, set the image index. */
58 if (imgIndex < NUMBER_OF_FILES) {
59 if (!SetAppCtxIfmIdx(ctx, imgIndex,"imgIndex")) {
60 return false;
61 }
62 }
63 if (!model.IsInited()) {
64 printf_err("Model is not initialised! Terminating processing.\n");
65 return false;
66 }
67
68 auto curImIdx = ctx.Get<uint32_t>("imgIndex");
69
70 TfLiteTensor *outputTensor = model.GetOutputTensor(0);
71 TfLiteTensor *inputTensor = model.GetInputTensor(0);
72
73 if (!inputTensor->dims) {
74 printf_err("Invalid input tensor dims\n");
75 return false;
76 } else if (inputTensor->dims->size < 3) {
77 printf_err("Input tensor dimension should be >= 3\n");
78 return false;
79 }
80 TfLiteIntArray* inputShape = model.GetInputShape(0);
81 const uint32_t nCols = inputShape->data[2];
82 const uint32_t nRows = inputShape->data[1];
83 const uint32_t nChannels = (inputShape->size == 4) ? inputShape->data[3] : 1;
84
85 std::vector<ClassificationResult> results;
86
87 do {
Richard Burton9b8d67a2021-12-10 12:32:51 +000088 platform.data_psn->clear(COLOR_BLACK);
Éanna Ó Catháin8f958872021-09-15 09:32:30 +010089
90 /* Strings for presentation/logging. */
91 std::string str_inf{"Running inference... "};
92
93 /* Copy over the data. */
94 LoadImageIntoTensor(ctx.Get<uint32_t>("imgIndex"), inputTensor);
95
96 /* Display this image on the LCD. */
97 platform.data_psn->present_data_image(
Isabella Gottardi79d41542021-10-20 15:52:32 +010098 static_cast<uint8_t *>(inputTensor->data.data),
Éanna Ó Catháin8f958872021-09-15 09:32:30 +010099 nCols, nRows, nChannels,
100 dataPsnImgStartX, dataPsnImgStartY, dataPsnImgDownscaleFactor);
101
Isabella Gottardi79d41542021-10-20 15:52:32 +0100102 /* Vww model preprocessing is image conversion from uint8 to [0,1] float values,
103 * then quantize them with input quantization info. */
104 QuantParams inQuantParams = GetTensorQuantParams(inputTensor);
105
106 auto* req_data = static_cast<uint8_t *>(inputTensor->data.data);
107 auto* signed_req_data = static_cast<int8_t *>(inputTensor->data.data);
108 for (size_t i = 0; i < inputTensor->bytes; i++) {
109 auto i_data_int8 = static_cast<int8_t>(((static_cast<float>(req_data[i]) / 255.0f) / inQuantParams.scale) + inQuantParams.offset);
110 signed_req_data[i] = std::min<int8_t>(INT8_MAX, std::max<int8_t>(i_data_int8, INT8_MIN));
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100111 }
112
113 /* Display message on the LCD - inference running. */
114 platform.data_psn->present_data_text(
115 str_inf.c_str(), str_inf.size(),
116 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
117
118 /* Run inference over this image. */
119 info("Running inference on image %" PRIu32 " => %s\n", ctx.Get<uint32_t>("imgIndex"),
120 get_filename(ctx.Get<uint32_t>("imgIndex")));
121
122 if (!RunInference(model, profiler)) {
123 return false;
124 }
125
126 /* Erase. */
127 str_inf = std::string(str_inf.size(), ' ');
128 platform.data_psn->present_data_text(
129 str_inf.c_str(), str_inf.size(),
130 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
131
132 auto& classifier = ctx.Get<Classifier&>("classifier");
133 classifier.GetClassificationResults(outputTensor, results,
134 ctx.Get<std::vector <std::string>&>("labels"), 1);
135
136 /* Add results to context for access outside handler. */
137 ctx.Set<std::vector<ClassificationResult>>("results", results);
138
139#if VERIFY_TEST_OUTPUT
140 arm::app::DumpTensor(outputTensor);
141#endif /* VERIFY_TEST_OUTPUT */
142
Liam Barrye9588502022-01-25 14:31:15 +0000143 if (!image::PresentInferenceResult(platform, results)) {
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100144 return false;
145 }
146
147 profiler.PrintProfilingResult();
148 IncrementAppCtxIfmIdx(ctx,"imgIndex");
149
150 } while (runAll && ctx.Get<uint32_t>("imgIndex") != curImIdx);
151
152 return true;
153 }
154
155 static bool LoadImageIntoTensor(const uint32_t imIdx,
156 TfLiteTensor *inputTensor)
157 {
158 const size_t copySz = inputTensor->bytes < IMAGE_DATA_SIZE ?
159 inputTensor->bytes : IMAGE_DATA_SIZE;
160 if (imIdx >= NUMBER_OF_FILES) {
161 printf_err("invalid image index %" PRIu32 " (max: %u)\n", imIdx,
162 NUMBER_OF_FILES - 1);
163 return false;
164 }
165
166 const uint32_t nChannels = (inputTensor->dims->size == 4) ? inputTensor->dims->data[3] : 1;
167
168 const uint8_t* srcPtr = get_img_array(imIdx);
Isabella Gottardi79d41542021-10-20 15:52:32 +0100169 auto* dstPtr = static_cast<uint8_t *>(inputTensor->data.data);
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100170 if (1 == nChannels) {
171 /**
172 * Visual Wake Word model accepts only one channel =>
173 * Convert image to grayscale here
174 **/
175 for (size_t i = 0; i < copySz; ++i, srcPtr += 3) {
176 *dstPtr++ = 0.2989*(*srcPtr) +
177 0.587*(*(srcPtr+1)) +
178 0.114*(*(srcPtr+2));
179 }
180 } else {
181 memcpy(inputTensor->data.data, srcPtr, copySz);
182 }
183
184 debug("Image %" PRIu32 " loaded\n", imIdx);
185 return true;
186 }
187
188} /* namespace app */
189} /* namespace arm */