blob: a59a0b0e6d67bda903c017e2646441dce7023aea [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
Richard Burtonf32a86a2022-11-15 11:46:11 +00002 * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
alexander3c798932021-03-26 21:42:19 +00003 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
alexander3c798932021-03-26 21:42:19 +000017#include "InputFiles.hpp" /* For input images. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +000018#include "Labels_micronetkws.hpp" /* For MicroNetKws label strings. */
alexander3c798932021-03-26 21:42:19 +000019#include "Labels_wav2letter.hpp" /* For Wav2Letter label strings. */
Richard Burtonec5e99b2022-10-05 11:00:37 +010020#include "KwsClassifier.hpp" /* KWS classifier. */
alexander3c798932021-03-26 21:42:19 +000021#include "AsrClassifier.hpp" /* ASR classifier. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +000022#include "MicroNetKwsModel.hpp" /* KWS model class for running inference. */
alexander3c798932021-03-26 21:42:19 +000023#include "Wav2LetterModel.hpp" /* ASR model class for running inference. */
24#include "UseCaseCommonUtils.hpp" /* Utils functions. */
25#include "UseCaseHandler.hpp" /* Handlers for different user options. */
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010026#include "log_macros.h" /* Logging functions */
27#include "BufAttributes.hpp" /* Buffer attributes to be applied */
28
29namespace arm {
30namespace app {
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010031
32 namespace asr {
33 extern uint8_t* GetModelPointer();
34 extern size_t GetModelLen();
Liam Barry213a5432022-05-09 17:06:19 +010035 } /* namespace asr */
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010036
37 namespace kws {
38 extern uint8_t* GetModelPointer();
39 extern size_t GetModelLen();
Liam Barry213a5432022-05-09 17:06:19 +010040 } /* namespace kws */
41 static uint8_t tensorArena[ACTIVATION_BUF_SZ] ACTIVATION_BUF_ATTRIBUTE;
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010042} /* namespace app */
43} /* namespace arm */
alexander3c798932021-03-26 21:42:19 +000044
alexander3c798932021-03-26 21:42:19 +000045enum opcodes
46{
47 MENU_OPT_RUN_INF_NEXT = 1, /* Run on next vector. */
48 MENU_OPT_RUN_INF_CHOSEN, /* Run on a user provided vector index. */
49 MENU_OPT_RUN_INF_ALL, /* Run inference on all. */
50 MENU_OPT_SHOW_MODEL_INFO, /* Show model info. */
51 MENU_OPT_LIST_AUDIO_CLIPS /* List the current baked audio clips. */
52};
53
54static void DisplayMenu()
55{
Kshitij Sisodia3c8256d2021-05-24 16:12:40 +010056 printf("\n\n");
57 printf("User input required\n");
alexander3c798932021-03-26 21:42:19 +000058 printf("Enter option number from:\n\n");
59 printf(" %u. Classify next audio clip\n", MENU_OPT_RUN_INF_NEXT);
60 printf(" %u. Classify audio clip at chosen index\n", MENU_OPT_RUN_INF_CHOSEN);
61 printf(" %u. Run classification on all audio clips\n", MENU_OPT_RUN_INF_ALL);
62 printf(" %u. Show NN model info\n", MENU_OPT_SHOW_MODEL_INFO);
63 printf(" %u. List audio clips\n\n", MENU_OPT_LIST_AUDIO_CLIPS);
64 printf(" Choice: ");
George Gekov93e59512021-08-03 11:18:41 +010065 fflush(stdout);
alexander3c798932021-03-26 21:42:19 +000066}
67
Richard Burton4e002792022-05-04 09:45:02 +010068/** @brief Verify input and output tensor are of certain min dimensions. */
69static bool VerifyTensorDimensions(const arm::app::Model& model);
alexander3c798932021-03-26 21:42:19 +000070
Kshitij Sisodia4cc40212022-04-08 09:54:53 +010071void main_loop()
alexander3c798932021-03-26 21:42:19 +000072{
73 /* Model wrapper objects. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +000074 arm::app::MicroNetKwsModel kwsModel;
alexander3c798932021-03-26 21:42:19 +000075 arm::app::Wav2LetterModel asrModel;
76
77 /* Load the models. */
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010078 if (!kwsModel.Init(arm::app::tensorArena,
79 sizeof(arm::app::tensorArena),
80 arm::app::kws::GetModelPointer(),
81 arm::app::kws::GetModelLen())) {
alexander3c798932021-03-26 21:42:19 +000082 printf_err("Failed to initialise KWS model\n");
83 return;
84 }
85
86 /* Initialise the asr model using the same allocator from KWS
87 * to re-use the tensor arena. */
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010088 if (!asrModel.Init(arm::app::tensorArena,
89 sizeof(arm::app::tensorArena),
90 arm::app::asr::GetModelPointer(),
91 arm::app::asr::GetModelLen(),
92 kwsModel.GetAllocator())) {
Kshitij Sisodia76a15802021-12-24 11:05:11 +000093 printf_err("Failed to initialise ASR model\n");
alexander3c798932021-03-26 21:42:19 +000094 return;
Richard Burton4e002792022-05-04 09:45:02 +010095 } else if (!VerifyTensorDimensions(asrModel)) {
96 printf_err("Model's input or output dimension verification failed\n");
97 return;
alexander3c798932021-03-26 21:42:19 +000098 }
99
alexander3c798932021-03-26 21:42:19 +0000100 /* Instantiate application context. */
101 arm::app::ApplicationContext caseContext;
102
Kshitij Sisodia4cc40212022-04-08 09:54:53 +0100103 arm::app::Profiler profiler{"kws_asr"};
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100104 caseContext.Set<arm::app::Profiler&>("profiler", profiler);
Richard Burton4e002792022-05-04 09:45:02 +0100105 caseContext.Set<arm::app::Model&>("kwsModel", kwsModel);
106 caseContext.Set<arm::app::Model&>("asrModel", asrModel);
alexander3c798932021-03-26 21:42:19 +0000107 caseContext.Set<uint32_t>("clipIndex", 0);
108 caseContext.Set<uint32_t>("ctxLen", arm::app::asr::g_ctxLen); /* Left and right context length (MFCC feat vectors). */
Richard Burton4e002792022-05-04 09:45:02 +0100109 caseContext.Set<int>("kwsFrameLength", arm::app::kws::g_FrameLength);
110 caseContext.Set<int>("kwsFrameStride", arm::app::kws::g_FrameStride);
111 caseContext.Set<float>("kwsScoreThreshold", arm::app::kws::g_ScoreThreshold); /* Normalised score threshold. */
alexander3c798932021-03-26 21:42:19 +0000112 caseContext.Set<uint32_t >("kwsNumMfcc", arm::app::kws::g_NumMfcc);
113 caseContext.Set<uint32_t >("kwsNumAudioWins", arm::app::kws::g_NumAudioWins);
114
Richard Burton4e002792022-05-04 09:45:02 +0100115 caseContext.Set<int>("asrFrameLength", arm::app::asr::g_FrameLength);
116 caseContext.Set<int>("asrFrameStride", arm::app::asr::g_FrameStride);
117 caseContext.Set<float>("asrScoreThreshold", arm::app::asr::g_ScoreThreshold); /* Normalised score threshold. */
alexander3c798932021-03-26 21:42:19 +0000118
Richard Burtonec5e99b2022-10-05 11:00:37 +0100119 arm::app::KwsClassifier kwsClassifier; /* Classifier wrapper object. */
alexander3c798932021-03-26 21:42:19 +0000120 arm::app::AsrClassifier asrClassifier; /* Classifier wrapper object. */
Richard Burtonec5e99b2022-10-05 11:00:37 +0100121 caseContext.Set<arm::app::KwsClassifier&>("kwsClassifier", kwsClassifier);
Richard Burton4e002792022-05-04 09:45:02 +0100122 caseContext.Set<arm::app::AsrClassifier&>("asrClassifier", asrClassifier);
alexander3c798932021-03-26 21:42:19 +0000123
124 std::vector<std::string> asrLabels;
125 arm::app::asr::GetLabelsVector(asrLabels);
126 std::vector<std::string> kwsLabels;
127 arm::app::kws::GetLabelsVector(kwsLabels);
Richard Burton4e002792022-05-04 09:45:02 +0100128 caseContext.Set<const std::vector <std::string>&>("asrLabels", asrLabels);
129 caseContext.Set<const std::vector <std::string>&>("kwsLabels", kwsLabels);
alexander3c798932021-03-26 21:42:19 +0000130
Liam Barryb5b32d32021-12-30 11:35:00 +0000131 /* KWS keyword that triggers ASR and associated checks */
Richard Burton4e002792022-05-04 09:45:02 +0100132 std::string triggerKeyword = std::string("no");
Liam Barryb5b32d32021-12-30 11:35:00 +0000133 if (std::find(kwsLabels.begin(), kwsLabels.end(), triggerKeyword) != kwsLabels.end()) {
Richard Burton4e002792022-05-04 09:45:02 +0100134 caseContext.Set<const std::string &>("triggerKeyword", triggerKeyword);
Liam Barryb5b32d32021-12-30 11:35:00 +0000135 }
136 else {
137 printf_err("Selected trigger keyword not found in labels file\n");
138 return;
139 }
alexander3c798932021-03-26 21:42:19 +0000140
141 /* Loop. */
142 bool executionSuccessful = true;
143 constexpr bool bUseMenu = NUMBER_OF_FILES > 1 ? true : false;
144
145 /* Loop. */
146 do {
147 int menuOption = MENU_OPT_RUN_INF_NEXT;
148 if (bUseMenu) {
149 DisplayMenu();
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100150 menuOption = arm::app::ReadUserInputAsInt();
alexander3c798932021-03-26 21:42:19 +0000151 printf("\n");
152 }
153 switch (menuOption) {
154 case MENU_OPT_RUN_INF_NEXT:
155 executionSuccessful = ClassifyAudioHandler(
156 caseContext,
157 caseContext.Get<uint32_t>("clipIndex"),
158 false);
159 break;
160 case MENU_OPT_RUN_INF_CHOSEN: {
161 printf(" Enter the audio clip index [0, %d]: ",
162 NUMBER_OF_FILES-1);
Isabella Gottardi79d41542021-10-20 15:52:32 +0100163 fflush(stdout);
alexander3c798932021-03-26 21:42:19 +0000164 auto clipIndex = static_cast<uint32_t>(
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100165 arm::app::ReadUserInputAsInt());
alexander3c798932021-03-26 21:42:19 +0000166 executionSuccessful = ClassifyAudioHandler(caseContext,
167 clipIndex,
168 false);
169 break;
170 }
171 case MENU_OPT_RUN_INF_ALL:
172 executionSuccessful = ClassifyAudioHandler(
173 caseContext,
174 caseContext.Get<uint32_t>("clipIndex"),
175 true);
176 break;
177 case MENU_OPT_SHOW_MODEL_INFO:
178 executionSuccessful = kwsModel.ShowModelInfoHandler();
179 executionSuccessful = asrModel.ShowModelInfoHandler();
180 break;
181 case MENU_OPT_LIST_AUDIO_CLIPS:
182 executionSuccessful = ListFilesHandler(caseContext);
183 break;
184 default:
185 printf("Incorrect choice, try again.");
186 break;
187 }
188 } while (executionSuccessful && bUseMenu);
189 info("Main loop terminated.\n");
190}
191
Richard Burton4e002792022-05-04 09:45:02 +0100192static bool VerifyTensorDimensions(const arm::app::Model& model)
alexander3c798932021-03-26 21:42:19 +0000193{
Richard Burton4e002792022-05-04 09:45:02 +0100194 /* Populate tensor related parameters. */
alexander3c798932021-03-26 21:42:19 +0000195 TfLiteTensor* inputTensor = model.GetInputTensor(0);
Richard Burton4e002792022-05-04 09:45:02 +0100196 if (!inputTensor->dims) {
197 printf_err("Invalid input tensor dims\n");
198 return false;
199 } else if (inputTensor->dims->size < 3) {
200 printf_err("Input tensor dimension should be >= 3\n");
201 return false;
alexander3c798932021-03-26 21:42:19 +0000202 }
203
204 TfLiteTensor* outputTensor = model.GetOutputTensor(0);
Richard Burton4e002792022-05-04 09:45:02 +0100205 if (!outputTensor->dims) {
206 printf_err("Invalid output tensor dims\n");
207 return false;
208 } else if (outputTensor->dims->size < 3) {
209 printf_err("Output tensor dimension should be >= 3\n");
210 return false;
211 }
alexander3c798932021-03-26 21:42:19 +0000212
Richard Burton4e002792022-05-04 09:45:02 +0100213 return true;
alexander3c798932021-03-26 21:42:19 +0000214}