MLECO-3173: Add AD, KWS_ASR and Noise reduction use case API's

Signed-off-by: Richard Burton <richard.burton@arm.com>

Change-Id: I36f61ce74bf17f7b327cdae9704a22ca54144f37
diff --git a/source/application/main/include/BaseProcessing.hpp b/source/application/main/include/BaseProcessing.hpp
index c1c3255..c099db2 100644
--- a/source/application/main/include/BaseProcessing.hpp
+++ b/source/application/main/include/BaseProcessing.hpp
@@ -41,9 +41,6 @@
          * @return      true if successful, false otherwise.
          **/
         virtual bool DoPreProcess(const void* input, size_t inputSize) = 0;
-
-    protected:
-        Model* m_model = nullptr;
     };
 
     /**
@@ -62,9 +59,6 @@
          * @return      true if successful, false otherwise.
          **/
         virtual bool DoPostProcess() = 0;
-
-    protected:
-        Model* m_model = nullptr;
     };
 
 } /* namespace app */
diff --git a/source/use_case/ad/include/AdModel.hpp b/source/use_case/ad/include/AdModel.hpp
index 8d914c4..2195a7c 100644
--- a/source/use_case/ad/include/AdModel.hpp
+++ b/source/use_case/ad/include/AdModel.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
  * SPDX-License-Identifier: Apache-2.0
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
@@ -28,6 +28,12 @@
 namespace app {
 
     class AdModel : public Model {
+
+    public:
+        /* Indices for the expected model - based on input tensor shape */
+        static constexpr uint32_t ms_inputRowsIdx = 1;
+        static constexpr uint32_t ms_inputColsIdx = 2;
+
     protected:
         /** @brief   Gets the reference to op resolver interface class */
         const tflite::MicroOpResolver& GetOpResolver() override;
diff --git a/source/use_case/ad/include/AdPostProcessing.hpp b/source/use_case/ad/include/AdPostProcessing.hpp
deleted file mode 100644
index 7eaec84..0000000
--- a/source/use_case/ad/include/AdPostProcessing.hpp
+++ /dev/null
@@ -1,50 +0,0 @@
-/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
- * SPDX-License-Identifier: Apache-2.0
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#ifndef ADPOSTPROCESSING_HPP
-#define ADPOSTPROCESSING_HPP
-
-#include "TensorFlowLiteMicro.hpp"
-
-#include <vector>
-
-namespace arm {
-namespace app {
-
-    /** @brief      Dequantize TensorFlow Lite Micro tensor.
-     *  @param[in]  tensor Pointer to the TensorFlow Lite Micro tensor to be dequantized.
-     *  @return     Vector with the dequantized tensor values.
-    **/
-    template<typename T>
-    std::vector<float> Dequantize(TfLiteTensor* tensor);
-
-    /**
-     * @brief   Calculates the softmax of vector in place. **/
-    void Softmax(std::vector<float>& inputVector);
-
-
-    /** @brief      Given a wav file name return AD model output index.
-     *  @param[in]  wavFileName Audio WAV filename.
-     *                          File name should be in format anything_goes_XX_here.wav
-     *                          where XX is the machine ID e.g. 00, 02, 04 or 06
-     *  @return     AD model output index as 8 bit integer.
-    **/
-    int8_t OutputIndexFromFileName(std::string wavFileName);
-
-} /* namespace app */
-} /* namespace arm */
-
-#endif /* ADPOSTPROCESSING_HPP */
diff --git a/source/use_case/ad/include/AdProcessing.hpp b/source/use_case/ad/include/AdProcessing.hpp
new file mode 100644
index 0000000..9abf6f1
--- /dev/null
+++ b/source/use_case/ad/include/AdProcessing.hpp
@@ -0,0 +1,230 @@
+/*
+ * Copyright (c) 2022 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef AD_PROCESSING_HPP
+#define AD_PROCESSING_HPP
+
+#include "BaseProcessing.hpp"
+#include "AudioUtils.hpp"
+#include "AdMelSpectrogram.hpp"
+#include "log_macros.h"
+
+namespace arm {
+namespace app {
+
+    /**
+     * @brief   Pre-processing class for anomaly detection use case.
+     *          Implements methods declared by BasePreProcess and anything else needed
+     *          to populate input tensors ready for inference.
+     */
+    class AdPreProcess : public BasePreProcess {
+
+    public:
+        /**
+         * @brief Constructor for AdPreProcess class objects
+         * @param[in] inputTensor  input tensor pointer from the tensor arena.
+         * @param[in] melSpectrogramFrameLen MEL spectrogram's frame length
+         * @param[in] melSpectrogramFrameStride MEL spectrogram's frame stride
+         * @param[in] adModelTrainingMean Training mean for the Anomaly detection model being used.
+         */
+        explicit AdPreProcess(TfLiteTensor* inputTensor,
+                              uint32_t melSpectrogramFrameLen,
+                              uint32_t melSpectrogramFrameStride,
+                              float adModelTrainingMean);
+
+        ~AdPreProcess() = default;
+
+        /**
+         * @brief Function to invoke pre-processing and populate the input vector
+         * @param input pointer to input data. For anomaly detection, this is the pointer to
+         *              the audio data.
+         * @param inputSize Size of the data being passed in for pre-processing.
+         * @return True if successful, false otherwise.
+         */
+        bool DoPreProcess(const void* input, size_t inputSize) override;
+
+        /**
+         * @brief Getter function for audio window size computed when constructing
+         *        the class object.
+         * @return Audio window size as 32 bit unsigned integer.
+         */
+        uint32_t GetAudioWindowSize();
+
+        /**
+         * @brief Getter function for audio window stride computed when constructing
+         *        the class object.
+         * @return Audio window stride as 32 bit unsigned integer.
+         */
+        uint32_t GetAudioDataStride();
+
+        /**
+         * @brief Setter function for current audio index. This is only used for evaluating
+         *        if previously computed features can be re-used from cache.
+         */
+        void SetAudioWindowIndex(uint32_t idx);
+
+    private:
+        bool        m_validInstance{false}; /**< Indicates the current object is valid. */
+        uint32_t    m_melSpectrogramFrameLen{}; /**< MEL spectrogram's window frame length */
+        uint32_t    m_melSpectrogramFrameStride{}; /**< MEL spectrogram's window frame stride */
+        uint8_t     m_inputResizeScale{}; /**< Downscaling factor for the MEL energy matrix. */
+        uint32_t    m_numMelSpecVectorsInAudioStride{};  /**< Number of frames to move across the audio. */
+        uint32_t    m_audioDataWindowSize{}; /**< Audio window size computed based on other parameters. */
+        uint32_t    m_audioDataStride{}; /**< Audio window stride computed. */
+        uint32_t    m_numReusedFeatureVectors{}; /**< Number of MEL vectors that can be re-used */
+        uint32_t    m_audioWindowIndex{}; /**< Current audio window index (from audio's sliding window) */
+
+        audio::SlidingWindow<const int16_t> m_melWindowSlider; /**< Internal MEL spectrogram window slider */
+        audio::AdMelSpectrogram m_melSpec; /**< MEL spectrogram computation object */
+        std::function<void
+            (std::vector<int16_t>&, int, bool, size_t, size_t)> m_featureCalc; /**< Feature calculator object */
+    };
+
+    class AdPostProcess : public BasePostProcess {
+    public:
+        /**
+         * @brief Constructor for AdPostProcess object.
+         * @param[in] outputTensor Output tensor pointer.
+         */
+        explicit AdPostProcess(TfLiteTensor* outputTensor);
+
+        ~AdPostProcess() = default;
+
+        /**
+         * @brief Function to do the post-processing on the output tensor.
+         * @return True if successful, false otherwise.
+         */
+        bool DoPostProcess() override;
+
+        /**
+         * @brief Getter function for an element from the de-quantised output vector.
+         * @param index Index of the element to be retrieved.
+         * @return index represented as a 32 bit floating point number.
+         */
+        float GetOutputValue(uint32_t index);
+
+    private:
+        TfLiteTensor* m_outputTensor{}; /**< Output tensor pointer */
+        std::vector<float> m_dequantizedOutputVec{}; /**< Internal output vector */
+
+        /**
+         * @brief De-quantizes and flattens the output tensor into a vector.
+         * @tparam T template parameter to indicate data type.
+         * @return True if successful, false otherwise.
+         */
+        template<typename T>
+        bool Dequantize()
+        {
+            TfLiteTensor* tensor = this->m_outputTensor;
+            if (tensor == nullptr) {
+                printf_err("Invalid output tensor.\n");
+                return false;
+            }
+            T* tensorData = tflite::GetTensorData<T>(tensor);
+
+            uint32_t totalOutputSize = 1;
+            for (int inputDim = 0; inputDim < tensor->dims->size; inputDim++){
+                totalOutputSize *= tensor->dims->data[inputDim];
+            }
+
+            /* For getting the floating point values, we need quantization parameters */
+            QuantParams quantParams = GetTensorQuantParams(tensor);
+
+            this->m_dequantizedOutputVec = std::vector<float>(totalOutputSize, 0);
+
+            for (size_t i = 0; i < totalOutputSize; ++i) {
+                this->m_dequantizedOutputVec[i] = quantParams.scale * (tensorData[i] - quantParams.offset);
+            }
+
+            return true;
+        }
+    };
+
+    /* Templated instances available: */
+    template bool AdPostProcess::Dequantize<int8_t>();
+
+    /**
+     * @brief Generic feature calculator factory.
+     *
+     * Returns lambda function to compute features using features cache.
+     * Real features math is done by a lambda function provided as a parameter.
+     * Features are written to input tensor memory.
+     *
+     * @tparam T            feature vector type.
+     * @param inputTensor   model input tensor pointer.
+     * @param cacheSize     number of feature vectors to cache. Defined by the sliding window overlap.
+     * @param compute       features calculator function.
+     * @return              lambda function to compute features.
+     */
+    template<class T>
+    std::function<void (std::vector<int16_t>&, size_t, bool, size_t, size_t)>
+    FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
+                std::function<std::vector<T> (std::vector<int16_t>& )> compute)
+    {
+        /* Feature cache to be captured by lambda function*/
+        static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
+
+        return [=](std::vector<int16_t>& audioDataWindow,
+                   size_t index,
+                   bool useCache,
+                   size_t featuresOverlapIndex,
+                   size_t resizeScale)
+        {
+            T* tensorData = tflite::GetTensorData<T>(inputTensor);
+            std::vector<T> features;
+
+            /* Reuse features from cache if cache is ready and sliding windows overlap.
+             * Overlap is in the beginning of sliding window with a size of a feature cache. */
+            if (useCache && index < featureCache.size()) {
+                features = std::move(featureCache[index]);
+            } else {
+                features = std::move(compute(audioDataWindow));
+            }
+            auto size = features.size() / resizeScale;
+            auto sizeBytes = sizeof(T);
+
+            /* Input should be transposed and "resized" by skipping elements. */
+            for (size_t outIndex = 0; outIndex < size; outIndex++) {
+                std::memcpy(tensorData + (outIndex*size) + index, &features[outIndex*resizeScale], sizeBytes);
+            }
+
+            /* Start renewing cache as soon iteration goes out of the windows overlap. */
+            if (index >= featuresOverlapIndex / resizeScale) {
+                featureCache[index - featuresOverlapIndex / resizeScale] = std::move(features);
+            }
+        };
+    }
+
+    template std::function<void (std::vector<int16_t>&, size_t , bool, size_t, size_t)>
+    FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
+                        size_t cacheSize,
+                        std::function<std::vector<int8_t> (std::vector<int16_t>&)> compute);
+
+    template std::function<void(std::vector<int16_t>&, size_t, bool, size_t, size_t)>
+    FeatureCalc<float>(TfLiteTensor *inputTensor,
+                       size_t cacheSize,
+                       std::function<std::vector<float>(std::vector<int16_t>&)> compute);
+
+    std::function<void (std::vector<int16_t>&, int, bool, size_t, size_t)>
+    GetFeatureCalculator(audio::AdMelSpectrogram& melSpec,
+                         TfLiteTensor* inputTensor,
+                         size_t cacheSize,
+                         float trainingMean);
+
+} /* namespace app */
+} /* namespace arm */
+
+#endif /* AD_PROCESSING_HPP */
diff --git a/source/use_case/ad/src/AdPostProcessing.cc b/source/use_case/ad/src/AdPostProcessing.cc
deleted file mode 100644
index c461875..0000000
--- a/source/use_case/ad/src/AdPostProcessing.cc
+++ /dev/null
@@ -1,115 +0,0 @@
-/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
- * SPDX-License-Identifier: Apache-2.0
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#include "AdPostProcessing.hpp"
-#include "log_macros.h"
-
-#include <numeric>
-#include <cmath>
-#include <string>
-
-namespace arm {
-namespace app {
-
-    template<typename T>
-    std::vector<float> Dequantize(TfLiteTensor* tensor) {
-
-        if (tensor == nullptr) {
-            printf_err("Tensor is null pointer can not dequantize.\n");
-            return std::vector<float>();
-        }
-        T* tensorData = tflite::GetTensorData<T>(tensor);
-
-        uint32_t totalOutputSize = 1;
-        for (int inputDim = 0; inputDim < tensor->dims->size; inputDim++){
-            totalOutputSize *= tensor->dims->data[inputDim];
-        }
-
-        /* For getting the floating point values, we need quantization parameters */
-        QuantParams quantParams = GetTensorQuantParams(tensor);
-
-        std::vector<float> dequantizedOutput(totalOutputSize);
-
-        for (size_t i = 0; i < totalOutputSize; ++i) {
-            dequantizedOutput[i] = quantParams.scale * (tensorData[i] - quantParams.offset);
-        }
-
-        return dequantizedOutput;
-    }
-
-    void Softmax(std::vector<float>& inputVector) {
-        auto start = inputVector.begin();
-        auto end = inputVector.end();
-
-        /* Fix for numerical stability and apply exp. */
-        float maxValue = *std::max_element(start, end);
-        for (auto it = start; it!=end; ++it) {
-            *it = std::exp((*it) - maxValue);
-        }
-
-        float sumExp = std::accumulate(start, end, 0.0f);
-
-        for (auto it = start; it!=end; ++it) {
-            *it = (*it)/sumExp;
-        }
-    }
-
-    int8_t OutputIndexFromFileName(std::string wavFileName) {
-        /* Filename is assumed in the form machine_id_00.wav */
-        std::string delimiter = "_";  /* First character used to split the file name up. */
-        size_t delimiterStart;
-        std::string subString;
-        size_t machineIdxInString = 3;  /* Which part of the file name the machine id should be at. */
-
-        for (size_t i = 0; i < machineIdxInString; ++i) {
-            delimiterStart = wavFileName.find(delimiter);
-            subString = wavFileName.substr(0, delimiterStart);
-            wavFileName.erase(0, delimiterStart + delimiter.length());
-        }
-
-        /* At this point substring should be 00.wav */
-        delimiter = ".";  /* Second character used to split the file name up. */
-        delimiterStart = subString.find(delimiter);
-        subString = (delimiterStart != std::string::npos) ? subString.substr(0, delimiterStart) : subString;
-
-        auto is_number = [](const std::string& str) ->  bool
-        {
-            std::string::const_iterator it = str.begin();
-            while (it != str.end() && std::isdigit(*it)) ++it;
-            return !str.empty() && it == str.end();
-        };
-
-        const int8_t machineIdx = is_number(subString) ? std::stoi(subString) : -1;
-
-        /* Return corresponding index in the output vector. */
-        if (machineIdx == 0) {
-            return 0;
-        } else if (machineIdx == 2) {
-            return 1;
-        } else if (machineIdx == 4) {
-            return 2;
-        } else if (machineIdx == 6) {
-            return 3;
-        } else {
-            printf_err("%d is an invalid machine index \n", machineIdx);
-            return -1;
-        }
-    }
-
-    template std::vector<float> Dequantize<uint8_t>(TfLiteTensor* tensor);
-    template std::vector<float> Dequantize<int8_t>(TfLiteTensor* tensor);
-} /* namespace app */
-} /* namespace arm */
\ No newline at end of file
diff --git a/source/use_case/ad/src/AdProcessing.cc b/source/use_case/ad/src/AdProcessing.cc
new file mode 100644
index 0000000..a33131c
--- /dev/null
+++ b/source/use_case/ad/src/AdProcessing.cc
@@ -0,0 +1,208 @@
+/*
+ * Copyright (c) 2022 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "AdProcessing.hpp"
+
+#include "AdModel.hpp"
+
+namespace arm {
+namespace app {
+
+AdPreProcess::AdPreProcess(TfLiteTensor* inputTensor,
+                           uint32_t melSpectrogramFrameLen,
+                           uint32_t melSpectrogramFrameStride,
+                           float adModelTrainingMean):
+       m_validInstance{false},
+       m_melSpectrogramFrameLen{melSpectrogramFrameLen},
+       m_melSpectrogramFrameStride{melSpectrogramFrameStride},
+        /**< Model is trained on features downsampled 2x */
+       m_inputResizeScale{2},
+        /**< We are choosing to move by 20 frames across the audio for each inference. */
+       m_numMelSpecVectorsInAudioStride{20},
+       m_audioDataStride{m_numMelSpecVectorsInAudioStride * melSpectrogramFrameStride},
+       m_melSpec{melSpectrogramFrameLen}
+{
+    if (!inputTensor) {
+        printf_err("Invalid input tensor provided to pre-process\n");
+        return;
+    }
+
+    TfLiteIntArray* inputShape = inputTensor->dims;
+
+    if (!inputShape) {
+        printf_err("Invalid input tensor dims\n");
+        return;
+    }
+
+    const uint32_t kNumRows = inputShape->data[AdModel::ms_inputRowsIdx];
+    const uint32_t kNumCols = inputShape->data[AdModel::ms_inputColsIdx];
+
+    /* Deduce the data length required for 1 inference from the network parameters. */
+    this->m_audioDataWindowSize = (((this->m_inputResizeScale * kNumCols) - 1) *
+                                    melSpectrogramFrameStride) +
+                                    melSpectrogramFrameLen;
+    this->m_numReusedFeatureVectors = kNumRows -
+                                      (this->m_numMelSpecVectorsInAudioStride /
+                                       this->m_inputResizeScale);
+    this->m_melSpec.Init();
+
+    /* Creating a Mel Spectrogram sliding window for the data required for 1 inference.
+     * "resizing" done here by multiplying stride by resize scale. */
+    this->m_melWindowSlider = audio::SlidingWindow<const int16_t>(
+            nullptr, /* to be populated later. */
+            this->m_audioDataWindowSize,
+            melSpectrogramFrameLen,
+            melSpectrogramFrameStride * this->m_inputResizeScale);
+
+    /* Construct feature calculation function. */
+    this->m_featureCalc = GetFeatureCalculator(this->m_melSpec, inputTensor,
+                                               this->m_numReusedFeatureVectors,
+                                               adModelTrainingMean);
+    this->m_validInstance = true;
+}
+
+bool AdPreProcess::DoPreProcess(const void* input, size_t inputSize)
+{
+    /* Check that we have a valid instance. */
+    if (!this->m_validInstance) {
+        printf_err("Invalid pre-processor instance\n");
+        return false;
+    }
+
+    /* We expect that we can traverse the size with which the MEL spectrogram
+     * sliding window was initialised with. */
+    if (!input || inputSize < this->m_audioDataWindowSize) {
+        printf_err("Invalid input provided for pre-processing\n");
+        return false;
+    }
+
+    /* We moved to the next window - set the features sliding to the new address. */
+    this->m_melWindowSlider.Reset(static_cast<const int16_t*>(input));
+
+    /* The first window does not have cache ready. */
+    const bool useCache = this->m_audioWindowIndex > 0 && this->m_numReusedFeatureVectors > 0;
+
+    /* Start calculating features inside one audio sliding window. */
+    while (this->m_melWindowSlider.HasNext()) {
+        const int16_t* melSpecWindow = this->m_melWindowSlider.Next();
+        std::vector<int16_t> melSpecAudioData = std::vector<int16_t>(
+                melSpecWindow,
+                melSpecWindow + this->m_melSpectrogramFrameLen);
+
+        /* Compute features for this window and write them to input tensor. */
+        this->m_featureCalc(melSpecAudioData,
+                            this->m_melWindowSlider.Index(),
+                            useCache,
+                            this->m_numMelSpecVectorsInAudioStride,
+                            this->m_inputResizeScale);
+    }
+
+    return true;
+}
+
+uint32_t AdPreProcess::GetAudioWindowSize()
+{
+    return this->m_audioDataWindowSize;
+}
+
+uint32_t AdPreProcess::GetAudioDataStride()
+{
+    return this->m_audioDataStride;
+}
+
+void AdPreProcess::SetAudioWindowIndex(uint32_t idx)
+{
+    this->m_audioWindowIndex = idx;
+}
+
+AdPostProcess::AdPostProcess(TfLiteTensor* outputTensor) :
+    m_outputTensor {outputTensor}
+{}
+
+bool AdPostProcess::DoPostProcess()
+{
+    switch (this->m_outputTensor->type) {
+        case kTfLiteInt8:
+            this->Dequantize<int8_t>();
+            break;
+        default:
+            printf_err("Unsupported tensor type");
+            return false;
+    }
+
+    math::MathUtils::SoftmaxF32(this->m_dequantizedOutputVec);
+    return true;
+}
+
+float AdPostProcess::GetOutputValue(uint32_t index)
+{
+    if (index < this->m_dequantizedOutputVec.size()) {
+        return this->m_dequantizedOutputVec[index];
+    }
+    printf_err("Invalid index for output\n");
+    return 0.0;
+}
+
+std::function<void (std::vector<int16_t>&, int, bool, size_t, size_t)>
+GetFeatureCalculator(audio::AdMelSpectrogram& melSpec,
+                     TfLiteTensor* inputTensor,
+                     size_t cacheSize,
+                     float trainingMean)
+{
+    std::function<void (std::vector<int16_t>&, size_t, bool, size_t, size_t)> melSpecFeatureCalc;
+
+    TfLiteQuantization quant = inputTensor->quantization;
+
+    if (kTfLiteAffineQuantization == quant.type) {
+
+        auto* quantParams = static_cast<TfLiteAffineQuantization*>(quant.params);
+        const float quantScale = quantParams->scale->data[0];
+        const int quantOffset = quantParams->zero_point->data[0];
+
+        switch (inputTensor->type) {
+            case kTfLiteInt8: {
+                melSpecFeatureCalc = FeatureCalc<int8_t>(
+                        inputTensor,
+                        cacheSize,
+                        [=, &melSpec](std::vector<int16_t>& audioDataWindow) {
+                            return melSpec.MelSpecComputeQuant<int8_t>(
+                                    audioDataWindow,
+                                    quantScale,
+                                    quantOffset,
+                                    trainingMean);
+                        }
+                );
+                break;
+            }
+            default:
+            printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type));
+        }
+    } else {
+        melSpecFeatureCalc = FeatureCalc<float>(
+                inputTensor,
+                cacheSize,
+                [=, &melSpec](
+                        std::vector<int16_t>& audioDataWindow) {
+                    return melSpec.ComputeMelSpec(
+                            audioDataWindow,
+                            trainingMean);
+                });
+    }
+    return melSpecFeatureCalc;
+}
+
+} /* namespace app */
+} /* namespace arm */
diff --git a/source/use_case/ad/src/MainLoop.cc b/source/use_case/ad/src/MainLoop.cc
index 23d1e51..140359b 100644
--- a/source/use_case/ad/src/MainLoop.cc
+++ b/source/use_case/ad/src/MainLoop.cc
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
  * SPDX-License-Identifier: Apache-2.0
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,7 +14,6 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-#include "hal.h"                    /* Brings in platform definitions */
 #include "InputFiles.hpp"           /* For input data */
 #include "AdModel.hpp"              /* Model class for running inference */
 #include "UseCaseCommonUtils.hpp"   /* Utils functions */
@@ -63,8 +62,8 @@
     caseContext.Set<arm::app::Profiler&>("profiler", profiler);
     caseContext.Set<arm::app::Model&>("model", model);
     caseContext.Set<uint32_t>("clipIndex", 0);
-    caseContext.Set<int>("frameLength", g_FrameLength);
-    caseContext.Set<int>("frameStride", g_FrameStride);
+    caseContext.Set<uint32_t>("frameLength", g_FrameLength);
+    caseContext.Set<uint32_t>("frameStride", g_FrameStride);
     caseContext.Set<float>("scoreThreshold", g_ScoreThreshold);
     caseContext.Set<float>("trainingMean", g_TrainingMean);
 
diff --git a/source/use_case/ad/src/UseCaseHandler.cc b/source/use_case/ad/src/UseCaseHandler.cc
index 5585f36..0179d6b 100644
--- a/source/use_case/ad/src/UseCaseHandler.cc
+++ b/source/use_case/ad/src/UseCaseHandler.cc
@@ -24,8 +24,8 @@
 #include "AudioUtils.hpp"
 #include "ImageUtils.hpp"
 #include "UseCaseCommonUtils.hpp"
-#include "AdPostProcessing.hpp"
 #include "log_macros.h"
+#include "AdProcessing.hpp"
 
 namespace arm {
 namespace app {
@@ -39,32 +39,17 @@
      **/
     static bool PresentInferenceResult(float result, float threshold);
 
-    /**
-     * @brief Returns a function to perform feature calculation and populates input tensor data with
-     * MelSpe data.
-     *
-     * Input tensor data type check is performed to choose correct MFCC feature data type.
-     * If tensor has an integer data type then original features are quantised.
-     *
-     * Warning: mfcc calculator provided as input must have the same life scope as returned function.
-     *
-     * @param[in]           melSpec         MFCC feature calculator.
-     * @param[in,out]       inputTensor     Input tensor pointer to store calculated features.
-     * @param[in]           cacheSize       Size of the feture vectors cache (number of feature vectors).
-     * @param[in]           trainingMean    Training mean.
-     * @return function     function to be called providing audio sample and sliding window index.
-     */
-    static std::function<void (std::vector<int16_t>&, int, bool, size_t, size_t)>
-    GetFeatureCalculator(audio::AdMelSpectrogram&  melSpec,
-                         TfLiteTensor*             inputTensor,
-                         size_t                    cacheSize,
-                         float                     trainingMean);
+    /** @brief      Given a wav file name return AD model output index.
+     *  @param[in]  wavFileName Audio WAV filename.
+     *                          File name should be in format anything_goes_XX_here.wav
+     *                          where XX is the machine ID e.g. 00, 02, 04 or 06
+     *  @return     AD model output index as 8 bit integer.
+    **/
+    static int8_t OutputIndexFromFileName(std::string wavFileName);
 
-    /* Vibration classification handler */
+    /* Anomaly Detection inference handler */
     bool ClassifyVibrationHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
     {
-        auto& profiler = ctx.Get<Profiler&>("profiler");
-
         constexpr uint32_t dataPsnTxtInfStartX = 20;
         constexpr uint32_t dataPsnTxtInfStartY = 40;
 
@@ -81,8 +66,9 @@
             return false;
         }
 
-        const auto frameLength = ctx.Get<int>("frameLength");
-        const auto frameStride = ctx.Get<int>("frameStride");
+        auto& profiler = ctx.Get<Profiler&>("profiler");
+        const auto melSpecFrameLength = ctx.Get<uint32_t>("frameLength");
+        const auto melSpecFrameStride = ctx.Get<uint32_t>("frameStride");
         const auto scoreThreshold = ctx.Get<float>("scoreThreshold");
         const auto trainingMean = ctx.Get<float>("trainingMean");
         auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
@@ -95,21 +81,13 @@
             return false;
         }
 
-        TfLiteIntArray* inputShape = model.GetInputShape(0);
-        const uint32_t kNumRows = inputShape->data[1];
-        const uint32_t kNumCols = inputShape->data[2];
+        AdPreProcess preProcess{
+            inputTensor,
+            melSpecFrameLength,
+            melSpecFrameStride,
+            trainingMean};
 
-        audio::AdMelSpectrogram melSpec = audio::AdMelSpectrogram(frameLength);
-        melSpec.Init();
-
-        /* Deduce the data length required for 1 inference from the network parameters. */
-        const uint8_t inputResizeScale = 2;
-        const uint32_t audioDataWindowSize = (((inputResizeScale * kNumCols) - 1) * frameStride) + frameLength;
-
-        /* We are choosing to move by 20 frames across the audio for each inference. */
-        const uint8_t nMelSpecVectorsInAudioStride = 20;
-
-        auto audioDataStride = nMelSpecVectorsInAudioStride * frameStride;
+        AdPostProcess postProcess{outputTensor};
 
         do {
             hal_lcd_clear(COLOR_BLACK);
@@ -122,29 +100,12 @@
                 return false;
             }
 
-            /* Creating a Mel Spectrogram sliding window for the data required for 1 inference.
-             * "resizing" done here by multiplying stride by resize scale. */
-            auto audioMelSpecWindowSlider = audio::SlidingWindow<const int16_t>(
-                    get_audio_array(currentIndex),
-                    audioDataWindowSize, frameLength,
-                    frameStride * inputResizeScale);
-
             /* Creating a sliding window through the whole audio clip. */
             auto audioDataSlider = audio::SlidingWindow<const int16_t>(
-                    get_audio_array(currentIndex),
-                    get_audio_array_size(currentIndex),
-                    audioDataWindowSize, audioDataStride);
-
-            /* Calculate number of the feature vectors in the window overlap region taking into account resizing.
-             * These feature vectors will be reused.*/
-            auto numberOfReusedFeatureVectors = kNumRows - (nMelSpecVectorsInAudioStride / inputResizeScale);
-
-            /* Construct feature calculation function. */
-            auto melSpecFeatureCalc = GetFeatureCalculator(melSpec, inputTensor,
-                                                           numberOfReusedFeatureVectors, trainingMean);
-            if (!melSpecFeatureCalc){
-                return false;
-            }
+                get_audio_array(currentIndex),
+                get_audio_array_size(currentIndex),
+                preProcess.GetAudioWindowSize(),
+                preProcess.GetAudioDataStride());
 
             /* Result is an averaged sum over inferences. */
             float result = 0;
@@ -152,30 +113,18 @@
             /* Display message on the LCD - inference running. */
             std::string str_inf{"Running inference... "};
             hal_lcd_display_text(
-                    str_inf.c_str(), str_inf.size(),
-                    dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
-            info("Running inference on audio clip %" PRIu32 " => %s\n", currentIndex, get_filename(currentIndex));
+                str_inf.c_str(), str_inf.size(),
+                dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
+
+            info("Running inference on audio clip %" PRIu32 " => %s\n",
+                currentIndex, get_filename(currentIndex));
 
             /* Start sliding through audio clip. */
             while (audioDataSlider.HasNext()) {
-                const int16_t *inferenceWindow = audioDataSlider.Next();
+                const int16_t* inferenceWindow = audioDataSlider.Next();
 
-                /* We moved to the next window - set the features sliding to the new address. */
-                audioMelSpecWindowSlider.Reset(inferenceWindow);
-
-                /* The first window does not have cache ready. */
-                bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0;
-
-                /* Start calculating features inside one audio sliding window. */
-                while (audioMelSpecWindowSlider.HasNext()) {
-                    const int16_t *melSpecWindow = audioMelSpecWindowSlider.Next();
-                    std::vector<int16_t> melSpecAudioData = std::vector<int16_t>(melSpecWindow,
-                                                                                 melSpecWindow + frameLength);
-
-                    /* Compute features for this window and write them to input tensor. */
-                    melSpecFeatureCalc(melSpecAudioData, audioMelSpecWindowSlider.Index(),
-                                       useCache, nMelSpecVectorsInAudioStride, inputResizeScale);
-                }
+                preProcess.SetAudioWindowIndex(audioDataSlider.Index());
+                preProcess.DoPreProcess(inferenceWindow, preProcess.GetAudioWindowSize());
 
                 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
                      audioDataSlider.TotalStrides() + 1);
@@ -185,13 +134,11 @@
                     return false;
                 }
 
-                /* Use the negative softmax score of the corresponding index as the outlier score */
-                std::vector<float> dequantOutput = Dequantize<int8_t>(outputTensor);
-                Softmax(dequantOutput);
-                result += -dequantOutput[machineOutputIndex];
+                postProcess.DoPostProcess();
+                result += 0 - postProcess.GetOutputValue(machineOutputIndex);
 
 #if VERIFY_TEST_OUTPUT
-                arm::app::DumpTensor(outputTensor);
+                DumpTensor(outputTensor);
 #endif /* VERIFY_TEST_OUTPUT */
             } /* while (audioDataSlider.HasNext()) */
 
@@ -218,7 +165,6 @@
         return true;
     }
 
-
     static bool PresentInferenceResult(float result, float threshold)
     {
         constexpr uint32_t dataPsnTxtStartX1 = 20;
@@ -251,148 +197,47 @@
         return true;
     }
 
-    /**
-     * @brief Generic feature calculator factory.
-     *
-     * Returns lambda function to compute features using features cache.
-     * Real features math is done by a lambda function provided as a parameter.
-     * Features are written to input tensor memory.
-     *
-     * @tparam T            feature vector type.
-     * @param inputTensor   model input tensor pointer.
-     * @param cacheSize     number of feature vectors to cache. Defined by the sliding window overlap.
-     * @param compute       features calculator function.
-     * @return              lambda function to compute features.
-     */
-    template<class T>
-    std::function<void (std::vector<int16_t>&, size_t, bool, size_t, size_t)>
-    FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
-                std::function<std::vector<T> (std::vector<int16_t>& )> compute)
+    static int8_t OutputIndexFromFileName(std::string wavFileName)
     {
-        /* Feature cache to be captured by lambda function*/
-        static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
+        /* Filename is assumed in the form machine_id_00.wav */
+        std::string delimiter = "_";  /* First character used to split the file name up. */
+        size_t delimiterStart;
+        std::string subString;
+        size_t machineIdxInString = 3;  /* Which part of the file name the machine id should be at. */
 
-        return [=](std::vector<int16_t>& audioDataWindow,
-                   size_t index,
-                   bool useCache,
-                   size_t featuresOverlapIndex,
-                   size_t resizeScale)
-        {
-            T *tensorData = tflite::GetTensorData<T>(inputTensor);
-            std::vector<T> features;
-
-            /* Reuse features from cache if cache is ready and sliding windows overlap.
-             * Overlap is in the beginning of sliding window with a size of a feature cache. */
-            if (useCache && index < featureCache.size()) {
-                features = std::move(featureCache[index]);
-            } else {
-                features = std::move(compute(audioDataWindow));
-            }
-            auto size = features.size() / resizeScale;
-            auto sizeBytes = sizeof(T);
-
-            /* Input should be transposed and "resized" by skipping elements. */
-            for (size_t outIndex = 0; outIndex < size; outIndex++) {
-                std::memcpy(tensorData + (outIndex*size) + index, &features[outIndex*resizeScale], sizeBytes);
-            }
-
-            /* Start renewing cache as soon iteration goes out of the windows overlap. */
-            if (index >= featuresOverlapIndex / resizeScale) {
-                featureCache[index - featuresOverlapIndex / resizeScale] = std::move(features);
-            }
-        };
-    }
-
-    template std::function<void (std::vector<int16_t>&, size_t , bool, size_t, size_t)>
-    FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
-                        size_t cacheSize,
-                        std::function<std::vector<int8_t> (std::vector<int16_t>&)> compute);
-
-    template std::function<void (std::vector<int16_t>&, size_t , bool, size_t, size_t)>
-    FeatureCalc<uint8_t>(TfLiteTensor* inputTensor,
-                         size_t cacheSize,
-                         std::function<std::vector<uint8_t> (std::vector<int16_t>&)> compute);
-
-    template std::function<void (std::vector<int16_t>&, size_t , bool, size_t, size_t)>
-    FeatureCalc<int16_t>(TfLiteTensor* inputTensor,
-                         size_t cacheSize,
-                         std::function<std::vector<int16_t> (std::vector<int16_t>&)> compute);
-
-    template std::function<void(std::vector<int16_t>&, size_t, bool, size_t, size_t)>
-    FeatureCalc<float>(TfLiteTensor *inputTensor,
-                       size_t cacheSize,
-                       std::function<std::vector<float>(std::vector<int16_t>&)> compute);
-
-
-    static std::function<void (std::vector<int16_t>&, int, bool, size_t, size_t)>
-    GetFeatureCalculator(audio::AdMelSpectrogram& melSpec, TfLiteTensor* inputTensor, size_t cacheSize, float trainingMean)
-    {
-        std::function<void (std::vector<int16_t>&, size_t, bool, size_t, size_t)> melSpecFeatureCalc;
-
-        TfLiteQuantization quant = inputTensor->quantization;
-
-        if (kTfLiteAffineQuantization == quant.type) {
-
-            auto *quantParams = (TfLiteAffineQuantization *) quant.params;
-            const float quantScale = quantParams->scale->data[0];
-            const int quantOffset = quantParams->zero_point->data[0];
-
-            switch (inputTensor->type) {
-                case kTfLiteInt8: {
-                    melSpecFeatureCalc = FeatureCalc<int8_t>(inputTensor,
-                                                             cacheSize,
-                                                             [=, &melSpec](std::vector<int16_t>& audioDataWindow) {
-                                                                 return melSpec.MelSpecComputeQuant<int8_t>(
-                                                                         audioDataWindow,
-                                                                         quantScale,
-                                                                         quantOffset,
-                                                                         trainingMean);
-                                                             }
-                    );
-                    break;
-                }
-                case kTfLiteUInt8: {
-                    melSpecFeatureCalc = FeatureCalc<uint8_t>(inputTensor,
-                                                              cacheSize,
-                                                              [=, &melSpec](std::vector<int16_t>& audioDataWindow) {
-                                                                  return melSpec.MelSpecComputeQuant<uint8_t>(
-                                                                          audioDataWindow,
-                                                                          quantScale,
-                                                                          quantOffset,
-                                                                          trainingMean);
-                                                              }
-                    );
-                    break;
-                }
-                case kTfLiteInt16: {
-                    melSpecFeatureCalc = FeatureCalc<int16_t>(inputTensor,
-                                                              cacheSize,
-                                                              [=, &melSpec](std::vector<int16_t>& audioDataWindow) {
-                                                                  return melSpec.MelSpecComputeQuant<int16_t>(
-                                                                          audioDataWindow,
-                                                                          quantScale,
-                                                                          quantOffset,
-                                                                          trainingMean);
-                                                              }
-                    );
-                    break;
-                }
-                default:
-                printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type));
-            }
-
-
-        } else {
-            melSpecFeatureCalc = melSpecFeatureCalc = FeatureCalc<float>(inputTensor,
-                                                                         cacheSize,
-                                                                         [=, &melSpec](
-                                                                                 std::vector<int16_t>& audioDataWindow) {
-                                                                             return melSpec.ComputeMelSpec(
-                                                                                     audioDataWindow,
-                                                                                     trainingMean);
-                                                                         });
+        for (size_t i = 0; i < machineIdxInString; ++i) {
+            delimiterStart = wavFileName.find(delimiter);
+            subString = wavFileName.substr(0, delimiterStart);
+            wavFileName.erase(0, delimiterStart + delimiter.length());
         }
-        return melSpecFeatureCalc;
+
+        /* At this point substring should be 00.wav */
+        delimiter = ".";  /* Second character used to split the file name up. */
+        delimiterStart = subString.find(delimiter);
+        subString = (delimiterStart != std::string::npos) ? subString.substr(0, delimiterStart) : subString;
+
+        auto is_number = [](const std::string& str) ->  bool
+        {
+            std::string::const_iterator it = str.begin();
+            while (it != str.end() && std::isdigit(*it)) ++it;
+            return !str.empty() && it == str.end();
+        };
+
+        const int8_t machineIdx = is_number(subString) ? std::stoi(subString) : -1;
+
+        /* Return corresponding index in the output vector. */
+        if (machineIdx == 0) {
+            return 0;
+        } else if (machineIdx == 2) {
+            return 1;
+        } else if (machineIdx == 4) {
+            return 2;
+        } else if (machineIdx == 6) {
+            return 3;
+        } else {
+            printf_err("%d is an invalid machine index \n", machineIdx);
+            return -1;
+        }
     }
 
 } /* namespace app */
diff --git a/source/use_case/asr/include/Wav2LetterModel.hpp b/source/use_case/asr/include/Wav2LetterModel.hpp
index 0078e44..bec70ab 100644
--- a/source/use_case/asr/include/Wav2LetterModel.hpp
+++ b/source/use_case/asr/include/Wav2LetterModel.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited. All rights reserved.rved.
+ * Copyright (c) 2021 Arm Limited. All rights reserved.
  * SPDX-License-Identifier: Apache-2.0
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/source/use_case/kws_asr/include/KwsProcessing.hpp b/source/use_case/kws_asr/include/KwsProcessing.hpp
new file mode 100644
index 0000000..d3de3b3
--- /dev/null
+++ b/source/use_case/kws_asr/include/KwsProcessing.hpp
@@ -0,0 +1,138 @@
+/*
+ * Copyright (c) 2022 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef KWS_PROCESSING_HPP
+#define KWS_PROCESSING_HPP
+
+#include <AudioUtils.hpp>
+#include "BaseProcessing.hpp"
+#include "Model.hpp"
+#include "Classifier.hpp"
+#include "MicroNetKwsMfcc.hpp"
+
+#include <functional>
+
+namespace arm {
+namespace app {
+
+    /**
+     * @brief   Pre-processing class for Keyword Spotting use case.
+     *          Implements methods declared by BasePreProcess and anything else needed
+     *          to populate input tensors ready for inference.
+     */
+    class KwsPreProcess : public BasePreProcess {
+
+    public:
+        /**
+         * @brief       Constructor
+         * @param[in]   inputTensor        Pointer to the TFLite Micro input Tensor.
+         * @param[in]   numFeatures        How many MFCC features to use.
+         * @param[in]   numFeatureFrames   Number of MFCC vectors that need to be calculated
+         *                                 for an inference.
+         * @param[in]   mfccFrameLength    Number of audio samples used to calculate one set of MFCC values when
+         *                                 sliding a window through the audio sample.
+         * @param[in]   mfccFrameStride    Number of audio samples between consecutive windows.
+         **/
+        explicit KwsPreProcess(TfLiteTensor* inputTensor, size_t numFeatures, size_t numFeatureFrames,
+                               int mfccFrameLength, int mfccFrameStride);
+
+        /**
+         * @brief       Should perform pre-processing of 'raw' input audio data and load it into
+         *              TFLite Micro input tensors ready for inference.
+         * @param[in]   input      Pointer to the data that pre-processing will work on.
+         * @param[in]   inputSize  Size of the input data.
+         * @return      true if successful, false otherwise.
+         **/
+        bool DoPreProcess(const void* input, size_t inputSize) override;
+
+        size_t m_audioWindowIndex = 0;  /* Index of audio slider, used when caching features in longer clips. */
+        size_t m_audioDataWindowSize;   /* Amount of audio needed for 1 inference. */
+        size_t m_audioDataStride;       /* Amount of audio to stride across if doing >1 inference in longer clips. */
+
+    private:
+        TfLiteTensor* m_inputTensor;    /* Model input tensor. */
+        const int m_mfccFrameLength;
+        const int m_mfccFrameStride;
+        const size_t m_numMfccFrames;   /* How many sets of m_numMfccFeats. */
+
+        audio::MicroNetKwsMFCC m_mfcc;
+        audio::SlidingWindow<const int16_t> m_mfccSlidingWindow;
+        size_t m_numMfccVectorsInAudioStride;
+        size_t m_numReusedMfccVectors;
+        std::function<void (std::vector<int16_t>&, int, bool, size_t)> m_mfccFeatureCalculator;
+
+        /**
+         * @brief Returns a function to perform feature calculation and populates input tensor data with
+         * MFCC data.
+         *
+         * Input tensor data type check is performed to choose correct MFCC feature data type.
+         * If tensor has an integer data type then original features are quantised.
+         *
+         * Warning: MFCC calculator provided as input must have the same life scope as returned function.
+         *
+         * @param[in]       mfcc          MFCC feature calculator.
+         * @param[in,out]   inputTensor   Input tensor pointer to store calculated features.
+         * @param[in]       cacheSize     Size of the feature vectors cache (number of feature vectors).
+         * @return          Function to be called providing audio sample and sliding window index.
+         */
+        std::function<void (std::vector<int16_t>&, int, bool, size_t)>
+        GetFeatureCalculator(audio::MicroNetKwsMFCC&  mfcc,
+                             TfLiteTensor*            inputTensor,
+                             size_t                   cacheSize);
+
+        template<class T>
+        std::function<void (std::vector<int16_t>&, size_t, bool, size_t)>
+        FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
+                    std::function<std::vector<T> (std::vector<int16_t>& )> compute);
+    };
+
+    /**
+     * @brief   Post-processing class for Keyword Spotting use case.
+     *          Implements methods declared by BasePostProcess and anything else needed
+     *          to populate result vector.
+     */
+    class KwsPostProcess : public BasePostProcess {
+
+    private:
+        TfLiteTensor* m_outputTensor;                   /* Model output tensor. */
+        Classifier& m_kwsClassifier;                    /* KWS Classifier object. */
+        const std::vector<std::string>& m_labels;       /* KWS Labels. */
+        std::vector<ClassificationResult>& m_results;   /* Results vector for a single inference. */
+
+    public:
+        /**
+         * @brief           Constructor
+         * @param[in]       outputTensor   Pointer to the TFLite Micro output Tensor.
+         * @param[in]       classifier     Classifier object used to get top N results from classification.
+         * @param[in]       labels         Vector of string labels to identify each output of the model.
+         * @param[in/out]   results        Vector of classification results to store decoded outputs.
+         **/
+        KwsPostProcess(TfLiteTensor* outputTensor, Classifier& classifier,
+                       const std::vector<std::string>& labels,
+                       std::vector<ClassificationResult>& results);
+
+        /**
+         * @brief    Should perform post-processing of the result of inference then
+         *           populate KWS result data for any later use.
+         * @return   true if successful, false otherwise.
+         **/
+        bool DoPostProcess() override;
+    };
+
+} /* namespace app */
+} /* namespace arm */
+
+#endif /* KWS_PROCESSING_HPP */
\ No newline at end of file
diff --git a/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp b/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp
index 43bd390..af6ba5f 100644
--- a/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp
+++ b/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
  * SPDX-License-Identifier: Apache-2.0
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
@@ -24,7 +24,7 @@
 namespace audio {
 
     /* Class to provide MicroNet specific MFCC calculation requirements. */
-    class MicroNetMFCC : public MFCC {
+    class MicroNetKwsMFCC : public MFCC {
 
     public:
         static constexpr uint32_t  ms_defaultSamplingFreq = 16000;
@@ -34,14 +34,14 @@
         static constexpr bool      ms_defaultUseHtkMethod =  true;
 
 
-        explicit MicroNetMFCC(const size_t numFeats, const size_t frameLen)
+        explicit MicroNetKwsMFCC(const size_t numFeats, const size_t frameLen)
             :  MFCC(MfccParams(
                         ms_defaultSamplingFreq, ms_defaultNumFbankBins,
                         ms_defaultMelLoFreq, ms_defaultMelHiFreq,
                         numFeats, frameLen, ms_defaultUseHtkMethod))
         {}
-        MicroNetMFCC()  = delete;
-        ~MicroNetMFCC() = default;
+        MicroNetKwsMFCC()  = delete;
+        ~MicroNetKwsMFCC() = default;
     };
 
 } /* namespace audio */
diff --git a/source/use_case/kws_asr/include/Wav2LetterModel.hpp b/source/use_case/kws_asr/include/Wav2LetterModel.hpp
index 7c327b3..0e1adc5 100644
--- a/source/use_case/kws_asr/include/Wav2LetterModel.hpp
+++ b/source/use_case/kws_asr/include/Wav2LetterModel.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
  * SPDX-License-Identifier: Apache-2.0
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
@@ -34,14 +34,18 @@
 namespace app {
 
     class Wav2LetterModel : public Model {
-        	
+
     public:
         /* Indices for the expected model - based on input and output tensor shapes */
-        static constexpr uint32_t ms_inputRowsIdx = 1;
-        static constexpr uint32_t ms_inputColsIdx = 2;
+        static constexpr uint32_t ms_inputRowsIdx  = 1;
+        static constexpr uint32_t ms_inputColsIdx  = 2;
         static constexpr uint32_t ms_outputRowsIdx = 2;
         static constexpr uint32_t ms_outputColsIdx = 3;
 
+        /* Model specific constants. */
+        static constexpr uint32_t ms_blankTokenIdx   = 28;
+        static constexpr uint32_t ms_numMfccFeatures = 13;
+
     protected:
         /** @brief   Gets the reference to op resolver interface class. */
         const tflite::MicroOpResolver& GetOpResolver() override;
diff --git a/source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp b/source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp
index 029a641..d1bc9a2 100644
--- a/source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp
+++ b/source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
  * SPDX-License-Identifier: Apache-2.0
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,88 +14,95 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-#ifndef KWS_ASR_WAV2LET_POSTPROC_HPP
-#define KWS_ASR_WAV2LET_POSTPROC_HPP
+#ifndef KWS_ASR_WAV2LETTER_POSTPROCESS_HPP
+#define KWS_ASR_WAV2LETTER_POSTPROCESS_HPP
 
-#include "TensorFlowLiteMicro.hpp" /* TensorFlow headers */
+#include "TensorFlowLiteMicro.hpp"   /* TensorFlow headers. */
+#include "BaseProcessing.hpp"
+#include "AsrClassifier.hpp"
+#include "AsrResult.hpp"
+#include "log_macros.h"
 
 namespace arm {
 namespace app {
-namespace audio {
-namespace asr {
 
     /**
      * @brief   Helper class to manage tensor post-processing for "wav2letter"
      *          output.
      */
-    class Postprocess {
+    class AsrPostProcess : public BasePostProcess {
     public:
+        bool m_lastIteration = false;   /* Flag to set if processing the last set of data for a clip. */
+
         /**
-         * @brief       Constructor
-         * @param[in]   contextLen     Left and right context length for
-         *                             output tensor.
-         * @param[in]   innerLen       This is the length of the section
-         *                             between left and right context.
-         * @param[in]   blankTokenIdx  Blank token index.
+         * @brief           Constructor
+         * @param[in]       outputTensor       Pointer to the TFLite Micro output Tensor.
+         * @param[in]       classifier         Object used to get top N results from classification.
+         * @param[in]       labels             Vector of string labels to identify each output of the model.
+         * @param[in/out]   result             Vector of classification results to store decoded outputs.
+         * @param[in]       outputContextLen   Left/right context length for output tensor.
+         * @param[in]       blankTokenIdx      Index in the labels that the "Blank token" takes.
+         * @param[in]       reductionAxis      The axis that the logits of each time step is on.
          **/
-        Postprocess(uint32_t contextLen,
-                    uint32_t innerLen,
-                    uint32_t blankTokenIdx);
-
-        Postprocess() = delete;
-        ~Postprocess() = default;
+        AsrPostProcess(TfLiteTensor* outputTensor, AsrClassifier& classifier,
+                       const std::vector<std::string>& labels, asr::ResultVec& result,
+                       uint32_t outputContextLen,
+                       uint32_t blankTokenIdx, uint32_t reductionAxis);
 
         /**
-         * @brief       Erases the required part of the tensor based
-         *              on context lengths set up during initialisation
-         * @param[in]   tensor          Pointer to the tensor
-         * @param[in]   axisIdx         Index of the axis on which erase is
-         *                              performed.
-         * @param[in]   lastIteration   Flag to signal is this is the
-         *                              last iteration in which case
-         *                              the right context is preserved.
-         * @return      true if successful, false otherwise.
-         */
-        bool Invoke(TfLiteTensor*  tensor,
-                    uint32_t axisIdx,
-                    bool lastIteration = false);
+         * @brief    Should perform post-processing of the result of inference then
+         *           populate ASR result data for any later use.
+         * @return   true if successful, false otherwise.
+         **/
+        bool DoPostProcess() override;
+
+        /** @brief   Gets the output inner length for post-processing. */
+        static uint32_t GetOutputInnerLen(const TfLiteTensor*, uint32_t outputCtxLen);
+
+        /** @brief   Gets the output context length (left/right) for post-processing. */
+        static uint32_t GetOutputContextLen(const Model& model, uint32_t inputCtxLen);
+
+        /** @brief   Gets the number of feature vectors to be computed. */
+        static uint32_t GetNumFeatureVectors(const Model& model);
 
     private:
-        uint32_t    m_contextLen;      /* Lengths of left and right contexts. */
-        uint32_t    m_innerLen;        /* Length of inner context. */
-        uint32_t    m_totalLen;        /* Total length of the required axis. */
-        uint32_t    m_countIterations; /* Current number of iterations. */
-        uint32_t    m_blankTokenIdx;   /* Index of the labels blank token. */
+        AsrClassifier& m_classifier;                /* ASR Classifier object. */
+        TfLiteTensor* m_outputTensor;               /* Model output tensor. */
+        const std::vector<std::string>& m_labels;   /* ASR Labels. */
+        asr::ResultVec & m_results;                 /* Results vector for a single inference. */
+        uint32_t m_outputContextLen;                /* lengths of left/right contexts for output. */
+        uint32_t m_outputInnerLen;                  /* Length of output inner context. */
+        uint32_t m_totalLen;                        /* Total length of the required axis. */
+        uint32_t m_countIterations;                 /* Current number of iterations. */
+        uint32_t m_blankTokenIdx;                   /* Index of the labels blank token. */
+        uint32_t m_reductionAxisIdx;                /* Axis containing output logits for a single step. */
+
         /**
-         * @brief       Checks if the tensor and axis index are valid
-         *              inputs to the object - based on how it has been
-         *              initialised.
-         * @return      true if valid, false otherwise.
+         * @brief    Checks if the tensor and axis index are valid
+         *           inputs to the object - based on how it has been initialised.
+         * @return   true if valid, false otherwise.
          */
         bool IsInputValid(TfLiteTensor*  tensor,
-                          const uint32_t axisIdx) const;
+                          uint32_t axisIdx) const;
 
         /**
-         * @brief       Gets the tensor data element size in bytes based
-         *              on the tensor type.
-         * @return      Size in bytes, 0 if not supported.
+         * @brief    Gets the tensor data element size in bytes based
+         *           on the tensor type.
+         * @return   Size in bytes, 0 if not supported.
          */
-        uint32_t GetTensorElementSize(TfLiteTensor* tensor);
+        static uint32_t GetTensorElementSize(TfLiteTensor* tensor);
 
         /**
-         * @brief       Erases sections from the data assuming row-wise
-         *              arrangement along the context axis.
-         * @return      true if successful, false otherwise.
+         * @brief    Erases sections from the data assuming row-wise
+         *           arrangement along the context axis.
+         * @return   true if successful, false otherwise.
          */
         bool EraseSectionsRowWise(uint8_t* ptrData,
-                                  const uint32_t strideSzBytes,
-                                  const bool lastIteration);
-
+                                  uint32_t strideSzBytes,
+                                  bool lastIteration);
     };
 
-} /* namespace asr */
-} /* namespace audio */
 } /* namespace app */
 } /* namespace arm */
 
-#endif /* KWS_ASR_WAV2LET_POSTPROC_HPP */
\ No newline at end of file
+#endif /* KWS_ASR_WAV2LETTER_POSTPROCESS_HPP */
\ No newline at end of file
diff --git a/source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp b/source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp
index 3609c49..1224c23 100644
--- a/source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp
+++ b/source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
  * SPDX-License-Identifier: Apache-2.0
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,56 +14,51 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-#ifndef KWS_ASR_WAV2LET_PREPROC_HPP
-#define KWS_ASR_WAV2LET_PREPROC_HPP
+#ifndef KWS_ASR_WAV2LETTER_PREPROCESS_HPP
+#define KWS_ASR_WAV2LETTER_PREPROCESS_HPP
 
 #include "Wav2LetterModel.hpp"
 #include "Wav2LetterMfcc.hpp"
 #include "AudioUtils.hpp"
 #include "DataStructures.hpp"
+#include "BaseProcessing.hpp"
 #include "log_macros.h"
 
 namespace arm {
 namespace app {
-namespace audio {
-namespace asr {
 
     /* Class to facilitate pre-processing calculation for Wav2Letter model
      * for ASR. */
-    using AudioWindow = SlidingWindow <const int16_t>;
+    using AudioWindow = audio::SlidingWindow<const int16_t>;
 
-    class Preprocess {
+    class AsrPreProcess : public BasePreProcess {
     public:
         /**
-         * @brief       Constructor
-         * @param[in]   numMfccFeatures   Number of MFCC features per window.
-         * @param[in]   windowLen         Number of elements in a window.
-         * @param[in]   windowStride      Stride (in number of elements) for
-         *                                moving the window.
-         * @param[in]   numMfccVectors    Number of MFCC vectors per window.
-        */
-        Preprocess(
-            uint32_t  numMfccFeatures,
-            uint32_t  windowLen,
-            uint32_t  windowStride,
-            uint32_t  numMfccVectors);
-        Preprocess() = delete;
-        ~Preprocess() = default;
+         * @brief       Constructor.
+         * @param[in]   inputTensor        Pointer to the TFLite Micro input Tensor.
+         * @param[in]   numMfccFeatures    Number of MFCC features per window.
+         * @param[in]   numFeatureFrames   Number of MFCC vectors that need to be calculated
+         *                                 for an inference.
+         * @param[in]   mfccWindowLen      Number of audio elements to calculate MFCC features per window.
+         * @param[in]   mfccWindowStride   Stride (in number of elements) for moving the MFCC window.
+         */
+        AsrPreProcess(TfLiteTensor* inputTensor,
+                      uint32_t  numMfccFeatures,
+                      uint32_t  numFeatureFrames,
+                      uint32_t  mfccWindowLen,
+                      uint32_t  mfccWindowStride);
 
         /**
          * @brief       Calculates the features required from audio data. This
          *              includes MFCC, first and second order deltas,
          *              normalisation and finally, quantisation. The tensor is
-         *              populated with feature from a given window placed along
+         *              populated with features from a given window placed along
          *              in a single row.
          * @param[in]   audioData      Pointer to the first element of audio data.
          * @param[in]   audioDataLen   Number of elements in the audio data.
-         * @param[in]   tensor         Tensor to be populated.
          * @return      true if successful, false in case of error.
          */
-        bool Invoke(const int16_t * audioData,
-                    uint32_t  audioDataLen,
-                    TfLiteTensor *  tensor);
+        bool DoPreProcess(const void* audioData, size_t audioDataLen) override;
 
     protected:
          /**
@@ -73,49 +68,32 @@
           * @param[in]  mfcc     MFCC buffers.
           * @param[out] delta1   Result of the first diff computation.
           * @param[out] delta2   Result of the second diff computation.
-          *
-          * @return true if successful, false otherwise.
+          * @return     true if successful, false otherwise.
           */
          static bool ComputeDeltas(Array2d<float>& mfcc,
                                    Array2d<float>& delta1,
                                    Array2d<float>& delta2);
 
         /**
-         * @brief       Given a 2D vector of floats, computes the mean.
-         * @param[in]   vec   Vector of vector of floats.
-         * @return      Mean value.
-         */
-        static float GetMean(Array2d<float>& vec);
-
-        /**
-         * @brief       Given a 2D vector of floats, computes the stddev.
-         * @param[in]   vec    Vector of vector of floats.
-         * @param[in]   mean   Mean value of the vector passed in.
-         * @return      stddev value.
-         */
-        static float GetStdDev(Array2d<float>& vec,
-                               const float mean);
-
-        /**
-         * @brief           Given a 2D vector of floats, normalises it using
-         *                  the mean and the stddev
+         * @brief           Given a 2D vector of floats, rescale it to have mean of 0 and
+        *                   standard deviation of 1.
          * @param[in,out]   vec   Vector of vector of floats.
          */
-        static void NormaliseVec(Array2d<float>& vec);
+        static void StandardizeVecF32(Array2d<float>& vec);
 
         /**
-         * @brief       Normalises the MFCC and delta buffers.
+         * @brief   Standardizes all the MFCC and delta buffers to have mean 0 and std. dev 1.
          */
-        void Normalise();
+        void Standarize();
 
         /**
          * @brief       Given the quantisation and data type limits, computes
          *              the quantised values of a floating point input data.
-         * @param[in]   elem            Element to be quantised.
-         * @param[in]   quantScale      Scale.
-         * @param[in]   quantOffset     Offset.
-         * @param[in]   minVal          Numerical limit - minimum.
-         * @param[in]   maxVal          Numerical limit - maximum.
+         * @param[in]   elem          Element to be quantised.
+         * @param[in]   quantScale    Scale.
+         * @param[in]   quantOffset   Offset.
+         * @param[in]   minVal        Numerical limit - minimum.
+         * @param[in]   maxVal        Numerical limit - maximum.
          * @return      Floating point quantised value.
          */
         static float GetQuantElem(
@@ -133,44 +111,43 @@
          *              this being the convolution speed up (as we can use
          *              contiguous memory). The output, however, requires the
          *              time axis to be in column major arrangement.
-         * @param[in]   outputBuf       Pointer to the output buffer.
-         * @param[in]   outputBufSz     Output buffer's size.
-         * @param[in]   quantScale      Quantisation scale.
-         * @param[in]   quantOffset     Quantisation offset.
+         * @param[in]   outputBuf     Pointer to the output buffer.
+         * @param[in]   outputBufSz   Output buffer's size.
+         * @param[in]   quantScale    Quantisation scale.
+         * @param[in]   quantOffset   Quantisation offset.
          */
         template <typename T>
         bool Quantise(
-                T *             outputBuf,
+                T*              outputBuf,
                 const uint32_t  outputBufSz,
                 const float     quantScale,
                 const int       quantOffset)
         {
-            /* Check the output size will for everything. */
+            /* Check the output size will fit everything. */
             if (outputBufSz < (this->m_mfccBuf.size(0) * 3 * sizeof(T))) {
                 printf_err("Tensor size too small for features\n");
                 return false;
             }
 
             /* Populate. */
-            T * outputBufMfcc = outputBuf;
-            T * outputBufD1 = outputBuf + this->m_numMfccFeats;
-            T * outputBufD2 = outputBufD1 + this->m_numMfccFeats;
+            T* outputBufMfcc = outputBuf;
+            T* outputBufD1 = outputBuf + this->m_numMfccFeats;
+            T* outputBufD2 = outputBufD1 + this->m_numMfccFeats;
             const uint32_t ptrIncr = this->m_numMfccFeats * 2;  /* (3 vectors - 1 vector) */
 
             const float minVal = std::numeric_limits<T>::min();
             const float maxVal = std::numeric_limits<T>::max();
 
-            /* We need to do a transpose while copying and concatenating
-             * the tensor. */
-            for (uint32_t j = 0; j < this->m_numFeatVectors; ++j) {
+            /* Need to transpose while copying and concatenating the tensor. */
+            for (uint32_t j = 0; j < this->m_numFeatureFrames; ++j) {
                 for (uint32_t i = 0; i < this->m_numMfccFeats; ++i) {
-                    *outputBufMfcc++ = static_cast<T>(this->GetQuantElem(
+                    *outputBufMfcc++ = static_cast<T>(AsrPreProcess::GetQuantElem(
                             this->m_mfccBuf(i, j), quantScale,
                             quantOffset, minVal, maxVal));
-                    *outputBufD1++ = static_cast<T>(this->GetQuantElem(
+                    *outputBufD1++ = static_cast<T>(AsrPreProcess::GetQuantElem(
                             this->m_delta1Buf(i, j), quantScale,
                             quantOffset, minVal, maxVal));
-                    *outputBufD2++ = static_cast<T>(this->GetQuantElem(
+                    *outputBufD2++ = static_cast<T>(AsrPreProcess::GetQuantElem(
                             this->m_delta2Buf(i, j), quantScale,
                             quantOffset, minVal, maxVal));
                 }
@@ -183,24 +160,23 @@
         }
 
     private:
-        Wav2LetterMFCC      m_mfcc;            /* MFCC instance. */
+        audio::Wav2LetterMFCC   m_mfcc;          /* MFCC instance. */
+        TfLiteTensor*           m_inputTensor;   /* Model input tensor. */
 
         /* Actual buffers to be populated. */
-        Array2d<float>      m_mfccBuf;         /* Contiguous buffer 1D: MFCC */
-        Array2d<float>      m_delta1Buf;       /* Contiguous buffer 1D: Delta 1 */
-        Array2d<float>      m_delta2Buf;       /* Contiguous buffer 1D: Delta 2 */
+        Array2d<float>   m_mfccBuf;              /* Contiguous buffer 1D: MFCC */
+        Array2d<float>   m_delta1Buf;            /* Contiguous buffer 1D: Delta 1 */
+        Array2d<float>   m_delta2Buf;            /* Contiguous buffer 1D: Delta 2 */
 
-        uint32_t            m_windowLen;       /* Window length for MFCC. */
-        uint32_t            m_windowStride;    /* Window stride len for MFCC. */
-        uint32_t            m_numMfccFeats;    /* Number of MFCC features per window. */
-        uint32_t            m_numFeatVectors;  /* Number of m_numMfccFeats. */
-        AudioWindow         m_window;          /* Sliding window. */
+        uint32_t         m_mfccWindowLen;        /* Window length for MFCC. */
+        uint32_t         m_mfccWindowStride;     /* Window stride len for MFCC. */
+        uint32_t         m_numMfccFeats;         /* Number of MFCC features per window. */
+        uint32_t         m_numFeatureFrames;     /* How many sets of m_numMfccFeats. */
+        AudioWindow      m_mfccSlidingWindow;    /* Sliding window to calculate MFCCs. */
 
     };
 
-} /* namespace asr */
-} /* namespace audio */
 } /* namespace app */
 } /* namespace arm */
 
-#endif /* KWS_ASR_WAV2LET_PREPROC_HPP */
\ No newline at end of file
+#endif /* KWS_ASR_WAV2LETTER_PREPROCESS_HPP */
\ No newline at end of file
diff --git a/source/use_case/kws_asr/src/KwsProcessing.cc b/source/use_case/kws_asr/src/KwsProcessing.cc
new file mode 100644
index 0000000..328709d
--- /dev/null
+++ b/source/use_case/kws_asr/src/KwsProcessing.cc
@@ -0,0 +1,212 @@
+/*
+ * Copyright (c) 2022 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "KwsProcessing.hpp"
+#include "ImageUtils.hpp"
+#include "log_macros.h"
+#include "MicroNetKwsModel.hpp"
+
+namespace arm {
+namespace app {
+
+    KwsPreProcess::KwsPreProcess(TfLiteTensor* inputTensor, size_t numFeatures, size_t numMfccFrames,
+            int mfccFrameLength, int mfccFrameStride
+        ):
+        m_inputTensor{inputTensor},
+        m_mfccFrameLength{mfccFrameLength},
+        m_mfccFrameStride{mfccFrameStride},
+        m_numMfccFrames{numMfccFrames},
+        m_mfcc{audio::MicroNetKwsMFCC(numFeatures, mfccFrameLength)}
+    {
+        this->m_mfcc.Init();
+
+        /* Deduce the data length required for 1 inference from the network parameters. */
+        this->m_audioDataWindowSize = this->m_numMfccFrames * this->m_mfccFrameStride +
+                (this->m_mfccFrameLength - this->m_mfccFrameStride);
+
+        /* Creating an MFCC feature sliding window for the data required for 1 inference. */
+        this->m_mfccSlidingWindow = audio::SlidingWindow<const int16_t>(nullptr, this->m_audioDataWindowSize,
+                this->m_mfccFrameLength, this->m_mfccFrameStride);
+
+        /* For longer audio clips we choose to move by half the audio window size
+         * => for a 1 second window size there is an overlap of 0.5 seconds. */
+        this->m_audioDataStride = this->m_audioDataWindowSize / 2;
+
+        /* To have the previously calculated features re-usable, stride must be multiple
+         * of MFCC features window stride. Reduce stride through audio if needed. */
+        if (0 != this->m_audioDataStride % this->m_mfccFrameStride) {
+            this->m_audioDataStride -= this->m_audioDataStride % this->m_mfccFrameStride;
+        }
+
+        this->m_numMfccVectorsInAudioStride = this->m_audioDataStride / this->m_mfccFrameStride;
+
+        /* Calculate number of the feature vectors in the window overlap region.
+         * These feature vectors will be reused.*/
+        this->m_numReusedMfccVectors = this->m_mfccSlidingWindow.TotalStrides() + 1
+                - this->m_numMfccVectorsInAudioStride;
+
+        /* Construct feature calculation function. */
+        this->m_mfccFeatureCalculator = GetFeatureCalculator(this->m_mfcc, this->m_inputTensor,
+                                                             this->m_numReusedMfccVectors);
+
+        if (!this->m_mfccFeatureCalculator) {
+            printf_err("Feature calculator not initialized.");
+        }
+    }
+
+    bool KwsPreProcess::DoPreProcess(const void* data, size_t inputSize)
+    {
+        UNUSED(inputSize);
+        if (data == nullptr) {
+            printf_err("Data pointer is null");
+        }
+
+        /* Set the features sliding window to the new address. */
+        auto input = static_cast<const int16_t*>(data);
+        this->m_mfccSlidingWindow.Reset(input);
+
+        /* Cache is only usable if we have more than 1 inference in an audio clip. */
+        bool useCache = this->m_audioWindowIndex > 0 && this->m_numReusedMfccVectors > 0;
+
+        /* Use a sliding window to calculate MFCC features frame by frame. */
+        while (this->m_mfccSlidingWindow.HasNext()) {
+            const int16_t* mfccWindow = this->m_mfccSlidingWindow.Next();
+
+            std::vector<int16_t> mfccFrameAudioData = std::vector<int16_t>(mfccWindow,
+                    mfccWindow + this->m_mfccFrameLength);
+
+            /* Compute features for this window and write them to input tensor. */
+            this->m_mfccFeatureCalculator(mfccFrameAudioData, this->m_mfccSlidingWindow.Index(),
+                                          useCache, this->m_numMfccVectorsInAudioStride);
+        }
+
+        debug("Input tensor populated \n");
+
+        return true;
+    }
+
+    /**
+     * @brief Generic feature calculator factory.
+     *
+     * Returns lambda function to compute features using features cache.
+     * Real features math is done by a lambda function provided as a parameter.
+     * Features are written to input tensor memory.
+     *
+     * @tparam T                Feature vector type.
+     * @param[in] inputTensor   Model input tensor pointer.
+     * @param[in] cacheSize     Number of feature vectors to cache. Defined by the sliding window overlap.
+     * @param[in] compute       Features calculator function.
+     * @return                  Lambda function to compute features.
+     */
+    template<class T>
+    std::function<void (std::vector<int16_t>&, size_t, bool, size_t)>
+    KwsPreProcess::FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
+                               std::function<std::vector<T> (std::vector<int16_t>& )> compute)
+    {
+        /* Feature cache to be captured by lambda function. */
+        static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
+
+        return [=](std::vector<int16_t>& audioDataWindow,
+                   size_t index,
+                   bool useCache,
+                   size_t featuresOverlapIndex)
+        {
+            T* tensorData = tflite::GetTensorData<T>(inputTensor);
+            std::vector<T> features;
+
+            /* Reuse features from cache if cache is ready and sliding windows overlap.
+             * Overlap is in the beginning of sliding window with a size of a feature cache. */
+            if (useCache && index < featureCache.size()) {
+                features = std::move(featureCache[index]);
+            } else {
+                features = std::move(compute(audioDataWindow));
+            }
+            auto size = features.size();
+            auto sizeBytes = sizeof(T) * size;
+            std::memcpy(tensorData + (index * size), features.data(), sizeBytes);
+
+            /* Start renewing cache as soon iteration goes out of the windows overlap. */
+            if (index >= featuresOverlapIndex) {
+                featureCache[index - featuresOverlapIndex] = std::move(features);
+            }
+        };
+    }
+
+    template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
+    KwsPreProcess::FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
+                                       size_t cacheSize,
+                                       std::function<std::vector<int8_t> (std::vector<int16_t>&)> compute);
+
+    template std::function<void(std::vector<int16_t>&, size_t, bool, size_t)>
+    KwsPreProcess::FeatureCalc<float>(TfLiteTensor* inputTensor,
+                                      size_t cacheSize,
+                                      std::function<std::vector<float>(std::vector<int16_t>&)> compute);
+
+
+    std::function<void (std::vector<int16_t>&, int, bool, size_t)>
+    KwsPreProcess::GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
+    {
+        std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc;
+
+        TfLiteQuantization quant = inputTensor->quantization;
+
+        if (kTfLiteAffineQuantization == quant.type) {
+            auto *quantParams = (TfLiteAffineQuantization *) quant.params;
+            const float quantScale = quantParams->scale->data[0];
+            const int quantOffset = quantParams->zero_point->data[0];
+
+            switch (inputTensor->type) {
+                case kTfLiteInt8: {
+                    mfccFeatureCalc = this->FeatureCalc<int8_t>(inputTensor,
+                                                          cacheSize,
+                                                          [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
+                                                              return mfcc.MfccComputeQuant<int8_t>(audioDataWindow,
+                                                                                                   quantScale,
+                                                                                                   quantOffset);
+                                                          }
+                    );
+                    break;
+                }
+                default:
+                printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type));
+            }
+        } else {
+            mfccFeatureCalc = this->FeatureCalc<float>(inputTensor, cacheSize,
+                    [&mfcc](std::vector<int16_t>& audioDataWindow) {
+                return mfcc.MfccCompute(audioDataWindow); }
+                );
+        }
+        return mfccFeatureCalc;
+    }
+
+    KwsPostProcess::KwsPostProcess(TfLiteTensor* outputTensor, Classifier& classifier,
+                                   const std::vector<std::string>& labels,
+                                   std::vector<ClassificationResult>& results)
+            :m_outputTensor{outputTensor},
+             m_kwsClassifier{classifier},
+             m_labels{labels},
+             m_results{results}
+    {}
+
+    bool KwsPostProcess::DoPostProcess()
+    {
+        return this->m_kwsClassifier.GetClassificationResults(
+                this->m_outputTensor, this->m_results,
+                this->m_labels, 1, true);
+    }
+
+} /* namespace app */
+} /* namespace arm */
\ No newline at end of file
diff --git a/source/use_case/kws_asr/src/MainLoop.cc b/source/use_case/kws_asr/src/MainLoop.cc
index 5c1d0e0..f1d97a0 100644
--- a/source/use_case/kws_asr/src/MainLoop.cc
+++ b/source/use_case/kws_asr/src/MainLoop.cc
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
  * SPDX-License-Identifier: Apache-2.0
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,7 +14,6 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-#include "hal.h"                    /* Brings in platform definitions. */
 #include "InputFiles.hpp"           /* For input images. */
 #include "Labels_micronetkws.hpp"   /* For MicroNetKws label strings. */
 #include "Labels_wav2letter.hpp"    /* For Wav2Letter label strings. */
@@ -24,8 +23,6 @@
 #include "Wav2LetterModel.hpp"      /* ASR model class for running inference. */
 #include "UseCaseCommonUtils.hpp"   /* Utils functions. */
 #include "UseCaseHandler.hpp"       /* Handlers for different user options. */
-#include "Wav2LetterPreprocess.hpp" /* ASR pre-processing class. */
-#include "Wav2LetterPostprocess.hpp"/* ASR post-processing class. */
 #include "log_macros.h"
 
 using KwsClassifier = arm::app::Classifier;
@@ -53,19 +50,8 @@
     fflush(stdout);
 }
 
-/** @brief Gets the number of MFCC features for a single window. */
-static uint32_t GetNumMfccFeatures(const arm::app::Model& model);
-
-/** @brief Gets the number of MFCC feature vectors to be computed. */
-static uint32_t GetNumMfccFeatureVectors(const arm::app::Model& model);
-
-/** @brief Gets the output context length (left and right) for post-processing. */
-static uint32_t GetOutputContextLen(const arm::app::Model& model,
-                                    uint32_t inputCtxLen);
-
-/** @brief Gets the output inner length for post-processing. */
-static uint32_t GetOutputInnerLen(const arm::app::Model& model,
-                                  uint32_t outputCtxLen);
+/** @brief   Verify input and output tensor are of certain min dimensions. */
+static bool VerifyTensorDimensions(const arm::app::Model& model);
 
 void main_loop()
 {
@@ -84,61 +70,46 @@
     if (!asrModel.Init(kwsModel.GetAllocator())) {
         printf_err("Failed to initialise ASR model\n");
         return;
+    } else if (!VerifyTensorDimensions(asrModel)) {
+        printf_err("Model's input or output dimension verification failed\n");
+        return;
     }
 
-    /* Initialise ASR pre-processing. */
-    arm::app::audio::asr::Preprocess prep(
-            GetNumMfccFeatures(asrModel),
-            arm::app::asr::g_FrameLength,
-            arm::app::asr::g_FrameStride,
-            GetNumMfccFeatureVectors(asrModel));
-
-    /* Initialise ASR post-processing. */
-    const uint32_t outputCtxLen = GetOutputContextLen(asrModel, arm::app::asr::g_ctxLen);
-    const uint32_t blankTokenIdx = 28;
-    arm::app::audio::asr::Postprocess postp(
-            outputCtxLen,
-            GetOutputInnerLen(asrModel, outputCtxLen),
-            blankTokenIdx);
-
     /* Instantiate application context. */
     arm::app::ApplicationContext caseContext;
 
     arm::app::Profiler profiler{"kws_asr"};
     caseContext.Set<arm::app::Profiler&>("profiler", profiler);
-    caseContext.Set<arm::app::Model&>("kwsmodel", kwsModel);
-    caseContext.Set<arm::app::Model&>("asrmodel", asrModel);
+    caseContext.Set<arm::app::Model&>("kwsModel", kwsModel);
+    caseContext.Set<arm::app::Model&>("asrModel", asrModel);
     caseContext.Set<uint32_t>("clipIndex", 0);
     caseContext.Set<uint32_t>("ctxLen", arm::app::asr::g_ctxLen);  /* Left and right context length (MFCC feat vectors). */
-    caseContext.Set<int>("kwsframeLength", arm::app::kws::g_FrameLength);
-    caseContext.Set<int>("kwsframeStride", arm::app::kws::g_FrameStride);
-    caseContext.Set<float>("kwsscoreThreshold", arm::app::kws::g_ScoreThreshold);  /* Normalised score threshold. */
+    caseContext.Set<int>("kwsFrameLength", arm::app::kws::g_FrameLength);
+    caseContext.Set<int>("kwsFrameStride", arm::app::kws::g_FrameStride);
+    caseContext.Set<float>("kwsScoreThreshold", arm::app::kws::g_ScoreThreshold);  /* Normalised score threshold. */
     caseContext.Set<uint32_t >("kwsNumMfcc", arm::app::kws::g_NumMfcc);
     caseContext.Set<uint32_t >("kwsNumAudioWins", arm::app::kws::g_NumAudioWins);
 
-    caseContext.Set<int>("asrframeLength", arm::app::asr::g_FrameLength);
-    caseContext.Set<int>("asrframeStride", arm::app::asr::g_FrameStride);
-    caseContext.Set<float>("asrscoreThreshold", arm::app::asr::g_ScoreThreshold);  /* Normalised score threshold. */
+    caseContext.Set<int>("asrFrameLength", arm::app::asr::g_FrameLength);
+    caseContext.Set<int>("asrFrameStride", arm::app::asr::g_FrameStride);
+    caseContext.Set<float>("asrScoreThreshold", arm::app::asr::g_ScoreThreshold);  /* Normalised score threshold. */
 
     KwsClassifier kwsClassifier;  /* Classifier wrapper object. */
     arm::app::AsrClassifier asrClassifier;  /* Classifier wrapper object. */
-    caseContext.Set<arm::app::Classifier&>("kwsclassifier", kwsClassifier);
-    caseContext.Set<arm::app::AsrClassifier&>("asrclassifier", asrClassifier);
-
-    caseContext.Set<arm::app::audio::asr::Preprocess&>("preprocess", prep);
-    caseContext.Set<arm::app::audio::asr::Postprocess&>("postprocess", postp);
+    caseContext.Set<arm::app::Classifier&>("kwsClassifier", kwsClassifier);
+    caseContext.Set<arm::app::AsrClassifier&>("asrClassifier", asrClassifier);
 
     std::vector<std::string> asrLabels;
     arm::app::asr::GetLabelsVector(asrLabels);
     std::vector<std::string> kwsLabels;
     arm::app::kws::GetLabelsVector(kwsLabels);
-    caseContext.Set<const std::vector <std::string>&>("asrlabels", asrLabels);
-    caseContext.Set<const std::vector <std::string>&>("kwslabels", kwsLabels);
+    caseContext.Set<const std::vector <std::string>&>("asrLabels", asrLabels);
+    caseContext.Set<const std::vector <std::string>&>("kwsLabels", kwsLabels);
 
     /* KWS keyword that triggers ASR and associated checks */
-    std::string triggerKeyword = std::string("yes");
+    std::string triggerKeyword = std::string("no");
     if (std::find(kwsLabels.begin(), kwsLabels.end(), triggerKeyword) != kwsLabels.end()) {
-        caseContext.Set<const std::string &>("triggerkeyword", triggerKeyword);
+        caseContext.Set<const std::string &>("triggerKeyword", triggerKeyword);
     }
     else {
         printf_err("Selected trigger keyword not found in labels file\n");
@@ -196,50 +167,26 @@
     info("Main loop terminated.\n");
 }
 
-static uint32_t GetNumMfccFeatures(const arm::app::Model& model)
+static bool VerifyTensorDimensions(const arm::app::Model& model)
 {
+    /* Populate tensor related parameters. */
     TfLiteTensor* inputTensor = model.GetInputTensor(0);
-    const int inputCols = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputColsIdx];
-    if (0 != inputCols % 3) {
-        printf_err("Number of input columns is not a multiple of 3\n");
-    }
-    return std::max(inputCols/3, 0);
-}
-
-static uint32_t GetNumMfccFeatureVectors(const arm::app::Model& model)
-{
-    TfLiteTensor* inputTensor = model.GetInputTensor(0);
-    const int inputRows = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx];
-    return std::max(inputRows, 0);
-}
-
-static uint32_t GetOutputContextLen(const arm::app::Model& model, const uint32_t inputCtxLen)
-{
-    const uint32_t inputRows = GetNumMfccFeatureVectors(model);
-    const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen);
-    constexpr uint32_t ms_outputRowsIdx = arm::app::Wav2LetterModel::ms_outputRowsIdx;
-
-    /* Check to make sure that the input tensor supports the above context and inner lengths. */
-    if (inputRows <= 2 * inputCtxLen || inputRows <= inputInnerLen) {
-        printf_err("Input rows not compatible with ctx of %" PRIu32 "\n",
-                   inputCtxLen);
-        return 0;
+    if (!inputTensor->dims) {
+        printf_err("Invalid input tensor dims\n");
+        return false;
+    } else if (inputTensor->dims->size < 3) {
+        printf_err("Input tensor dimension should be >= 3\n");
+        return false;
     }
 
     TfLiteTensor* outputTensor = model.GetOutputTensor(0);
-    const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0);
+    if (!outputTensor->dims) {
+        printf_err("Invalid output tensor dims\n");
+        return false;
+    } else if (outputTensor->dims->size < 3) {
+        printf_err("Output tensor dimension should be >= 3\n");
+        return false;
+    }
 
-    const float tensorColRatio = static_cast<float>(inputRows)/
-                                 static_cast<float>(outputRows);
-
-    return std::round(static_cast<float>(inputCtxLen)/tensorColRatio);
-}
-
-static uint32_t GetOutputInnerLen(const arm::app::Model& model,
-                                  const uint32_t outputCtxLen)
-{
-    constexpr uint32_t ms_outputRowsIdx = arm::app::Wav2LetterModel::ms_outputRowsIdx;
-    TfLiteTensor* outputTensor = model.GetOutputTensor(0);
-    const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0);
-    return (outputRows - (2 * outputCtxLen));
+    return true;
 }
diff --git a/source/use_case/kws_asr/src/UseCaseHandler.cc b/source/use_case/kws_asr/src/UseCaseHandler.cc
index 1e1a400..01aefae 100644
--- a/source/use_case/kws_asr/src/UseCaseHandler.cc
+++ b/source/use_case/kws_asr/src/UseCaseHandler.cc
@@ -28,6 +28,7 @@
 #include "Wav2LetterMfcc.hpp"
 #include "Wav2LetterPreprocess.hpp"
 #include "Wav2LetterPostprocess.hpp"
+#include "KwsProcessing.hpp"
 #include "AsrResult.hpp"
 #include "AsrClassifier.hpp"
 #include "OutputDecode.hpp"
@@ -39,11 +40,6 @@
 namespace arm {
 namespace app {
 
-    enum AsrOutputReductionAxis {
-        AxisRow = 1,
-        AxisCol = 2
-    };
-
     struct KWSOutput {
         bool executionSuccess = false;
         const int16_t* asrAudioStart = nullptr;
@@ -51,73 +47,53 @@
     };
 
     /**
-     * @brief           Presents kws inference results using the data presentation
-     *                  object.
-     * @param[in]       results     vector of classification results to be displayed
-     * @return          true if successful, false otherwise
+     * @brief       Presents KWS inference results.
+     * @param[in]   results   Vector of KWS classification results to be displayed.
+     * @return      true if successful, false otherwise.
      **/
-    static bool PresentInferenceResult(std::vector<arm::app::kws::KwsResult>& results);
+    static bool PresentInferenceResult(std::vector<kws::KwsResult>& results);
 
     /**
-     * @brief           Presents asr inference results using the data presentation
-     *                  object.
-     * @param[in]       platform    reference to the hal platform object
-     * @param[in]       results     vector of classification results to be displayed
-     * @return          true if successful, false otherwise
+     * @brief       Presents ASR inference results.
+     * @param[in]   results   Vector of ASR classification results to be displayed.
+     * @return      true if successful, false otherwise.
      **/
-    static bool PresentInferenceResult(std::vector<arm::app::asr::AsrResult>& results);
+    static bool PresentInferenceResult(std::vector<asr::AsrResult>& results);
 
     /**
-     * @brief Returns a function to perform feature calculation and populates input tensor data with
-     * MFCC data.
-     *
-     * Input tensor data type check is performed to choose correct MFCC feature data type.
-     * If tensor has an integer data type then original features are quantised.
-     *
-     * Warning: mfcc calculator provided as input must have the same life scope as returned function.
-     *
-     * @param[in]           mfcc            MFCC feature calculator.
-     * @param[in,out]       inputTensor     Input tensor pointer to store calculated features.
-     * @param[in]           cacheSize       Size of the feature vectors cache (number of feature vectors).
-     *
-     * @return function     function to be called providing audio sample and sliding window index.
+     * @brief           Performs the KWS pipeline.
+     * @param[in,out]   ctx   pointer to the application context object
+     * @return          struct containing pointer to audio data where ASR should begin
+     *                  and how much data to process.
      **/
-    static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
-    GetFeatureCalculator(audio::MicroNetMFCC&  mfcc,
-                         TfLiteTensor*      inputTensor,
-                         size_t             cacheSize);
+    static KWSOutput doKws(ApplicationContext& ctx)
+    {
+        auto& profiler = ctx.Get<Profiler&>("profiler");
+        auto& kwsModel = ctx.Get<Model&>("kwsModel");
+        const auto kwsMfccFrameLength = ctx.Get<int>("kwsFrameLength");
+        const auto kwsMfccFrameStride = ctx.Get<int>("kwsFrameStride");
+        const auto kwsScoreThreshold = ctx.Get<float>("kwsScoreThreshold");
 
-    /**
-     * @brief Performs the KWS pipeline.
-     * @param[in,out]   ctx pointer to the application context object
-     *
-     * @return KWSOutput    struct containing pointer to audio data where ASR should begin
-     *                      and how much data to process.
-     */
-    static KWSOutput doKws(ApplicationContext& ctx) {
+        auto currentIndex = ctx.Get<uint32_t>("clipIndex");
+
         constexpr uint32_t dataPsnTxtInfStartX = 20;
         constexpr uint32_t dataPsnTxtInfStartY = 40;
 
         constexpr int minTensorDims = static_cast<int>(
-            (arm::app::MicroNetKwsModel::ms_inputRowsIdx > arm::app::MicroNetKwsModel::ms_inputColsIdx)?
-             arm::app::MicroNetKwsModel::ms_inputRowsIdx : arm::app::MicroNetKwsModel::ms_inputColsIdx);
+            (MicroNetKwsModel::ms_inputRowsIdx > MicroNetKwsModel::ms_inputColsIdx)?
+             MicroNetKwsModel::ms_inputRowsIdx : MicroNetKwsModel::ms_inputColsIdx);
 
-        KWSOutput output;
+        /* Output struct from doing KWS. */
+        KWSOutput output {};
 
-        auto& profiler = ctx.Get<Profiler&>("profiler");
-        auto& kwsModel = ctx.Get<Model&>("kwsmodel");
         if (!kwsModel.IsInited()) {
             printf_err("KWS model has not been initialised\n");
             return output;
         }
 
-        const int kwsFrameLength = ctx.Get<int>("kwsframeLength");
-        const int kwsFrameStride = ctx.Get<int>("kwsframeStride");
-        const float kwsScoreThreshold = ctx.Get<float>("kwsscoreThreshold");
-
-        TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0);
+        /* Get Input and Output tensors for pre/post processing. */
         TfLiteTensor* kwsInputTensor = kwsModel.GetInputTensor(0);
-
+        TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0);
         if (!kwsInputTensor->dims) {
             printf_err("Invalid input tensor dims\n");
             return output;
@@ -126,63 +102,32 @@
             return output;
         }
 
-        const uint32_t kwsNumMfccFeats = ctx.Get<uint32_t>("kwsNumMfcc");
-        const uint32_t kwsNumAudioWindows = ctx.Get<uint32_t>("kwsNumAudioWins");
-
-        audio::MicroNetMFCC kwsMfcc = audio::MicroNetMFCC(kwsNumMfccFeats, kwsFrameLength);
-        kwsMfcc.Init();
-
-        /* Deduce the data length required for 1 KWS inference from the network parameters. */
-        auto kwsAudioDataWindowSize = kwsNumAudioWindows * kwsFrameStride +
-                                        (kwsFrameLength - kwsFrameStride);
-        auto kwsMfccWindowSize = kwsFrameLength;
-        auto kwsMfccWindowStride = kwsFrameStride;
-
-        /* We are choosing to move by half the window size => for a 1 second window size,
-         * this means an overlap of 0.5 seconds. */
-        auto kwsAudioDataStride = kwsAudioDataWindowSize / 2;
-
-        info("KWS audio data window size %" PRIu32 "\n", kwsAudioDataWindowSize);
-
-        /* Stride must be multiple of mfcc features window stride to re-use features. */
-        if (0 != kwsAudioDataStride % kwsMfccWindowStride) {
-            kwsAudioDataStride -= kwsAudioDataStride % kwsMfccWindowStride;
-        }
-
-        auto kwsMfccVectorsInAudioStride = kwsAudioDataStride/kwsMfccWindowStride;
+        /* Get input shape for feature extraction. */
+        TfLiteIntArray* inputShape = kwsModel.GetInputShape(0);
+        const uint32_t numMfccFeatures = inputShape->data[MicroNetKwsModel::ms_inputColsIdx];
+        const uint32_t numMfccFrames = inputShape->data[MicroNetKwsModel::ms_inputRowsIdx];
 
         /* We expect to be sampling 1 second worth of data at a time
          * NOTE: This is only used for time stamp calculation. */
-        const float kwsAudioParamsSecondsPerSample = 1.0/audio::MicroNetMFCC::ms_defaultSamplingFreq;
+        const float kwsAudioParamsSecondsPerSample = 1.0 / audio::MicroNetKwsMFCC::ms_defaultSamplingFreq;
 
-        auto currentIndex = ctx.Get<uint32_t>("clipIndex");
+        /* Set up pre and post-processing. */
+        KwsPreProcess preProcess = KwsPreProcess(kwsInputTensor, numMfccFeatures, numMfccFrames,
+                                                 kwsMfccFrameLength, kwsMfccFrameStride);
 
-        /* Creating a mfcc features sliding window for the data required for 1 inference. */
-        auto kwsAudioMFCCWindowSlider = audio::SlidingWindow<const int16_t>(
-                get_audio_array(currentIndex),
-                kwsAudioDataWindowSize, kwsMfccWindowSize,
-                kwsMfccWindowStride);
+        std::vector<ClassificationResult> singleInfResult;
+        KwsPostProcess postProcess = KwsPostProcess(kwsOutputTensor, ctx.Get<KwsClassifier &>("kwsClassifier"),
+                                                    ctx.Get<std::vector<std::string>&>("kwsLabels"),
+                                                    singleInfResult);
 
         /* Creating a sliding window through the whole audio clip. */
         auto audioDataSlider = audio::SlidingWindow<const int16_t>(
                 get_audio_array(currentIndex),
                 get_audio_array_size(currentIndex),
-                kwsAudioDataWindowSize, kwsAudioDataStride);
+                preProcess.m_audioDataWindowSize, preProcess.m_audioDataStride);
 
-        /* Calculate number of the feature vectors in the window overlap region.
-         * These feature vectors will be reused.*/
-        size_t numberOfReusedFeatureVectors = kwsAudioMFCCWindowSlider.TotalStrides() + 1
-                                              - kwsMfccVectorsInAudioStride;
-
-        auto kwsMfccFeatureCalc = GetFeatureCalculator(kwsMfcc, kwsInputTensor,
-                                                       numberOfReusedFeatureVectors);
-
-        if (!kwsMfccFeatureCalc){
-            return output;
-        }
-
-        /* Container for KWS results. */
-        std::vector<arm::app::kws::KwsResult> kwsResults;
+        /* Declare a container to hold kws results from across the whole audio clip. */
+        std::vector<kws::KwsResult> finalResults;
 
         /* Display message on the LCD - inference running. */
         std::string str_inf{"Running KWS inference... "};
@@ -197,70 +142,56 @@
         while (audioDataSlider.HasNext()) {
             const int16_t* inferenceWindow = audioDataSlider.Next();
 
-            /* We moved to the next window - set the features sliding to the new address. */
-            kwsAudioMFCCWindowSlider.Reset(inferenceWindow);
-
             /* The first window does not have cache ready. */
-            bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0;
+            preProcess.m_audioWindowIndex = audioDataSlider.Index();
 
-            /* Start calculating features inside one audio sliding window. */
-            while (kwsAudioMFCCWindowSlider.HasNext()) {
-                const int16_t* kwsMfccWindow = kwsAudioMFCCWindowSlider.Next();
-                std::vector<int16_t> kwsMfccAudioData =
-                    std::vector<int16_t>(kwsMfccWindow, kwsMfccWindow + kwsMfccWindowSize);
+            /* Run the pre-processing, inference and post-processing. */
+            if (!preProcess.DoPreProcess(inferenceWindow, audio::MicroNetKwsMFCC::ms_defaultSamplingFreq)) {
+                printf_err("KWS Pre-processing failed.");
+                return output;
+            }
 
-                /* Compute features for this window and write them to input tensor. */
-                kwsMfccFeatureCalc(kwsMfccAudioData,
-                                   kwsAudioMFCCWindowSlider.Index(),
-                                   useCache,
-                                   kwsMfccVectorsInAudioStride);
+            if (!RunInference(kwsModel, profiler)) {
+                printf_err("KWS Inference failed.");
+                return output;
+            }
+
+            if (!postProcess.DoPostProcess()) {
+                printf_err("KWS Post-processing failed.");
+                return output;
             }
 
             info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
                  audioDataSlider.TotalStrides() + 1);
 
-            /* Run inference over this audio clip sliding window. */
-            if (!RunInference(kwsModel, profiler)) {
-                printf_err("KWS inference failed\n");
-                return output;
-            }
+            /* Add results from this window to our final results vector. */
+            finalResults.emplace_back(
+                    kws::KwsResult(singleInfResult,
+                            audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * preProcess.m_audioDataStride,
+                            audioDataSlider.Index(), kwsScoreThreshold));
 
-            std::vector<ClassificationResult> kwsClassificationResult;
-            auto& kwsClassifier = ctx.Get<KwsClassifier&>("kwsclassifier");
-
-            kwsClassifier.GetClassificationResults(
-                            kwsOutputTensor, kwsClassificationResult,
-                            ctx.Get<std::vector<std::string>&>("kwslabels"), 1, true);
-
-            kwsResults.emplace_back(
-                kws::KwsResult(
-                    kwsClassificationResult,
-                    audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * kwsAudioDataStride,
-                    audioDataSlider.Index(), kwsScoreThreshold)
-                );
-
-            /* Keyword detected. */
-            if (kwsClassificationResult[0].m_label == ctx.Get<const std::string&>("triggerkeyword")) {
-                output.asrAudioStart = inferenceWindow + kwsAudioDataWindowSize;
+            /* Break out when trigger keyword is detected. */
+            if (singleInfResult[0].m_label == ctx.Get<const std::string&>("triggerKeyword")
+                    && singleInfResult[0].m_normalisedVal > kwsScoreThreshold) {
+                output.asrAudioStart = inferenceWindow + preProcess.m_audioDataWindowSize;
                 output.asrAudioSamples = get_audio_array_size(currentIndex) -
                                         (audioDataSlider.NextWindowStartIndex() -
-                                        kwsAudioDataStride + kwsAudioDataWindowSize);
+                                        preProcess.m_audioDataStride + preProcess.m_audioDataWindowSize);
                 break;
             }
 
 #if VERIFY_TEST_OUTPUT
-            arm::app::DumpTensor(kwsOutputTensor);
+            DumpTensor(kwsOutputTensor);
 #endif /* VERIFY_TEST_OUTPUT */
 
         } /* while (audioDataSlider.HasNext()) */
 
         /* Erase. */
         str_inf = std::string(str_inf.size(), ' ');
-        hal_lcd_display_text(
-                            str_inf.c_str(), str_inf.size(),
-                            dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
+        hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
+                dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
 
-        if (!PresentInferenceResult(kwsResults)) {
+        if (!PresentInferenceResult(finalResults)) {
             return output;
         }
 
@@ -271,41 +202,41 @@
     }
 
     /**
-     * @brief Performs the ASR pipeline.
-     *
-     * @param[in,out] ctx   pointer to the application context object
-     * @param[in] kwsOutput struct containing pointer to audio data where ASR should begin
-     *                      and how much data to process
-     * @return bool         true if pipeline executed without failure
-     */
-    static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput) {
+     * @brief           Performs the ASR pipeline.
+     * @param[in,out]   ctx         Pointer to the application context object.
+     * @param[in]       kwsOutput   Struct containing pointer to audio data where ASR should begin
+     *                              and how much data to process.
+     * @return          true if pipeline executed without failure.
+     **/
+    static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput)
+    {
+        auto& asrModel = ctx.Get<Model&>("asrModel");
+        auto& profiler = ctx.Get<Profiler&>("profiler");
+        auto asrMfccFrameLen = ctx.Get<uint32_t>("asrFrameLength");
+        auto asrMfccFrameStride = ctx.Get<uint32_t>("asrFrameStride");
+        auto asrScoreThreshold = ctx.Get<float>("asrScoreThreshold");
+        auto asrInputCtxLen = ctx.Get<uint32_t>("ctxLen");
+
         constexpr uint32_t dataPsnTxtInfStartX = 20;
         constexpr uint32_t dataPsnTxtInfStartY = 40;
 
-        auto& profiler = ctx.Get<Profiler&>("profiler");
-        hal_lcd_clear(COLOR_BLACK);
-
-        /* Get model reference. */
-        auto& asrModel = ctx.Get<Model&>("asrmodel");
         if (!asrModel.IsInited()) {
             printf_err("ASR model has not been initialised\n");
             return false;
         }
 
-        /* Get score threshold to be applied for the classifier (post-inference). */
-        auto asrScoreThreshold = ctx.Get<float>("asrscoreThreshold");
+        hal_lcd_clear(COLOR_BLACK);
 
-        /* Dimensions of the tensor should have been verified by the callee. */
+        /* Get Input and Output tensors for pre/post processing. */
         TfLiteTensor* asrInputTensor = asrModel.GetInputTensor(0);
         TfLiteTensor* asrOutputTensor = asrModel.GetOutputTensor(0);
-        const uint32_t asrInputRows = asrInputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx];
 
-        /* Populate ASR MFCC related parameters. */
-        auto asrMfccParamsWinLen = ctx.Get<uint32_t>("asrframeLength");
-        auto asrMfccParamsWinStride = ctx.Get<uint32_t>("asrframeStride");
+        /* Get input shape. Dimensions of the tensor should have been verified by
+        * the callee. */
+        TfLiteIntArray* inputShape = asrModel.GetInputShape(0);
 
-        /* Populate ASR inference context and inner lengths for input. */
-        auto asrInputCtxLen = ctx.Get<uint32_t>("ctxLen");
+
+        const uint32_t asrInputRows = asrInputTensor->dims->data[Wav2LetterModel::ms_inputRowsIdx];
         const uint32_t asrInputInnerLen = asrInputRows - (2 * asrInputCtxLen);
 
         /* Make sure the input tensor supports the above context and inner lengths. */
@@ -316,18 +247,9 @@
         }
 
         /* Audio data stride corresponds to inputInnerLen feature vectors. */
-        const uint32_t asrAudioParamsWinLen = (asrInputRows - 1) *
-                                              asrMfccParamsWinStride + (asrMfccParamsWinLen);
-        const uint32_t asrAudioParamsWinStride = asrInputInnerLen * asrMfccParamsWinStride;
-        const float asrAudioParamsSecondsPerSample =
-                                        (1.0/audio::Wav2LetterMFCC::ms_defaultSamplingFreq);
-
-        /* Get pre/post-processing objects */
-        auto& asrPrep = ctx.Get<audio::asr::Preprocess&>("preprocess");
-        auto& asrPostp = ctx.Get<audio::asr::Postprocess&>("postprocess");
-
-        /* Set default reduction axis for post-processing. */
-        const uint32_t reductionAxis = arm::app::Wav2LetterModel::ms_outputRowsIdx;
+        const uint32_t asrAudioDataWindowLen = (asrInputRows - 1) * asrMfccFrameStride + (asrMfccFrameLen);
+        const uint32_t asrAudioDataWindowStride = asrInputInnerLen * asrMfccFrameStride;
+        const float asrAudioParamsSecondsPerSample = 1.0 / audio::Wav2LetterMFCC::ms_defaultSamplingFreq;
 
         /* Get the remaining audio buffer and respective size from KWS results. */
         const int16_t* audioArr = kwsOutput.asrAudioStart;
@@ -335,9 +257,9 @@
 
         /* Audio clip must have enough samples to produce 1 MFCC feature. */
         std::vector<int16_t> audioBuffer = std::vector<int16_t>(audioArr, audioArr + audioArrSize);
-        if (audioArrSize < asrMfccParamsWinLen) {
+        if (audioArrSize < asrMfccFrameLen) {
             printf_err("Not enough audio samples, minimum needed is %" PRIu32 "\n",
-                asrMfccParamsWinLen);
+                       asrMfccFrameLen);
             return false;
         }
 
@@ -345,26 +267,38 @@
         auto audioDataSlider = audio::FractionalSlidingWindow<const int16_t>(
                 audioBuffer.data(),
                 audioBuffer.size(),
-                asrAudioParamsWinLen,
-                asrAudioParamsWinStride);
+                asrAudioDataWindowLen,
+                asrAudioDataWindowStride);
 
         /* Declare a container for results. */
-        std::vector<arm::app::asr::AsrResult> asrResults;
+        std::vector<asr::AsrResult> asrResults;
 
         /* Display message on the LCD - inference running. */
         std::string str_inf{"Running ASR inference... "};
-        hal_lcd_display_text(
-                str_inf.c_str(), str_inf.size(),
+        hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
                 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
 
-        size_t asrInferenceWindowLen = asrAudioParamsWinLen;
+        size_t asrInferenceWindowLen = asrAudioDataWindowLen;
 
+        /* Set up pre and post-processing objects. */
+        AsrPreProcess asrPreProcess = AsrPreProcess(asrInputTensor, arm::app::Wav2LetterModel::ms_numMfccFeatures,
+                                              inputShape->data[Wav2LetterModel::ms_inputRowsIdx],
+                                              asrMfccFrameLen, asrMfccFrameStride);
+
+        std::vector<ClassificationResult> singleInfResult;
+        const uint32_t outputCtxLen = AsrPostProcess::GetOutputContextLen(asrModel, asrInputCtxLen);
+        AsrPostProcess asrPostProcess = AsrPostProcess(
+                asrOutputTensor, ctx.Get<AsrClassifier&>("asrClassifier"),
+                ctx.Get<std::vector<std::string>&>("asrLabels"),
+                singleInfResult, outputCtxLen,
+                Wav2LetterModel::ms_blankTokenIdx, Wav2LetterModel::ms_outputRowsIdx
+        );
         /* Start sliding through audio clip. */
         while (audioDataSlider.HasNext()) {
 
             /* If not enough audio see how much can be sent for processing. */
             size_t nextStartIndex = audioDataSlider.NextWindowStartIndex();
-            if (nextStartIndex + asrAudioParamsWinLen > audioBuffer.size()) {
+            if (nextStartIndex + asrAudioDataWindowLen > audioBuffer.size()) {
                 asrInferenceWindowLen = audioBuffer.size() - nextStartIndex;
             }
 
@@ -373,8 +307,11 @@
             info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
                 static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1)));
 
-            /* Calculate MFCCs, deltas and populate the input tensor. */
-            asrPrep.Invoke(asrInferenceWindow, asrInferenceWindowLen, asrInputTensor);
+            /* Run the pre-processing, inference and post-processing. */
+            if (!asrPreProcess.DoPreProcess(asrInferenceWindow, asrInferenceWindowLen)) {
+                printf_err("ASR pre-processing failed.");
+                return false;
+            }
 
             /* Run inference over this audio clip sliding window. */
             if (!RunInference(asrModel, profiler)) {
@@ -382,24 +319,28 @@
                 return false;
             }
 
-            /* Post-process. */
-            asrPostp.Invoke(asrOutputTensor, reductionAxis, !audioDataSlider.HasNext());
+            /* Post processing needs to know if we are on the last audio window. */
+            asrPostProcess.m_lastIteration = !audioDataSlider.HasNext();
+            if (!asrPostProcess.DoPostProcess()) {
+                printf_err("ASR post-processing failed.");
+                return false;
+            }
 
             /* Get results. */
             std::vector<ClassificationResult> asrClassificationResult;
-            auto& asrClassifier = ctx.Get<AsrClassifier&>("asrclassifier");
+            auto& asrClassifier = ctx.Get<AsrClassifier&>("asrClassifier");
             asrClassifier.GetClassificationResults(
                     asrOutputTensor, asrClassificationResult,
-                    ctx.Get<std::vector<std::string>&>("asrlabels"), 1);
+                    ctx.Get<std::vector<std::string>&>("asrLabels"), 1);
 
             asrResults.emplace_back(asr::AsrResult(asrClassificationResult,
                                                 (audioDataSlider.Index() *
                                                  asrAudioParamsSecondsPerSample *
-                                                 asrAudioParamsWinStride),
+                                                 asrAudioDataWindowStride),
                                                  audioDataSlider.Index(), asrScoreThreshold));
 
 #if VERIFY_TEST_OUTPUT
-            arm::app::DumpTensor(asrOutputTensor, asrOutputTensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]);
+            armDumpTensor(asrOutputTensor, asrOutputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]);
 #endif /* VERIFY_TEST_OUTPUT */
 
             /* Erase */
@@ -417,7 +358,7 @@
         return true;
     }
 
-    /* Audio inference classification handler. */
+    /* KWS and ASR inference handler. */
     bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
     {
         hal_lcd_clear(COLOR_BLACK);
@@ -434,13 +375,14 @@
         do {
             KWSOutput kwsOutput = doKws(ctx);
             if (!kwsOutput.executionSuccess) {
+                printf_err("KWS failed\n");
                 return false;
             }
 
             if (kwsOutput.asrAudioStart != nullptr && kwsOutput.asrAudioSamples > 0) {
-                info("Keyword spotted\n");
+                info("Trigger keyword spotted\n");
                 if(!doAsr(ctx, kwsOutput)) {
-                    printf_err("ASR failed");
+                    printf_err("ASR failed\n");
                     return false;
                 }
             }
@@ -452,7 +394,6 @@
         return true;
     }
 
-
     static bool PresentInferenceResult(std::vector<arm::app::kws::KwsResult>& results)
     {
         constexpr uint32_t dataPsnTxtStartX1 = 20;
@@ -464,33 +405,31 @@
         /* Display each result. */
         uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
 
-        for (uint32_t i = 0; i < results.size(); ++i) {
-
+        for (auto & result : results) {
             std::string topKeyword{"<none>"};
             float score = 0.f;
 
-            if (!results[i].m_resultVec.empty()) {
-                topKeyword = results[i].m_resultVec[0].m_label;
-                score = results[i].m_resultVec[0].m_normalisedVal;
+            if (!result.m_resultVec.empty()) {
+                topKeyword = result.m_resultVec[0].m_label;
+                score = result.m_resultVec[0].m_normalisedVal;
             }
 
             std::string resultStr =
-                    std::string{"@"} + std::to_string(results[i].m_timeStamp) +
+                    std::string{"@"} + std::to_string(result.m_timeStamp) +
                     std::string{"s: "} + topKeyword + std::string{" ("} +
                     std::to_string(static_cast<int>(score * 100)) + std::string{"%)"};
 
-            hal_lcd_display_text(
-                        resultStr.c_str(), resultStr.size(),
-                        dataPsnTxtStartX1, rowIdx1, 0);
+            hal_lcd_display_text(resultStr.c_str(), resultStr.size(),
+                    dataPsnTxtStartX1, rowIdx1, 0);
             rowIdx1 += dataPsnTxtYIncr;
 
             info("For timestamp: %f (inference #: %" PRIu32 "); threshold: %f\n",
-                 results[i].m_timeStamp, results[i].m_inferenceNumber,
-                 results[i].m_threshold);
-            for (uint32_t j = 0; j < results[i].m_resultVec.size(); ++j) {
+                 result.m_timeStamp, result.m_inferenceNumber,
+                 result.m_threshold);
+            for (uint32_t j = 0; j < result.m_resultVec.size(); ++j) {
                 info("\t\tlabel @ %" PRIu32 ": %s, score: %f\n", j,
-                     results[i].m_resultVec[j].m_label.c_str(),
-                     results[i].m_resultVec[j].m_normalisedVal);
+                     result.m_resultVec[j].m_label.c_str(),
+                     result.m_resultVec[j].m_normalisedVal);
             }
         }
 
@@ -523,143 +462,12 @@
 
         std::string finalResultStr = audio::asr::DecodeOutput(combinedResults);
 
-        hal_lcd_display_text(
-                    finalResultStr.c_str(), finalResultStr.size(),
-                    dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines);
+        hal_lcd_display_text(finalResultStr.c_str(), finalResultStr.size(),
+                dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines);
 
         info("Final result: %s\n", finalResultStr.c_str());
         return true;
     }
 
-    /**
-     * @brief Generic feature calculator factory.
-     *
-     * Returns lambda function to compute features using features cache.
-     * Real features math is done by a lambda function provided as a parameter.
-     * Features are written to input tensor memory.
-     *
-     * @tparam T            feature vector type.
-     * @param inputTensor   model input tensor pointer.
-     * @param cacheSize     number of feature vectors to cache. Defined by the sliding window overlap.
-     * @param compute       features calculator function.
-     * @return              lambda function to compute features.
-     **/
-    template<class T>
-    std::function<void (std::vector<int16_t>&, size_t, bool, size_t)>
-    FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
-                std::function<std::vector<T> (std::vector<int16_t>& )> compute)
-    {
-        /* Feature cache to be captured by lambda function. */
-        static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
-
-        return [=](std::vector<int16_t>& audioDataWindow,
-                   size_t index,
-                   bool useCache,
-                   size_t featuresOverlapIndex)
-        {
-            T* tensorData = tflite::GetTensorData<T>(inputTensor);
-            std::vector<T> features;
-
-            /* Reuse features from cache if cache is ready and sliding windows overlap.
-             * Overlap is in the beginning of sliding window with a size of a feature cache.
-             */
-            if (useCache && index < featureCache.size()) {
-                features = std::move(featureCache[index]);
-            } else {
-                features = std::move(compute(audioDataWindow));
-            }
-            auto size = features.size();
-            auto sizeBytes = sizeof(T) * size;
-            std::memcpy(tensorData + (index * size), features.data(), sizeBytes);
-
-            /* Start renewing cache as soon iteration goes out of the windows overlap. */
-            if (index >= featuresOverlapIndex) {
-                featureCache[index - featuresOverlapIndex] = std::move(features);
-            }
-        };
-    }
-
-    template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
-    FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
-                        size_t cacheSize,
-                        std::function<std::vector<int8_t> (std::vector<int16_t>& )> compute);
-
-    template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
-    FeatureCalc<uint8_t>(TfLiteTensor* inputTensor,
-                         size_t cacheSize,
-                         std::function<std::vector<uint8_t> (std::vector<int16_t>& )> compute);
-
-    template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
-    FeatureCalc<int16_t>(TfLiteTensor* inputTensor,
-                         size_t cacheSize,
-                         std::function<std::vector<int16_t> (std::vector<int16_t>& )> compute);
-
-    template std::function<void(std::vector<int16_t>&, size_t, bool, size_t)>
-    FeatureCalc<float>(TfLiteTensor* inputTensor,
-                       size_t cacheSize,
-                       std::function<std::vector<float>(std::vector<int16_t>&)> compute);
-
-
-    static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
-    GetFeatureCalculator(audio::MicroNetMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
-    {
-        std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc;
-
-        TfLiteQuantization quant = inputTensor->quantization;
-
-        if (kTfLiteAffineQuantization == quant.type) {
-
-            auto* quantParams = (TfLiteAffineQuantization*) quant.params;
-            const float quantScale = quantParams->scale->data[0];
-            const int quantOffset = quantParams->zero_point->data[0];
-
-            switch (inputTensor->type) {
-                case kTfLiteInt8: {
-                    mfccFeatureCalc = FeatureCalc<int8_t>(inputTensor,
-                                                          cacheSize,
-                                                          [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
-                                                              return mfcc.MfccComputeQuant<int8_t>(audioDataWindow,
-                                                                                                   quantScale,
-                                                                                                   quantOffset);
-                                                          }
-                    );
-                    break;
-                }
-                case kTfLiteUInt8: {
-                    mfccFeatureCalc = FeatureCalc<uint8_t>(inputTensor,
-                                                           cacheSize,
-                                                           [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
-                                                               return mfcc.MfccComputeQuant<uint8_t>(audioDataWindow,
-                                                                                                     quantScale,
-                                                                                                     quantOffset);
-                                                           }
-                    );
-                    break;
-                }
-                case kTfLiteInt16: {
-                    mfccFeatureCalc = FeatureCalc<int16_t>(inputTensor,
-                                                           cacheSize,
-                                                           [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
-                                                               return mfcc.MfccComputeQuant<int16_t>(audioDataWindow,
-                                                                                                     quantScale,
-                                                                                                     quantOffset);
-                                                           }
-                    );
-                    break;
-                }
-                default:
-                printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type));
-            }
-
-
-        } else {
-            mfccFeatureCalc = mfccFeatureCalc = FeatureCalc<float>(inputTensor,
-                                                                   cacheSize,
-                                                                   [&mfcc](std::vector<int16_t>& audioDataWindow) {
-                                                                       return mfcc.MfccCompute(audioDataWindow);
-                                                                   });
-        }
-        return mfccFeatureCalc;
-    }
 } /* namespace app */
 } /* namespace arm */
\ No newline at end of file
diff --git a/source/use_case/kws_asr/src/Wav2LetterPostprocess.cc b/source/use_case/kws_asr/src/Wav2LetterPostprocess.cc
index 2a76b1b..42f434e 100644
--- a/source/use_case/kws_asr/src/Wav2LetterPostprocess.cc
+++ b/source/use_case/kws_asr/src/Wav2LetterPostprocess.cc
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
  * SPDX-License-Identifier: Apache-2.0
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,62 +15,71 @@
  * limitations under the License.
  */
 #include "Wav2LetterPostprocess.hpp"
+
 #include "Wav2LetterModel.hpp"
 #include "log_macros.h"
 
+#include <cmath>
+
 namespace arm {
 namespace app {
-namespace audio {
-namespace asr {
 
-    Postprocess::Postprocess(const uint32_t contextLen,
-                             const uint32_t innerLen,
-                             const uint32_t blankTokenIdx)
-        :   m_contextLen(contextLen),
-            m_innerLen(innerLen),
-            m_totalLen(2 * this->m_contextLen + this->m_innerLen),
+    AsrPostProcess::AsrPostProcess(TfLiteTensor* outputTensor, AsrClassifier& classifier,
+            const std::vector<std::string>& labels, std::vector<ClassificationResult>& results,
+            const uint32_t outputContextLen,
+            const uint32_t blankTokenIdx, const uint32_t reductionAxisIdx
+            ):
+            m_classifier(classifier),
+            m_outputTensor(outputTensor),
+            m_labels{labels},
+            m_results(results),
+            m_outputContextLen(outputContextLen),
             m_countIterations(0),
-            m_blankTokenIdx(blankTokenIdx)
-    {}
+            m_blankTokenIdx(blankTokenIdx),
+            m_reductionAxisIdx(reductionAxisIdx)
+    {
+        this->m_outputInnerLen = AsrPostProcess::GetOutputInnerLen(this->m_outputTensor, this->m_outputContextLen);
+        this->m_totalLen = (2 * this->m_outputContextLen + this->m_outputInnerLen);
+    }
 
-    bool Postprocess::Invoke(TfLiteTensor*  tensor,
-                            const uint32_t  axisIdx,
-                            const bool      lastIteration)
+    bool AsrPostProcess::DoPostProcess()
     {
         /* Basic checks. */
-        if (!this->IsInputValid(tensor, axisIdx)) {
+        if (!this->IsInputValid(this->m_outputTensor, this->m_reductionAxisIdx)) {
             return false;
         }
 
         /* Irrespective of tensor type, we use unsigned "byte" */
-        uint8_t* ptrData = tflite::GetTensorData<uint8_t>(tensor);
-        const uint32_t elemSz = this->GetTensorElementSize(tensor);
+        auto* ptrData = tflite::GetTensorData<uint8_t>(this->m_outputTensor);
+        const uint32_t elemSz = AsrPostProcess::GetTensorElementSize(this->m_outputTensor);
 
         /* Other sanity checks. */
         if (0 == elemSz) {
             printf_err("Tensor type not supported for post processing\n");
             return false;
-        } else if (elemSz * this->m_totalLen > tensor->bytes) {
+        } else if (elemSz * this->m_totalLen > this->m_outputTensor->bytes) {
             printf_err("Insufficient number of tensor bytes\n");
             return false;
         }
 
         /* Which axis do we need to process? */
-        switch (axisIdx) {
-            case arm::app::Wav2LetterModel::ms_outputRowsIdx:
-                return this->EraseSectionsRowWise(ptrData,
-                                                  elemSz *
-                                                  tensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx],
-                                                  lastIteration);
+        switch (this->m_reductionAxisIdx) {
+            case Wav2LetterModel::ms_outputRowsIdx:
+                this->EraseSectionsRowWise(
+                        ptrData, elemSz * this->m_outputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx],
+                        this->m_lastIteration);
+                break;
             default:
-                printf_err("Unsupported axis index: %" PRIu32 "\n", axisIdx);
+                printf_err("Unsupported axis index: %" PRIu32 "\n", this->m_reductionAxisIdx);
+                return false;
         }
+        this->m_classifier.GetClassificationResults(this->m_outputTensor,
+                this->m_results, this->m_labels, 1);
 
-        return false;
+        return true;
     }
 
-    bool Postprocess::IsInputValid(TfLiteTensor*  tensor,
-                                   const uint32_t axisIdx) const
+    bool AsrPostProcess::IsInputValid(TfLiteTensor* tensor, const uint32_t axisIdx) const
     {
         if (nullptr == tensor) {
             return false;
@@ -84,25 +93,23 @@
 
         if (static_cast<int>(this->m_totalLen) !=
                              tensor->dims->data[axisIdx]) {
-            printf_err("Unexpected tensor dimension for axis %d, \n",
-                tensor->dims->data[axisIdx]);
+            printf_err("Unexpected tensor dimension for axis %d, got %d, \n",
+                axisIdx, tensor->dims->data[axisIdx]);
             return false;
         }
 
         return true;
     }
 
-    uint32_t Postprocess::GetTensorElementSize(TfLiteTensor*  tensor)
+    uint32_t AsrPostProcess::GetTensorElementSize(TfLiteTensor* tensor)
     {
         switch(tensor->type) {
             case kTfLiteUInt8:
-                return 1;
             case kTfLiteInt8:
                 return 1;
             case kTfLiteInt16:
                 return 2;
             case kTfLiteInt32:
-                return 4;
             case kTfLiteFloat32:
                 return 4;
             default:
@@ -113,30 +120,30 @@
         return 0;
     }
 
-    bool Postprocess::EraseSectionsRowWise(
-                        uint8_t*         ptrData,
-                        const uint32_t   strideSzBytes,
-                        const bool       lastIteration)
+    bool AsrPostProcess::EraseSectionsRowWise(
+            uint8_t*         ptrData,
+            const uint32_t   strideSzBytes,
+            const bool       lastIteration)
     {
         /* In this case, the "zero-ing" is quite simple as the region
          * to be zeroed sits in contiguous memory (row-major). */
-        const uint32_t eraseLen = strideSzBytes * this->m_contextLen;
+        const uint32_t eraseLen = strideSzBytes * this->m_outputContextLen;
 
         /* Erase left context? */
         if (this->m_countIterations > 0) {
             /* Set output of each classification window to the blank token. */
             std::memset(ptrData, 0, eraseLen);
-            for (size_t windowIdx = 0; windowIdx < this->m_contextLen; windowIdx++) {
+            for (size_t windowIdx = 0; windowIdx < this->m_outputContextLen; windowIdx++) {
                 ptrData[windowIdx*strideSzBytes + this->m_blankTokenIdx] = 1;
             }
         }
 
         /* Erase right context? */
         if (false == lastIteration) {
-            uint8_t * rightCtxPtr = ptrData + (strideSzBytes * (this->m_contextLen + this->m_innerLen));
+            uint8_t* rightCtxPtr = ptrData + (strideSzBytes * (this->m_outputContextLen + this->m_outputInnerLen));
             /* Set output of each classification window to the blank token. */
             std::memset(rightCtxPtr, 0, eraseLen);
-            for (size_t windowIdx = 0; windowIdx < this->m_contextLen; windowIdx++) {
+            for (size_t windowIdx = 0; windowIdx < this->m_outputContextLen; windowIdx++) {
                 rightCtxPtr[windowIdx*strideSzBytes + this->m_blankTokenIdx] = 1;
             }
         }
@@ -150,7 +157,58 @@
         return true;
     }
 
-} /* namespace asr */
-} /* namespace audio */
+    uint32_t AsrPostProcess::GetNumFeatureVectors(const Model& model)
+    {
+        TfLiteTensor* inputTensor = model.GetInputTensor(0);
+        const int inputRows = std::max(inputTensor->dims->data[Wav2LetterModel::ms_inputRowsIdx], 0);
+        if (inputRows == 0) {
+            printf_err("Error getting number of input rows for axis: %" PRIu32 "\n",
+                    Wav2LetterModel::ms_inputRowsIdx);
+        }
+        return inputRows;
+    }
+
+    uint32_t AsrPostProcess::GetOutputInnerLen(const TfLiteTensor* outputTensor, const uint32_t outputCtxLen)
+    {
+        const uint32_t outputRows = std::max(outputTensor->dims->data[Wav2LetterModel::ms_outputRowsIdx], 0);
+        if (outputRows == 0) {
+            printf_err("Error getting number of output rows for axis: %" PRIu32 "\n",
+                    Wav2LetterModel::ms_outputRowsIdx);
+        }
+
+        /* Watching for underflow. */
+        int innerLen = (outputRows - (2 * outputCtxLen));
+
+        return std::max(innerLen, 0);
+    }
+
+    uint32_t AsrPostProcess::GetOutputContextLen(const Model& model, const uint32_t inputCtxLen)
+    {
+        const uint32_t inputRows = AsrPostProcess::GetNumFeatureVectors(model);
+        const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen);
+        constexpr uint32_t ms_outputRowsIdx = Wav2LetterModel::ms_outputRowsIdx;
+
+        /* Check to make sure that the input tensor supports the above
+         * context and inner lengths. */
+        if (inputRows <= 2 * inputCtxLen || inputRows <= inputInnerLen) {
+            printf_err("Input rows not compatible with ctx of %" PRIu32 "\n",
+                       inputCtxLen);
+            return 0;
+        }
+
+        TfLiteTensor* outputTensor = model.GetOutputTensor(0);
+        const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0);
+        if (outputRows == 0) {
+            printf_err("Error getting number of output rows for axis: %" PRIu32 "\n",
+                       Wav2LetterModel::ms_outputRowsIdx);
+            return 0;
+        }
+
+        const float inOutRowRatio = static_cast<float>(inputRows) /
+                                     static_cast<float>(outputRows);
+
+        return std::round(static_cast<float>(inputCtxLen) / inOutRowRatio);
+    }
+
 } /* namespace app */
 } /* namespace arm */
\ No newline at end of file
diff --git a/source/use_case/kws_asr/src/Wav2LetterPreprocess.cc b/source/use_case/kws_asr/src/Wav2LetterPreprocess.cc
index d3f3579..92b0631 100644
--- a/source/use_case/kws_asr/src/Wav2LetterPreprocess.cc
+++ b/source/use_case/kws_asr/src/Wav2LetterPreprocess.cc
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
  * SPDX-License-Identifier: Apache-2.0
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
@@ -20,41 +20,35 @@
 #include "TensorFlowLiteMicro.hpp"
 
 #include <algorithm>
-#include <math.h>
+#include <cmath>
 
 namespace arm {
 namespace app {
-namespace audio {
-namespace asr {
 
-    Preprocess::Preprocess(
-        const uint32_t  numMfccFeatures,
-        const uint32_t  windowLen,
-        const uint32_t  windowStride,
-        const uint32_t  numMfccVectors):
-            m_mfcc(numMfccFeatures, windowLen),
-            m_mfccBuf(numMfccFeatures, numMfccVectors),
-            m_delta1Buf(numMfccFeatures, numMfccVectors),
-            m_delta2Buf(numMfccFeatures, numMfccVectors),
-            m_windowLen(windowLen),
-            m_windowStride(windowStride),
+    AsrPreProcess::AsrPreProcess(TfLiteTensor* inputTensor, const uint32_t numMfccFeatures,
+                                 const uint32_t numFeatureFrames, const uint32_t mfccWindowLen,
+                                 const uint32_t mfccWindowStride
+            ):
+            m_mfcc(numMfccFeatures, mfccWindowLen),
+            m_inputTensor(inputTensor),
+            m_mfccBuf(numMfccFeatures, numFeatureFrames),
+            m_delta1Buf(numMfccFeatures, numFeatureFrames),
+            m_delta2Buf(numMfccFeatures, numFeatureFrames),
+            m_mfccWindowLen(mfccWindowLen),
+            m_mfccWindowStride(mfccWindowStride),
             m_numMfccFeats(numMfccFeatures),
-            m_numFeatVectors(numMfccVectors),
-            m_window()
+            m_numFeatureFrames(numFeatureFrames)
     {
-        if (numMfccFeatures > 0 && windowLen > 0) {
+        if (numMfccFeatures > 0 && mfccWindowLen > 0) {
             this->m_mfcc.Init();
         }
     }
 
-    bool Preprocess::Invoke(
-                const int16_t*  audioData,
-                const uint32_t  audioDataLen,
-                TfLiteTensor*   tensor)
+    bool AsrPreProcess::DoPreProcess(const void* audioData, const size_t audioDataLen)
     {
-        this->m_window = SlidingWindow<const int16_t>(
-                            audioData, audioDataLen,
-                            this->m_windowLen, this->m_windowStride);
+        this->m_mfccSlidingWindow = audio::SlidingWindow<const int16_t>(
+                static_cast<const int16_t*>(audioData), audioDataLen,
+                this->m_mfccWindowLen, this->m_mfccWindowStride);
 
         uint32_t mfccBufIdx = 0;
 
@@ -62,12 +56,12 @@
         std::fill(m_delta1Buf.begin(), m_delta1Buf.end(), 0.f);
         std::fill(m_delta2Buf.begin(), m_delta2Buf.end(), 0.f);
 
-        /* While we can slide over the window. */
-        while (this->m_window.HasNext()) {
-            const int16_t*  mfccWindow = this->m_window.Next();
+        /* While we can slide over the audio. */
+        while (this->m_mfccSlidingWindow.HasNext()) {
+            const int16_t* mfccWindow = this->m_mfccSlidingWindow.Next();
             auto mfccAudioData = std::vector<int16_t>(
                                         mfccWindow,
-                                        mfccWindow + this->m_windowLen);
+                                        mfccWindow + this->m_mfccWindowLen);
             auto mfcc = this->m_mfcc.MfccCompute(mfccAudioData);
             for (size_t i = 0; i < this->m_mfccBuf.size(0); ++i) {
                 this->m_mfccBuf(i, mfccBufIdx) = mfcc[i];
@@ -76,11 +70,11 @@
         }
 
         /* Pad MFCC if needed by adding MFCC for zeros. */
-        if (mfccBufIdx != this->m_numFeatVectors) {
-            std::vector<int16_t> zerosWindow = std::vector<int16_t>(this->m_windowLen, 0);
+        if (mfccBufIdx != this->m_numFeatureFrames) {
+            std::vector<int16_t> zerosWindow = std::vector<int16_t>(this->m_mfccWindowLen, 0);
             std::vector<float> mfccZeros = this->m_mfcc.MfccCompute(zerosWindow);
 
-            while (mfccBufIdx != this->m_numFeatVectors) {
+            while (mfccBufIdx != this->m_numFeatureFrames) {
                 memcpy(&this->m_mfccBuf(0, mfccBufIdx),
                        mfccZeros.data(), sizeof(float) * m_numMfccFeats);
                 ++mfccBufIdx;
@@ -88,41 +82,39 @@
         }
 
         /* Compute first and second order deltas from MFCCs. */
-        this->ComputeDeltas(this->m_mfccBuf,
-                            this->m_delta1Buf,
-                            this->m_delta2Buf);
+        AsrPreProcess::ComputeDeltas(this->m_mfccBuf, this->m_delta1Buf, this->m_delta2Buf);
 
-        /* Normalise. */
-        this->Normalise();
+        /* Standardize calculated features. */
+        this->Standarize();
 
         /* Quantise. */
-        QuantParams quantParams = GetTensorQuantParams(tensor);
+        QuantParams quantParams = GetTensorQuantParams(this->m_inputTensor);
 
         if (0 == quantParams.scale) {
             printf_err("Quantisation scale can't be 0\n");
             return false;
         }
 
-        switch(tensor->type) {
+        switch(this->m_inputTensor->type) {
             case kTfLiteUInt8:
                 return this->Quantise<uint8_t>(
-                        tflite::GetTensorData<uint8_t>(tensor), tensor->bytes,
+                        tflite::GetTensorData<uint8_t>(this->m_inputTensor), this->m_inputTensor->bytes,
                         quantParams.scale, quantParams.offset);
             case kTfLiteInt8:
                 return this->Quantise<int8_t>(
-                        tflite::GetTensorData<int8_t>(tensor), tensor->bytes,
+                        tflite::GetTensorData<int8_t>(this->m_inputTensor), this->m_inputTensor->bytes,
                         quantParams.scale, quantParams.offset);
             default:
                 printf_err("Unsupported tensor type %s\n",
-                    TfLiteTypeGetName(tensor->type));
+                    TfLiteTypeGetName(this->m_inputTensor->type));
         }
 
         return false;
     }
 
-    bool Preprocess::ComputeDeltas(Array2d<float>& mfcc,
-                                   Array2d<float>& delta1,
-                                   Array2d<float>& delta2)
+    bool AsrPreProcess::ComputeDeltas(Array2d<float>& mfcc,
+                                      Array2d<float>& delta1,
+                                      Array2d<float>& delta2)
     {
         const std::vector <float> delta1Coeffs =
             {6.66666667e-02,  5.00000000e-02,  3.33333333e-02,
@@ -148,11 +140,11 @@
         /* Iterate through features in MFCC vector. */
         for (size_t i = 0; i < numFeatures; ++i) {
             /* For each feature, iterate through time (t) samples representing feature evolution and
-             * calculate d/dt and d^2/dt^2, using 1d convolution with differential kernels.
+             * calculate d/dt and d^2/dt^2, using 1D convolution with differential kernels.
              * Convolution padding = valid, result size is `time length - kernel length + 1`.
              * The result is padded with 0 from both sides to match the size of initial time samples data.
              *
-             * For the small filter, conv1d implementation as a simple loop is efficient enough.
+             * For the small filter, conv1D implementation as a simple loop is efficient enough.
              * Filters of a greater size would need CMSIS-DSP functions to be used, like arm_fir_f32.
              */
 
@@ -175,20 +167,10 @@
         return true;
     }
 
-    float Preprocess::GetMean(Array2d<float>& vec)
+    void AsrPreProcess::StandardizeVecF32(Array2d<float>& vec)
     {
-        return math::MathUtils::MeanF32(vec.begin(), vec.totalSize());
-    }
-
-    float Preprocess::GetStdDev(Array2d<float>& vec, const float mean)
-    {
-        return math::MathUtils::StdDevF32(vec.begin(), vec.totalSize(), mean);
-    }
-
-    void Preprocess::NormaliseVec(Array2d<float>& vec)
-    {
-        auto mean = Preprocess::GetMean(vec);
-        auto stddev = Preprocess::GetStdDev(vec, mean);
+        auto mean = math::MathUtils::MeanF32(vec.begin(), vec.totalSize());
+        auto stddev = math::MathUtils::StdDevF32(vec.begin(), vec.totalSize(), mean);
 
         debug("Mean: %f, Stddev: %f\n", mean, stddev);
         if (stddev == 0) {
@@ -204,14 +186,14 @@
         }
     }
 
-    void Preprocess::Normalise()
+    void AsrPreProcess::Standarize()
     {
-        Preprocess::NormaliseVec(this->m_mfccBuf);
-        Preprocess::NormaliseVec(this->m_delta1Buf);
-        Preprocess::NormaliseVec(this->m_delta2Buf);
+        AsrPreProcess::StandardizeVecF32(this->m_mfccBuf);
+        AsrPreProcess::StandardizeVecF32(this->m_delta1Buf);
+        AsrPreProcess::StandardizeVecF32(this->m_delta2Buf);
     }
 
-    float Preprocess::GetQuantElem(
+    float AsrPreProcess::GetQuantElem(
                 const float     elem,
                 const float     quantScale,
                 const int       quantOffset,
@@ -222,7 +204,5 @@
         return std::min<float>(std::max<float>(val, minVal), maxVal);
     }
 
-} /* namespace asr */
-} /* namespace audio */
 } /* namespace app */
 } /* namespace arm */
\ No newline at end of file
diff --git a/source/use_case/kws_asr/usecase.cmake b/source/use_case/kws_asr/usecase.cmake
index b3fe020..40df4d7 100644
--- a/source/use_case/kws_asr/usecase.cmake
+++ b/source/use_case/kws_asr/usecase.cmake
@@ -1,5 +1,5 @@
 #----------------------------------------------------------------------------
-#  Copyright (c) 2021 Arm Limited. All rights reserved.
+#  Copyright (c) 2021-2022 Arm Limited. All rights reserved.
 #  SPDX-License-Identifier: Apache-2.0
 #
 #  Licensed under the Apache License, Version 2.0 (the "License");
@@ -59,7 +59,7 @@
     STRING)
 
 USER_OPTION(${use_case}_MODEL_SCORE_THRESHOLD_KWS "Specify the score threshold [0.0, 1.0) that must be applied to the KWS results for a label to be deemed valid."
-    0.9
+    0.7
     STRING)
 
 USER_OPTION(${use_case}_MODEL_SCORE_THRESHOLD_ASR "Specify the score threshold [0.0, 1.0) that must be applied to the ASR results for a label to be deemed valid."
diff --git a/source/use_case/noise_reduction/include/RNNoiseProcess.hpp b/source/use_case/noise_reduction/include/RNNoiseFeatureProcessor.hpp
similarity index 97%
rename from source/use_case/noise_reduction/include/RNNoiseProcess.hpp
rename to source/use_case/noise_reduction/include/RNNoiseFeatureProcessor.hpp
index c188e42..cbf0e4e 100644
--- a/source/use_case/noise_reduction/include/RNNoiseProcess.hpp
+++ b/source/use_case/noise_reduction/include/RNNoiseFeatureProcessor.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
  * SPDX-License-Identifier: Apache-2.0
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,6 +14,9 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
+#ifndef RNNOISE_FEATURE_PROCESSOR_HPP
+#define RNNOISE_FEATURE_PROCESSOR_HPP
+
 #include "PlatformMath.hpp"
 #include <cstdint>
 #include <vector>
@@ -47,11 +50,11 @@
      *          - https://jmvalin.ca/demo/rnnoise/
      *          - https://arxiv.org/abs/1709.08243
      **/
-    class RNNoiseProcess {
+    class RNNoiseFeatureProcessor {
     /* Public interface */
     public:
-        RNNoiseProcess();
-        ~RNNoiseProcess() = default;
+        RNNoiseFeatureProcessor();
+        ~RNNoiseFeatureProcessor() = default;
 
         /**
          * @brief        Calculates the features from a given audio buffer ready to be sent to RNNoise model.
@@ -328,10 +331,11 @@
         const std::array <uint32_t, NB_BANDS> m_eband5ms {
             0,  1,  2,  3,  4,  5,  6,  7,  8, 10,  12,
             14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100};
-
     };
 
 
 } /* namespace rnn */
-} /* namspace app */
+} /* namespace app */
 } /* namespace arm */
+
+#endif /* RNNOISE_FEATURE_PROCESSOR_HPP */
diff --git a/source/use_case/noise_reduction/include/RNNoiseProcessing.hpp b/source/use_case/noise_reduction/include/RNNoiseProcessing.hpp
new file mode 100644
index 0000000..15e62d9
--- /dev/null
+++ b/source/use_case/noise_reduction/include/RNNoiseProcessing.hpp
@@ -0,0 +1,113 @@
+/*
+ * Copyright (c) 2022 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef RNNOISE_PROCESSING_HPP
+#define RNNOISE_PROCESSING_HPP
+
+#include "BaseProcessing.hpp"
+#include "Model.hpp"
+#include "RNNoiseFeatureProcessor.hpp"
+
+namespace arm {
+namespace app {
+
+    /**
+     * @brief   Pre-processing class for Noise Reduction use case.
+     *          Implements methods declared by BasePreProcess and anything else needed
+     *          to populate input tensors ready for inference.
+     */
+    class RNNoisePreProcess : public BasePreProcess {
+
+    public:
+        /**
+         * @brief           Constructor
+         * @param[in]       inputTensor        Pointer to the TFLite Micro input Tensor.
+         * @param[in/out]   featureProcessor   RNNoise specific feature extractor object.
+         * @param[in/out]   frameFeatures      RNNoise specific features shared between pre & post-processing.
+         *
+         **/
+        explicit RNNoisePreProcess(TfLiteTensor* inputTensor,
+                                   std::shared_ptr<rnn::RNNoiseFeatureProcessor> featureProcessor,
+                                   std::shared_ptr<rnn::FrameFeatures> frameFeatures);
+
+        /**
+         * @brief       Should perform pre-processing of 'raw' input audio data and load it into
+         *              TFLite Micro input tensors ready for inference
+         * @param[in]   input      Pointer to the data that pre-processing will work on.
+         * @param[in]   inputSize  Size of the input data.
+         * @return      true if successful, false otherwise.
+         **/
+        bool DoPreProcess(const void* input, size_t inputSize) override;
+
+    private:
+        TfLiteTensor* m_inputTensor;                        /* Model input tensor. */
+        std::shared_ptr<rnn::RNNoiseFeatureProcessor> m_featureProcessor;   /* RNNoise feature processor shared between pre & post-processing. */
+        std::shared_ptr<rnn::FrameFeatures> m_frameFeatures;                /* RNNoise features shared between pre & post-processing. */
+        rnn::vec1D32F m_audioFrame;                         /* Audio frame cast to FP32 */
+
+        /**
+         * @brief            Quantize the given features and populate the input Tensor.
+         * @param[in]        inputFeatures   Vector of floating point features to quantize.
+         * @param[in]        quantScale      Quantization scale for the inputTensor.
+         * @param[in]        quantOffset     Quantization offset for the inputTensor.
+         * @param[in,out]    inputTensor     TFLite micro tensor to populate.
+         **/
+        static void QuantizeAndPopulateInput(rnn::vec1D32F& inputFeatures,
+                float quantScale, int quantOffset,
+                TfLiteTensor* inputTensor);
+    };
+
+    /**
+     * @brief   Post-processing class for Noise Reduction use case.
+     *          Implements methods declared by BasePostProcess and anything else needed
+     *          to populate result vector.
+     */
+    class RNNoisePostProcess : public BasePostProcess {
+
+    public:
+        /**
+         * @brief           Constructor
+         * @param[in]       outputTensor         Pointer to the TFLite Micro output Tensor.
+         * @param[out]      denoisedAudioFrame   Vector to store the final denoised audio frame.
+         * @param[in/out]   featureProcessor     RNNoise specific feature extractor object.
+         * @param[in/out]   frameFeatures        RNNoise specific features shared between pre & post-processing.
+         **/
+        RNNoisePostProcess(TfLiteTensor* outputTensor,
+                           std::vector<int16_t>& denoisedAudioFrame,
+                           std::shared_ptr<rnn::RNNoiseFeatureProcessor> featureProcessor,
+                           std::shared_ptr<rnn::FrameFeatures> frameFeatures);
+
+        /**
+         * @brief       Should perform post-processing of the result of inference then
+         *              populate result data for any later use.
+         * @return      true if successful, false otherwise.
+         **/
+        bool DoPostProcess() override;
+
+    private:
+        TfLiteTensor* m_outputTensor;                       /* Model output tensor. */
+        std::vector<int16_t>& m_denoisedAudioFrame;         /* Vector to store the final denoised frame. */
+        rnn::vec1D32F m_denoisedAudioFrameFloat;            /* Internal vector to store the final denoised frame (FP32). */
+        std::shared_ptr<rnn::RNNoiseFeatureProcessor> m_featureProcessor;   /* RNNoise feature processor shared between pre & post-processing. */
+        std::shared_ptr<rnn::FrameFeatures> m_frameFeatures;                /* RNNoise features shared between pre & post-processing. */
+        std::vector<float> m_modelOutputFloat;              /* Internal vector to store de-quantized model output. */
+
+    };
+
+} /* namespace app */
+} /* namespace arm */
+
+#endif /* RNNOISE_PROCESSING_HPP */
\ No newline at end of file
diff --git a/source/use_case/noise_reduction/src/MainLoop.cc b/source/use_case/noise_reduction/src/MainLoop.cc
index 5fd7823..fd72127 100644
--- a/source/use_case/noise_reduction/src/MainLoop.cc
+++ b/source/use_case/noise_reduction/src/MainLoop.cc
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
  * SPDX-License-Identifier: Apache-2.0
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,12 +14,10 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-#include "hal.h"                    /* Brings in platform definitions. */
 #include "UseCaseHandler.hpp"       /* Handlers for different user options. */
 #include "UseCaseCommonUtils.hpp"   /* Utils functions. */
 #include "RNNoiseModel.hpp"         /* Model class for running inference. */
 #include "InputFiles.hpp"           /* For input audio clips. */
-#include "RNNoiseProcess.hpp"       /* Pre-processing class */
 #include "log_macros.h"
 
 enum opcodes
diff --git a/source/use_case/noise_reduction/src/RNNoiseProcess.cc b/source/use_case/noise_reduction/src/RNNoiseFeatureProcessor.cc
similarity index 92%
rename from source/use_case/noise_reduction/src/RNNoiseProcess.cc
rename to source/use_case/noise_reduction/src/RNNoiseFeatureProcessor.cc
index 4c568fa..036894c 100644
--- a/source/use_case/noise_reduction/src/RNNoiseProcess.cc
+++ b/source/use_case/noise_reduction/src/RNNoiseFeatureProcessor.cc
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
  * SPDX-License-Identifier: Apache-2.0
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,7 +14,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-#include "RNNoiseProcess.hpp"
+#include "RNNoiseFeatureProcessor.hpp"
 #include "log_macros.h"
 
 #include <algorithm>
@@ -33,7 +33,7 @@
     }                                               \
 } while(0)
 
-RNNoiseProcess::RNNoiseProcess() :
+RNNoiseFeatureProcessor::RNNoiseFeatureProcessor() :
         m_halfWindow(FRAME_SIZE, 0),
         m_dctTable(NB_BANDS * NB_BANDS),
         m_analysisMem(FRAME_SIZE, 0),
@@ -54,9 +54,9 @@
     this->InitTables();
 }
 
-void RNNoiseProcess::PreprocessFrame(const float*   audioData,
-                                     const size_t   audioLen,
-                                     FrameFeatures& features)
+void RNNoiseFeatureProcessor::PreprocessFrame(const float*   audioData,
+                                              const size_t   audioLen,
+                                              FrameFeatures& features)
 {
     /* Note audioWindow is modified in place */
     const arrHp aHp {-1.99599, 0.99600 };
@@ -68,7 +68,7 @@
     this->ComputeFrameFeatures(audioWindow, features);
 }
 
-void RNNoiseProcess::PostProcessFrame(vec1D32F& modelOutput, FrameFeatures& features, vec1D32F& outFrame)
+void RNNoiseFeatureProcessor::PostProcessFrame(vec1D32F& modelOutput, FrameFeatures& features, vec1D32F& outFrame)
 {
     std::vector<float> outputBands = modelOutput;
     std::vector<float> gain(FREQ_SIZE, 0);
@@ -92,7 +92,7 @@
     FrameSynthesis(outFrame, features.m_fftX);
 }
 
-void RNNoiseProcess::InitTables()
+void RNNoiseFeatureProcessor::InitTables()
 {
     constexpr float pi = M_PI;
     constexpr float halfPi = M_PI / 2;
@@ -111,7 +111,7 @@
     }
 }
 
-void RNNoiseProcess::BiQuad(
+void RNNoiseFeatureProcessor::BiQuad(
         const arrHp& bHp,
         const arrHp& aHp,
         arrHp& memHpX,
@@ -126,8 +126,8 @@
     }
 }
 
-void RNNoiseProcess::ComputeFrameFeatures(vec1D32F& audioWindow,
-                                          FrameFeatures& features)
+void RNNoiseFeatureProcessor::ComputeFrameFeatures(vec1D32F& audioWindow,
+                                                   FrameFeatures& features)
 {
     this->FrameAnalysis(audioWindow,
                         features.m_fftX,
@@ -264,7 +264,7 @@
     features.m_featuresVec[NB_BANDS + 3 * NB_DELTA_CEPS + 1] = specVariability / CEPS_MEM - 2.1;
 }
 
-void RNNoiseProcess::FrameAnalysis(
+void RNNoiseFeatureProcessor::FrameAnalysis(
     const vec1D32F& audioWindow,
     vec1D32F& fft,
     vec1D32F& energy,
@@ -289,7 +289,7 @@
     ComputeBandEnergy(fft, energy);
 }
 
-void RNNoiseProcess::ApplyWindow(vec1D32F& x)
+void RNNoiseFeatureProcessor::ApplyWindow(vec1D32F& x)
 {
     if (WINDOW_SIZE != x.size()) {
         printf_err("Invalid size for vector to be windowed\n");
@@ -305,7 +305,7 @@
     }
 }
 
-void RNNoiseProcess::ForwardTransform(
+void RNNoiseFeatureProcessor::ForwardTransform(
     vec1D32F& x,
     vec1D32F& fft)
 {
@@ -327,7 +327,7 @@
      * first half of the FFT's. The conjugates are not present. */
 }
 
-void RNNoiseProcess::ComputeBandEnergy(const vec1D32F& fftX, vec1D32F& bandE)
+void RNNoiseFeatureProcessor::ComputeBandEnergy(const vec1D32F& fftX, vec1D32F& bandE)
 {
     bandE = vec1D32F(NB_BANDS, 0);
 
@@ -351,7 +351,7 @@
     bandE[NB_BANDS - 1] *= 2;
 }
 
-void RNNoiseProcess::ComputeBandCorr(const vec1D32F& X, const vec1D32F& P, vec1D32F& bandC)
+void RNNoiseFeatureProcessor::ComputeBandCorr(const vec1D32F& X, const vec1D32F& P, vec1D32F& bandC)
 {
     bandC = vec1D32F(NB_BANDS, 0);
     VERIFY(this->m_eband5ms.size() >= NB_BANDS);
@@ -374,7 +374,7 @@
     bandC[NB_BANDS - 1] *= 2;
 }
 
-void RNNoiseProcess::DCT(vec1D32F& input, vec1D32F& output)
+void RNNoiseFeatureProcessor::DCT(vec1D32F& input, vec1D32F& output)
 {
     VERIFY(this->m_dctTable.size() >= NB_BANDS * NB_BANDS);
     for (uint32_t i = 0; i < NB_BANDS; ++i) {
@@ -387,7 +387,7 @@
     }
 }
 
-void RNNoiseProcess::PitchDownsample(vec1D32F& pitchBuf, size_t pitchBufSz) {
+void RNNoiseFeatureProcessor::PitchDownsample(vec1D32F& pitchBuf, size_t pitchBufSz) {
     for (size_t i = 1; i < (pitchBufSz >> 1); ++i) {
         pitchBuf[i] = 0.5 * (
                         0.5 * (this->m_pitchBuf[2 * i - 1] + this->m_pitchBuf[2 * i + 1])
@@ -431,7 +431,7 @@
     this->Fir5(lpc2, pitchBufSz >> 1, pitchBuf);
 }
 
-int RNNoiseProcess::PitchSearch(vec1D32F& xLp, vec1D32F& y, uint32_t len, uint32_t maxPitch) {
+int RNNoiseFeatureProcessor::PitchSearch(vec1D32F& xLp, vec1D32F& y, uint32_t len, uint32_t maxPitch) {
     uint32_t lag = len + maxPitch;
     vec1D32F xLp4(len >> 2, 0);
     vec1D32F yLp4(lag >> 2, 0);
@@ -488,7 +488,7 @@
     return 2*bestPitch[0] - offset;
 }
 
-arrHp RNNoiseProcess::FindBestPitch(vec1D32F& xCorr, vec1D32F& y, uint32_t len, uint32_t maxPitch)
+arrHp RNNoiseFeatureProcessor::FindBestPitch(vec1D32F& xCorr, vec1D32F& y, uint32_t len, uint32_t maxPitch)
 {
     float Syy = 1;
     arrHp bestNum {-1, -1};
@@ -527,7 +527,7 @@
     return bestPitch;
 }
 
-int RNNoiseProcess::RemoveDoubling(
+int RNNoiseFeatureProcessor::RemoveDoubling(
     vec1D32F& pitchBuf,
     uint32_t maxPeriod,
     uint32_t minPeriod,
@@ -679,12 +679,12 @@
     return this->m_lastPeriod;
 }
 
-float RNNoiseProcess::ComputePitchGain(float xy, float xx, float yy)
+float RNNoiseFeatureProcessor::ComputePitchGain(float xy, float xx, float yy)
 {
     return xy / math::MathUtils::SqrtF32(1+xx*yy);
 }
 
-void RNNoiseProcess::AutoCorr(
+void RNNoiseFeatureProcessor::AutoCorr(
     const vec1D32F& x,
     vec1D32F& ac,
     size_t lag,
@@ -711,7 +711,7 @@
 }
 
 
-void RNNoiseProcess::PitchXCorr(
+void RNNoiseFeatureProcessor::PitchXCorr(
     const vec1D32F& x,
     const vec1D32F& y,
     vec1D32F& xCorr,
@@ -728,7 +728,7 @@
 }
 
 /* Linear predictor coefficients */
-void RNNoiseProcess::LPC(
+void RNNoiseFeatureProcessor::LPC(
     const vec1D32F& correlation,
     int32_t p,
     vec1D32F& lpc)
@@ -766,7 +766,7 @@
     }
 }
 
-void RNNoiseProcess::Fir5(
+void RNNoiseFeatureProcessor::Fir5(
     const vec1D32F &num,
     uint32_t N,
     vec1D32F &x)
@@ -794,7 +794,7 @@
     }
 }
 
-void RNNoiseProcess::PitchFilter(FrameFeatures &features, vec1D32F &gain) {
+void RNNoiseFeatureProcessor::PitchFilter(FrameFeatures &features, vec1D32F &gain) {
     std::vector<float> r(NB_BANDS, 0);
     std::vector<float> rf(FREQ_SIZE, 0);
     std::vector<float> newE(NB_BANDS);
@@ -835,7 +835,7 @@
     }
 }
 
-void RNNoiseProcess::FrameSynthesis(vec1D32F& outFrame, vec1D32F& fftY) {
+void RNNoiseFeatureProcessor::FrameSynthesis(vec1D32F& outFrame, vec1D32F& fftY) {
     std::vector<float> x(WINDOW_SIZE, 0);
     InverseTransform(x, fftY);
     ApplyWindow(x);
@@ -845,7 +845,7 @@
     memcpy((m_synthesisMem.data()), &x[FRAME_SIZE], FRAME_SIZE*sizeof(float));
 }
 
-void RNNoiseProcess::InterpBandGain(vec1D32F& g, vec1D32F& bandE) {
+void RNNoiseFeatureProcessor::InterpBandGain(vec1D32F& g, vec1D32F& bandE) {
     for (size_t i = 0; i < NB_BANDS - 1; i++) {
         int bandSize = (m_eband5ms[i + 1] - m_eband5ms[i]) << FRAME_SIZE_SHIFT;
         for (int j = 0; j < bandSize; j++) {
@@ -855,7 +855,7 @@
     }
 }
 
-void RNNoiseProcess::InverseTransform(vec1D32F& out, vec1D32F& fftXIn) {
+void RNNoiseFeatureProcessor::InverseTransform(vec1D32F& out, vec1D32F& fftXIn) {
 
     std::vector<float> x(WINDOW_SIZE * 2);  /* This is complex. */
     vec1D32F newFFT;  /* This is complex. */
diff --git a/source/use_case/noise_reduction/src/RNNoiseProcessing.cc b/source/use_case/noise_reduction/src/RNNoiseProcessing.cc
new file mode 100644
index 0000000..f6a3ec4
--- /dev/null
+++ b/source/use_case/noise_reduction/src/RNNoiseProcessing.cc
@@ -0,0 +1,100 @@
+/*
+ * Copyright (c) 2022 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "RNNoiseProcessing.hpp"
+#include "log_macros.h"
+
+namespace arm {
+namespace app {
+
+    RNNoisePreProcess::RNNoisePreProcess(TfLiteTensor* inputTensor,
+            std::shared_ptr<rnn::RNNoiseFeatureProcessor> featureProcessor, std::shared_ptr<rnn::FrameFeatures> frameFeatures)
+    :   m_inputTensor{inputTensor},
+        m_featureProcessor{featureProcessor},
+        m_frameFeatures{frameFeatures}
+    {}
+
+    bool RNNoisePreProcess::DoPreProcess(const void* data, size_t inputSize)
+    {
+        if (data == nullptr) {
+            printf_err("Data pointer is null");
+            return false;
+        }
+
+        auto input = static_cast<const int16_t*>(data);
+        this->m_audioFrame = rnn::vec1D32F(input, input + inputSize);
+        m_featureProcessor->PreprocessFrame(this->m_audioFrame.data(), inputSize, *this->m_frameFeatures);
+
+        QuantizeAndPopulateInput(this->m_frameFeatures->m_featuresVec,
+                this->m_inputTensor->params.scale, this->m_inputTensor->params.zero_point,
+                this->m_inputTensor);
+
+        debug("Input tensor populated \n");
+
+        return true;
+    }
+
+    void RNNoisePreProcess::QuantizeAndPopulateInput(rnn::vec1D32F& inputFeatures,
+            const float quantScale, const int quantOffset,
+            TfLiteTensor* inputTensor)
+    {
+        const float minVal = std::numeric_limits<int8_t>::min();
+        const float maxVal = std::numeric_limits<int8_t>::max();
+
+        auto* inputTensorData = tflite::GetTensorData<int8_t>(inputTensor);
+
+        for (size_t i=0; i < inputFeatures.size(); ++i) {
+            float quantValue = ((inputFeatures[i] / quantScale) + quantOffset);
+            inputTensorData[i] = static_cast<int8_t>(std::min<float>(std::max<float>(quantValue, minVal), maxVal));
+        }
+    }
+
+    RNNoisePostProcess::RNNoisePostProcess(TfLiteTensor* outputTensor,
+            std::vector<int16_t>& denoisedAudioFrame,
+            std::shared_ptr<rnn::RNNoiseFeatureProcessor> featureProcessor,
+            std::shared_ptr<rnn::FrameFeatures> frameFeatures)
+    :   m_outputTensor{outputTensor},
+        m_denoisedAudioFrame{denoisedAudioFrame},
+        m_featureProcessor{featureProcessor},
+        m_frameFeatures{frameFeatures}
+        {
+            this->m_denoisedAudioFrameFloat.reserve(denoisedAudioFrame.size());
+            this->m_modelOutputFloat.resize(outputTensor->bytes);
+        }
+
+    bool RNNoisePostProcess::DoPostProcess()
+    {
+        const auto* outputData = tflite::GetTensorData<int8_t>(this->m_outputTensor);
+        auto outputQuantParams = GetTensorQuantParams(this->m_outputTensor);
+
+        for (size_t i = 0; i < this->m_outputTensor->bytes; ++i) {
+            this->m_modelOutputFloat[i] = (static_cast<float>(outputData[i]) - outputQuantParams.offset)
+                                  * outputQuantParams.scale;
+        }
+
+        this->m_featureProcessor->PostProcessFrame(this->m_modelOutputFloat,
+                *this->m_frameFeatures, this->m_denoisedAudioFrameFloat);
+
+        for (size_t i = 0; i < this->m_denoisedAudioFrame.size(); ++i) {
+            this->m_denoisedAudioFrame[i] = static_cast<int16_t>(
+                    std::roundf(this->m_denoisedAudioFrameFloat[i]));
+        }
+
+        return true;
+    }
+
+} /* namespace app */
+} /* namespace arm */
\ No newline at end of file
diff --git a/source/use_case/noise_reduction/src/UseCaseHandler.cc b/source/use_case/noise_reduction/src/UseCaseHandler.cc
index acb8ba7..53bb43e 100644
--- a/source/use_case/noise_reduction/src/UseCaseHandler.cc
+++ b/source/use_case/noise_reduction/src/UseCaseHandler.cc
@@ -21,12 +21,10 @@
 #include "ImageUtils.hpp"
 #include "InputFiles.hpp"
 #include "RNNoiseModel.hpp"
-#include "RNNoiseProcess.hpp"
+#include "RNNoiseFeatureProcessor.hpp"
+#include "RNNoiseProcessing.hpp"
 #include "log_macros.h"
 
-#include <cmath>
-#include <algorithm>
-
 namespace arm {
 namespace app {
 
@@ -36,17 +34,6 @@
     **/
     static void IncrementAppCtxClipIdx(ApplicationContext& ctx);
 
-    /**
-    * @brief            Quantize the given features and populate the input Tensor.
-    * @param[in]        inputFeatures   Vector of floating point features to quantize.
-    * @param[in]        quantScale      Quantization scale for the inputTensor.
-    * @param[in]        quantOffset     Quantization offset for the inputTensor.
-    * @param[in,out]    inputTensor     TFLite micro tensor to populate.
-    **/
-    static void QuantizeAndPopulateInput(rnn::vec1D32F& inputFeatures,
-                                         float quantScale, int quantOffset,
-                                         TfLiteTensor* inputTensor);
-
     /* Noise reduction inference handler. */
     bool NoiseReductionHandler(ApplicationContext& ctx, bool runAll)
     {
@@ -57,7 +44,7 @@
         size_t memDumpMaxLen = 0;
         uint8_t* memDumpBaseAddr = nullptr;
         size_t undefMemDumpBytesWritten = 0;
-        size_t *pMemDumpBytesWritten = &undefMemDumpBytesWritten;
+        size_t* pMemDumpBytesWritten = &undefMemDumpBytesWritten;
         if (ctx.Has("MEM_DUMP_LEN") && ctx.Has("MEM_DUMP_BASE_ADDR") && ctx.Has("MEM_DUMP_BYTE_WRITTEN")) {
             memDumpMaxLen = ctx.Get<size_t>("MEM_DUMP_LEN");
             memDumpBaseAddr = ctx.Get<uint8_t*>("MEM_DUMP_BASE_ADDR");
@@ -74,8 +61,8 @@
         }
 
         /* Populate Pre-Processing related parameters. */
-        auto audioParamsWinLen = ctx.Get<uint32_t>("frameLength");
-        auto audioParamsWinStride = ctx.Get<uint32_t>("frameStride");
+        auto audioFrameLen = ctx.Get<uint32_t>("frameLength");
+        auto audioFrameStride = ctx.Get<uint32_t>("frameStride");
         auto nrNumInputFeatures = ctx.Get<uint32_t>("numInputFeatures");
 
         TfLiteTensor* inputTensor = model.GetInputTensor(0);
@@ -103,7 +90,7 @@
         if (ctx.Has("featureFileNames")) {
             audioFileAccessorFunc = ctx.Get<std::function<const char*(const uint32_t)>>("featureFileNames");
         }
-        do{
+        do {
             hal_lcd_clear(COLOR_BLACK);
 
             auto startDumpAddress = memDumpBaseAddr + memDumpBytesWritten;
@@ -112,32 +99,38 @@
             /* Creating a sliding window through the audio. */
             auto audioDataSlider = audio::SlidingWindow<const int16_t>(
                     audioAccessorFunc(currentIndex),
-                    audioSizeAccessorFunc(currentIndex), audioParamsWinLen,
-                    audioParamsWinStride);
+                    audioSizeAccessorFunc(currentIndex), audioFrameLen,
+                    audioFrameStride);
 
             info("Running inference on input feature map %" PRIu32 " => %s\n", currentIndex,
                  audioFileAccessorFunc(currentIndex));
 
             memDumpBytesWritten += DumpDenoisedAudioHeader(audioFileAccessorFunc(currentIndex),
-                 (audioDataSlider.TotalStrides() + 1) * audioParamsWinLen,
+                                                           (audioDataSlider.TotalStrides() + 1) * audioFrameLen,
                  memDumpBaseAddr + memDumpBytesWritten,
                  memDumpMaxLen - memDumpBytesWritten);
 
-            rnn::RNNoiseProcess featureProcessor = rnn::RNNoiseProcess();
-            rnn::vec1D32F audioFrame(audioParamsWinLen);
-            rnn::vec1D32F inputFeatures(nrNumInputFeatures);
-            rnn::vec1D32F denoisedAudioFrameFloat(audioParamsWinLen);
-            std::vector<int16_t> denoisedAudioFrame(audioParamsWinLen);
+            /* Set up pre and post-processing. */
+            std::shared_ptr<rnn::RNNoiseFeatureProcessor> featureProcessor =
+                    std::make_shared<rnn::RNNoiseFeatureProcessor>();
+            std::shared_ptr<rnn::FrameFeatures> frameFeatures =
+                    std::make_shared<rnn::FrameFeatures>();
 
-            std::vector<float> modelOutputFloat(outputTensor->bytes);
-            rnn::FrameFeatures frameFeatures;
+            RNNoisePreProcess preProcess = RNNoisePreProcess(inputTensor, featureProcessor, frameFeatures);
+
+            std::vector<int16_t> denoisedAudioFrame(audioFrameLen);
+            RNNoisePostProcess postProcess = RNNoisePostProcess(outputTensor, denoisedAudioFrame,
+                    featureProcessor, frameFeatures);
+
             bool resetGRU = true;
 
             while (audioDataSlider.HasNext()) {
                 const int16_t* inferenceWindow = audioDataSlider.Next();
-                audioFrame = rnn::vec1D32F(inferenceWindow, inferenceWindow+audioParamsWinLen);
 
-                featureProcessor.PreprocessFrame(audioFrame.data(), audioParamsWinLen, frameFeatures);
+                if (!preProcess.DoPreProcess(inferenceWindow, audioFrameLen)) {
+                    printf_err("Pre-processing failed.");
+                    return false;
+                }
 
                 /* Reset or copy over GRU states first to avoid TFLu memory overlap issues. */
                 if (resetGRU){
@@ -148,53 +141,35 @@
                     model.CopyGruStates();
                 }
 
-                QuantizeAndPopulateInput(frameFeatures.m_featuresVec,
-                        inputTensor->params.scale, inputTensor->params.zero_point,
-                        inputTensor);
-
                 /* Strings for presentation/logging. */
                 std::string str_inf{"Running inference... "};
 
                 /* Display message on the LCD - inference running. */
-                hal_lcd_display_text(
-                            str_inf.c_str(), str_inf.size(),
-                            dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
+                hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
+                        dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
 
                 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, audioDataSlider.TotalStrides() + 1);
 
                 /* Run inference over this feature sliding window. */
-                profiler.StartProfiling("Inference");
-                bool success = model.RunInference();
-                profiler.StopProfiling();
-                resetGRU = false;
-
-                if (!success) {
+                if (!RunInference(model, profiler)) {
+                    printf_err("Inference failed.");
                     return false;
                 }
+                resetGRU = false;
 
-                /* De-quantize main model output ready for post-processing. */
-                const auto* outputData = tflite::GetTensorData<int8_t>(outputTensor);
-                auto outputQuantParams = arm::app::GetTensorQuantParams(outputTensor);
-
-                for (size_t i = 0; i < outputTensor->bytes; ++i) {
-                    modelOutputFloat[i] = (static_cast<float>(outputData[i]) - outputQuantParams.offset)
-                            * outputQuantParams.scale;
-                }
-
-                /* Round and cast the post-processed results for dumping to wav. */
-                featureProcessor.PostProcessFrame(modelOutputFloat, frameFeatures, denoisedAudioFrameFloat);
-                for (size_t i = 0; i < audioParamsWinLen; ++i) {
-                    denoisedAudioFrame[i] = static_cast<int16_t>(std::roundf(denoisedAudioFrameFloat[i]));
+                /* Carry out post-processing. */
+                if (!postProcess.DoPostProcess()) {
+                    printf_err("Post-processing failed.");
+                    return false;
                 }
 
                 /* Erase. */
                 str_inf = std::string(str_inf.size(), ' ');
-                hal_lcd_display_text(
-                                str_inf.c_str(), str_inf.size(),
-                                dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
+                hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
+                        dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
 
                 if (memDumpMaxLen > 0) {
-                    /* Dump output tensors to memory. */
+                    /* Dump final post processed output to memory. */
                     memDumpBytesWritten += DumpOutputDenoisedAudioFrame(
                             denoisedAudioFrame,
                             memDumpBaseAddr + memDumpBytesWritten,
@@ -209,6 +184,7 @@
                      valMemDumpBytesWritten, startDumpAddress);
             }
 
+            /* Finish by dumping the footer. */
             DumpDenoisedAudioFooter(memDumpBaseAddr + memDumpBytesWritten, memDumpMaxLen - memDumpBytesWritten);
 
             info("All inferences for audio clip complete.\n");
@@ -216,15 +192,13 @@
             IncrementAppCtxClipIdx(ctx);
 
             std::string clearString{' '};
-            hal_lcd_display_text(
-                    clearString.c_str(), clearString.size(),
+            hal_lcd_display_text(clearString.c_str(), clearString.size(),
                     dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
 
             std::string completeMsg{"Inference complete!"};
 
             /* Display message on the LCD - inference complete. */
-            hal_lcd_display_text(
-                    completeMsg.c_str(), completeMsg.size(),
+            hal_lcd_display_text(completeMsg.c_str(), completeMsg.size(),
                     dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
 
         } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
@@ -233,7 +207,7 @@
     }
 
     size_t DumpDenoisedAudioHeader(const char* filename, size_t dumpSize,
-                                   uint8_t *memAddress, size_t memSize){
+                                   uint8_t* memAddress, size_t memSize){
 
         if (memAddress == nullptr){
             return 0;
@@ -284,7 +258,7 @@
         return numBytesWritten;
     }
 
-    size_t DumpDenoisedAudioFooter(uint8_t *memAddress, size_t memSize){
+    size_t DumpDenoisedAudioFooter(uint8_t* memAddress, size_t memSize){
         if ((memAddress == nullptr) || (memSize < 4)) {
             return 0;
         }
@@ -294,8 +268,8 @@
         return sizeof(int32_t);
      }
 
-    size_t DumpOutputDenoisedAudioFrame(const std::vector<int16_t> &audioFrame,
-                                        uint8_t *memAddress, size_t memSize)
+    size_t DumpOutputDenoisedAudioFrame(const std::vector<int16_t>& audioFrame,
+                                        uint8_t* memAddress, size_t memSize)
     {
         if (memAddress == nullptr) {
             return 0;
@@ -324,7 +298,7 @@
             const TfLiteTensor* tensor = model.GetOutputTensor(i);
             const auto* tData = tflite::GetTensorData<uint8_t>(tensor);
 #if VERIFY_TEST_OUTPUT
-            arm::app::DumpTensor(tensor);
+            DumpTensor(tensor);
 #endif /* VERIFY_TEST_OUTPUT */
             /* Ensure that we don't overflow the allowed limit. */
             if (numBytesWritten + tensor->bytes <= memSize) {
@@ -360,20 +334,5 @@
         ctx.Set<uint32_t>("clipIndex", curClipIdx);
     }
 
-    void QuantizeAndPopulateInput(rnn::vec1D32F& inputFeatures,
-            const float quantScale, const int quantOffset, TfLiteTensor* inputTensor)
-    {
-        const float minVal = std::numeric_limits<int8_t>::min();
-        const float maxVal = std::numeric_limits<int8_t>::max();
-
-        auto* inputTensorData = tflite::GetTensorData<int8_t>(inputTensor);
-
-        for (size_t i=0; i < inputFeatures.size(); ++i) {
-            float quantValue = ((inputFeatures[i] / quantScale) + quantOffset);
-            inputTensorData[i] = static_cast<int8_t>(std::min<float>(std::max<float>(quantValue, minVal), maxVal));
-        }
-    }
-
-
 } /* namespace app */
 } /* namespace arm */
diff --git a/tests/use_case/ad/PostProcessTests.cc b/tests/use_case/ad/PostProcessTests.cc
deleted file mode 100644
index 62fa9e7..0000000
--- a/tests/use_case/ad/PostProcessTests.cc
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
- * SPDX-License-Identifier: Apache-2.0
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "AdPostProcessing.hpp"
-#include <catch.hpp>
-
-TEST_CASE("Softmax_vector") {
-
-    std::vector<float> testVec = {1, 2, 3, 4, 1, 2, 3};
-    arm::app::Softmax(testVec);
-    CHECK((testVec[0] - 0.024) == Approx(0.0).margin(0.001));
-    CHECK((testVec[1] - 0.064) == Approx(0.0).margin(0.001));
-    CHECK((testVec[2] - 0.175) == Approx(0.0).margin(0.001));
-    CHECK((testVec[3] - 0.475) == Approx(0.0).margin(0.001));
-    CHECK((testVec[4] - 0.024) == Approx(0.0).margin(0.001));
-    CHECK((testVec[5] - 0.064) == Approx(0.0).margin(0.001));
-    CHECK((testVec[6] - 0.175) == Approx(0.0).margin(0.001));
-}
-
-TEST_CASE("Output machine index") {
-
-    auto index = arm::app::OutputIndexFromFileName("test_id_00.wav");
-    CHECK(index == 0);
-
-    auto index1 = arm::app::OutputIndexFromFileName("test_id_02.wav");
-    CHECK(index1 == 1);
-
-    auto index2 = arm::app::OutputIndexFromFileName("test_id_4.wav");
-    CHECK(index2 == 2);
-
-    auto index3 = arm::app::OutputIndexFromFileName("test_id_6.wav");
-    CHECK(index3 == 3);
-
-    auto index4 = arm::app::OutputIndexFromFileName("test_id_id_00.wav");
-    CHECK(index4 == -1);
-
-    auto index5 = arm::app::OutputIndexFromFileName("test_id_7.wav");
-    CHECK(index5 == -1);
-}
\ No newline at end of file
diff --git a/tests/use_case/kws_asr/MfccTests.cc b/tests/use_case/kws_asr/MfccTests.cc
index 3ebdcf4..883c215 100644
--- a/tests/use_case/kws_asr/MfccTests.cc
+++ b/tests/use_case/kws_asr/MfccTests.cc
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
  * SPDX-License-Identifier: Apache-2.0
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
@@ -93,13 +93,13 @@
     -22.67135, -0.61615, 2.07233, 0.58137, 1.01655, 0.85816, 0.46039, 0.03393, 1.16511, 0.0072,
 };
 
-arm::app::audio::MicroNetMFCC GetMFCCInstance() {
-    const int sampFreq = arm::app::audio::MicroNetMFCC::ms_defaultSamplingFreq;
+arm::app::audio::MicroNetKwsMFCC GetMFCCInstance() {
+    const int sampFreq = arm::app::audio::MicroNetKwsMFCC::ms_defaultSamplingFreq;
     const int frameLenMs = 40;
     const int frameLenSamples = sampFreq * frameLenMs * 0.001;
     const int numMfccFeats = 10;
 
-   return arm::app::audio::MicroNetMFCC(numMfccFeats, frameLenSamples);
+   return arm::app::audio::MicroNetKwsMFCC(numMfccFeats, frameLenSamples);
 }
 
 template <class T>
diff --git a/tests/use_case/kws_asr/Wav2LetterPostprocessingTest.cc b/tests/use_case/kws_asr/Wav2LetterPostprocessingTest.cc
index 6fd7df3..e343b66 100644
--- a/tests/use_case/kws_asr/Wav2LetterPostprocessingTest.cc
+++ b/tests/use_case/kws_asr/Wav2LetterPostprocessingTest.cc
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
  * SPDX-License-Identifier: Apache-2.0
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,15 +16,17 @@
  */
 #include "Wav2LetterPostprocess.hpp"
 #include "Wav2LetterModel.hpp"
+#include "ClassificationResult.hpp"
 
 #include <algorithm>
 #include <catch.hpp>
 #include <limits>
 
 template <typename T>
-static TfLiteTensor GetTestTensor(std::vector <int>& shape,
-                                  T                  initVal,
-                                  std::vector<T>&    vectorBuf)
+static TfLiteTensor GetTestTensor(
+        std::vector<int>&      shape,
+        T                      initVal,
+        std::vector<T>&        vectorBuf)
 {
     REQUIRE(0 != shape.size());
 
@@ -38,91 +40,112 @@
     vectorBuf = std::vector<T>(sizeInBytes, initVal);
     TfLiteIntArray* dims = tflite::testing::IntArrayFromInts(shape.data());
     return tflite::testing::CreateQuantizedTensor(
-                                vectorBuf.data(), dims,
-                                1, 0, "test-tensor");
+            vectorBuf.data(), dims,
+            1, 0, "test-tensor");
 }
 
 TEST_CASE("Checking return value")
 {
     SECTION("Mismatched post processing parameters and tensor size")
     {
-        const uint32_t ctxLen = 5;
-        const uint32_t innerLen = 3;
-        arm::app::audio::asr::Postprocess post{ctxLen, innerLen, 0};
-
+        const uint32_t outputCtxLen = 5;
+        arm::app::AsrClassifier classifier;
+        arm::app::Wav2LetterModel model;
+        model.Init();
+        std::vector<std::string> dummyLabels = {"a", "b", "$"};
+        const uint32_t blankTokenIdx = 2;
+        std::vector<arm::app::ClassificationResult> dummyResult;
         std::vector <int> tensorShape = {1, 1, 1, 13};
         std::vector <int8_t> tensorVec;
         TfLiteTensor tensor = GetTestTensor<int8_t>(
-                                tensorShape, 100, tensorVec);
-        REQUIRE(false == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false));
+                tensorShape, 100, tensorVec);
+
+        arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen,
+                                      blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx};
+
+        REQUIRE(!post.DoPostProcess());
     }
 
     SECTION("Post processing succeeds")
     {
-        const uint32_t ctxLen = 5;
-        const uint32_t innerLen = 3;
-        arm::app::audio::asr::Postprocess post{ctxLen, innerLen, 0};
-
-        std::vector <int> tensorShape = {1, 1, 13, 1};
-        std::vector <int8_t> tensorVec;
+        const uint32_t outputCtxLen = 5;
+        arm::app::AsrClassifier classifier;
+        arm::app::Wav2LetterModel model;
+        model.Init();
+        std::vector<std::string> dummyLabels = {"a", "b", "$"};
+        const uint32_t blankTokenIdx = 2;
+        std::vector<arm::app::ClassificationResult> dummyResult;
+        std::vector<int> tensorShape = {1, 1, 13, 1};
+        std::vector<int8_t> tensorVec;
         TfLiteTensor tensor = GetTestTensor<int8_t>(
-                                tensorShape, 100, tensorVec);
+                tensorShape, 100, tensorVec);
+
+        arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen,
+                                      blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx};
 
         /* Copy elements to compare later. */
-        std::vector <int8_t> originalVec = tensorVec;
+        std::vector<int8_t> originalVec = tensorVec;
 
         /* This step should not erase anything. */
-        REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false));
+        REQUIRE(post.DoPostProcess());
     }
 }
 
+
 TEST_CASE("Postprocessing - erasing required elements")
 {
-    constexpr uint32_t ctxLen = 5;
+    constexpr uint32_t outputCtxLen = 5;
     constexpr uint32_t innerLen = 3;
-    constexpr uint32_t nRows = 2*ctxLen + innerLen;
+    constexpr uint32_t nRows = 2*outputCtxLen + innerLen;
     constexpr uint32_t nCols = 10;
     constexpr uint32_t blankTokenIdx = nCols - 1;
-    std::vector <int> tensorShape = {1, 1, nRows, nCols};
+    std::vector<int> tensorShape = {1, 1, nRows, nCols};
+    arm::app::AsrClassifier classifier;
+    arm::app::Wav2LetterModel model;
+    model.Init();
+    std::vector<std::string> dummyLabels = {"a", "b", "$"};
+    std::vector<arm::app::ClassificationResult> dummyResult;
 
     SECTION("First and last iteration")
     {
-        arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx};
-        std::vector <int8_t> tensorVec;
-        TfLiteTensor tensor = GetTestTensor<int8_t>(
-                                tensorShape, 100, tensorVec);
+        std::vector<int8_t> tensorVec;
+        TfLiteTensor tensor = GetTestTensor<int8_t>(tensorShape, 100, tensorVec);
+        arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen,
+                                      blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx};
 
         /* Copy elements to compare later. */
-        std::vector <int8_t> originalVec = tensorVec;
+        std::vector<int8_t>originalVec = tensorVec;
 
         /* This step should not erase anything. */
-        REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, true));
+        post.m_lastIteration = true;
+        REQUIRE(post.DoPostProcess());
         REQUIRE(originalVec == tensorVec);
     }
 
     SECTION("Right context erase")
     {
-        arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx};
-
         std::vector <int8_t> tensorVec;
         TfLiteTensor tensor = GetTestTensor<int8_t>(
-                                tensorShape, 100, tensorVec);
+                tensorShape, 100, tensorVec);
+        arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen,
+                                      blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx};
 
         /* Copy elements to compare later. */
-        std::vector <int8_t> originalVec = tensorVec;
+        std::vector<int8_t> originalVec = tensorVec;
 
         /* This step should erase the right context only. */
-        REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false));
+        post.m_lastIteration = false;
+        REQUIRE(post.DoPostProcess());
         REQUIRE(originalVec != tensorVec);
 
         /* The last ctxLen * 10 elements should be gone. */
-        for (size_t i = 0; i < ctxLen; ++i) {
+        for (size_t i = 0; i < outputCtxLen; ++i) {
             for (size_t j = 0; j < nCols; ++j) {
-                /* Check right context elements are zeroed. */
+                /* Check right context elements are zeroed. Blank token idx should be set to 1 when erasing. */
                 if (j == blankTokenIdx) {
-                    CHECK(tensorVec[(ctxLen + innerLen) * nCols + i*nCols + j] == 1);
+                    CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 1);
                 } else {
-                    CHECK(tensorVec[(ctxLen + innerLen) * nCols + i*nCols + j] == 0);
+                    CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 0);
                 }
 
                 /* Check left context is preserved. */
@@ -131,45 +154,47 @@
         }
 
         /* Check inner elements are preserved. */
-        for (size_t i = ctxLen * nCols; i < (ctxLen + innerLen) * nCols; ++i) {
+        for (size_t i = outputCtxLen * nCols; i < (outputCtxLen + innerLen) * nCols; ++i) {
             CHECK(tensorVec[i] == originalVec[i]);
         }
     }
 
     SECTION("Left and right context erase")
     {
-        arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx};
-
         std::vector <int8_t> tensorVec;
-        TfLiteTensor tensor = GetTestTensor<int8_t>(tensorShape, 100, tensorVec);
+        TfLiteTensor tensor = GetTestTensor<int8_t>(
+                tensorShape, 100, tensorVec);
+        arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen,
+                                      blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx};
 
         /* Copy elements to compare later. */
         std::vector <int8_t> originalVec = tensorVec;
 
         /* This step should erase right context. */
-        REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false));
+        post.m_lastIteration = false;
+        REQUIRE(post.DoPostProcess());
 
         /* Calling it the second time should erase the left context. */
-        REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false));
+        REQUIRE(post.DoPostProcess());
 
         REQUIRE(originalVec != tensorVec);
 
         /* The first and last ctxLen * 10 elements should be gone. */
-        for (size_t i = 0; i < ctxLen; ++i) {
+        for (size_t i = 0; i < outputCtxLen; ++i) {
             for (size_t j = 0; j < nCols; ++j) {
                 /* Check left and right context elements are zeroed. */
                 if (j == blankTokenIdx) {
-                    CHECK(tensorVec[(ctxLen + innerLen) * nCols + i * nCols + j] == 1);
-                    CHECK(tensorVec[i * nCols + j] == 1);
+                    CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 1);
+                    CHECK(tensorVec[i*nCols + j] == 1);
                 } else {
-                    CHECK(tensorVec[(ctxLen + innerLen) * nCols + i * nCols + j] == 0);
-                    CHECK(tensorVec[i * nCols + j] == 0);
+                    CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 0);
+                    CHECK(tensorVec[i*nCols + j] == 0);
                 }
             }
         }
 
         /* Check inner elements are preserved. */
-        for (size_t i = ctxLen * nCols; i < (ctxLen + innerLen) * nCols; ++i) {
+        for (size_t i = outputCtxLen * nCols; i < (outputCtxLen + innerLen) * nCols; ++i) {
             /* Check left context is preserved. */
             CHECK(tensorVec[i] == originalVec[i]);
         }
@@ -177,18 +202,21 @@
 
     SECTION("Try left context erase")
     {
-        /* Should not be able to erase the left context if it is the first iteration. */
-        arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx};
-
         std::vector <int8_t> tensorVec;
         TfLiteTensor tensor = GetTestTensor<int8_t>(
-                                tensorShape, 100, tensorVec);
+                tensorShape, 100, tensorVec);
+
+        /* Should not be able to erase the left context if it is the first iteration. */
+        arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen,
+                                      blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx};
 
         /* Copy elements to compare later. */
         std::vector <int8_t> originalVec = tensorVec;
 
         /* Calling it the second time should erase the left context. */
-        REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, true));
+        post.m_lastIteration = true;
+        REQUIRE(post.DoPostProcess());
+
         REQUIRE(originalVec == tensorVec);
     }
-}
\ No newline at end of file
+}
diff --git a/tests/use_case/kws_asr/Wav2LetterPreprocessingTest.cc b/tests/use_case/kws_asr/Wav2LetterPreprocessingTest.cc
index 26ddb24..372152d 100644
--- a/tests/use_case/kws_asr/Wav2LetterPreprocessingTest.cc
+++ b/tests/use_case/kws_asr/Wav2LetterPreprocessingTest.cc
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
  * SPDX-License-Identifier: Apache-2.0
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,64 +16,54 @@
  */
 #include "Wav2LetterPreprocess.hpp"
 
-#include <algorithm>
-#include <catch.hpp>
 #include <limits>
+#include <catch.hpp>
 
 constexpr uint32_t numMfccFeatures = 13;
 constexpr uint32_t numMfccVectors  = 10;
 
 /* Test vector output: generated using test-asr-preprocessing.py. */
-int8_t expectedResult[numMfccVectors][numMfccFeatures*3] = {
-    /* Feature vec 0. */
-    -32,   4,  -9,  -8, -10, -10, -11, -11, -11, -11, -12, -11, -11,    /* MFCCs.   */
-    -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,    /* Delta 1. */
-    -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,    /* Delta 2. */
-
-    /* Feature vec 1. */
-    -31,   4,  -9,  -8, -10, -10, -11, -11, -11, -11, -12, -11, -11,
-    -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
-    -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
-
-    /* Feature vec 2. */
-    -31,   4,  -9,  -9, -10, -10, -11, -11, -11, -11, -12, -12, -12,
-    -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
-    -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
-
-    /* Feature vec 3. */
-    -31,   4,  -9,  -9, -10, -10, -11, -11, -11, -11, -11, -12, -12,
-    -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
-    -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
-
-    /* Feature vec 4 : this should have valid delta 1 and delta 2. */
-    -31,   4,  -9,  -9, -10, -10, -11, -11, -11, -11, -11, -12, -12,
-    -38, -29,  -9,   1,  -2,  -7,  -8,  -8, -12, -16, -14,  -5,   5,
-    -68, -50, -13,   5,   0,  -9,  -9,  -8, -13, -20, -19,  -3,  15,
-
-    /* Feature vec 5 : this should have valid delta 1 and delta 2. */
-    -31,   4,  -9,  -8, -10, -10, -11, -11, -11, -11, -11, -12, -12,
-    -62, -45, -11,   5,   0,  -8,  -9,  -8, -12, -19, -17,  -3,  13,
-    -27, -22, -13,  -9, -11, -12, -12, -11, -11, -13, -13, -10,  -6,
-
-    /* Feature vec 6. */
-    -31,   4,  -9,  -8, -10, -10, -11, -11, -11, -11, -12, -11, -11,
-    -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
-    -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
-
-    /* Feature vec 7. */
-    -32,   4,  -9,  -8, -10, -10, -11, -11, -11, -12, -12, -11, -11,
-    -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
-    -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
-
-    /* Feature vec 8. */
-    -32,   4,  -9,  -8, -10, -10, -11, -11, -11, -12, -12, -11, -11,
-    -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
-    -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
-
-    /* Feature vec 9. */
-    -31,   4,  -9,  -8, -10, -10, -11, -11, -11, -11, -12, -11, -11,
-    -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
-    -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10
+int8_t expectedResult[numMfccVectors][numMfccFeatures * 3] = {
+        /* Feature vec 0. */
+        {-32,   4,  -9,  -8, -10, -10, -11, -11, -11, -11, -12, -11, -11,    /* MFCCs.   */
+                -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,     /* Delta 1. */
+                -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10},    /* Delta 2. */
+        /* Feature vec 1. */
+        {-31,   4,  -9,  -8, -10, -10, -11, -11, -11, -11, -12, -11, -11,
+                -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
+                -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10},
+        /* Feature vec 2. */
+        {-31,   4,  -9,  -9, -10, -10, -11, -11, -11, -11, -12, -12, -12,
+                -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
+                -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10},
+        /* Feature vec 3. */
+        {-31,   4,  -9,  -9, -10, -10, -11, -11, -11, -11, -11, -12, -12,
+                -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
+                -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10},
+        /* Feature vec 4 : this should have valid delta 1 and delta 2. */
+        {-31,   4,  -9,  -9, -10, -10, -11, -11, -11, -11, -11, -12, -12,
+                -38, -29,  -9,   1,  -2,  -7,  -8,  -8, -12, -16, -14,  -5,   5,
+                -68, -50, -13,   5,   0,  -9,  -9,  -8, -13, -20, -19,  -3,  15},
+        /* Feature vec 5 : this should have valid delta 1 and delta 2. */
+        {-31,   4,  -9,  -8, -10, -10, -11, -11, -11, -11, -11, -12, -12,
+                -62, -45, -11,   5,   0,  -8,  -9,  -8, -12, -19, -17,  -3,  13,
+                -27, -22, -13,  -9, -11, -12, -12, -11, -11, -13, -13, -10,  -6},
+        /* Feature vec 6. */
+        {-31,   4,  -9,  -8, -10, -10, -11, -11, -11, -11, -12, -11, -11,
+                -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
+                -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10},
+        /* Feature vec 7. */
+        {-32,   4,  -9,  -8, -10, -10, -11, -11, -11, -12, -12, -11, -11,
+                -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
+                -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10},
+        /* Feature vec 8. */
+        {-32,   4,  -9,  -8, -10, -10, -11, -11, -11, -12, -12, -11, -11,
+                -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
+                -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10},
+        /* Feature vec 9. */
+        {-31,   4,  -9,  -8, -10, -10, -11, -11, -11, -11, -12, -11, -11,
+                -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
+                -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}
 };
 
 void PopulateTestWavVector(std::vector<int16_t>& vec)
@@ -97,17 +87,17 @@
 
 TEST_CASE("Preprocessing calculation INT8")
 {
-
     /* Constants. */
-    const uint32_t  windowLen       = 512;
-    const uint32_t  windowStride    = 160;
-    int             dimArray[]      = {3, 1, numMfccFeatures * 3, numMfccVectors};
-    const float     quantScale      = 0.1410219967365265;
-    const int       quantOffset     = -11;
+    const uint32_t  mfccWindowLen      = 512;
+    const uint32_t  mfccWindowStride   = 160;
+    int             dimArray[]         = {3, 1, numMfccFeatures * 3, numMfccVectors};
+    const float     quantScale         = 0.1410219967365265;
+    const int       quantOffset        = -11;
 
     /* Test wav memory. */
-    std::vector <int16_t> testWav((windowStride * numMfccVectors) +
-                                  (windowLen - windowStride));
+    std::vector<int16_t> testWav((mfccWindowStride * numMfccVectors) +
+                                 (mfccWindowLen - mfccWindowStride)
+    );
 
     /* Populate with dummy input. */
     PopulateTestWavVector(testWav);
@@ -117,20 +107,20 @@
 
     /* Initialise dimensions and the test tensor. */
     TfLiteIntArray* dims= tflite::testing::IntArrayFromInts(dimArray);
-    TfLiteTensor tensor = tflite::testing::CreateQuantizedTensor(
-        tensorVec.data(), dims, quantScale, quantOffset, "preprocessedInput");
+    TfLiteTensor inputTensor = tflite::testing::CreateQuantizedTensor(
+            tensorVec.data(), dims, quantScale, quantOffset, "preprocessedInput");
 
     /* Initialise pre-processing module. */
-    arm::app::audio::asr::Preprocess prep{
-        numMfccFeatures, windowLen, windowStride, numMfccVectors};
+    arm::app::AsrPreProcess prep{&inputTensor,
+                                 numMfccFeatures, numMfccVectors, mfccWindowLen, mfccWindowStride};
 
     /* Invoke pre-processing. */
-    REQUIRE(prep.Invoke(testWav.data(), testWav.size(), &tensor));
+    REQUIRE(prep.DoPreProcess(testWav.data(), testWav.size()));
 
     /* Wrap the tensor with a std::vector for ease. */
-    int8_t * tensorData = tflite::GetTensorData<int8_t>(&tensor);
+    auto* tensorData = tflite::GetTensorData<int8_t>(&inputTensor);
     std::vector <int8_t> vecResults =
-        std::vector<int8_t>(tensorData, tensorData + tensor.bytes);
+            std::vector<int8_t>(tensorData, tensorData + inputTensor.bytes);
 
     /* Check sizes. */
     REQUIRE(vecResults.size() == sizeof(expectedResult));
diff --git a/tests/use_case/noise_reduction/RNNoiseProcessingTests.cpp b/tests/use_case/noise_reduction/RNNoiseProcessingTests.cpp
index e28a6da..ca5aab1 100644
--- a/tests/use_case/noise_reduction/RNNoiseProcessingTests.cpp
+++ b/tests/use_case/noise_reduction/RNNoiseProcessingTests.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
  * SPDX-License-Identifier: Apache-2.0
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,7 +14,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-#include "RNNoiseProcess.hpp"
+#include "RNNoiseFeatureProcessor.hpp"
 #include <catch.hpp>
 #include <limits>
 
@@ -208,7 +208,7 @@
 {
     SECTION("FP32")
     {
-        arm::app::rnn::RNNoiseProcess rnnoiseProcessor;
+        arm::app::rnn::RNNoiseFeatureProcessor rnnoiseProcessor;
         arm::app::rnn::FrameFeatures features;
 
         rnnoiseProcessor.PreprocessFrame(testWav0.data(), testWav0.size(), features);
@@ -223,7 +223,7 @@
 
 TEST_CASE("RNNoise postprocessing test", "[RNNoise]")
 {
-    arm::app::rnn::RNNoiseProcess rnnoiseProcessor;
+    arm::app::rnn::RNNoiseFeatureProcessor rnnoiseProcessor;
     arm::app::rnn::FrameFeatures p;
     rnnoiseProcessor.PreprocessFrame(testWav0.data(), testWav0.size(), p);
     std::vector<float> denoised(testWav0.size());