MLECO-3173: Add AD, KWS_ASR and Noise reduction use case API's

Signed-off-by: Richard Burton <richard.burton@arm.com>

Change-Id: I36f61ce74bf17f7b327cdae9704a22ca54144f37
diff --git a/source/use_case/ad/include/AdModel.hpp b/source/use_case/ad/include/AdModel.hpp
index 8d914c4..2195a7c 100644
--- a/source/use_case/ad/include/AdModel.hpp
+++ b/source/use_case/ad/include/AdModel.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
  * SPDX-License-Identifier: Apache-2.0
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
@@ -28,6 +28,12 @@
 namespace app {
 
     class AdModel : public Model {
+
+    public:
+        /* Indices for the expected model - based on input tensor shape */
+        static constexpr uint32_t ms_inputRowsIdx = 1;
+        static constexpr uint32_t ms_inputColsIdx = 2;
+
     protected:
         /** @brief   Gets the reference to op resolver interface class */
         const tflite::MicroOpResolver& GetOpResolver() override;
diff --git a/source/use_case/ad/include/AdPostProcessing.hpp b/source/use_case/ad/include/AdPostProcessing.hpp
deleted file mode 100644
index 7eaec84..0000000
--- a/source/use_case/ad/include/AdPostProcessing.hpp
+++ /dev/null
@@ -1,50 +0,0 @@
-/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
- * SPDX-License-Identifier: Apache-2.0
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#ifndef ADPOSTPROCESSING_HPP
-#define ADPOSTPROCESSING_HPP
-
-#include "TensorFlowLiteMicro.hpp"
-
-#include <vector>
-
-namespace arm {
-namespace app {
-
-    /** @brief      Dequantize TensorFlow Lite Micro tensor.
-     *  @param[in]  tensor Pointer to the TensorFlow Lite Micro tensor to be dequantized.
-     *  @return     Vector with the dequantized tensor values.
-    **/
-    template<typename T>
-    std::vector<float> Dequantize(TfLiteTensor* tensor);
-
-    /**
-     * @brief   Calculates the softmax of vector in place. **/
-    void Softmax(std::vector<float>& inputVector);
-
-
-    /** @brief      Given a wav file name return AD model output index.
-     *  @param[in]  wavFileName Audio WAV filename.
-     *                          File name should be in format anything_goes_XX_here.wav
-     *                          where XX is the machine ID e.g. 00, 02, 04 or 06
-     *  @return     AD model output index as 8 bit integer.
-    **/
-    int8_t OutputIndexFromFileName(std::string wavFileName);
-
-} /* namespace app */
-} /* namespace arm */
-
-#endif /* ADPOSTPROCESSING_HPP */
diff --git a/source/use_case/ad/include/AdProcessing.hpp b/source/use_case/ad/include/AdProcessing.hpp
new file mode 100644
index 0000000..9abf6f1
--- /dev/null
+++ b/source/use_case/ad/include/AdProcessing.hpp
@@ -0,0 +1,230 @@
+/*
+ * Copyright (c) 2022 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef AD_PROCESSING_HPP
+#define AD_PROCESSING_HPP
+
+#include "BaseProcessing.hpp"
+#include "AudioUtils.hpp"
+#include "AdMelSpectrogram.hpp"
+#include "log_macros.h"
+
+namespace arm {
+namespace app {
+
+    /**
+     * @brief   Pre-processing class for anomaly detection use case.
+     *          Implements methods declared by BasePreProcess and anything else needed
+     *          to populate input tensors ready for inference.
+     */
+    class AdPreProcess : public BasePreProcess {
+
+    public:
+        /**
+         * @brief Constructor for AdPreProcess class objects
+         * @param[in] inputTensor  input tensor pointer from the tensor arena.
+         * @param[in] melSpectrogramFrameLen MEL spectrogram's frame length
+         * @param[in] melSpectrogramFrameStride MEL spectrogram's frame stride
+         * @param[in] adModelTrainingMean Training mean for the Anomaly detection model being used.
+         */
+        explicit AdPreProcess(TfLiteTensor* inputTensor,
+                              uint32_t melSpectrogramFrameLen,
+                              uint32_t melSpectrogramFrameStride,
+                              float adModelTrainingMean);
+
+        ~AdPreProcess() = default;
+
+        /**
+         * @brief Function to invoke pre-processing and populate the input vector
+         * @param input pointer to input data. For anomaly detection, this is the pointer to
+         *              the audio data.
+         * @param inputSize Size of the data being passed in for pre-processing.
+         * @return True if successful, false otherwise.
+         */
+        bool DoPreProcess(const void* input, size_t inputSize) override;
+
+        /**
+         * @brief Getter function for audio window size computed when constructing
+         *        the class object.
+         * @return Audio window size as 32 bit unsigned integer.
+         */
+        uint32_t GetAudioWindowSize();
+
+        /**
+         * @brief Getter function for audio window stride computed when constructing
+         *        the class object.
+         * @return Audio window stride as 32 bit unsigned integer.
+         */
+        uint32_t GetAudioDataStride();
+
+        /**
+         * @brief Setter function for current audio index. This is only used for evaluating
+         *        if previously computed features can be re-used from cache.
+         */
+        void SetAudioWindowIndex(uint32_t idx);
+
+    private:
+        bool        m_validInstance{false}; /**< Indicates the current object is valid. */
+        uint32_t    m_melSpectrogramFrameLen{}; /**< MEL spectrogram's window frame length */
+        uint32_t    m_melSpectrogramFrameStride{}; /**< MEL spectrogram's window frame stride */
+        uint8_t     m_inputResizeScale{}; /**< Downscaling factor for the MEL energy matrix. */
+        uint32_t    m_numMelSpecVectorsInAudioStride{};  /**< Number of frames to move across the audio. */
+        uint32_t    m_audioDataWindowSize{}; /**< Audio window size computed based on other parameters. */
+        uint32_t    m_audioDataStride{}; /**< Audio window stride computed. */
+        uint32_t    m_numReusedFeatureVectors{}; /**< Number of MEL vectors that can be re-used */
+        uint32_t    m_audioWindowIndex{}; /**< Current audio window index (from audio's sliding window) */
+
+        audio::SlidingWindow<const int16_t> m_melWindowSlider; /**< Internal MEL spectrogram window slider */
+        audio::AdMelSpectrogram m_melSpec; /**< MEL spectrogram computation object */
+        std::function<void
+            (std::vector<int16_t>&, int, bool, size_t, size_t)> m_featureCalc; /**< Feature calculator object */
+    };
+
+    class AdPostProcess : public BasePostProcess {
+    public:
+        /**
+         * @brief Constructor for AdPostProcess object.
+         * @param[in] outputTensor Output tensor pointer.
+         */
+        explicit AdPostProcess(TfLiteTensor* outputTensor);
+
+        ~AdPostProcess() = default;
+
+        /**
+         * @brief Function to do the post-processing on the output tensor.
+         * @return True if successful, false otherwise.
+         */
+        bool DoPostProcess() override;
+
+        /**
+         * @brief Getter function for an element from the de-quantised output vector.
+         * @param index Index of the element to be retrieved.
+         * @return index represented as a 32 bit floating point number.
+         */
+        float GetOutputValue(uint32_t index);
+
+    private:
+        TfLiteTensor* m_outputTensor{}; /**< Output tensor pointer */
+        std::vector<float> m_dequantizedOutputVec{}; /**< Internal output vector */
+
+        /**
+         * @brief De-quantizes and flattens the output tensor into a vector.
+         * @tparam T template parameter to indicate data type.
+         * @return True if successful, false otherwise.
+         */
+        template<typename T>
+        bool Dequantize()
+        {
+            TfLiteTensor* tensor = this->m_outputTensor;
+            if (tensor == nullptr) {
+                printf_err("Invalid output tensor.\n");
+                return false;
+            }
+            T* tensorData = tflite::GetTensorData<T>(tensor);
+
+            uint32_t totalOutputSize = 1;
+            for (int inputDim = 0; inputDim < tensor->dims->size; inputDim++){
+                totalOutputSize *= tensor->dims->data[inputDim];
+            }
+
+            /* For getting the floating point values, we need quantization parameters */
+            QuantParams quantParams = GetTensorQuantParams(tensor);
+
+            this->m_dequantizedOutputVec = std::vector<float>(totalOutputSize, 0);
+
+            for (size_t i = 0; i < totalOutputSize; ++i) {
+                this->m_dequantizedOutputVec[i] = quantParams.scale * (tensorData[i] - quantParams.offset);
+            }
+
+            return true;
+        }
+    };
+
+    /* Templated instances available: */
+    template bool AdPostProcess::Dequantize<int8_t>();
+
+    /**
+     * @brief Generic feature calculator factory.
+     *
+     * Returns lambda function to compute features using features cache.
+     * Real features math is done by a lambda function provided as a parameter.
+     * Features are written to input tensor memory.
+     *
+     * @tparam T            feature vector type.
+     * @param inputTensor   model input tensor pointer.
+     * @param cacheSize     number of feature vectors to cache. Defined by the sliding window overlap.
+     * @param compute       features calculator function.
+     * @return              lambda function to compute features.
+     */
+    template<class T>
+    std::function<void (std::vector<int16_t>&, size_t, bool, size_t, size_t)>
+    FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
+                std::function<std::vector<T> (std::vector<int16_t>& )> compute)
+    {
+        /* Feature cache to be captured by lambda function*/
+        static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
+
+        return [=](std::vector<int16_t>& audioDataWindow,
+                   size_t index,
+                   bool useCache,
+                   size_t featuresOverlapIndex,
+                   size_t resizeScale)
+        {
+            T* tensorData = tflite::GetTensorData<T>(inputTensor);
+            std::vector<T> features;
+
+            /* Reuse features from cache if cache is ready and sliding windows overlap.
+             * Overlap is in the beginning of sliding window with a size of a feature cache. */
+            if (useCache && index < featureCache.size()) {
+                features = std::move(featureCache[index]);
+            } else {
+                features = std::move(compute(audioDataWindow));
+            }
+            auto size = features.size() / resizeScale;
+            auto sizeBytes = sizeof(T);
+
+            /* Input should be transposed and "resized" by skipping elements. */
+            for (size_t outIndex = 0; outIndex < size; outIndex++) {
+                std::memcpy(tensorData + (outIndex*size) + index, &features[outIndex*resizeScale], sizeBytes);
+            }
+
+            /* Start renewing cache as soon iteration goes out of the windows overlap. */
+            if (index >= featuresOverlapIndex / resizeScale) {
+                featureCache[index - featuresOverlapIndex / resizeScale] = std::move(features);
+            }
+        };
+    }
+
+    template std::function<void (std::vector<int16_t>&, size_t , bool, size_t, size_t)>
+    FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
+                        size_t cacheSize,
+                        std::function<std::vector<int8_t> (std::vector<int16_t>&)> compute);
+
+    template std::function<void(std::vector<int16_t>&, size_t, bool, size_t, size_t)>
+    FeatureCalc<float>(TfLiteTensor *inputTensor,
+                       size_t cacheSize,
+                       std::function<std::vector<float>(std::vector<int16_t>&)> compute);
+
+    std::function<void (std::vector<int16_t>&, int, bool, size_t, size_t)>
+    GetFeatureCalculator(audio::AdMelSpectrogram& melSpec,
+                         TfLiteTensor* inputTensor,
+                         size_t cacheSize,
+                         float trainingMean);
+
+} /* namespace app */
+} /* namespace arm */
+
+#endif /* AD_PROCESSING_HPP */
diff --git a/source/use_case/ad/src/AdPostProcessing.cc b/source/use_case/ad/src/AdPostProcessing.cc
deleted file mode 100644
index c461875..0000000
--- a/source/use_case/ad/src/AdPostProcessing.cc
+++ /dev/null
@@ -1,115 +0,0 @@
-/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
- * SPDX-License-Identifier: Apache-2.0
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#include "AdPostProcessing.hpp"
-#include "log_macros.h"
-
-#include <numeric>
-#include <cmath>
-#include <string>
-
-namespace arm {
-namespace app {
-
-    template<typename T>
-    std::vector<float> Dequantize(TfLiteTensor* tensor) {
-
-        if (tensor == nullptr) {
-            printf_err("Tensor is null pointer can not dequantize.\n");
-            return std::vector<float>();
-        }
-        T* tensorData = tflite::GetTensorData<T>(tensor);
-
-        uint32_t totalOutputSize = 1;
-        for (int inputDim = 0; inputDim < tensor->dims->size; inputDim++){
-            totalOutputSize *= tensor->dims->data[inputDim];
-        }
-
-        /* For getting the floating point values, we need quantization parameters */
-        QuantParams quantParams = GetTensorQuantParams(tensor);
-
-        std::vector<float> dequantizedOutput(totalOutputSize);
-
-        for (size_t i = 0; i < totalOutputSize; ++i) {
-            dequantizedOutput[i] = quantParams.scale * (tensorData[i] - quantParams.offset);
-        }
-
-        return dequantizedOutput;
-    }
-
-    void Softmax(std::vector<float>& inputVector) {
-        auto start = inputVector.begin();
-        auto end = inputVector.end();
-
-        /* Fix for numerical stability and apply exp. */
-        float maxValue = *std::max_element(start, end);
-        for (auto it = start; it!=end; ++it) {
-            *it = std::exp((*it) - maxValue);
-        }
-
-        float sumExp = std::accumulate(start, end, 0.0f);
-
-        for (auto it = start; it!=end; ++it) {
-            *it = (*it)/sumExp;
-        }
-    }
-
-    int8_t OutputIndexFromFileName(std::string wavFileName) {
-        /* Filename is assumed in the form machine_id_00.wav */
-        std::string delimiter = "_";  /* First character used to split the file name up. */
-        size_t delimiterStart;
-        std::string subString;
-        size_t machineIdxInString = 3;  /* Which part of the file name the machine id should be at. */
-
-        for (size_t i = 0; i < machineIdxInString; ++i) {
-            delimiterStart = wavFileName.find(delimiter);
-            subString = wavFileName.substr(0, delimiterStart);
-            wavFileName.erase(0, delimiterStart + delimiter.length());
-        }
-
-        /* At this point substring should be 00.wav */
-        delimiter = ".";  /* Second character used to split the file name up. */
-        delimiterStart = subString.find(delimiter);
-        subString = (delimiterStart != std::string::npos) ? subString.substr(0, delimiterStart) : subString;
-
-        auto is_number = [](const std::string& str) ->  bool
-        {
-            std::string::const_iterator it = str.begin();
-            while (it != str.end() && std::isdigit(*it)) ++it;
-            return !str.empty() && it == str.end();
-        };
-
-        const int8_t machineIdx = is_number(subString) ? std::stoi(subString) : -1;
-
-        /* Return corresponding index in the output vector. */
-        if (machineIdx == 0) {
-            return 0;
-        } else if (machineIdx == 2) {
-            return 1;
-        } else if (machineIdx == 4) {
-            return 2;
-        } else if (machineIdx == 6) {
-            return 3;
-        } else {
-            printf_err("%d is an invalid machine index \n", machineIdx);
-            return -1;
-        }
-    }
-
-    template std::vector<float> Dequantize<uint8_t>(TfLiteTensor* tensor);
-    template std::vector<float> Dequantize<int8_t>(TfLiteTensor* tensor);
-} /* namespace app */
-} /* namespace arm */
\ No newline at end of file
diff --git a/source/use_case/ad/src/AdProcessing.cc b/source/use_case/ad/src/AdProcessing.cc
new file mode 100644
index 0000000..a33131c
--- /dev/null
+++ b/source/use_case/ad/src/AdProcessing.cc
@@ -0,0 +1,208 @@
+/*
+ * Copyright (c) 2022 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "AdProcessing.hpp"
+
+#include "AdModel.hpp"
+
+namespace arm {
+namespace app {
+
+AdPreProcess::AdPreProcess(TfLiteTensor* inputTensor,
+                           uint32_t melSpectrogramFrameLen,
+                           uint32_t melSpectrogramFrameStride,
+                           float adModelTrainingMean):
+       m_validInstance{false},
+       m_melSpectrogramFrameLen{melSpectrogramFrameLen},
+       m_melSpectrogramFrameStride{melSpectrogramFrameStride},
+        /**< Model is trained on features downsampled 2x */
+       m_inputResizeScale{2},
+        /**< We are choosing to move by 20 frames across the audio for each inference. */
+       m_numMelSpecVectorsInAudioStride{20},
+       m_audioDataStride{m_numMelSpecVectorsInAudioStride * melSpectrogramFrameStride},
+       m_melSpec{melSpectrogramFrameLen}
+{
+    if (!inputTensor) {
+        printf_err("Invalid input tensor provided to pre-process\n");
+        return;
+    }
+
+    TfLiteIntArray* inputShape = inputTensor->dims;
+
+    if (!inputShape) {
+        printf_err("Invalid input tensor dims\n");
+        return;
+    }
+
+    const uint32_t kNumRows = inputShape->data[AdModel::ms_inputRowsIdx];
+    const uint32_t kNumCols = inputShape->data[AdModel::ms_inputColsIdx];
+
+    /* Deduce the data length required for 1 inference from the network parameters. */
+    this->m_audioDataWindowSize = (((this->m_inputResizeScale * kNumCols) - 1) *
+                                    melSpectrogramFrameStride) +
+                                    melSpectrogramFrameLen;
+    this->m_numReusedFeatureVectors = kNumRows -
+                                      (this->m_numMelSpecVectorsInAudioStride /
+                                       this->m_inputResizeScale);
+    this->m_melSpec.Init();
+
+    /* Creating a Mel Spectrogram sliding window for the data required for 1 inference.
+     * "resizing" done here by multiplying stride by resize scale. */
+    this->m_melWindowSlider = audio::SlidingWindow<const int16_t>(
+            nullptr, /* to be populated later. */
+            this->m_audioDataWindowSize,
+            melSpectrogramFrameLen,
+            melSpectrogramFrameStride * this->m_inputResizeScale);
+
+    /* Construct feature calculation function. */
+    this->m_featureCalc = GetFeatureCalculator(this->m_melSpec, inputTensor,
+                                               this->m_numReusedFeatureVectors,
+                                               adModelTrainingMean);
+    this->m_validInstance = true;
+}
+
+bool AdPreProcess::DoPreProcess(const void* input, size_t inputSize)
+{
+    /* Check that we have a valid instance. */
+    if (!this->m_validInstance) {
+        printf_err("Invalid pre-processor instance\n");
+        return false;
+    }
+
+    /* We expect that we can traverse the size with which the MEL spectrogram
+     * sliding window was initialised with. */
+    if (!input || inputSize < this->m_audioDataWindowSize) {
+        printf_err("Invalid input provided for pre-processing\n");
+        return false;
+    }
+
+    /* We moved to the next window - set the features sliding to the new address. */
+    this->m_melWindowSlider.Reset(static_cast<const int16_t*>(input));
+
+    /* The first window does not have cache ready. */
+    const bool useCache = this->m_audioWindowIndex > 0 && this->m_numReusedFeatureVectors > 0;
+
+    /* Start calculating features inside one audio sliding window. */
+    while (this->m_melWindowSlider.HasNext()) {
+        const int16_t* melSpecWindow = this->m_melWindowSlider.Next();
+        std::vector<int16_t> melSpecAudioData = std::vector<int16_t>(
+                melSpecWindow,
+                melSpecWindow + this->m_melSpectrogramFrameLen);
+
+        /* Compute features for this window and write them to input tensor. */
+        this->m_featureCalc(melSpecAudioData,
+                            this->m_melWindowSlider.Index(),
+                            useCache,
+                            this->m_numMelSpecVectorsInAudioStride,
+                            this->m_inputResizeScale);
+    }
+
+    return true;
+}
+
+uint32_t AdPreProcess::GetAudioWindowSize()
+{
+    return this->m_audioDataWindowSize;
+}
+
+uint32_t AdPreProcess::GetAudioDataStride()
+{
+    return this->m_audioDataStride;
+}
+
+void AdPreProcess::SetAudioWindowIndex(uint32_t idx)
+{
+    this->m_audioWindowIndex = idx;
+}
+
+AdPostProcess::AdPostProcess(TfLiteTensor* outputTensor) :
+    m_outputTensor {outputTensor}
+{}
+
+bool AdPostProcess::DoPostProcess()
+{
+    switch (this->m_outputTensor->type) {
+        case kTfLiteInt8:
+            this->Dequantize<int8_t>();
+            break;
+        default:
+            printf_err("Unsupported tensor type");
+            return false;
+    }
+
+    math::MathUtils::SoftmaxF32(this->m_dequantizedOutputVec);
+    return true;
+}
+
+float AdPostProcess::GetOutputValue(uint32_t index)
+{
+    if (index < this->m_dequantizedOutputVec.size()) {
+        return this->m_dequantizedOutputVec[index];
+    }
+    printf_err("Invalid index for output\n");
+    return 0.0;
+}
+
+std::function<void (std::vector<int16_t>&, int, bool, size_t, size_t)>
+GetFeatureCalculator(audio::AdMelSpectrogram& melSpec,
+                     TfLiteTensor* inputTensor,
+                     size_t cacheSize,
+                     float trainingMean)
+{
+    std::function<void (std::vector<int16_t>&, size_t, bool, size_t, size_t)> melSpecFeatureCalc;
+
+    TfLiteQuantization quant = inputTensor->quantization;
+
+    if (kTfLiteAffineQuantization == quant.type) {
+
+        auto* quantParams = static_cast<TfLiteAffineQuantization*>(quant.params);
+        const float quantScale = quantParams->scale->data[0];
+        const int quantOffset = quantParams->zero_point->data[0];
+
+        switch (inputTensor->type) {
+            case kTfLiteInt8: {
+                melSpecFeatureCalc = FeatureCalc<int8_t>(
+                        inputTensor,
+                        cacheSize,
+                        [=, &melSpec](std::vector<int16_t>& audioDataWindow) {
+                            return melSpec.MelSpecComputeQuant<int8_t>(
+                                    audioDataWindow,
+                                    quantScale,
+                                    quantOffset,
+                                    trainingMean);
+                        }
+                );
+                break;
+            }
+            default:
+            printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type));
+        }
+    } else {
+        melSpecFeatureCalc = FeatureCalc<float>(
+                inputTensor,
+                cacheSize,
+                [=, &melSpec](
+                        std::vector<int16_t>& audioDataWindow) {
+                    return melSpec.ComputeMelSpec(
+                            audioDataWindow,
+                            trainingMean);
+                });
+    }
+    return melSpecFeatureCalc;
+}
+
+} /* namespace app */
+} /* namespace arm */
diff --git a/source/use_case/ad/src/MainLoop.cc b/source/use_case/ad/src/MainLoop.cc
index 23d1e51..140359b 100644
--- a/source/use_case/ad/src/MainLoop.cc
+++ b/source/use_case/ad/src/MainLoop.cc
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
  * SPDX-License-Identifier: Apache-2.0
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,7 +14,6 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-#include "hal.h"                    /* Brings in platform definitions */
 #include "InputFiles.hpp"           /* For input data */
 #include "AdModel.hpp"              /* Model class for running inference */
 #include "UseCaseCommonUtils.hpp"   /* Utils functions */
@@ -63,8 +62,8 @@
     caseContext.Set<arm::app::Profiler&>("profiler", profiler);
     caseContext.Set<arm::app::Model&>("model", model);
     caseContext.Set<uint32_t>("clipIndex", 0);
-    caseContext.Set<int>("frameLength", g_FrameLength);
-    caseContext.Set<int>("frameStride", g_FrameStride);
+    caseContext.Set<uint32_t>("frameLength", g_FrameLength);
+    caseContext.Set<uint32_t>("frameStride", g_FrameStride);
     caseContext.Set<float>("scoreThreshold", g_ScoreThreshold);
     caseContext.Set<float>("trainingMean", g_TrainingMean);
 
diff --git a/source/use_case/ad/src/UseCaseHandler.cc b/source/use_case/ad/src/UseCaseHandler.cc
index 5585f36..0179d6b 100644
--- a/source/use_case/ad/src/UseCaseHandler.cc
+++ b/source/use_case/ad/src/UseCaseHandler.cc
@@ -24,8 +24,8 @@
 #include "AudioUtils.hpp"
 #include "ImageUtils.hpp"
 #include "UseCaseCommonUtils.hpp"
-#include "AdPostProcessing.hpp"
 #include "log_macros.h"
+#include "AdProcessing.hpp"
 
 namespace arm {
 namespace app {
@@ -39,32 +39,17 @@
      **/
     static bool PresentInferenceResult(float result, float threshold);
 
-    /**
-     * @brief Returns a function to perform feature calculation and populates input tensor data with
-     * MelSpe data.
-     *
-     * Input tensor data type check is performed to choose correct MFCC feature data type.
-     * If tensor has an integer data type then original features are quantised.
-     *
-     * Warning: mfcc calculator provided as input must have the same life scope as returned function.
-     *
-     * @param[in]           melSpec         MFCC feature calculator.
-     * @param[in,out]       inputTensor     Input tensor pointer to store calculated features.
-     * @param[in]           cacheSize       Size of the feture vectors cache (number of feature vectors).
-     * @param[in]           trainingMean    Training mean.
-     * @return function     function to be called providing audio sample and sliding window index.
-     */
-    static std::function<void (std::vector<int16_t>&, int, bool, size_t, size_t)>
-    GetFeatureCalculator(audio::AdMelSpectrogram&  melSpec,
-                         TfLiteTensor*             inputTensor,
-                         size_t                    cacheSize,
-                         float                     trainingMean);
+    /** @brief      Given a wav file name return AD model output index.
+     *  @param[in]  wavFileName Audio WAV filename.
+     *                          File name should be in format anything_goes_XX_here.wav
+     *                          where XX is the machine ID e.g. 00, 02, 04 or 06
+     *  @return     AD model output index as 8 bit integer.
+    **/
+    static int8_t OutputIndexFromFileName(std::string wavFileName);
 
-    /* Vibration classification handler */
+    /* Anomaly Detection inference handler */
     bool ClassifyVibrationHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
     {
-        auto& profiler = ctx.Get<Profiler&>("profiler");
-
         constexpr uint32_t dataPsnTxtInfStartX = 20;
         constexpr uint32_t dataPsnTxtInfStartY = 40;
 
@@ -81,8 +66,9 @@
             return false;
         }
 
-        const auto frameLength = ctx.Get<int>("frameLength");
-        const auto frameStride = ctx.Get<int>("frameStride");
+        auto& profiler = ctx.Get<Profiler&>("profiler");
+        const auto melSpecFrameLength = ctx.Get<uint32_t>("frameLength");
+        const auto melSpecFrameStride = ctx.Get<uint32_t>("frameStride");
         const auto scoreThreshold = ctx.Get<float>("scoreThreshold");
         const auto trainingMean = ctx.Get<float>("trainingMean");
         auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
@@ -95,21 +81,13 @@
             return false;
         }
 
-        TfLiteIntArray* inputShape = model.GetInputShape(0);
-        const uint32_t kNumRows = inputShape->data[1];
-        const uint32_t kNumCols = inputShape->data[2];
+        AdPreProcess preProcess{
+            inputTensor,
+            melSpecFrameLength,
+            melSpecFrameStride,
+            trainingMean};
 
-        audio::AdMelSpectrogram melSpec = audio::AdMelSpectrogram(frameLength);
-        melSpec.Init();
-
-        /* Deduce the data length required for 1 inference from the network parameters. */
-        const uint8_t inputResizeScale = 2;
-        const uint32_t audioDataWindowSize = (((inputResizeScale * kNumCols) - 1) * frameStride) + frameLength;
-
-        /* We are choosing to move by 20 frames across the audio for each inference. */
-        const uint8_t nMelSpecVectorsInAudioStride = 20;
-
-        auto audioDataStride = nMelSpecVectorsInAudioStride * frameStride;
+        AdPostProcess postProcess{outputTensor};
 
         do {
             hal_lcd_clear(COLOR_BLACK);
@@ -122,29 +100,12 @@
                 return false;
             }
 
-            /* Creating a Mel Spectrogram sliding window for the data required for 1 inference.
-             * "resizing" done here by multiplying stride by resize scale. */
-            auto audioMelSpecWindowSlider = audio::SlidingWindow<const int16_t>(
-                    get_audio_array(currentIndex),
-                    audioDataWindowSize, frameLength,
-                    frameStride * inputResizeScale);
-
             /* Creating a sliding window through the whole audio clip. */
             auto audioDataSlider = audio::SlidingWindow<const int16_t>(
-                    get_audio_array(currentIndex),
-                    get_audio_array_size(currentIndex),
-                    audioDataWindowSize, audioDataStride);
-
-            /* Calculate number of the feature vectors in the window overlap region taking into account resizing.
-             * These feature vectors will be reused.*/
-            auto numberOfReusedFeatureVectors = kNumRows - (nMelSpecVectorsInAudioStride / inputResizeScale);
-
-            /* Construct feature calculation function. */
-            auto melSpecFeatureCalc = GetFeatureCalculator(melSpec, inputTensor,
-                                                           numberOfReusedFeatureVectors, trainingMean);
-            if (!melSpecFeatureCalc){
-                return false;
-            }
+                get_audio_array(currentIndex),
+                get_audio_array_size(currentIndex),
+                preProcess.GetAudioWindowSize(),
+                preProcess.GetAudioDataStride());
 
             /* Result is an averaged sum over inferences. */
             float result = 0;
@@ -152,30 +113,18 @@
             /* Display message on the LCD - inference running. */
             std::string str_inf{"Running inference... "};
             hal_lcd_display_text(
-                    str_inf.c_str(), str_inf.size(),
-                    dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
-            info("Running inference on audio clip %" PRIu32 " => %s\n", currentIndex, get_filename(currentIndex));
+                str_inf.c_str(), str_inf.size(),
+                dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
+
+            info("Running inference on audio clip %" PRIu32 " => %s\n",
+                currentIndex, get_filename(currentIndex));
 
             /* Start sliding through audio clip. */
             while (audioDataSlider.HasNext()) {
-                const int16_t *inferenceWindow = audioDataSlider.Next();
+                const int16_t* inferenceWindow = audioDataSlider.Next();
 
-                /* We moved to the next window - set the features sliding to the new address. */
-                audioMelSpecWindowSlider.Reset(inferenceWindow);
-
-                /* The first window does not have cache ready. */
-                bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0;
-
-                /* Start calculating features inside one audio sliding window. */
-                while (audioMelSpecWindowSlider.HasNext()) {
-                    const int16_t *melSpecWindow = audioMelSpecWindowSlider.Next();
-                    std::vector<int16_t> melSpecAudioData = std::vector<int16_t>(melSpecWindow,
-                                                                                 melSpecWindow + frameLength);
-
-                    /* Compute features for this window and write them to input tensor. */
-                    melSpecFeatureCalc(melSpecAudioData, audioMelSpecWindowSlider.Index(),
-                                       useCache, nMelSpecVectorsInAudioStride, inputResizeScale);
-                }
+                preProcess.SetAudioWindowIndex(audioDataSlider.Index());
+                preProcess.DoPreProcess(inferenceWindow, preProcess.GetAudioWindowSize());
 
                 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
                      audioDataSlider.TotalStrides() + 1);
@@ -185,13 +134,11 @@
                     return false;
                 }
 
-                /* Use the negative softmax score of the corresponding index as the outlier score */
-                std::vector<float> dequantOutput = Dequantize<int8_t>(outputTensor);
-                Softmax(dequantOutput);
-                result += -dequantOutput[machineOutputIndex];
+                postProcess.DoPostProcess();
+                result += 0 - postProcess.GetOutputValue(machineOutputIndex);
 
 #if VERIFY_TEST_OUTPUT
-                arm::app::DumpTensor(outputTensor);
+                DumpTensor(outputTensor);
 #endif /* VERIFY_TEST_OUTPUT */
             } /* while (audioDataSlider.HasNext()) */
 
@@ -218,7 +165,6 @@
         return true;
     }
 
-
     static bool PresentInferenceResult(float result, float threshold)
     {
         constexpr uint32_t dataPsnTxtStartX1 = 20;
@@ -251,148 +197,47 @@
         return true;
     }
 
-    /**
-     * @brief Generic feature calculator factory.
-     *
-     * Returns lambda function to compute features using features cache.
-     * Real features math is done by a lambda function provided as a parameter.
-     * Features are written to input tensor memory.
-     *
-     * @tparam T            feature vector type.
-     * @param inputTensor   model input tensor pointer.
-     * @param cacheSize     number of feature vectors to cache. Defined by the sliding window overlap.
-     * @param compute       features calculator function.
-     * @return              lambda function to compute features.
-     */
-    template<class T>
-    std::function<void (std::vector<int16_t>&, size_t, bool, size_t, size_t)>
-    FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
-                std::function<std::vector<T> (std::vector<int16_t>& )> compute)
+    static int8_t OutputIndexFromFileName(std::string wavFileName)
     {
-        /* Feature cache to be captured by lambda function*/
-        static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
+        /* Filename is assumed in the form machine_id_00.wav */
+        std::string delimiter = "_";  /* First character used to split the file name up. */
+        size_t delimiterStart;
+        std::string subString;
+        size_t machineIdxInString = 3;  /* Which part of the file name the machine id should be at. */
 
-        return [=](std::vector<int16_t>& audioDataWindow,
-                   size_t index,
-                   bool useCache,
-                   size_t featuresOverlapIndex,
-                   size_t resizeScale)
-        {
-            T *tensorData = tflite::GetTensorData<T>(inputTensor);
-            std::vector<T> features;
-
-            /* Reuse features from cache if cache is ready and sliding windows overlap.
-             * Overlap is in the beginning of sliding window with a size of a feature cache. */
-            if (useCache && index < featureCache.size()) {
-                features = std::move(featureCache[index]);
-            } else {
-                features = std::move(compute(audioDataWindow));
-            }
-            auto size = features.size() / resizeScale;
-            auto sizeBytes = sizeof(T);
-
-            /* Input should be transposed and "resized" by skipping elements. */
-            for (size_t outIndex = 0; outIndex < size; outIndex++) {
-                std::memcpy(tensorData + (outIndex*size) + index, &features[outIndex*resizeScale], sizeBytes);
-            }
-
-            /* Start renewing cache as soon iteration goes out of the windows overlap. */
-            if (index >= featuresOverlapIndex / resizeScale) {
-                featureCache[index - featuresOverlapIndex / resizeScale] = std::move(features);
-            }
-        };
-    }
-
-    template std::function<void (std::vector<int16_t>&, size_t , bool, size_t, size_t)>
-    FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
-                        size_t cacheSize,
-                        std::function<std::vector<int8_t> (std::vector<int16_t>&)> compute);
-
-    template std::function<void (std::vector<int16_t>&, size_t , bool, size_t, size_t)>
-    FeatureCalc<uint8_t>(TfLiteTensor* inputTensor,
-                         size_t cacheSize,
-                         std::function<std::vector<uint8_t> (std::vector<int16_t>&)> compute);
-
-    template std::function<void (std::vector<int16_t>&, size_t , bool, size_t, size_t)>
-    FeatureCalc<int16_t>(TfLiteTensor* inputTensor,
-                         size_t cacheSize,
-                         std::function<std::vector<int16_t> (std::vector<int16_t>&)> compute);
-
-    template std::function<void(std::vector<int16_t>&, size_t, bool, size_t, size_t)>
-    FeatureCalc<float>(TfLiteTensor *inputTensor,
-                       size_t cacheSize,
-                       std::function<std::vector<float>(std::vector<int16_t>&)> compute);
-
-
-    static std::function<void (std::vector<int16_t>&, int, bool, size_t, size_t)>
-    GetFeatureCalculator(audio::AdMelSpectrogram& melSpec, TfLiteTensor* inputTensor, size_t cacheSize, float trainingMean)
-    {
-        std::function<void (std::vector<int16_t>&, size_t, bool, size_t, size_t)> melSpecFeatureCalc;
-
-        TfLiteQuantization quant = inputTensor->quantization;
-
-        if (kTfLiteAffineQuantization == quant.type) {
-
-            auto *quantParams = (TfLiteAffineQuantization *) quant.params;
-            const float quantScale = quantParams->scale->data[0];
-            const int quantOffset = quantParams->zero_point->data[0];
-
-            switch (inputTensor->type) {
-                case kTfLiteInt8: {
-                    melSpecFeatureCalc = FeatureCalc<int8_t>(inputTensor,
-                                                             cacheSize,
-                                                             [=, &melSpec](std::vector<int16_t>& audioDataWindow) {
-                                                                 return melSpec.MelSpecComputeQuant<int8_t>(
-                                                                         audioDataWindow,
-                                                                         quantScale,
-                                                                         quantOffset,
-                                                                         trainingMean);
-                                                             }
-                    );
-                    break;
-                }
-                case kTfLiteUInt8: {
-                    melSpecFeatureCalc = FeatureCalc<uint8_t>(inputTensor,
-                                                              cacheSize,
-                                                              [=, &melSpec](std::vector<int16_t>& audioDataWindow) {
-                                                                  return melSpec.MelSpecComputeQuant<uint8_t>(
-                                                                          audioDataWindow,
-                                                                          quantScale,
-                                                                          quantOffset,
-                                                                          trainingMean);
-                                                              }
-                    );
-                    break;
-                }
-                case kTfLiteInt16: {
-                    melSpecFeatureCalc = FeatureCalc<int16_t>(inputTensor,
-                                                              cacheSize,
-                                                              [=, &melSpec](std::vector<int16_t>& audioDataWindow) {
-                                                                  return melSpec.MelSpecComputeQuant<int16_t>(
-                                                                          audioDataWindow,
-                                                                          quantScale,
-                                                                          quantOffset,
-                                                                          trainingMean);
-                                                              }
-                    );
-                    break;
-                }
-                default:
-                printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type));
-            }
-
-
-        } else {
-            melSpecFeatureCalc = melSpecFeatureCalc = FeatureCalc<float>(inputTensor,
-                                                                         cacheSize,
-                                                                         [=, &melSpec](
-                                                                                 std::vector<int16_t>& audioDataWindow) {
-                                                                             return melSpec.ComputeMelSpec(
-                                                                                     audioDataWindow,
-                                                                                     trainingMean);
-                                                                         });
+        for (size_t i = 0; i < machineIdxInString; ++i) {
+            delimiterStart = wavFileName.find(delimiter);
+            subString = wavFileName.substr(0, delimiterStart);
+            wavFileName.erase(0, delimiterStart + delimiter.length());
         }
-        return melSpecFeatureCalc;
+
+        /* At this point substring should be 00.wav */
+        delimiter = ".";  /* Second character used to split the file name up. */
+        delimiterStart = subString.find(delimiter);
+        subString = (delimiterStart != std::string::npos) ? subString.substr(0, delimiterStart) : subString;
+
+        auto is_number = [](const std::string& str) ->  bool
+        {
+            std::string::const_iterator it = str.begin();
+            while (it != str.end() && std::isdigit(*it)) ++it;
+            return !str.empty() && it == str.end();
+        };
+
+        const int8_t machineIdx = is_number(subString) ? std::stoi(subString) : -1;
+
+        /* Return corresponding index in the output vector. */
+        if (machineIdx == 0) {
+            return 0;
+        } else if (machineIdx == 2) {
+            return 1;
+        } else if (machineIdx == 4) {
+            return 2;
+        } else if (machineIdx == 6) {
+            return 3;
+        } else {
+            printf_err("%d is an invalid machine index \n", machineIdx);
+            return -1;
+        }
     }
 
 } /* namespace app */