blob: a1a95402331c3c7c89dd7199b697cc907a5f8b96 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
alexander3c798932021-03-26 21:42:19 +000017#include "Labels.hpp" /* For label strings. */
18#include "UseCaseHandler.hpp" /* Handlers for different user options. */
19#include "Wav2LetterModel.hpp" /* Model class for running inference. */
20#include "UseCaseCommonUtils.hpp" /* Utils functions. */
21#include "AsrClassifier.hpp" /* Classifier. */
22#include "InputFiles.hpp" /* Generated audio clip header. */
alexander31ae9f02022-02-10 16:15:54 +000023#include "log_macros.h"
alexander3c798932021-03-26 21:42:19 +000024
25enum opcodes
26{
27 MENU_OPT_RUN_INF_NEXT = 1, /* Run on next vector. */
28 MENU_OPT_RUN_INF_CHOSEN, /* Run on a user provided vector index. */
29 MENU_OPT_RUN_INF_ALL, /* Run inference on all. */
30 MENU_OPT_SHOW_MODEL_INFO, /* Show model info. */
31 MENU_OPT_LIST_AUDIO_CLIPS /* List the current baked audio clips. */
32};
33
34static void DisplayMenu()
35{
Kshitij Sisodia3c8256d2021-05-24 16:12:40 +010036 printf("\n\n");
37 printf("User input required\n");
alexander3c798932021-03-26 21:42:19 +000038 printf("Enter option number from:\n\n");
39 printf(" %u. Classify next audio clip\n", MENU_OPT_RUN_INF_NEXT);
40 printf(" %u. Classify audio clip at chosen index\n", MENU_OPT_RUN_INF_CHOSEN);
41 printf(" %u. Run classification on all audio clips\n", MENU_OPT_RUN_INF_ALL);
42 printf(" %u. Show NN model info\n", MENU_OPT_SHOW_MODEL_INFO);
43 printf(" %u. List audio clips\n\n", MENU_OPT_LIST_AUDIO_CLIPS);
44 printf(" Choice: ");
George Gekov93e59512021-08-03 11:18:41 +010045 fflush(stdout);
alexander3c798932021-03-26 21:42:19 +000046}
47
Richard Burtonc2911442022-04-22 09:08:21 +010048/** @brief Verify input and output tensor are of certain min dimensions. */
alexander3c798932021-03-26 21:42:19 +000049static bool VerifyTensorDimensions(const arm::app::Model& model);
50
Kshitij Sisodia4cc40212022-04-08 09:54:53 +010051void main_loop()
alexander3c798932021-03-26 21:42:19 +000052{
53 arm::app::Wav2LetterModel model; /* Model wrapper object. */
54
55 /* Load the model. */
56 if (!model.Init()) {
57 printf_err("Failed to initialise model\n");
58 return;
59 } else if (!VerifyTensorDimensions(model)) {
60 printf_err("Model's input or output dimension verification failed\n");
61 return;
62 }
63
alexander3c798932021-03-26 21:42:19 +000064 /* Instantiate application context. */
65 arm::app::ApplicationContext caseContext;
66 std::vector <std::string> labels;
67 GetLabelsVector(labels);
68 arm::app::AsrClassifier classifier; /* Classifier wrapper object. */
69
Kshitij Sisodia4cc40212022-04-08 09:54:53 +010070 arm::app::Profiler profiler{"asr"};
Isabella Gottardi8df12f32021-04-07 17:15:31 +010071 caseContext.Set<arm::app::Profiler&>("profiler", profiler);
alexander3c798932021-03-26 21:42:19 +000072 caseContext.Set<arm::app::Model&>("model", model);
73 caseContext.Set<uint32_t>("clipIndex", 0);
74 caseContext.Set<uint32_t>("frameLength", g_FrameLength);
75 caseContext.Set<uint32_t>("frameStride", g_FrameStride);
76 caseContext.Set<float>("scoreThreshold", g_ScoreThreshold); /* Score threshold. */
77 caseContext.Set<uint32_t>("ctxLen", g_ctxLen); /* Left and right context length (MFCC feat vectors). */
78 caseContext.Set<const std::vector <std::string>&>("labels", labels);
79 caseContext.Set<arm::app::AsrClassifier&>("classifier", classifier);
alexander3c798932021-03-26 21:42:19 +000080
81 bool executionSuccessful = true;
82 constexpr bool bUseMenu = NUMBER_OF_FILES > 1 ? true : false;
83
84 /* Loop. */
85 do {
86 int menuOption = MENU_OPT_RUN_INF_NEXT;
87 if (bUseMenu) {
88 DisplayMenu();
Kshitij Sisodia68fdd112022-04-06 13:03:20 +010089 menuOption = arm::app::ReadUserInputAsInt();
alexander3c798932021-03-26 21:42:19 +000090 printf("\n");
91 }
92 switch (menuOption) {
93 case MENU_OPT_RUN_INF_NEXT:
94 executionSuccessful = ClassifyAudioHandler(
95 caseContext,
96 caseContext.Get<uint32_t>("clipIndex"),
97 false);
98 break;
99 case MENU_OPT_RUN_INF_CHOSEN: {
100 printf(" Enter the audio clip index [0, %d]: ",
101 NUMBER_OF_FILES-1);
Isabella Gottardi79d41542021-10-20 15:52:32 +0100102 fflush(stdout);
alexander3c798932021-03-26 21:42:19 +0000103 auto clipIndex = static_cast<uint32_t>(
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100104 arm::app::ReadUserInputAsInt());
alexander3c798932021-03-26 21:42:19 +0000105 executionSuccessful = ClassifyAudioHandler(caseContext,
106 clipIndex,
107 false);
108 break;
109 }
110 case MENU_OPT_RUN_INF_ALL:
111 executionSuccessful = ClassifyAudioHandler(
112 caseContext,
113 caseContext.Get<uint32_t>("clipIndex"),
114 true);
115 break;
116 case MENU_OPT_SHOW_MODEL_INFO:
117 executionSuccessful = model.ShowModelInfoHandler();
118 break;
119 case MENU_OPT_LIST_AUDIO_CLIPS:
120 executionSuccessful = ListFilesHandler(caseContext);
121 break;
122 default:
123 printf("Incorrect choice, try again.");
124 break;
125 }
126 } while (executionSuccessful && bUseMenu);
127 info("Main loop terminated.\n");
128}
129
130static bool VerifyTensorDimensions(const arm::app::Model& model)
131{
132 /* Populate tensor related parameters. */
133 TfLiteTensor* inputTensor = model.GetInputTensor(0);
134 if (!inputTensor->dims) {
135 printf_err("Invalid input tensor dims\n");
136 return false;
137 } else if (inputTensor->dims->size < 3) {
138 printf_err("Input tensor dimension should be >= 3\n");
139 return false;
140 }
141
142 TfLiteTensor* outputTensor = model.GetOutputTensor(0);
143 if (!outputTensor->dims) {
144 printf_err("Invalid output tensor dims\n");
145 return false;
146 } else if (outputTensor->dims->size < 3) {
147 printf_err("Output tensor dimension should be >= 3\n");
148 return false;
149 }
150
151 return true;
152}