blob: 236526438148486e8ffadaa7592e435ba9b3e382 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
Richard Burton4e002792022-05-04 09:45:02 +01002 * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
alexander3c798932021-03-26 21:42:19 +00003 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
alexander3c798932021-03-26 21:42:19 +000017#include "InputFiles.hpp" /* For input images. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +000018#include "Labels_micronetkws.hpp" /* For MicroNetKws label strings. */
alexander3c798932021-03-26 21:42:19 +000019#include "Labels_wav2letter.hpp" /* For Wav2Letter label strings. */
20#include "Classifier.hpp" /* KWS classifier. */
21#include "AsrClassifier.hpp" /* ASR classifier. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +000022#include "MicroNetKwsModel.hpp" /* KWS model class for running inference. */
alexander3c798932021-03-26 21:42:19 +000023#include "Wav2LetterModel.hpp" /* ASR model class for running inference. */
24#include "UseCaseCommonUtils.hpp" /* Utils functions. */
25#include "UseCaseHandler.hpp" /* Handlers for different user options. */
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010026#include "log_macros.h" /* Logging functions */
27#include "BufAttributes.hpp" /* Buffer attributes to be applied */
28
29namespace arm {
30namespace app {
31 static uint8_t tensorArena[ACTIVATION_BUF_SZ] ACTIVATION_BUF_ATTRIBUTE;
32
33 namespace asr {
34 extern uint8_t* GetModelPointer();
35 extern size_t GetModelLen();
36 }
37
38 namespace kws {
39 extern uint8_t* GetModelPointer();
40 extern size_t GetModelLen();
41 }
42} /* namespace app */
43} /* namespace arm */
alexander3c798932021-03-26 21:42:19 +000044
45using KwsClassifier = arm::app::Classifier;
46
47enum opcodes
48{
49 MENU_OPT_RUN_INF_NEXT = 1, /* Run on next vector. */
50 MENU_OPT_RUN_INF_CHOSEN, /* Run on a user provided vector index. */
51 MENU_OPT_RUN_INF_ALL, /* Run inference on all. */
52 MENU_OPT_SHOW_MODEL_INFO, /* Show model info. */
53 MENU_OPT_LIST_AUDIO_CLIPS /* List the current baked audio clips. */
54};
55
56static void DisplayMenu()
57{
Kshitij Sisodia3c8256d2021-05-24 16:12:40 +010058 printf("\n\n");
59 printf("User input required\n");
alexander3c798932021-03-26 21:42:19 +000060 printf("Enter option number from:\n\n");
61 printf(" %u. Classify next audio clip\n", MENU_OPT_RUN_INF_NEXT);
62 printf(" %u. Classify audio clip at chosen index\n", MENU_OPT_RUN_INF_CHOSEN);
63 printf(" %u. Run classification on all audio clips\n", MENU_OPT_RUN_INF_ALL);
64 printf(" %u. Show NN model info\n", MENU_OPT_SHOW_MODEL_INFO);
65 printf(" %u. List audio clips\n\n", MENU_OPT_LIST_AUDIO_CLIPS);
66 printf(" Choice: ");
George Gekov93e59512021-08-03 11:18:41 +010067 fflush(stdout);
alexander3c798932021-03-26 21:42:19 +000068}
69
Richard Burton4e002792022-05-04 09:45:02 +010070/** @brief Verify input and output tensor are of certain min dimensions. */
71static bool VerifyTensorDimensions(const arm::app::Model& model);
alexander3c798932021-03-26 21:42:19 +000072
Kshitij Sisodia4cc40212022-04-08 09:54:53 +010073void main_loop()
alexander3c798932021-03-26 21:42:19 +000074{
75 /* Model wrapper objects. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +000076 arm::app::MicroNetKwsModel kwsModel;
alexander3c798932021-03-26 21:42:19 +000077 arm::app::Wav2LetterModel asrModel;
78
79 /* Load the models. */
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010080 if (!kwsModel.Init(arm::app::tensorArena,
81 sizeof(arm::app::tensorArena),
82 arm::app::kws::GetModelPointer(),
83 arm::app::kws::GetModelLen())) {
alexander3c798932021-03-26 21:42:19 +000084 printf_err("Failed to initialise KWS model\n");
85 return;
86 }
87
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010088#if !defined(ARM_NPU)
89 /* If it is not a NPU build check if the model contains a NPU operator */
90 if (kwsModel.ContainsEthosUOperator()) {
91 printf_err("No driver support for Ethos-U operator found in the KWS model.\n");
92 return;
93 }
94#endif /* ARM_NPU */
95
alexander3c798932021-03-26 21:42:19 +000096 /* Initialise the asr model using the same allocator from KWS
97 * to re-use the tensor arena. */
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010098 if (!asrModel.Init(arm::app::tensorArena,
99 sizeof(arm::app::tensorArena),
100 arm::app::asr::GetModelPointer(),
101 arm::app::asr::GetModelLen(),
102 kwsModel.GetAllocator())) {
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000103 printf_err("Failed to initialise ASR model\n");
alexander3c798932021-03-26 21:42:19 +0000104 return;
Richard Burton4e002792022-05-04 09:45:02 +0100105 } else if (!VerifyTensorDimensions(asrModel)) {
106 printf_err("Model's input or output dimension verification failed\n");
107 return;
alexander3c798932021-03-26 21:42:19 +0000108 }
109
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +0100110#if !defined(ARM_NPU)
111 /* If it is not a NPU build check if the model contains a NPU operator */
112 if (asrModel.ContainsEthosUOperator()) {
113 printf_err("No driver support for Ethos-U operator found in the ASR model.\n");
114 return;
115 }
116#endif /* ARM_NPU */
117
alexander3c798932021-03-26 21:42:19 +0000118 /* Instantiate application context. */
119 arm::app::ApplicationContext caseContext;
120
Kshitij Sisodia4cc40212022-04-08 09:54:53 +0100121 arm::app::Profiler profiler{"kws_asr"};
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100122 caseContext.Set<arm::app::Profiler&>("profiler", profiler);
Richard Burton4e002792022-05-04 09:45:02 +0100123 caseContext.Set<arm::app::Model&>("kwsModel", kwsModel);
124 caseContext.Set<arm::app::Model&>("asrModel", asrModel);
alexander3c798932021-03-26 21:42:19 +0000125 caseContext.Set<uint32_t>("clipIndex", 0);
126 caseContext.Set<uint32_t>("ctxLen", arm::app::asr::g_ctxLen); /* Left and right context length (MFCC feat vectors). */
Richard Burton4e002792022-05-04 09:45:02 +0100127 caseContext.Set<int>("kwsFrameLength", arm::app::kws::g_FrameLength);
128 caseContext.Set<int>("kwsFrameStride", arm::app::kws::g_FrameStride);
129 caseContext.Set<float>("kwsScoreThreshold", arm::app::kws::g_ScoreThreshold); /* Normalised score threshold. */
alexander3c798932021-03-26 21:42:19 +0000130 caseContext.Set<uint32_t >("kwsNumMfcc", arm::app::kws::g_NumMfcc);
131 caseContext.Set<uint32_t >("kwsNumAudioWins", arm::app::kws::g_NumAudioWins);
132
Richard Burton4e002792022-05-04 09:45:02 +0100133 caseContext.Set<int>("asrFrameLength", arm::app::asr::g_FrameLength);
134 caseContext.Set<int>("asrFrameStride", arm::app::asr::g_FrameStride);
135 caseContext.Set<float>("asrScoreThreshold", arm::app::asr::g_ScoreThreshold); /* Normalised score threshold. */
alexander3c798932021-03-26 21:42:19 +0000136
137 KwsClassifier kwsClassifier; /* Classifier wrapper object. */
138 arm::app::AsrClassifier asrClassifier; /* Classifier wrapper object. */
Richard Burton4e002792022-05-04 09:45:02 +0100139 caseContext.Set<arm::app::Classifier&>("kwsClassifier", kwsClassifier);
140 caseContext.Set<arm::app::AsrClassifier&>("asrClassifier", asrClassifier);
alexander3c798932021-03-26 21:42:19 +0000141
142 std::vector<std::string> asrLabels;
143 arm::app::asr::GetLabelsVector(asrLabels);
144 std::vector<std::string> kwsLabels;
145 arm::app::kws::GetLabelsVector(kwsLabels);
Richard Burton4e002792022-05-04 09:45:02 +0100146 caseContext.Set<const std::vector <std::string>&>("asrLabels", asrLabels);
147 caseContext.Set<const std::vector <std::string>&>("kwsLabels", kwsLabels);
alexander3c798932021-03-26 21:42:19 +0000148
Liam Barryb5b32d32021-12-30 11:35:00 +0000149 /* KWS keyword that triggers ASR and associated checks */
Richard Burton4e002792022-05-04 09:45:02 +0100150 std::string triggerKeyword = std::string("no");
Liam Barryb5b32d32021-12-30 11:35:00 +0000151 if (std::find(kwsLabels.begin(), kwsLabels.end(), triggerKeyword) != kwsLabels.end()) {
Richard Burton4e002792022-05-04 09:45:02 +0100152 caseContext.Set<const std::string &>("triggerKeyword", triggerKeyword);
Liam Barryb5b32d32021-12-30 11:35:00 +0000153 }
154 else {
155 printf_err("Selected trigger keyword not found in labels file\n");
156 return;
157 }
alexander3c798932021-03-26 21:42:19 +0000158
159 /* Loop. */
160 bool executionSuccessful = true;
161 constexpr bool bUseMenu = NUMBER_OF_FILES > 1 ? true : false;
162
163 /* Loop. */
164 do {
165 int menuOption = MENU_OPT_RUN_INF_NEXT;
166 if (bUseMenu) {
167 DisplayMenu();
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100168 menuOption = arm::app::ReadUserInputAsInt();
alexander3c798932021-03-26 21:42:19 +0000169 printf("\n");
170 }
171 switch (menuOption) {
172 case MENU_OPT_RUN_INF_NEXT:
173 executionSuccessful = ClassifyAudioHandler(
174 caseContext,
175 caseContext.Get<uint32_t>("clipIndex"),
176 false);
177 break;
178 case MENU_OPT_RUN_INF_CHOSEN: {
179 printf(" Enter the audio clip index [0, %d]: ",
180 NUMBER_OF_FILES-1);
Isabella Gottardi79d41542021-10-20 15:52:32 +0100181 fflush(stdout);
alexander3c798932021-03-26 21:42:19 +0000182 auto clipIndex = static_cast<uint32_t>(
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100183 arm::app::ReadUserInputAsInt());
alexander3c798932021-03-26 21:42:19 +0000184 executionSuccessful = ClassifyAudioHandler(caseContext,
185 clipIndex,
186 false);
187 break;
188 }
189 case MENU_OPT_RUN_INF_ALL:
190 executionSuccessful = ClassifyAudioHandler(
191 caseContext,
192 caseContext.Get<uint32_t>("clipIndex"),
193 true);
194 break;
195 case MENU_OPT_SHOW_MODEL_INFO:
196 executionSuccessful = kwsModel.ShowModelInfoHandler();
197 executionSuccessful = asrModel.ShowModelInfoHandler();
198 break;
199 case MENU_OPT_LIST_AUDIO_CLIPS:
200 executionSuccessful = ListFilesHandler(caseContext);
201 break;
202 default:
203 printf("Incorrect choice, try again.");
204 break;
205 }
206 } while (executionSuccessful && bUseMenu);
207 info("Main loop terminated.\n");
208}
209
Richard Burton4e002792022-05-04 09:45:02 +0100210static bool VerifyTensorDimensions(const arm::app::Model& model)
alexander3c798932021-03-26 21:42:19 +0000211{
Richard Burton4e002792022-05-04 09:45:02 +0100212 /* Populate tensor related parameters. */
alexander3c798932021-03-26 21:42:19 +0000213 TfLiteTensor* inputTensor = model.GetInputTensor(0);
Richard Burton4e002792022-05-04 09:45:02 +0100214 if (!inputTensor->dims) {
215 printf_err("Invalid input tensor dims\n");
216 return false;
217 } else if (inputTensor->dims->size < 3) {
218 printf_err("Input tensor dimension should be >= 3\n");
219 return false;
alexander3c798932021-03-26 21:42:19 +0000220 }
221
222 TfLiteTensor* outputTensor = model.GetOutputTensor(0);
Richard Burton4e002792022-05-04 09:45:02 +0100223 if (!outputTensor->dims) {
224 printf_err("Invalid output tensor dims\n");
225 return false;
226 } else if (outputTensor->dims->size < 3) {
227 printf_err("Output tensor dimension should be >= 3\n");
228 return false;
229 }
alexander3c798932021-03-26 21:42:19 +0000230
Richard Burton4e002792022-05-04 09:45:02 +0100231 return true;
alexander3c798932021-03-26 21:42:19 +0000232}