MLECO-3183: Refactoring application sources

Platform agnostic application sources are moved into application
api module with their own independent CMake projects.

Changes for MLECO-3080 also included - they create CMake projects
individial API's (again, platform agnostic) that dependent on the
common logic. The API for KWS_API "joint" API has been removed and
now the use case relies on individual KWS, and ASR API libraries.

Change-Id: I1f7748dc767abb3904634a04e0991b74ac7b756d
Signed-off-by: Kshitij Sisodia <kshitij.sisodia@arm.com>
diff --git a/source/application/api/use_case/ad/include/AdMelSpectrogram.hpp b/source/application/api/use_case/ad/include/AdMelSpectrogram.hpp
new file mode 100644
index 0000000..05c5bfc
--- /dev/null
+++ b/source/application/api/use_case/ad/include/AdMelSpectrogram.hpp
@@ -0,0 +1,97 @@
+/*
+ * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef ADMELSPECTROGRAM_HPP
+#define ADMELSPECTROGRAM_HPP
+
+#include "MelSpectrogram.hpp"
+
+namespace arm {
+namespace app {
+namespace audio {
+
+    /* Class to provide anomaly detection specific Mel Spectrogram calculation requirements */
+    class AdMelSpectrogram : public MelSpectrogram {
+
+    public:
+        static constexpr uint32_t  ms_defaultSamplingFreq = 16000;
+        static constexpr uint32_t  ms_defaultNumFbankBins =    64;
+        static constexpr uint32_t  ms_defaultMelLoFreq    =     0;
+        static constexpr uint32_t  ms_defaultMelHiFreq    =  8000;
+        static constexpr bool      ms_defaultUseHtkMethod = false;
+
+        explicit AdMelSpectrogram(const size_t frameLen)
+                :  MelSpectrogram(MelSpecParams(
+                ms_defaultSamplingFreq, ms_defaultNumFbankBins,
+                ms_defaultMelLoFreq, ms_defaultMelHiFreq,
+                frameLen, ms_defaultUseHtkMethod))
+        {}
+
+        AdMelSpectrogram()  = delete;
+        ~AdMelSpectrogram() = default;
+
+    protected:
+
+        /**
+         * @brief       Overrides base class implementation of this function.
+         * @param[in]   fftVec                  Vector populated with FFT magnitudes
+         * @param[in]   melFilterBank           2D Vector with filter bank weights
+         * @param[in]   filterBankFilterFirst   Vector containing the first indices of filter bank
+         *                                      to be used for each bin.
+         * @param[in]   filterBankFilterLast    Vector containing the last indices of filter bank
+         *                                      to be used for each bin.
+         * @param[out]  melEnergies             Pre-allocated vector of MEL energies to be
+         *                                      populated.
+         * @return      true if successful, false otherwise
+         */
+        virtual bool ApplyMelFilterBank(
+                std::vector<float>&                 fftVec,
+                std::vector<std::vector<float>>&    melFilterBank,
+                std::vector<uint32_t>&               filterBankFilterFirst,
+                std::vector<uint32_t>&               filterBankFilterLast,
+                std::vector<float>&                 melEnergies) override;
+
+        /**
+         * @brief       Override for the base class implementation convert mel
+         *              energies to logarithmic scale. The difference from
+         *              default behaviour is that the power is converted to dB
+         *              and subsequently clamped.
+         * @param[in,out]   melEnergies - 1D vector of Mel energies
+         **/
+        virtual void ConvertToLogarithmicScale(std::vector<float>& melEnergies) override;
+
+        /**
+         * @brief       Given the low and high Mel values, get the normaliser
+         *              for weights to be applied when populating the filter
+         *              bank. Override for the base class implementation.
+         * @param[in]   leftMel - low Mel frequency value
+         * @param[in]   rightMel - high Mel frequency value
+         * @param[in]   useHTKMethod - bool to signal if HTK method is to be
+         *              used for calculation
+         * @return      Return float value to be applied 
+         *              when populating the filter bank.
+         */
+        virtual float GetMelFilterBankNormaliser(
+                const float&   leftMel,
+                const float&   rightMel,
+                const bool     useHTKMethod) override;
+    };
+
+} /* namespace audio */
+} /* namespace app */
+} /* namespace arm */
+
+#endif /* ADMELSPECTROGRAM_HPP */
diff --git a/source/application/api/use_case/ad/include/AdModel.hpp b/source/application/api/use_case/ad/include/AdModel.hpp
new file mode 100644
index 0000000..0436a89
--- /dev/null
+++ b/source/application/api/use_case/ad/include/AdModel.hpp
@@ -0,0 +1,55 @@
+/*
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef AD_MODEL_HPP
+#define AD_MODEL_HPP
+
+#include "Model.hpp"
+
+extern const int g_FrameLength;
+extern const int g_FrameStride;
+extern const float g_ScoreThreshold;
+extern const float g_TrainingMean;
+
+namespace arm {
+namespace app {
+
+    class AdModel : public Model {
+
+    public:
+        /* Indices for the expected model - based on input tensor shape */
+        static constexpr uint32_t ms_inputRowsIdx = 1;
+        static constexpr uint32_t ms_inputColsIdx = 2;
+
+    protected:
+        /** @brief   Gets the reference to op resolver interface class */
+        const tflite::MicroOpResolver& GetOpResolver() override;
+
+        /** @brief   Adds operations to the op resolver instance */
+        bool EnlistOperations() override;
+
+    private:
+        /* Maximum number of individual operations that can be enlisted */
+        static constexpr int ms_maxOpCnt = 6;
+
+        /* A mutable op resolver instance */
+        tflite::MicroMutableOpResolver<ms_maxOpCnt> m_opResolver;
+    };
+
+} /* namespace app */
+} /* namespace arm */
+
+#endif /* AD_MODEL_HPP */
diff --git a/source/application/api/use_case/ad/include/AdProcessing.hpp b/source/application/api/use_case/ad/include/AdProcessing.hpp
new file mode 100644
index 0000000..abee75e
--- /dev/null
+++ b/source/application/api/use_case/ad/include/AdProcessing.hpp
@@ -0,0 +1,231 @@
+/*
+ * Copyright (c) 2022 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef AD_PROCESSING_HPP
+#define AD_PROCESSING_HPP
+
+#include "BaseProcessing.hpp"
+#include "TensorFlowLiteMicro.hpp"
+#include "AudioUtils.hpp"
+#include "AdMelSpectrogram.hpp"
+#include "log_macros.h"
+
+namespace arm {
+namespace app {
+
+    /**
+     * @brief   Pre-processing class for anomaly detection use case.
+     *          Implements methods declared by BasePreProcess and anything else needed
+     *          to populate input tensors ready for inference.
+     */
+    class AdPreProcess : public BasePreProcess {
+
+    public:
+        /**
+         * @brief Constructor for AdPreProcess class objects
+         * @param[in] inputTensor  input tensor pointer from the tensor arena.
+         * @param[in] melSpectrogramFrameLen MEL spectrogram's frame length
+         * @param[in] melSpectrogramFrameStride MEL spectrogram's frame stride
+         * @param[in] adModelTrainingMean Training mean for the Anomaly detection model being used.
+         */
+        explicit AdPreProcess(TfLiteTensor* inputTensor,
+                              uint32_t melSpectrogramFrameLen,
+                              uint32_t melSpectrogramFrameStride,
+                              float adModelTrainingMean);
+
+        ~AdPreProcess() = default;
+
+        /**
+         * @brief Function to invoke pre-processing and populate the input vector
+         * @param input pointer to input data. For anomaly detection, this is the pointer to
+         *              the audio data.
+         * @param inputSize Size of the data being passed in for pre-processing.
+         * @return True if successful, false otherwise.
+         */
+        bool DoPreProcess(const void* input, size_t inputSize) override;
+
+        /**
+         * @brief Getter function for audio window size computed when constructing
+         *        the class object.
+         * @return Audio window size as 32 bit unsigned integer.
+         */
+        uint32_t GetAudioWindowSize();
+
+        /**
+         * @brief Getter function for audio window stride computed when constructing
+         *        the class object.
+         * @return Audio window stride as 32 bit unsigned integer.
+         */
+        uint32_t GetAudioDataStride();
+
+        /**
+         * @brief Setter function for current audio index. This is only used for evaluating
+         *        if previously computed features can be re-used from cache.
+         */
+        void SetAudioWindowIndex(uint32_t idx);
+
+    private:
+        bool        m_validInstance{false}; /**< Indicates the current object is valid. */
+        uint32_t    m_melSpectrogramFrameLen{}; /**< MEL spectrogram's window frame length */
+        uint32_t    m_melSpectrogramFrameStride{}; /**< MEL spectrogram's window frame stride */
+        uint8_t     m_inputResizeScale{}; /**< Downscaling factor for the MEL energy matrix. */
+        uint32_t    m_numMelSpecVectorsInAudioStride{};  /**< Number of frames to move across the audio. */
+        uint32_t    m_audioDataWindowSize{}; /**< Audio window size computed based on other parameters. */
+        uint32_t    m_audioDataStride{}; /**< Audio window stride computed. */
+        uint32_t    m_numReusedFeatureVectors{}; /**< Number of MEL vectors that can be re-used */
+        uint32_t    m_audioWindowIndex{}; /**< Current audio window index (from audio's sliding window) */
+
+        audio::SlidingWindow<const int16_t> m_melWindowSlider; /**< Internal MEL spectrogram window slider */
+        audio::AdMelSpectrogram m_melSpec; /**< MEL spectrogram computation object */
+        std::function<void
+            (std::vector<int16_t>&, int, bool, size_t, size_t)> m_featureCalc; /**< Feature calculator object */
+    };
+
+    class AdPostProcess : public BasePostProcess {
+    public:
+        /**
+         * @brief Constructor for AdPostProcess object.
+         * @param[in] outputTensor Output tensor pointer.
+         */
+        explicit AdPostProcess(TfLiteTensor* outputTensor);
+
+        ~AdPostProcess() = default;
+
+        /**
+         * @brief Function to do the post-processing on the output tensor.
+         * @return True if successful, false otherwise.
+         */
+        bool DoPostProcess() override;
+
+        /**
+         * @brief Getter function for an element from the de-quantised output vector.
+         * @param index Index of the element to be retrieved.
+         * @return index represented as a 32 bit floating point number.
+         */
+        float GetOutputValue(uint32_t index);
+
+    private:
+        TfLiteTensor* m_outputTensor{}; /**< Output tensor pointer */
+        std::vector<float> m_dequantizedOutputVec{}; /**< Internal output vector */
+
+        /**
+         * @brief De-quantizes and flattens the output tensor into a vector.
+         * @tparam T template parameter to indicate data type.
+         * @return True if successful, false otherwise.
+         */
+        template<typename T>
+        bool Dequantize()
+        {
+            TfLiteTensor* tensor = this->m_outputTensor;
+            if (tensor == nullptr) {
+                printf_err("Invalid output tensor.\n");
+                return false;
+            }
+            T* tensorData = tflite::GetTensorData<T>(tensor);
+
+            uint32_t totalOutputSize = 1;
+            for (int inputDim = 0; inputDim < tensor->dims->size; inputDim++){
+                totalOutputSize *= tensor->dims->data[inputDim];
+            }
+
+            /* For getting the floating point values, we need quantization parameters */
+            QuantParams quantParams = GetTensorQuantParams(tensor);
+
+            this->m_dequantizedOutputVec = std::vector<float>(totalOutputSize, 0);
+
+            for (size_t i = 0; i < totalOutputSize; ++i) {
+                this->m_dequantizedOutputVec[i] = quantParams.scale * (tensorData[i] - quantParams.offset);
+            }
+
+            return true;
+        }
+    };
+
+    /* Templated instances available: */
+    template bool AdPostProcess::Dequantize<int8_t>();
+
+    /**
+     * @brief Generic feature calculator factory.
+     *
+     * Returns lambda function to compute features using features cache.
+     * Real features math is done by a lambda function provided as a parameter.
+     * Features are written to input tensor memory.
+     *
+     * @tparam T            feature vector type.
+     * @param inputTensor   model input tensor pointer.
+     * @param cacheSize     number of feature vectors to cache. Defined by the sliding window overlap.
+     * @param compute       features calculator function.
+     * @return              lambda function to compute features.
+     */
+    template<class T>
+    std::function<void (std::vector<int16_t>&, size_t, bool, size_t, size_t)>
+    FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
+                std::function<std::vector<T> (std::vector<int16_t>& )> compute)
+    {
+        /* Feature cache to be captured by lambda function*/
+        static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
+
+        return [=](std::vector<int16_t>& audioDataWindow,
+                   size_t index,
+                   bool useCache,
+                   size_t featuresOverlapIndex,
+                   size_t resizeScale)
+        {
+            T* tensorData = tflite::GetTensorData<T>(inputTensor);
+            std::vector<T> features;
+
+            /* Reuse features from cache if cache is ready and sliding windows overlap.
+             * Overlap is in the beginning of sliding window with a size of a feature cache. */
+            if (useCache && index < featureCache.size()) {
+                features = std::move(featureCache[index]);
+            } else {
+                features = std::move(compute(audioDataWindow));
+            }
+            auto size = features.size() / resizeScale;
+            auto sizeBytes = sizeof(T);
+
+            /* Input should be transposed and "resized" by skipping elements. */
+            for (size_t outIndex = 0; outIndex < size; outIndex++) {
+                std::memcpy(tensorData + (outIndex*size) + index, &features[outIndex*resizeScale], sizeBytes);
+            }
+
+            /* Start renewing cache as soon iteration goes out of the windows overlap. */
+            if (index >= featuresOverlapIndex / resizeScale) {
+                featureCache[index - featuresOverlapIndex / resizeScale] = std::move(features);
+            }
+        };
+    }
+
+    template std::function<void (std::vector<int16_t>&, size_t , bool, size_t, size_t)>
+    FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
+                        size_t cacheSize,
+                        std::function<std::vector<int8_t> (std::vector<int16_t>&)> compute);
+
+    template std::function<void(std::vector<int16_t>&, size_t, bool, size_t, size_t)>
+    FeatureCalc<float>(TfLiteTensor *inputTensor,
+                       size_t cacheSize,
+                       std::function<std::vector<float>(std::vector<int16_t>&)> compute);
+
+    std::function<void (std::vector<int16_t>&, int, bool, size_t, size_t)>
+    GetFeatureCalculator(audio::AdMelSpectrogram& melSpec,
+                         TfLiteTensor* inputTensor,
+                         size_t cacheSize,
+                         float trainingMean);
+
+} /* namespace app */
+} /* namespace arm */
+
+#endif /* AD_PROCESSING_HPP */
diff --git a/source/application/api/use_case/ad/include/MelSpectrogram.hpp b/source/application/api/use_case/ad/include/MelSpectrogram.hpp
new file mode 100644
index 0000000..d3ea3f7
--- /dev/null
+++ b/source/application/api/use_case/ad/include/MelSpectrogram.hpp
@@ -0,0 +1,234 @@
+/*
+ * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef MELSPECTROGRAM_HPP
+#define MELSPECTROGRAM_HPP
+
+#include "PlatformMath.hpp"
+
+#include <vector>
+#include <cstdint>
+#include <cmath>
+#include <limits>
+#include <string>
+
+namespace arm {
+namespace app {
+namespace audio {
+
+    /* Mel Spectrogram consolidated parameters */
+    class MelSpecParams {
+    public:
+        float       m_samplingFreq;
+        uint32_t    m_numFbankBins;
+        float       m_melLoFreq;
+        float       m_melHiFreq;
+        uint32_t    m_frameLen;
+        uint32_t    m_frameLenPadded;
+        bool        m_useHtkMethod;
+
+        /** @brief  Constructor */
+        MelSpecParams(const float samplingFreq, const uint32_t numFbankBins,
+                      const float melLoFreq, const float melHiFreq,
+                      const uint32_t frameLen, const bool useHtkMethod);
+
+        MelSpecParams()  = delete;
+        ~MelSpecParams() = default;
+
+        /** @brief  String representation of parameters */
+        std::string Str() const;
+    };
+
+    /**
+     * @brief   Class for Mel Spectrogram feature extraction.
+     *          Based on https://github.com/ARM-software/ML-KWS-for-MCU/blob/master/Deployment/Source/MFCC/mfcc.cpp
+     *          This class is designed to be generic and self-sufficient but
+     *          certain calculation routines can be overridden to accommodate
+     *          use-case specific requirements.
+     */
+    class MelSpectrogram {
+
+    public:
+        /**
+        * @brief        Extract Mel Spectrogram for one single small frame of
+        *               audio data e.g. 640 samples.
+        * @param[in]    audioData       Vector of audio samples to calculate
+        *               features for.
+        * @param[in]    trainingMean    Value to subtract from the the computed mel spectrogram, default 0.
+        * @return       Vector of extracted Mel Spectrogram features.
+        **/
+        std::vector<float> ComputeMelSpec(const std::vector<int16_t>& audioData, float trainingMean = 0);
+
+        /**
+         * @brief       Constructor
+         * @param[in]   params   Mel Spectrogram parameters
+        */
+        explicit MelSpectrogram(const MelSpecParams& params);
+
+        MelSpectrogram() = delete;
+        ~MelSpectrogram() = default;
+
+        /** @brief  Initialise */
+        void Init();
+
+        /**
+         * @brief        Extract Mel Spectrogram features and quantise for one single small
+         *               frame of audio data e.g. 640 samples.
+         * @param[in]    audioData      Vector of audio samples to calculate
+         *               features for.
+         * @param[in]    quantScale     quantisation scale.
+         * @param[in]    quantOffset    quantisation offset.
+         * @param[in]    trainingMean   training mean.
+         * @return       Vector of extracted quantised Mel Spectrogram features.
+         **/
+        template<typename T>
+        std::vector<T> MelSpecComputeQuant(const std::vector<int16_t>& audioData,
+                                           const float quantScale,
+                                           const int quantOffset,
+                                           float trainingMean = 0)
+        {
+            this->ComputeMelSpec(audioData, trainingMean);
+            float minVal = std::numeric_limits<T>::min();
+            float maxVal = std::numeric_limits<T>::max();
+
+            std::vector<T> melSpecOut(this->m_params.m_numFbankBins);
+            const size_t numFbankBins = this->m_params.m_numFbankBins;
+
+            /* Quantize to T. */
+            for (size_t k = 0; k < numFbankBins; ++k) {
+                auto quantizedEnergy = std::round(((this->m_melEnergies[k]) / quantScale) + quantOffset);
+                melSpecOut[k] = static_cast<T>(std::min<float>(std::max<float>(quantizedEnergy, minVal), maxVal));
+            }
+
+            return melSpecOut;
+        }
+
+        /* Constants */
+        static constexpr float ms_logStep = /*logf(6.4)*/ 1.8562979903656 / 27.0;
+        static constexpr float ms_freqStep = 200.0 / 3;
+        static constexpr float ms_minLogHz = 1000.0;
+        static constexpr float ms_minLogMel = ms_minLogHz / ms_freqStep;
+
+    protected:
+        /**
+         * @brief       Project input frequency to Mel Scale.
+         * @param[in]   freq          input frequency in floating point
+         * @param[in]   useHTKMethod  bool to signal if HTK method is to be
+         *                            used for calculation
+         * @return      Mel transformed frequency in floating point
+         **/
+        static float MelScale(const float    freq,
+                              const bool     useHTKMethod = true);
+
+        /**
+         * @brief       Inverse Mel transform - convert MEL warped frequency
+         *              back to normal frequency
+         * @param[in]   melFreq          Mel frequency in floating point
+         * @param[in]   useHTKMethod  bool to signal if HTK method is to be
+         *                            used for calculation
+         * @return      Real world frequency in floating point
+         **/
+        static float InverseMelScale(const float melFreq,
+                                     const bool  useHTKMethod = true);
+
+        /**
+         * @brief       Populates MEL energies after applying the MEL filter
+         *              bank weights and adding them up to be placed into
+         *              bins, according to the filter bank's first and last
+         *              indices (pre-computed for each filter bank element
+         *              by CreateMelFilterBank function).
+         * @param[in]   fftVec                  Vector populated with FFT magnitudes
+         * @param[in]   melFilterBank           2D Vector with filter bank weights
+         * @param[in]   filterBankFilterFirst   Vector containing the first indices of filter bank
+         *                                      to be used for each bin.
+         * @param[in]   filterBankFilterLast    Vector containing the last indices of filter bank
+         *                                      to be used for each bin.
+         * @param[out]  melEnergies             Pre-allocated vector of MEL energies to be
+         *                                      populated.
+         * @return      true if successful, false otherwise
+         */
+        virtual bool ApplyMelFilterBank(
+                std::vector<float>&                 fftVec,
+                std::vector<std::vector<float>>&    melFilterBank,
+                std::vector<uint32_t>&               filterBankFilterFirst,
+                std::vector<uint32_t>&               filterBankFilterLast,
+                std::vector<float>&                 melEnergies);
+
+        /**
+         * @brief           Converts the Mel energies for logarithmic scale
+         * @param[in,out]   melEnergies 1D vector of Mel energies
+         **/
+        virtual void ConvertToLogarithmicScale(std::vector<float>& melEnergies);
+
+        /**
+         * @brief       Given the low and high Mel values, get the normaliser
+         *              for weights to be applied when populating the filter
+         *              bank.
+         * @param[in]   leftMel      low Mel frequency value
+         * @param[in]   rightMel     high Mel frequency value
+         * @param[in]   useHTKMethod bool to signal if HTK method is to be
+         *                           used for calculation
+         * @return      Return float value to be applied 
+         *              when populating the filter bank.
+         */
+        virtual float GetMelFilterBankNormaliser(
+                const float&   leftMel,
+                const float&   rightMel,
+                const bool     useHTKMethod);
+
+    private:
+        MelSpecParams                   m_params;
+        std::vector<float>              m_frame;
+        std::vector<float>              m_buffer;
+        std::vector<float>              m_melEnergies;
+        std::vector<float>              m_windowFunc;
+        std::vector<std::vector<float>> m_melFilterBank;
+        std::vector<uint32_t>            m_filterBankFilterFirst;
+        std::vector<uint32_t>            m_filterBankFilterLast;
+        bool                            m_filterBankInitialised;
+        arm::app::math::FftInstance     m_fftInstance;
+
+        /**
+         * @brief       Initialises the filter banks.
+         **/
+        void InitMelFilterBank();
+
+        /**
+         * @brief       Signals whether the instance of MelSpectrogram has had its
+         *              required buffers initialised
+         * @return      True if initialised, false otherwise
+         **/
+        bool IsMelFilterBankInited() const;
+
+        /**
+         * @brief       Create mel filter banks for Mel Spectrogram calculation.
+         * @return      2D vector of floats
+         **/
+        std::vector<std::vector<float>> CreateMelFilterBank();
+
+        /**
+         * @brief       Computes the magnitude from an interleaved complex array
+         **/
+        void ConvertToPowerSpectrum();
+
+    };
+
+} /* namespace audio */
+} /* namespace app */
+} /* namespace arm */
+
+
+#endif /* MELSPECTROGRAM_HPP */