MLECO-3077: Add ASR use case API

* Minor adjustments to doc strings in KWS
* Remove unused score threshold in KWS

Signed-off-by: Richard Burton <richard.burton@arm.com>
Change-Id: Ie1c5bf6f7bdbebb853b6a10cb7ba1c4a1d9a76c9
diff --git a/source/use_case/asr/src/Wav2LetterPostprocess.cc b/source/use_case/asr/src/Wav2LetterPostprocess.cc
index 0392061..e3e1999 100644
--- a/source/use_case/asr/src/Wav2LetterPostprocess.cc
+++ b/source/use_case/asr/src/Wav2LetterPostprocess.cc
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
  * SPDX-License-Identifier: Apache-2.0
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,67 +15,71 @@
  * limitations under the License.
  */
 #include "Wav2LetterPostprocess.hpp"
+
 #include "Wav2LetterModel.hpp"
 #include "log_macros.h"
 
+#include <cmath>
+
 namespace arm {
 namespace app {
-namespace audio {
-namespace asr {
 
-    Postprocess::Postprocess(const uint32_t contextLen,
-                             const uint32_t innerLen,
-                             const uint32_t blankTokenIdx)
-        :   m_contextLen(contextLen),
-            m_innerLen(innerLen),
-            m_totalLen(2 * this->m_contextLen + this->m_innerLen),
+    ASRPostProcess::ASRPostProcess(AsrClassifier& classifier, TfLiteTensor* outputTensor,
+            const std::vector<std::string>& labels, std::vector<ClassificationResult>& results,
+            const uint32_t outputContextLen,
+            const uint32_t blankTokenIdx, const uint32_t reductionAxisIdx
+            ):
+            m_classifier(classifier),
+            m_outputTensor(outputTensor),
+            m_labels{labels},
+            m_results(results),
+            m_outputContextLen(outputContextLen),
             m_countIterations(0),
-            m_blankTokenIdx(blankTokenIdx)
-    {}
+            m_blankTokenIdx(blankTokenIdx),
+            m_reductionAxisIdx(reductionAxisIdx)
+    {
+        this->m_outputInnerLen = ASRPostProcess::GetOutputInnerLen(this->m_outputTensor, this->m_outputContextLen);
+        this->m_totalLen = (2 * this->m_outputContextLen + this->m_outputInnerLen);
+    }
 
-    bool Postprocess::Invoke(TfLiteTensor*  tensor,
-                            const uint32_t  axisIdx,
-                            const bool      lastIteration)
+    bool ASRPostProcess::DoPostProcess()
     {
         /* Basic checks. */
-        if (!this->IsInputValid(tensor, axisIdx)) {
+        if (!this->IsInputValid(this->m_outputTensor, this->m_reductionAxisIdx)) {
             return false;
         }
 
         /* Irrespective of tensor type, we use unsigned "byte" */
-        uint8_t* ptrData = tflite::GetTensorData<uint8_t>(tensor);
-        const uint32_t elemSz = this->GetTensorElementSize(tensor);
+        auto* ptrData = tflite::GetTensorData<uint8_t>(this->m_outputTensor);
+        const uint32_t elemSz = ASRPostProcess::GetTensorElementSize(this->m_outputTensor);
 
         /* Other sanity checks. */
         if (0 == elemSz) {
             printf_err("Tensor type not supported for post processing\n");
             return false;
-        } else if (elemSz * this->m_totalLen > tensor->bytes) {
+        } else if (elemSz * this->m_totalLen > this->m_outputTensor->bytes) {
             printf_err("Insufficient number of tensor bytes\n");
             return false;
         }
 
         /* Which axis do we need to process? */
-        switch (axisIdx) {
-            case arm::app::Wav2LetterModel::ms_outputRowsIdx:
-                return this->EraseSectionsRowWise(ptrData,
-                                                  elemSz *
-                                                  tensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx],
-                                                  lastIteration);
-            case arm::app::Wav2LetterModel::ms_outputColsIdx:
-                return this->EraseSectionsColWise(ptrData,
-                                                  elemSz *
-                                                  tensor->dims->data[arm::app::Wav2LetterModel::ms_outputRowsIdx],
-                                                  lastIteration);
+        switch (this->m_reductionAxisIdx) {
+            case Wav2LetterModel::ms_outputRowsIdx:
+                this->EraseSectionsRowWise(
+                        ptrData, elemSz * this->m_outputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx],
+                        this->m_lastIteration);
+                break;
             default:
-                printf_err("Unsupported axis index: %" PRIu32 "\n", axisIdx);
+                printf_err("Unsupported axis index: %" PRIu32 "\n", this->m_reductionAxisIdx);
+                return false;
         }
+        this->m_classifier.GetClassificationResults(this->m_outputTensor,
+                this->m_results, this->m_labels, 1);
 
-        return false;
+        return true;
     }
 
-    bool Postprocess::IsInputValid(TfLiteTensor*  tensor,
-                                   const uint32_t axisIdx) const
+    bool ASRPostProcess::IsInputValid(TfLiteTensor* tensor, const uint32_t axisIdx) const
     {
         if (nullptr == tensor) {
             return false;
@@ -89,15 +93,15 @@
 
         if (static_cast<int>(this->m_totalLen) !=
                              tensor->dims->data[axisIdx]) {
-            printf_err("Unexpected tensor dimension for axis %d, \n",
-                tensor->dims->data[axisIdx]);
+            printf_err("Unexpected tensor dimension for axis %d, got %d, \n",
+                axisIdx, tensor->dims->data[axisIdx]);
             return false;
         }
 
         return true;
     }
 
-    uint32_t Postprocess::GetTensorElementSize(TfLiteTensor*  tensor)
+    uint32_t ASRPostProcess::GetTensorElementSize(TfLiteTensor* tensor)
     {
         switch(tensor->type) {
             case kTfLiteUInt8:
@@ -116,30 +120,30 @@
         return 0;
     }
 
-    bool Postprocess::EraseSectionsRowWise(
-                        uint8_t*         ptrData,
-                        const uint32_t   strideSzBytes,
-                        const bool       lastIteration)
+    bool ASRPostProcess::EraseSectionsRowWise(
+            uint8_t*         ptrData,
+            const uint32_t   strideSzBytes,
+            const bool       lastIteration)
     {
         /* In this case, the "zero-ing" is quite simple as the region
          * to be zeroed sits in contiguous memory (row-major). */
-        const uint32_t eraseLen = strideSzBytes * this->m_contextLen;
+        const uint32_t eraseLen = strideSzBytes * this->m_outputContextLen;
 
         /* Erase left context? */
         if (this->m_countIterations > 0) {
             /* Set output of each classification window to the blank token. */
             std::memset(ptrData, 0, eraseLen);
-            for (size_t windowIdx = 0; windowIdx < this->m_contextLen; windowIdx++) {
+            for (size_t windowIdx = 0; windowIdx < this->m_outputContextLen; windowIdx++) {
                 ptrData[windowIdx*strideSzBytes + this->m_blankTokenIdx] = 1;
             }
         }
 
         /* Erase right context? */
         if (false == lastIteration) {
-            uint8_t * rightCtxPtr = ptrData + (strideSzBytes * (this->m_contextLen + this->m_innerLen));
+            uint8_t* rightCtxPtr = ptrData + (strideSzBytes * (this->m_outputContextLen + this->m_outputInnerLen));
             /* Set output of each classification window to the blank token. */
             std::memset(rightCtxPtr, 0, eraseLen);
-            for (size_t windowIdx = 0; windowIdx < this->m_contextLen; windowIdx++) {
+            for (size_t windowIdx = 0; windowIdx < this->m_outputContextLen; windowIdx++) {
                 rightCtxPtr[windowIdx*strideSzBytes + this->m_blankTokenIdx] = 1;
             }
         }
@@ -153,19 +157,56 @@
         return true;
     }
 
-    bool Postprocess::EraseSectionsColWise(
-                        const uint8_t*         ptrData,
-                        const uint32_t   strideSzBytes,
-                        const bool       lastIteration)
+    uint32_t ASRPostProcess::GetNumFeatureVectors(const Model& model)
     {
-        /* Not implemented. */
-        UNUSED(ptrData);
-        UNUSED(strideSzBytes);
-        UNUSED(lastIteration);
-        return false;
+        TfLiteTensor* inputTensor = model.GetInputTensor(0);
+        const int inputRows = std::max(inputTensor->dims->data[Wav2LetterModel::ms_inputRowsIdx], 0);
+        if (inputRows == 0) {
+            printf_err("Error getting number of input rows for axis: %" PRIu32 "\n",
+                    Wav2LetterModel::ms_inputRowsIdx);
+        }
+        return inputRows;
     }
 
-} /* namespace asr */
-} /* namespace audio */
+    uint32_t ASRPostProcess::GetOutputInnerLen(const TfLiteTensor* outputTensor, const uint32_t outputCtxLen)
+    {
+        const uint32_t outputRows = std::max(outputTensor->dims->data[Wav2LetterModel::ms_outputRowsIdx], 0);
+        if (outputRows == 0) {
+            printf_err("Error getting number of output rows for axis: %" PRIu32 "\n",
+                    Wav2LetterModel::ms_outputRowsIdx);
+        }
+        int innerLen = (outputRows - (2 * outputCtxLen));
+
+        return std::max(innerLen, 0);
+    }
+
+    uint32_t ASRPostProcess::GetOutputContextLen(const Model& model, const uint32_t inputCtxLen)
+    {
+        const uint32_t inputRows = ASRPostProcess::GetNumFeatureVectors(model);
+        const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen);
+        constexpr uint32_t ms_outputRowsIdx = Wav2LetterModel::ms_outputRowsIdx;
+
+        /* Check to make sure that the input tensor supports the above
+         * context and inner lengths. */
+        if (inputRows <= 2 * inputCtxLen || inputRows <= inputInnerLen) {
+            printf_err("Input rows not compatible with ctx of %" PRIu32 "\n",
+                       inputCtxLen);
+            return 0;
+        }
+
+        TfLiteTensor* outputTensor = model.GetOutputTensor(0);
+        const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0);
+        if (outputRows == 0) {
+            printf_err("Error getting number of output rows for axis: %" PRIu32 "\n",
+                       Wav2LetterModel::ms_outputRowsIdx);
+            return 0;
+        }
+
+        const float inOutRowRatio = static_cast<float>(inputRows) /
+                                     static_cast<float>(outputRows);
+
+        return std::round(static_cast<float>(inputCtxLen) / inOutRowRatio);
+    }
+
 } /* namespace app */
 } /* namespace arm */
\ No newline at end of file