MLECO-3075: Add KWS use case API

Removed some of the templates for feature calculation that we are unlikely to ever use.
We might be able to refactor the feature caching and feature calculator code in the future
to better integrate it with with PreProcess API.

Signed-off-by: Richard Burton <richard.burton@arm.com>
Change-Id: Ic0c0c581c71e2553d41ff72cd1ed3b3efa64fa92
diff --git a/source/use_case/kws/src/UseCaseHandler.cc b/source/use_case/kws/src/UseCaseHandler.cc
index e04cefc..350d34b 100644
--- a/source/use_case/kws/src/UseCaseHandler.cc
+++ b/source/use_case/kws/src/UseCaseHandler.cc
@@ -20,15 +20,14 @@
 #include "Classifier.hpp"
 #include "MicroNetKwsModel.hpp"
 #include "hal.h"
-#include "MicroNetKwsMfcc.hpp"
 #include "AudioUtils.hpp"
 #include "ImageUtils.hpp"
 #include "UseCaseCommonUtils.hpp"
 #include "KwsResult.hpp"
 #include "log_macros.h"
+#include "KwsProcessing.hpp"
 
 #include <vector>
-#include <functional>
 
 using KwsClassifier = arm::app::Classifier;
 
@@ -37,36 +36,27 @@
 
 
     /**
-     * @brief           Presents inference results using the data presentation
-     *                  object.
-     * @param[in]       results     Vector of classification results to be displayed.
+     * @brief           Presents KWS inference results.
+     * @param[in]       results     Vector of KWS classification results to be displayed.
      * @return          true if successful, false otherwise.
      **/
     static bool PresentInferenceResult(const std::vector<arm::app::kws::KwsResult>& results);
 
-    /**
-     * @brief Returns a function to perform feature calculation and populates input tensor data with
-     * MFCC data.
-     *
-     * Input tensor data type check is performed to choose correct MFCC feature data type.
-     * If tensor has an integer data type then original features are quantised.
-     *
-     * Warning: MFCC calculator provided as input must have the same life scope as returned function.
-     *
-     * @param[in]       mfcc          MFCC feature calculator.
-     * @param[in,out]   inputTensor   Input tensor pointer to store calculated features.
-     * @param[in]       cacheSize     Size of the feature vectors cache (number of feature vectors).
-     * @return          Function to be called providing audio sample and sliding window index.
-     */
-    static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
-            GetFeatureCalculator(audio::MicroNetKwsMFCC&  mfcc,
-                                 TfLiteTensor*      inputTensor,
-                                 size_t             cacheSize);
-
-    /* Audio inference handler. */
+    /* KWS inference handler. */
     bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
     {
         auto& profiler = ctx.Get<Profiler&>("profiler");
+        auto& model = ctx.Get<Model&>("model");
+        const auto mfccFrameLength = ctx.Get<int>("frameLength");
+        const auto mfccFrameStride = ctx.Get<int>("frameStride");
+        const auto scoreThreshold = ctx.Get<float>("scoreThreshold");
+        /* If the request has a valid size, set the audio index. */
+        if (clipIndex < NUMBER_OF_FILES) {
+            if (!SetAppCtxIfmIdx(ctx, clipIndex,"clipIndex")) {
+                return false;
+            }
+        }
+        auto initialClipIdx = ctx.Get<uint32_t>("clipIndex");
 
         constexpr uint32_t dataPsnTxtInfStartX = 20;
         constexpr uint32_t dataPsnTxtInfStartY = 40;
@@ -74,27 +64,13 @@
             (arm::app::MicroNetKwsModel::ms_inputRowsIdx > arm::app::MicroNetKwsModel::ms_inputColsIdx)?
              arm::app::MicroNetKwsModel::ms_inputRowsIdx : arm::app::MicroNetKwsModel::ms_inputColsIdx);
 
-        auto& model = ctx.Get<Model&>("model");
 
-        /* If the request has a valid size, set the audio index. */
-        if (clipIndex < NUMBER_OF_FILES) {
-            if (!SetAppCtxIfmIdx(ctx, clipIndex,"clipIndex")) {
-                return false;
-            }
-        }
         if (!model.IsInited()) {
             printf_err("Model is not initialised! Terminating processing.\n");
             return false;
         }
 
-        const auto frameLength = ctx.Get<int>("frameLength");
-        const auto frameStride = ctx.Get<int>("frameStride");
-        const auto scoreThreshold = ctx.Get<float>("scoreThreshold");
-        auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
-
-        TfLiteTensor* outputTensor = model.GetOutputTensor(0);
         TfLiteTensor* inputTensor = model.GetInputTensor(0);
-
         if (!inputTensor->dims) {
             printf_err("Invalid input tensor dims\n");
             return false;
@@ -103,130 +79,89 @@
             return false;
         }
 
+        /* Get input shape for feature extraction. */
         TfLiteIntArray* inputShape = model.GetInputShape(0);
-        const uint32_t kNumCols = inputShape->data[arm::app::MicroNetKwsModel::ms_inputColsIdx];
-        const uint32_t kNumRows = inputShape->data[arm::app::MicroNetKwsModel::ms_inputRowsIdx];
-
-        audio::MicroNetKwsMFCC mfcc = audio::MicroNetKwsMFCC(kNumCols, frameLength);
-        mfcc.Init();
-
-        /* Deduce the data length required for 1 inference from the network parameters. */
-        auto audioDataWindowSize = kNumRows * frameStride + (frameLength - frameStride);
-        auto mfccWindowSize = frameLength;
-        auto mfccWindowStride = frameStride;
-
-        /* We choose to move by half the window size => for a 1 second window size
-         * there is an overlap of 0.5 seconds. */
-        auto audioDataStride = audioDataWindowSize / 2;
-
-        /* To have the previously calculated features re-usable, stride must be multiple
-         * of MFCC features window stride. */
-        if (0 != audioDataStride % mfccWindowStride) {
-
-            /* Reduce the stride. */
-            audioDataStride -= audioDataStride % mfccWindowStride;
-        }
-
-        auto nMfccVectorsInAudioStride = audioDataStride/mfccWindowStride;
+        const uint32_t numMfccFeatures = inputShape->data[arm::app::MicroNetKwsModel::ms_inputColsIdx];
 
         /* We expect to be sampling 1 second worth of data at a time.
          * NOTE: This is only used for time stamp calculation. */
-        const float secondsPerSample = 1.0/audio::MicroNetKwsMFCC::ms_defaultSamplingFreq;
+        const float secondsPerSample = 1.0 / audio::MicroNetKwsMFCC::ms_defaultSamplingFreq;
+
+        /* Set up pre and post-processing. */
+        KWSPreProcess preprocess = KWSPreProcess(&model, numMfccFeatures, mfccFrameLength, mfccFrameStride);
+
+        std::vector<ClassificationResult> singleInfResult;
+        KWSPostProcess postprocess = KWSPostProcess(ctx.Get<KwsClassifier &>("classifier"), &model,
+                                                    ctx.Get<std::vector<std::string>&>("labels"),
+                                                    singleInfResult, scoreThreshold);
+
+        UseCaseRunner runner = UseCaseRunner(&preprocess, &postprocess, &model);
 
         do {
             hal_lcd_clear(COLOR_BLACK);
 
             auto currentIndex = ctx.Get<uint32_t>("clipIndex");
 
-            /* Creating a mfcc features sliding window for the data required for 1 inference. */
-            auto audioMFCCWindowSlider = audio::SlidingWindow<const int16_t>(
-                                            get_audio_array(currentIndex),
-                                            audioDataWindowSize, mfccWindowSize,
-                                            mfccWindowStride);
-
             /* Creating a sliding window through the whole audio clip. */
             auto audioDataSlider = audio::SlidingWindow<const int16_t>(
-                                        get_audio_array(currentIndex),
-                                        get_audio_array_size(currentIndex),
-                                        audioDataWindowSize, audioDataStride);
+                    get_audio_array(currentIndex),
+                    get_audio_array_size(currentIndex),
+                    preprocess.m_audioDataWindowSize, preprocess.m_audioDataStride);
 
-            /* Calculate number of the feature vectors in the window overlap region.
-             * These feature vectors will be reused.*/
-            auto numberOfReusedFeatureVectors = audioMFCCWindowSlider.TotalStrides() + 1
-                                                - nMfccVectorsInAudioStride;
-
-            /* Construct feature calculation function. */
-            auto mfccFeatureCalc = GetFeatureCalculator(mfcc, inputTensor,
-                                                        numberOfReusedFeatureVectors);
-
-            if (!mfccFeatureCalc){
-                return false;
-            }
-
-            /* Declare a container for results. */
-            std::vector<arm::app::kws::KwsResult> results;
+            /* Declare a container to hold results from across the whole audio clip. */
+            std::vector<kws::KwsResult> finalResults;
 
             /* Display message on the LCD - inference running. */
             std::string str_inf{"Running inference... "};
-            hal_lcd_display_text(
-                                str_inf.c_str(), str_inf.size(),
-                                dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
+            hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
+                    dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
             info("Running inference on audio clip %" PRIu32 " => %s\n", currentIndex,
                  get_filename(currentIndex));
 
             /* Start sliding through audio clip. */
             while (audioDataSlider.HasNext()) {
-                const int16_t *inferenceWindow = audioDataSlider.Next();
-
-                /* We moved to the next window - set the features sliding to the new address. */
-                audioMFCCWindowSlider.Reset(inferenceWindow);
+                const int16_t* inferenceWindow = audioDataSlider.Next();
 
                 /* The first window does not have cache ready. */
-                bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0;
-
-                /* Start calculating features inside one audio sliding window. */
-                while (audioMFCCWindowSlider.HasNext()) {
-                    const int16_t *mfccWindow = audioMFCCWindowSlider.Next();
-                    std::vector<int16_t> mfccAudioData = std::vector<int16_t>(mfccWindow,
-                                                            mfccWindow + mfccWindowSize);
-                    /* Compute features for this window and write them to input tensor. */
-                    mfccFeatureCalc(mfccAudioData,
-                                    audioMFCCWindowSlider.Index(),
-                                    useCache,
-                                    nMfccVectorsInAudioStride);
-                }
+                preprocess.m_audioWindowIndex = audioDataSlider.Index();
 
                 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
                      audioDataSlider.TotalStrides() + 1);
 
-                /* Run inference over this audio clip sliding window. */
-                if (!RunInference(model, profiler)) {
+                /* Run the pre-processing, inference and post-processing. */
+                if (!runner.PreProcess(inferenceWindow, audio::MicroNetKwsMFCC::ms_defaultSamplingFreq)) {
                     return false;
                 }
 
-                std::vector<ClassificationResult> classificationResult;
-                auto& classifier = ctx.Get<KwsClassifier&>("classifier");
-                classifier.GetClassificationResults(outputTensor, classificationResult,
-                                                    ctx.Get<std::vector<std::string>&>("labels"), 1, true);
+                profiler.StartProfiling("Inference");
+                if (!runner.RunInference()) {
+                    return false;
+                }
+                profiler.StopProfiling();
 
-                results.emplace_back(kws::KwsResult(classificationResult,
-                    audioDataSlider.Index() * secondsPerSample * audioDataStride,
-                    audioDataSlider.Index(), scoreThreshold));
+                if (!runner.PostProcess()) {
+                    return false;
+                }
+
+                /* Add results from this window to our final results vector. */
+                finalResults.emplace_back(kws::KwsResult(singleInfResult,
+                        audioDataSlider.Index() * secondsPerSample * preprocess.m_audioDataStride,
+                        audioDataSlider.Index(), postprocess.m_scoreThreshold));
 
 #if VERIFY_TEST_OUTPUT
+                TfLiteTensor* outputTensor = model.GetOutputTensor(0);
                 arm::app::DumpTensor(outputTensor);
 #endif /* VERIFY_TEST_OUTPUT */
             } /* while (audioDataSlider.HasNext()) */
 
             /* Erase. */
             str_inf = std::string(str_inf.size(), ' ');
-            hal_lcd_display_text(
-                                str_inf.c_str(), str_inf.size(),
-                                dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
+            hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
+                    dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
 
-            ctx.Set<std::vector<arm::app::kws::KwsResult>>("results", results);
+            ctx.Set<std::vector<kws::KwsResult>>("results", finalResults);
 
-            if (!PresentInferenceResult(results)) {
+            if (!PresentInferenceResult(finalResults)) {
                 return false;
             }
 
@@ -234,58 +169,11 @@
 
             IncrementAppCtxIfmIdx(ctx,"clipIndex");
 
-        } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
+        } while (runAll && ctx.Get<uint32_t>("clipIndex") != initialClipIdx);
 
         return true;
     }
 
-    /**
-     * @brief Generic feature calculator factory.
-     *
-     * Returns lambda function to compute features using features cache.
-     * Real features math is done by a lambda function provided as a parameter.
-     * Features are written to input tensor memory.
-     *
-     * @tparam T                Feature vector type.
-     * @param[in] inputTensor   Model input tensor pointer.
-     * @param[in] cacheSize     Number of feature vectors to cache. Defined by the sliding window overlap.
-     * @param[in] compute       Features calculator function.
-     * @return                  Lambda function to compute features.
-     */
-    template<class T>
-    std::function<void (std::vector<int16_t>&, size_t, bool, size_t)>
-    FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
-                std::function<std::vector<T> (std::vector<int16_t>& )> compute)
-    {
-        /* Feature cache to be captured by lambda function. */
-        static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
-
-        return [=](std::vector<int16_t>& audioDataWindow,
-                                     size_t index,
-                                     bool useCache,
-                                     size_t featuresOverlapIndex)
-        {
-            T *tensorData = tflite::GetTensorData<T>(inputTensor);
-            std::vector<T> features;
-
-            /* Reuse features from cache if cache is ready and sliding windows overlap.
-             * Overlap is in the beginning of sliding window with a size of a feature cache. */
-            if (useCache && index < featureCache.size()) {
-                features = std::move(featureCache[index]);
-            } else {
-                features = std::move(compute(audioDataWindow));
-            }
-            auto size = features.size();
-            auto sizeBytes = sizeof(T) * size;
-            std::memcpy(tensorData + (index * size), features.data(), sizeBytes);
-
-            /* Start renewing cache as soon iteration goes out of the windows overlap. */
-            if (index >= featuresOverlapIndex) {
-                featureCache[index - featuresOverlapIndex] = std::move(features);
-            }
-        };
-    }
-
     static bool PresentInferenceResult(const std::vector<arm::app::kws::KwsResult>& results)
     {
         constexpr uint32_t dataPsnTxtStartX1 = 20;
@@ -299,40 +187,39 @@
         /* Display each result */
         uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
 
-        for (uint32_t i = 0; i < results.size(); ++i) {
+        for (const auto & result : results) {
 
             std::string topKeyword{"<none>"};
             float score = 0.f;
-            if (!results[i].m_resultVec.empty()) {
-                topKeyword = results[i].m_resultVec[0].m_label;
-                score = results[i].m_resultVec[0].m_normalisedVal;
+            if (!result.m_resultVec.empty()) {
+                topKeyword = result.m_resultVec[0].m_label;
+                score = result.m_resultVec[0].m_normalisedVal;
             }
 
             std::string resultStr =
-                    std::string{"@"} + std::to_string(results[i].m_timeStamp) +
+                    std::string{"@"} + std::to_string(result.m_timeStamp) +
                     std::string{"s: "} + topKeyword + std::string{" ("} +
                     std::to_string(static_cast<int>(score * 100)) + std::string{"%)"};
 
-            hal_lcd_display_text(
-                    resultStr.c_str(), resultStr.size(),
+            hal_lcd_display_text(resultStr.c_str(), resultStr.size(),
                     dataPsnTxtStartX1, rowIdx1, false);
             rowIdx1 += dataPsnTxtYIncr;
 
-            if (results[i].m_resultVec.empty()) {
+            if (result.m_resultVec.empty()) {
                 info("For timestamp: %f (inference #: %" PRIu32
                              "); label: %s; threshold: %f\n",
-                     results[i].m_timeStamp, results[i].m_inferenceNumber,
+                     result.m_timeStamp, result.m_inferenceNumber,
                      topKeyword.c_str(),
-                     results[i].m_threshold);
+                     result.m_threshold);
             } else {
-                for (uint32_t j = 0; j < results[i].m_resultVec.size(); ++j) {
+                for (uint32_t j = 0; j < result.m_resultVec.size(); ++j) {
                     info("For timestamp: %f (inference #: %" PRIu32
                                  "); label: %s, score: %f; threshold: %f\n",
-                         results[i].m_timeStamp,
-                         results[i].m_inferenceNumber,
-                         results[i].m_resultVec[j].m_label.c_str(),
-                         results[i].m_resultVec[j].m_normalisedVal,
-                         results[i].m_threshold);
+                         result.m_timeStamp,
+                         result.m_inferenceNumber,
+                         result.m_resultVec[j].m_label.c_str(),
+                         result.m_resultVec[j].m_normalisedVal,
+                         result.m_threshold);
                 }
             }
         }
@@ -340,88 +227,5 @@
         return true;
     }
 
-    template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
-        FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
-                            size_t cacheSize,
-                            std::function<std::vector<int8_t> (std::vector<int16_t>& )> compute);
-
-    template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
-        FeatureCalc<uint8_t>(TfLiteTensor* inputTensor,
-                             size_t cacheSize,
-                             std::function<std::vector<uint8_t> (std::vector<int16_t>& )> compute);
-
-    template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
-        FeatureCalc<int16_t>(TfLiteTensor* inputTensor,
-                             size_t cacheSize,
-                             std::function<std::vector<int16_t> (std::vector<int16_t>& )> compute);
-
-    template std::function<void(std::vector<int16_t>&, size_t, bool, size_t)>
-        FeatureCalc<float>(TfLiteTensor* inputTensor,
-                           size_t cacheSize,
-                           std::function<std::vector<float>(std::vector<int16_t>&)> compute);
-
-
-    static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
-    GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
-    {
-        std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc;
-
-        TfLiteQuantization quant = inputTensor->quantization;
-
-        if (kTfLiteAffineQuantization == quant.type) {
-
-            auto *quantParams = (TfLiteAffineQuantization *) quant.params;
-            const float quantScale = quantParams->scale->data[0];
-            const int quantOffset = quantParams->zero_point->data[0];
-
-            switch (inputTensor->type) {
-                case kTfLiteInt8: {
-                    mfccFeatureCalc = FeatureCalc<int8_t>(inputTensor,
-                                                          cacheSize,
-                                                          [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
-                                                              return mfcc.MfccComputeQuant<int8_t>(audioDataWindow,
-                                                                                                   quantScale,
-                                                                                                   quantOffset);
-                                                          }
-                    );
-                    break;
-                }
-                case kTfLiteUInt8: {
-                    mfccFeatureCalc = FeatureCalc<uint8_t>(inputTensor,
-                                                           cacheSize,
-                                                           [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
-                                                               return mfcc.MfccComputeQuant<uint8_t>(audioDataWindow,
-                                                                                                     quantScale,
-                                                                                                     quantOffset);
-                                                           }
-                    );
-                    break;
-                }
-                case kTfLiteInt16: {
-                    mfccFeatureCalc = FeatureCalc<int16_t>(inputTensor,
-                                                           cacheSize,
-                                                           [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
-                                                               return mfcc.MfccComputeQuant<int16_t>(audioDataWindow,
-                                                                                                     quantScale,
-                                                                                                     quantOffset);
-                                                           }
-                    );
-                    break;
-                }
-                default:
-                    printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type));
-            }
-
-
-        } else {
-            mfccFeatureCalc = mfccFeatureCalc = FeatureCalc<float>(inputTensor,
-                                                                   cacheSize,
-                                                                   [&mfcc](std::vector<int16_t>& audioDataWindow) {
-                                                                       return mfcc.MfccCompute(audioDataWindow);
-                                                                   });
-        }
-        return mfccFeatureCalc;
-    }
-
 } /* namespace app */
-} /* namespace arm */
\ No newline at end of file
+} /* namespace arm */