blob: f1d97a04f6796eda31db933fa523e0d82c7f2eb6 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
Richard Burton4e002792022-05-04 09:45:02 +01002 * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
alexander3c798932021-03-26 21:42:19 +00003 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
alexander3c798932021-03-26 21:42:19 +000017#include "InputFiles.hpp" /* For input images. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +000018#include "Labels_micronetkws.hpp" /* For MicroNetKws label strings. */
alexander3c798932021-03-26 21:42:19 +000019#include "Labels_wav2letter.hpp" /* For Wav2Letter label strings. */
20#include "Classifier.hpp" /* KWS classifier. */
21#include "AsrClassifier.hpp" /* ASR classifier. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +000022#include "MicroNetKwsModel.hpp" /* KWS model class for running inference. */
alexander3c798932021-03-26 21:42:19 +000023#include "Wav2LetterModel.hpp" /* ASR model class for running inference. */
24#include "UseCaseCommonUtils.hpp" /* Utils functions. */
25#include "UseCaseHandler.hpp" /* Handlers for different user options. */
alexander31ae9f02022-02-10 16:15:54 +000026#include "log_macros.h"
alexander3c798932021-03-26 21:42:19 +000027
28using KwsClassifier = arm::app::Classifier;
29
30enum opcodes
31{
32 MENU_OPT_RUN_INF_NEXT = 1, /* Run on next vector. */
33 MENU_OPT_RUN_INF_CHOSEN, /* Run on a user provided vector index. */
34 MENU_OPT_RUN_INF_ALL, /* Run inference on all. */
35 MENU_OPT_SHOW_MODEL_INFO, /* Show model info. */
36 MENU_OPT_LIST_AUDIO_CLIPS /* List the current baked audio clips. */
37};
38
39static void DisplayMenu()
40{
Kshitij Sisodia3c8256d2021-05-24 16:12:40 +010041 printf("\n\n");
42 printf("User input required\n");
alexander3c798932021-03-26 21:42:19 +000043 printf("Enter option number from:\n\n");
44 printf(" %u. Classify next audio clip\n", MENU_OPT_RUN_INF_NEXT);
45 printf(" %u. Classify audio clip at chosen index\n", MENU_OPT_RUN_INF_CHOSEN);
46 printf(" %u. Run classification on all audio clips\n", MENU_OPT_RUN_INF_ALL);
47 printf(" %u. Show NN model info\n", MENU_OPT_SHOW_MODEL_INFO);
48 printf(" %u. List audio clips\n\n", MENU_OPT_LIST_AUDIO_CLIPS);
49 printf(" Choice: ");
George Gekov93e59512021-08-03 11:18:41 +010050 fflush(stdout);
alexander3c798932021-03-26 21:42:19 +000051}
52
Richard Burton4e002792022-05-04 09:45:02 +010053/** @brief Verify input and output tensor are of certain min dimensions. */
54static bool VerifyTensorDimensions(const arm::app::Model& model);
alexander3c798932021-03-26 21:42:19 +000055
Kshitij Sisodia4cc40212022-04-08 09:54:53 +010056void main_loop()
alexander3c798932021-03-26 21:42:19 +000057{
58 /* Model wrapper objects. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +000059 arm::app::MicroNetKwsModel kwsModel;
alexander3c798932021-03-26 21:42:19 +000060 arm::app::Wav2LetterModel asrModel;
61
62 /* Load the models. */
63 if (!kwsModel.Init()) {
64 printf_err("Failed to initialise KWS model\n");
65 return;
66 }
67
68 /* Initialise the asr model using the same allocator from KWS
69 * to re-use the tensor arena. */
70 if (!asrModel.Init(kwsModel.GetAllocator())) {
Kshitij Sisodia76a15802021-12-24 11:05:11 +000071 printf_err("Failed to initialise ASR model\n");
alexander3c798932021-03-26 21:42:19 +000072 return;
Richard Burton4e002792022-05-04 09:45:02 +010073 } else if (!VerifyTensorDimensions(asrModel)) {
74 printf_err("Model's input or output dimension verification failed\n");
75 return;
alexander3c798932021-03-26 21:42:19 +000076 }
77
alexander3c798932021-03-26 21:42:19 +000078 /* Instantiate application context. */
79 arm::app::ApplicationContext caseContext;
80
Kshitij Sisodia4cc40212022-04-08 09:54:53 +010081 arm::app::Profiler profiler{"kws_asr"};
Isabella Gottardi8df12f32021-04-07 17:15:31 +010082 caseContext.Set<arm::app::Profiler&>("profiler", profiler);
Richard Burton4e002792022-05-04 09:45:02 +010083 caseContext.Set<arm::app::Model&>("kwsModel", kwsModel);
84 caseContext.Set<arm::app::Model&>("asrModel", asrModel);
alexander3c798932021-03-26 21:42:19 +000085 caseContext.Set<uint32_t>("clipIndex", 0);
86 caseContext.Set<uint32_t>("ctxLen", arm::app::asr::g_ctxLen); /* Left and right context length (MFCC feat vectors). */
Richard Burton4e002792022-05-04 09:45:02 +010087 caseContext.Set<int>("kwsFrameLength", arm::app::kws::g_FrameLength);
88 caseContext.Set<int>("kwsFrameStride", arm::app::kws::g_FrameStride);
89 caseContext.Set<float>("kwsScoreThreshold", arm::app::kws::g_ScoreThreshold); /* Normalised score threshold. */
alexander3c798932021-03-26 21:42:19 +000090 caseContext.Set<uint32_t >("kwsNumMfcc", arm::app::kws::g_NumMfcc);
91 caseContext.Set<uint32_t >("kwsNumAudioWins", arm::app::kws::g_NumAudioWins);
92
Richard Burton4e002792022-05-04 09:45:02 +010093 caseContext.Set<int>("asrFrameLength", arm::app::asr::g_FrameLength);
94 caseContext.Set<int>("asrFrameStride", arm::app::asr::g_FrameStride);
95 caseContext.Set<float>("asrScoreThreshold", arm::app::asr::g_ScoreThreshold); /* Normalised score threshold. */
alexander3c798932021-03-26 21:42:19 +000096
97 KwsClassifier kwsClassifier; /* Classifier wrapper object. */
98 arm::app::AsrClassifier asrClassifier; /* Classifier wrapper object. */
Richard Burton4e002792022-05-04 09:45:02 +010099 caseContext.Set<arm::app::Classifier&>("kwsClassifier", kwsClassifier);
100 caseContext.Set<arm::app::AsrClassifier&>("asrClassifier", asrClassifier);
alexander3c798932021-03-26 21:42:19 +0000101
102 std::vector<std::string> asrLabels;
103 arm::app::asr::GetLabelsVector(asrLabels);
104 std::vector<std::string> kwsLabels;
105 arm::app::kws::GetLabelsVector(kwsLabels);
Richard Burton4e002792022-05-04 09:45:02 +0100106 caseContext.Set<const std::vector <std::string>&>("asrLabels", asrLabels);
107 caseContext.Set<const std::vector <std::string>&>("kwsLabels", kwsLabels);
alexander3c798932021-03-26 21:42:19 +0000108
Liam Barryb5b32d32021-12-30 11:35:00 +0000109 /* KWS keyword that triggers ASR and associated checks */
Richard Burton4e002792022-05-04 09:45:02 +0100110 std::string triggerKeyword = std::string("no");
Liam Barryb5b32d32021-12-30 11:35:00 +0000111 if (std::find(kwsLabels.begin(), kwsLabels.end(), triggerKeyword) != kwsLabels.end()) {
Richard Burton4e002792022-05-04 09:45:02 +0100112 caseContext.Set<const std::string &>("triggerKeyword", triggerKeyword);
Liam Barryb5b32d32021-12-30 11:35:00 +0000113 }
114 else {
115 printf_err("Selected trigger keyword not found in labels file\n");
116 return;
117 }
alexander3c798932021-03-26 21:42:19 +0000118
119 /* Loop. */
120 bool executionSuccessful = true;
121 constexpr bool bUseMenu = NUMBER_OF_FILES > 1 ? true : false;
122
123 /* Loop. */
124 do {
125 int menuOption = MENU_OPT_RUN_INF_NEXT;
126 if (bUseMenu) {
127 DisplayMenu();
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100128 menuOption = arm::app::ReadUserInputAsInt();
alexander3c798932021-03-26 21:42:19 +0000129 printf("\n");
130 }
131 switch (menuOption) {
132 case MENU_OPT_RUN_INF_NEXT:
133 executionSuccessful = ClassifyAudioHandler(
134 caseContext,
135 caseContext.Get<uint32_t>("clipIndex"),
136 false);
137 break;
138 case MENU_OPT_RUN_INF_CHOSEN: {
139 printf(" Enter the audio clip index [0, %d]: ",
140 NUMBER_OF_FILES-1);
Isabella Gottardi79d41542021-10-20 15:52:32 +0100141 fflush(stdout);
alexander3c798932021-03-26 21:42:19 +0000142 auto clipIndex = static_cast<uint32_t>(
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100143 arm::app::ReadUserInputAsInt());
alexander3c798932021-03-26 21:42:19 +0000144 executionSuccessful = ClassifyAudioHandler(caseContext,
145 clipIndex,
146 false);
147 break;
148 }
149 case MENU_OPT_RUN_INF_ALL:
150 executionSuccessful = ClassifyAudioHandler(
151 caseContext,
152 caseContext.Get<uint32_t>("clipIndex"),
153 true);
154 break;
155 case MENU_OPT_SHOW_MODEL_INFO:
156 executionSuccessful = kwsModel.ShowModelInfoHandler();
157 executionSuccessful = asrModel.ShowModelInfoHandler();
158 break;
159 case MENU_OPT_LIST_AUDIO_CLIPS:
160 executionSuccessful = ListFilesHandler(caseContext);
161 break;
162 default:
163 printf("Incorrect choice, try again.");
164 break;
165 }
166 } while (executionSuccessful && bUseMenu);
167 info("Main loop terminated.\n");
168}
169
Richard Burton4e002792022-05-04 09:45:02 +0100170static bool VerifyTensorDimensions(const arm::app::Model& model)
alexander3c798932021-03-26 21:42:19 +0000171{
Richard Burton4e002792022-05-04 09:45:02 +0100172 /* Populate tensor related parameters. */
alexander3c798932021-03-26 21:42:19 +0000173 TfLiteTensor* inputTensor = model.GetInputTensor(0);
Richard Burton4e002792022-05-04 09:45:02 +0100174 if (!inputTensor->dims) {
175 printf_err("Invalid input tensor dims\n");
176 return false;
177 } else if (inputTensor->dims->size < 3) {
178 printf_err("Input tensor dimension should be >= 3\n");
179 return false;
alexander3c798932021-03-26 21:42:19 +0000180 }
181
182 TfLiteTensor* outputTensor = model.GetOutputTensor(0);
Richard Burton4e002792022-05-04 09:45:02 +0100183 if (!outputTensor->dims) {
184 printf_err("Invalid output tensor dims\n");
185 return false;
186 } else if (outputTensor->dims->size < 3) {
187 printf_err("Output tensor dimension should be >= 3\n");
188 return false;
189 }
alexander3c798932021-03-26 21:42:19 +0000190
Richard Burton4e002792022-05-04 09:45:02 +0100191 return true;
alexander3c798932021-03-26 21:42:19 +0000192}