blob: 37146c96fafdee3096a631eb7de429590d75c982 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "hal.h" /* Brings in platform definitions. */
18#include "InputFiles.hpp" /* For input images. */
19#include "Labels_dscnn.hpp" /* For DS-CNN label strings. */
20#include "Labels_wav2letter.hpp" /* For Wav2Letter label strings. */
21#include "Classifier.hpp" /* KWS classifier. */
22#include "AsrClassifier.hpp" /* ASR classifier. */
23#include "DsCnnModel.hpp" /* KWS model class for running inference. */
24#include "Wav2LetterModel.hpp" /* ASR model class for running inference. */
25#include "UseCaseCommonUtils.hpp" /* Utils functions. */
26#include "UseCaseHandler.hpp" /* Handlers for different user options. */
27#include "Wav2LetterPreprocess.hpp" /* ASR pre-processing class. */
28#include "Wav2LetterPostprocess.hpp"/* ASR post-processing class. */
29
30using KwsClassifier = arm::app::Classifier;
31
32enum opcodes
33{
34 MENU_OPT_RUN_INF_NEXT = 1, /* Run on next vector. */
35 MENU_OPT_RUN_INF_CHOSEN, /* Run on a user provided vector index. */
36 MENU_OPT_RUN_INF_ALL, /* Run inference on all. */
37 MENU_OPT_SHOW_MODEL_INFO, /* Show model info. */
38 MENU_OPT_LIST_AUDIO_CLIPS /* List the current baked audio clips. */
39};
40
41static void DisplayMenu()
42{
43 printf("\n\nUser input required\n");
44 printf("Enter option number from:\n\n");
45 printf(" %u. Classify next audio clip\n", MENU_OPT_RUN_INF_NEXT);
46 printf(" %u. Classify audio clip at chosen index\n", MENU_OPT_RUN_INF_CHOSEN);
47 printf(" %u. Run classification on all audio clips\n", MENU_OPT_RUN_INF_ALL);
48 printf(" %u. Show NN model info\n", MENU_OPT_SHOW_MODEL_INFO);
49 printf(" %u. List audio clips\n\n", MENU_OPT_LIST_AUDIO_CLIPS);
50 printf(" Choice: ");
51}
52
53/** @brief Gets the number of MFCC features for a single window. */
54static uint32_t GetNumMfccFeatures(const arm::app::Model& model);
55
56/** @brief Gets the number of MFCC feature vectors to be computed. */
57static uint32_t GetNumMfccFeatureVectors(const arm::app::Model& model);
58
59/** @brief Gets the output context length (left and right) for post-processing. */
60static uint32_t GetOutputContextLen(const arm::app::Model& model,
61 uint32_t inputCtxLen);
62
63/** @brief Gets the output inner length for post-processing. */
64static uint32_t GetOutputInnerLen(const arm::app::Model& model,
65 uint32_t outputCtxLen);
66
67void main_loop(hal_platform& platform)
68{
69 /* Model wrapper objects. */
70 arm::app::DsCnnModel kwsModel;
71 arm::app::Wav2LetterModel asrModel;
72
73 /* Load the models. */
74 if (!kwsModel.Init()) {
75 printf_err("Failed to initialise KWS model\n");
76 return;
77 }
78
79 /* Initialise the asr model using the same allocator from KWS
80 * to re-use the tensor arena. */
81 if (!asrModel.Init(kwsModel.GetAllocator())) {
82 printf_err("Failed to initalise ASR model\n");
83 return;
84 }
85
86 /* Initialise ASR pre-processing. */
87 arm::app::audio::asr::Preprocess prep(
88 GetNumMfccFeatures(asrModel),
89 arm::app::asr::g_FrameLength,
90 arm::app::asr::g_FrameStride,
91 GetNumMfccFeatureVectors(asrModel));
92
93 /* Initialise ASR post-processing. */
94 const uint32_t outputCtxLen = GetOutputContextLen(asrModel, arm::app::asr::g_ctxLen);
95 const uint32_t blankTokenIdx = 28;
96 arm::app::audio::asr::Postprocess postp(
97 outputCtxLen,
98 GetOutputInnerLen(asrModel, outputCtxLen),
99 blankTokenIdx);
100
101 /* Instantiate application context. */
102 arm::app::ApplicationContext caseContext;
103
104 caseContext.Set<hal_platform&>("platform", platform);
105 caseContext.Set<arm::app::Model&>("kwsmodel", kwsModel);
106 caseContext.Set<arm::app::Model&>("asrmodel", asrModel);
107 caseContext.Set<uint32_t>("clipIndex", 0);
108 caseContext.Set<uint32_t>("ctxLen", arm::app::asr::g_ctxLen); /* Left and right context length (MFCC feat vectors). */
109 caseContext.Set<int>("kwsframeLength", arm::app::kws::g_FrameLength);
110 caseContext.Set<int>("kwsframeStride", arm::app::kws::g_FrameStride);
111 caseContext.Set<float>("kwsscoreThreshold", arm::app::kws::g_ScoreThreshold); /* Normalised score threshold. */
112 caseContext.Set<uint32_t >("kwsNumMfcc", arm::app::kws::g_NumMfcc);
113 caseContext.Set<uint32_t >("kwsNumAudioWins", arm::app::kws::g_NumAudioWins);
114
115 caseContext.Set<int>("asrframeLength", arm::app::asr::g_FrameLength);
116 caseContext.Set<int>("asrframeStride", arm::app::asr::g_FrameStride);
117 caseContext.Set<float>("asrscoreThreshold", arm::app::asr::g_ScoreThreshold); /* Normalised score threshold. */
118
119 KwsClassifier kwsClassifier; /* Classifier wrapper object. */
120 arm::app::AsrClassifier asrClassifier; /* Classifier wrapper object. */
121 caseContext.Set<arm::app::Classifier&>("kwsclassifier", kwsClassifier);
122 caseContext.Set<arm::app::AsrClassifier&>("asrclassifier", asrClassifier);
123
124 caseContext.Set<arm::app::audio::asr::Preprocess&>("preprocess", prep);
125 caseContext.Set<arm::app::audio::asr::Postprocess&>("postprocess", postp);
126
127 std::vector<std::string> asrLabels;
128 arm::app::asr::GetLabelsVector(asrLabels);
129 std::vector<std::string> kwsLabels;
130 arm::app::kws::GetLabelsVector(kwsLabels);
131 caseContext.Set<const std::vector <std::string>&>("asrlabels", asrLabels);
132 caseContext.Set<const std::vector <std::string>&>("kwslabels", kwsLabels);
133
134 /* Index of the kws outputs we trigger ASR on. */
135 caseContext.Set<uint32_t>("keywordindex", 2);
136
137 /* Loop. */
138 bool executionSuccessful = true;
139 constexpr bool bUseMenu = NUMBER_OF_FILES > 1 ? true : false;
140
141 /* Loop. */
142 do {
143 int menuOption = MENU_OPT_RUN_INF_NEXT;
144 if (bUseMenu) {
145 DisplayMenu();
146 menuOption = arm::app::ReadUserInputAsInt(platform);
147 printf("\n");
148 }
149 switch (menuOption) {
150 case MENU_OPT_RUN_INF_NEXT:
151 executionSuccessful = ClassifyAudioHandler(
152 caseContext,
153 caseContext.Get<uint32_t>("clipIndex"),
154 false);
155 break;
156 case MENU_OPT_RUN_INF_CHOSEN: {
157 printf(" Enter the audio clip index [0, %d]: ",
158 NUMBER_OF_FILES-1);
159 auto clipIndex = static_cast<uint32_t>(
160 arm::app::ReadUserInputAsInt(platform));
161 executionSuccessful = ClassifyAudioHandler(caseContext,
162 clipIndex,
163 false);
164 break;
165 }
166 case MENU_OPT_RUN_INF_ALL:
167 executionSuccessful = ClassifyAudioHandler(
168 caseContext,
169 caseContext.Get<uint32_t>("clipIndex"),
170 true);
171 break;
172 case MENU_OPT_SHOW_MODEL_INFO:
173 executionSuccessful = kwsModel.ShowModelInfoHandler();
174 executionSuccessful = asrModel.ShowModelInfoHandler();
175 break;
176 case MENU_OPT_LIST_AUDIO_CLIPS:
177 executionSuccessful = ListFilesHandler(caseContext);
178 break;
179 default:
180 printf("Incorrect choice, try again.");
181 break;
182 }
183 } while (executionSuccessful && bUseMenu);
184 info("Main loop terminated.\n");
185}
186
187static uint32_t GetNumMfccFeatures(const arm::app::Model& model)
188{
189 TfLiteTensor* inputTensor = model.GetInputTensor(0);
190 const int inputCols = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputColsIdx];
191 if (0 != inputCols % 3) {
192 printf_err("Number of input columns is not a multiple of 3\n");
193 }
194 return std::max(inputCols/3, 0);
195}
196
197static uint32_t GetNumMfccFeatureVectors(const arm::app::Model& model)
198{
199 TfLiteTensor* inputTensor = model.GetInputTensor(0);
200 const int inputRows = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx];
201 return std::max(inputRows, 0);
202}
203
204static uint32_t GetOutputContextLen(const arm::app::Model& model, const uint32_t inputCtxLen)
205{
206 const uint32_t inputRows = GetNumMfccFeatureVectors(model);
207 const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen);
208 constexpr uint32_t ms_outputRowsIdx = arm::app::Wav2LetterModel::ms_outputRowsIdx;
209
210 /* Check to make sure that the input tensor supports the above context and inner lengths. */
211 if (inputRows <= 2 * inputCtxLen || inputRows <= inputInnerLen) {
212 printf_err("Input rows not compatible with ctx of %u\n",
213 inputCtxLen);
214 return 0;
215 }
216
217 TfLiteTensor* outputTensor = model.GetOutputTensor(0);
218 const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0);
219
220 const float tensorColRatio = static_cast<float>(inputRows)/
221 static_cast<float>(outputRows);
222
223 return std::round(static_cast<float>(inputCtxLen)/tensorColRatio);
224}
225
226static uint32_t GetOutputInnerLen(const arm::app::Model& model,
227 const uint32_t outputCtxLen)
228{
229 constexpr uint32_t ms_outputRowsIdx = arm::app::Wav2LetterModel::ms_outputRowsIdx;
230 TfLiteTensor* outputTensor = model.GetOutputTensor(0);
231 const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0);
232 return (outputRows - (2 * outputCtxLen));
233}