MLECO-3077: Add ASR use case API

* Minor adjustments to doc strings in KWS
* Remove unused score threshold in KWS

Signed-off-by: Richard Burton <richard.burton@arm.com>
Change-Id: Ie1c5bf6f7bdbebb853b6a10cb7ba1c4a1d9a76c9
diff --git a/source/use_case/asr/src/UseCaseHandler.cc b/source/use_case/asr/src/UseCaseHandler.cc
index 420f725..7fe959b 100644
--- a/source/use_case/asr/src/UseCaseHandler.cc
+++ b/source/use_case/asr/src/UseCaseHandler.cc
@@ -20,7 +20,6 @@
 #include "AsrClassifier.hpp"
 #include "Wav2LetterModel.hpp"
 #include "hal.h"
-#include "Wav2LetterMfcc.hpp"
 #include "AudioUtils.hpp"
 #include "ImageUtils.hpp"
 #include "UseCaseCommonUtils.hpp"
@@ -34,68 +33,63 @@
 namespace app {
 
     /**
-     * @brief           Presents inference results using the data presentation
-     *                  object.
-     * @param[in]       results     Vector of classification results to be displayed.
+     * @brief           Presents ASR inference results.
+     * @param[in]       results     Vector of ASR classification results to be displayed.
      * @return          true if successful, false otherwise.
      **/
-    static bool PresentInferenceResult(const std::vector<arm::app::asr::AsrResult>& results);
+    static bool PresentInferenceResult(const std::vector<asr::AsrResult>& results);
 
-    /* Audio inference classification handler. */
+    /* ASR inference handler. */
     bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
     {
-        constexpr uint32_t dataPsnTxtInfStartX = 20;
-        constexpr uint32_t dataPsnTxtInfStartY = 40;
-
-        hal_lcd_clear(COLOR_BLACK);
-
+        auto& model = ctx.Get<Model&>("model");
         auto& profiler = ctx.Get<Profiler&>("profiler");
-
+        auto mfccFrameLen = ctx.Get<uint32_t>("frameLength");
+        auto mfccFrameStride = ctx.Get<uint32_t>("frameStride");
+        auto scoreThreshold = ctx.Get<float>("scoreThreshold");
+        auto inputCtxLen = ctx.Get<uint32_t>("ctxLen");
         /* If the request has a valid size, set the audio index. */
         if (clipIndex < NUMBER_OF_FILES) {
             if (!SetAppCtxIfmIdx(ctx, clipIndex,"clipIndex")) {
                 return false;
             }
         }
+        auto initialClipIdx = ctx.Get<uint32_t>("clipIndex");
+        constexpr uint32_t dataPsnTxtInfStartX = 20;
+        constexpr uint32_t dataPsnTxtInfStartY = 40;
 
-        /* Get model reference. */
-        auto& model = ctx.Get<Model&>("model");
         if (!model.IsInited()) {
             printf_err("Model is not initialised! Terminating processing.\n");
             return false;
         }
 
-        /* Get score threshold to be applied for the classifier (post-inference). */
-        auto scoreThreshold = ctx.Get<float>("scoreThreshold");
-
-        /* Get tensors. Dimensions of the tensor should have been verified by
+        /* Get input shape. Dimensions of the tensor should have been verified by
          * the callee. */
-        TfLiteTensor* inputTensor = model.GetInputTensor(0);
-        TfLiteTensor* outputTensor = model.GetOutputTensor(0);
-        const uint32_t inputRows = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx];
+        TfLiteIntArray* inputShape = model.GetInputShape(0);
 
-        /* Populate MFCC related parameters. */
-        auto mfccParamsWinLen = ctx.Get<uint32_t>("frameLength");
-        auto mfccParamsWinStride = ctx.Get<uint32_t>("frameStride");
-
-        /* Populate ASR inference context and inner lengths for input. */
-        auto inputCtxLen = ctx.Get<uint32_t>("ctxLen");
-        const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen);
+        const uint32_t inputRowsSize = inputShape->data[Wav2LetterModel::ms_inputRowsIdx];
+        const uint32_t inputInnerLen = inputRowsSize - (2 * inputCtxLen);
 
         /* Audio data stride corresponds to inputInnerLen feature vectors. */
-        const uint32_t audioParamsWinLen = (inputRows - 1) * mfccParamsWinStride + (mfccParamsWinLen);
-        const uint32_t audioParamsWinStride = inputInnerLen * mfccParamsWinStride;
-        const float audioParamsSecondsPerSample = (1.0/audio::Wav2LetterMFCC::ms_defaultSamplingFreq);
+        const uint32_t audioDataWindowLen = (inputRowsSize - 1) * mfccFrameStride + (mfccFrameLen);
+        const uint32_t audioDataWindowStride = inputInnerLen * mfccFrameStride;
 
-        /* Get pre/post-processing objects. */
-        auto& prep = ctx.Get<audio::asr::Preprocess&>("preprocess");
-        auto& postp = ctx.Get<audio::asr::Postprocess&>("postprocess");
+        /* NOTE: This is only used for time stamp calculation. */
+        const float secondsPerSample = (1.0 / audio::Wav2LetterMFCC::ms_defaultSamplingFreq);
 
-        /* Set default reduction axis for post-processing. */
-        const uint32_t reductionAxis = arm::app::Wav2LetterModel::ms_outputRowsIdx;
+        /* Set up pre and post-processing objects. */
+        ASRPreProcess preProcess = ASRPreProcess(model.GetInputTensor(0), Wav2LetterModel::ms_numMfccFeatures,
+                inputShape->data[Wav2LetterModel::ms_inputRowsIdx], mfccFrameLen, mfccFrameStride);
 
-        /* Audio clip start index. */
-        auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
+        std::vector<ClassificationResult> singleInfResult;
+        const uint32_t outputCtxLen = ASRPostProcess::GetOutputContextLen(model, inputCtxLen);
+        ASRPostProcess postProcess = ASRPostProcess(ctx.Get<AsrClassifier&>("classifier"),
+                model.GetOutputTensor(0), ctx.Get<std::vector<std::string>&>("labels"),
+                singleInfResult, outputCtxLen,
+                Wav2LetterModel::ms_blankTokenIdx, Wav2LetterModel::ms_outputRowsIdx
+                );
+
+        UseCaseRunner runner = UseCaseRunner(&preProcess, &postProcess, &model);
 
         /* Loop to process audio clips. */
         do {
@@ -109,44 +103,41 @@
             const uint32_t audioArrSize = get_audio_array_size(currentIndex);
 
             if (!audioArr) {
-                printf_err("Invalid audio array pointer\n");
+                printf_err("Invalid audio array pointer.\n");
                 return false;
             }
 
-            /* Audio clip must have enough samples to produce 1 MFCC feature. */
-            if (audioArrSize < mfccParamsWinLen) {
+            /* Audio clip needs enough samples to produce at least 1 MFCC feature. */
+            if (audioArrSize < mfccFrameLen) {
                 printf_err("Not enough audio samples, minimum needed is %" PRIu32 "\n",
-                    mfccParamsWinLen);
+                           mfccFrameLen);
                 return false;
             }
 
-            /* Initialise an audio slider. */
+            /* Creating a sliding window through the whole audio clip. */
             auto audioDataSlider = audio::FractionalSlidingWindow<const int16_t>(
-                                        audioArr,
-                                        audioArrSize,
-                                        audioParamsWinLen,
-                                        audioParamsWinStride);
+                    audioArr, audioArrSize,
+                    audioDataWindowLen, audioDataWindowStride);
 
-            /* Declare a container for results. */
-            std::vector<arm::app::asr::AsrResult> results;
+            /* Declare a container for final results. */
+            std::vector<asr::AsrResult> finalResults;
 
             /* Display message on the LCD - inference running. */
             std::string str_inf{"Running inference... "};
-            hal_lcd_display_text(
-                                str_inf.c_str(), str_inf.size(),
-                                dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
+            hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
+                    dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
 
             info("Running inference on audio clip %" PRIu32 " => %s\n", currentIndex,
                  get_filename(currentIndex));
 
-            size_t inferenceWindowLen = audioParamsWinLen;
+            size_t inferenceWindowLen = audioDataWindowLen;
 
             /* Start sliding through audio clip. */
             while (audioDataSlider.HasNext()) {
 
-                /* If not enough audio see how much can be sent for processing. */
+                /* If not enough audio, see how much can be sent for processing. */
                 size_t nextStartIndex = audioDataSlider.NextWindowStartIndex();
-                if (nextStartIndex + audioParamsWinLen > audioArrSize) {
+                if (nextStartIndex + audioDataWindowLen > audioArrSize) {
                     inferenceWindowLen = audioArrSize - nextStartIndex;
                 }
 
@@ -155,46 +146,40 @@
                 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
                      static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1)));
 
-                /* Calculate MFCCs, deltas and populate the input tensor. */
-                prep.Invoke(inferenceWindow, inferenceWindowLen, inputTensor);
+                /* Run the pre-processing, inference and post-processing. */
+                runner.PreProcess(inferenceWindow, inferenceWindowLen);
 
-                /* Run inference over this audio clip sliding window. */
-                if (!RunInference(model, profiler)) {
+                profiler.StartProfiling("Inference");
+                if (!runner.RunInference()) {
+                    return false;
+                }
+                profiler.StopProfiling();
+
+                postProcess.m_lastIteration = !audioDataSlider.HasNext();
+                if (!runner.PostProcess()) {
                     return false;
                 }
 
-                /* Post-process. */
-                postp.Invoke(outputTensor, reductionAxis, !audioDataSlider.HasNext());
-
-                /* Get results. */
-                std::vector<ClassificationResult> classificationResult;
-                auto& classifier = ctx.Get<AsrClassifier&>("classifier");
-                classifier.GetClassificationResults(
-                            outputTensor, classificationResult,
-                            ctx.Get<std::vector<std::string>&>("labels"), 1);
-
-                results.emplace_back(asr::AsrResult(classificationResult,
-                                                    (audioDataSlider.Index() *
-                                                    audioParamsSecondsPerSample *
-                                                    audioParamsWinStride),
-                                                    audioDataSlider.Index(), scoreThreshold));
+                /* Add results from this window to our final results vector. */
+                finalResults.emplace_back(asr::AsrResult(singleInfResult,
+                        (audioDataSlider.Index() * secondsPerSample * audioDataWindowStride),
+                        audioDataSlider.Index(), scoreThreshold));
 
 #if VERIFY_TEST_OUTPUT
-                arm::app::DumpTensor(outputTensor,
-                    outputTensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]);
+                TfLiteTensor* outputTensor = model.GetOutputTensor(0);
+                armDumpTensor(outputTensor,
+                    outputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]);
 #endif /* VERIFY_TEST_OUTPUT */
-
-            }
+            } /* while (audioDataSlider.HasNext()) */
 
             /* Erase. */
             str_inf = std::string(str_inf.size(), ' ');
-            hal_lcd_display_text(
-                                str_inf.c_str(), str_inf.size(),
-                                dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
+            hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
+                    dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
 
-            ctx.Set<std::vector<arm::app::asr::AsrResult>>("results", results);
+            ctx.Set<std::vector<asr::AsrResult>>("results", finalResults);
 
-            if (!PresentInferenceResult(results)) {
+            if (!PresentInferenceResult(finalResults)) {
                 return false;
             }
 
@@ -202,13 +187,13 @@
 
             IncrementAppCtxIfmIdx(ctx,"clipIndex");
 
-        } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
+        } while (runAll && ctx.Get<uint32_t>("clipIndex") != initialClipIdx);
 
         return true;
     }
 
 
-    static bool PresentInferenceResult(const std::vector<arm::app::asr::AsrResult>& results)
+    static bool PresentInferenceResult(const std::vector<asr::AsrResult>& results)
     {
         constexpr uint32_t dataPsnTxtStartX1 = 20;
         constexpr uint32_t dataPsnTxtStartY1 = 60;
@@ -219,15 +204,15 @@
         info("Final results:\n");
         info("Total number of inferences: %zu\n", results.size());
         /* Results from multiple inferences should be combined before processing. */
-        std::vector<arm::app::ClassificationResult> combinedResults;
-        for (auto& result : results) {
+        std::vector<ClassificationResult> combinedResults;
+        for (const auto& result : results) {
             combinedResults.insert(combinedResults.end(),
                                    result.m_resultVec.begin(),
                                    result.m_resultVec.end());
         }
 
         /* Get each inference result string using the decoder. */
-        for (const auto & result : results) {
+        for (const auto& result : results) {
             std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec);
 
             info("For timestamp: %f (inference #: %" PRIu32 "); label: %s\n",
@@ -238,10 +223,9 @@
         /* Get the decoded result for the combined result. */
         std::string finalResultStr = audio::asr::DecodeOutput(combinedResults);
 
-        hal_lcd_display_text(
-                            finalResultStr.c_str(), finalResultStr.size(),
-                            dataPsnTxtStartX1, dataPsnTxtStartY1,
-                            allow_multiple_lines);
+        hal_lcd_display_text(finalResultStr.c_str(), finalResultStr.size(),
+                dataPsnTxtStartX1, dataPsnTxtStartY1,
+                allow_multiple_lines);
 
         info("Complete recognition: %s\n", finalResultStr.c_str());
         return true;