MLECO-3183: Refactoring application sources

Platform agnostic application sources are moved into application
api module with their own independent CMake projects.

Changes for MLECO-3080 also included - they create CMake projects
individial API's (again, platform agnostic) that dependent on the
common logic. The API for KWS_API "joint" API has been removed and
now the use case relies on individual KWS, and ASR API libraries.

Change-Id: I1f7748dc767abb3904634a04e0991b74ac7b756d
Signed-off-by: Kshitij Sisodia <kshitij.sisodia@arm.com>
diff --git a/source/application/api/use_case/noise_reduction/CMakeLists.txt b/source/application/api/use_case/noise_reduction/CMakeLists.txt
new file mode 100644
index 0000000..5fa9a73
--- /dev/null
+++ b/source/application/api/use_case/noise_reduction/CMakeLists.txt
@@ -0,0 +1,40 @@
+#----------------------------------------------------------------------------
+#  Copyright (c) 2022 Arm Limited. All rights reserved.
+#  SPDX-License-Identifier: Apache-2.0
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+#----------------------------------------------------------------------------
+#########################################################
+#            NOISE REDUCTION API library                #
+#########################################################
+cmake_minimum_required(VERSION 3.15.6)
+
+set(NOISE_REDUCTION_API_TARGET noise_reduction_api)
+project(${NOISE_REDUCTION_API_TARGET}
+        DESCRIPTION     "Noise reduction use case API library"
+        LANGUAGES       C CXX)
+
+# Create static library
+add_library(${NOISE_REDUCTION_API_TARGET} STATIC
+        src/RNNoiseProcessing.cc
+        src/RNNoiseFeatureProcessor.cc
+        src/RNNoiseModel.cc)
+
+target_include_directories(${NOISE_REDUCTION_API_TARGET} PUBLIC include)
+
+target_link_libraries(${NOISE_REDUCTION_API_TARGET} PUBLIC common_api)
+
+message(STATUS "*******************************************************")
+message(STATUS "Library                                : " ${NOISE_REDUCTION_API_TARGET})
+message(STATUS "CMAKE_SYSTEM_PROCESSOR                 : " ${CMAKE_SYSTEM_PROCESSOR})
+message(STATUS "*******************************************************")
diff --git a/source/application/api/use_case/noise_reduction/include/RNNoiseFeatureProcessor.hpp b/source/application/api/use_case/noise_reduction/include/RNNoiseFeatureProcessor.hpp
new file mode 100644
index 0000000..cbf0e4e
--- /dev/null
+++ b/source/application/api/use_case/noise_reduction/include/RNNoiseFeatureProcessor.hpp
@@ -0,0 +1,341 @@
+/*
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef RNNOISE_FEATURE_PROCESSOR_HPP
+#define RNNOISE_FEATURE_PROCESSOR_HPP
+
+#include "PlatformMath.hpp"
+#include <cstdint>
+#include <vector>
+#include <array>
+#include <tuple>
+
+namespace arm {
+namespace app {
+namespace rnn {
+
+    using vec1D32F = std::vector<float>;
+    using vec2D32F = std::vector<vec1D32F>;
+    using arrHp = std::array<float, 2>;
+    using math::FftInstance;
+    using math::FftType;
+
+    class FrameFeatures {
+    public:
+        bool m_silence{false};        /* If frame contains silence or not. */
+        vec1D32F m_featuresVec{};     /* Calculated feature vector to feed to model. */
+        vec1D32F m_fftX{};            /* Vector of floats arranged to represent complex numbers. */
+        vec1D32F m_fftP{};            /* Vector of floats arranged to represent complex numbers. */
+        vec1D32F m_Ex{};              /* Spectral band energy for audio x. */
+        vec1D32F m_Ep{};              /* Spectral band energy for pitch p. */
+        vec1D32F m_Exp{};             /* Correlated spectral energy between x and p. */
+    };
+
+    /**
+     * @brief   RNNoise pre and post processing class based on the 2018 paper from
+     *          Jan-Marc Valin. Recommended reading:
+     *          - https://jmvalin.ca/demo/rnnoise/
+     *          - https://arxiv.org/abs/1709.08243
+     **/
+    class RNNoiseFeatureProcessor {
+    /* Public interface */
+    public:
+        RNNoiseFeatureProcessor();
+        ~RNNoiseFeatureProcessor() = default;
+
+        /**
+         * @brief        Calculates the features from a given audio buffer ready to be sent to RNNoise model.
+         * @param[in]    audioData   Pointer to the floating point vector
+         *                           with audio data (within the numerical
+         *                           limits of int16_t type).
+         * @param[in]    audioLen    Number of elements in the audio window.
+         * @param[out]   features    FrameFeatures object reference.
+         **/
+        void PreprocessFrame(const float*   audioData,
+                             size_t   audioLen,
+                             FrameFeatures& features);
+
+        /**
+         * @brief        Use the RNNoise model output gain values with pre-processing features
+         *               to generate audio with noise suppressed.
+         * @param[in]    modelOutput   Output gain values from model.
+         * @param[in]    features      Calculated features from pre-processing step.
+         * @param[out]   outFrame      Output frame to be populated.
+         **/
+        void PostProcessFrame(vec1D32F& modelOutput, FrameFeatures& features,  vec1D32F& outFrame);
+
+
+    /* Public constants */
+    public:
+        static constexpr uint32_t FRAME_SIZE_SHIFT{2};
+        static constexpr uint32_t FRAME_SIZE{512};
+        static constexpr uint32_t WINDOW_SIZE{2 * FRAME_SIZE};
+        static constexpr uint32_t FREQ_SIZE{FRAME_SIZE + 1};
+
+        static constexpr uint32_t PITCH_MIN_PERIOD{64};
+        static constexpr uint32_t PITCH_MAX_PERIOD{820};
+        static constexpr uint32_t PITCH_FRAME_SIZE{1024};
+        static constexpr uint32_t PITCH_BUF_SIZE{PITCH_MAX_PERIOD + PITCH_FRAME_SIZE};
+
+        static constexpr uint32_t NB_BANDS{22};
+        static constexpr uint32_t CEPS_MEM{8};
+        static constexpr uint32_t NB_DELTA_CEPS{6};
+
+        static constexpr uint32_t NB_FEATURES{NB_BANDS + 3*NB_DELTA_CEPS + 2};
+
+    /* Private functions */
+    private:
+
+        /**
+         * @brief   Initialises the half window and DCT tables.
+         */
+        void InitTables();
+
+        /**
+         * @brief           Applies a bi-quadratic filter over the audio window.
+         * @param[in]       bHp           Constant coefficient set b (arrHp type).
+         * @param[in]       aHp           Constant coefficient set a (arrHp type).
+         * @param[in,out]   memHpX        Coefficients populated by this function.
+         * @param[in,out]   audioWindow   Floating point vector with audio data.
+         **/
+        void BiQuad(
+            const arrHp& bHp,
+            const arrHp& aHp,
+            arrHp& memHpX,
+            vec1D32F& audioWindow);
+
+        /**
+         * @brief        Computes features from the "filtered" audio window.
+         * @param[in]    audioWindow   Floating point vector with audio data.
+         * @param[out]   features      FrameFeatures object reference.
+         **/
+        void ComputeFrameFeatures(vec1D32F& audioWindow, FrameFeatures& features);
+
+        /**
+         * @brief        Runs analysis on the audio buffer.
+         * @param[in]    audioWindow   Floating point vector with audio data.
+         * @param[out]   fft           Floating point FFT vector containing real and
+         *                             imaginary pairs of elements. NOTE: this vector
+         *                             does not contain the mirror image (conjugates)
+         *                             part of the spectrum.
+         * @param[out]   energy        Computed energy for each band in the Bark scale.
+         * @param[out]   analysisMem   Buffer sequentially, but partially,
+         *                             populated with new audio data.
+         **/
+        void FrameAnalysis(
+            const vec1D32F& audioWindow,
+            vec1D32F& fft,
+            vec1D32F& energy,
+            vec1D32F& analysisMem);
+
+        /**
+         * @brief               Applies the window function, in-place, over the given
+         *                      floating point buffer.
+         * @param[in,out]   x   Buffer the window will be applied to.
+         **/
+        void ApplyWindow(vec1D32F& x);
+
+        /**
+         * @brief        Computes the FFT for a given vector.
+         * @param[in]    x     Vector to compute the FFT from.
+         * @param[out]   fft   Floating point FFT vector containing real and
+         *                     imaginary pairs of elements. NOTE: this vector
+         *                     does not contain the mirror image (conjugates)
+         *                     part of the spectrum.
+         **/
+        void ForwardTransform(
+            vec1D32F& x,
+            vec1D32F& fft);
+
+        /**
+         * @brief        Computes band energy for each of the 22 Bark scale bands.
+         * @param[in]    fft_X   FFT spectrum (as computed by ForwardTransform).
+         * @param[out]   bandE   Vector with 22 elements populated with energy for
+         *                       each band.
+         **/
+        void ComputeBandEnergy(const vec1D32F& fft_X, vec1D32F& bandE);
+
+        /**
+         * @brief        Computes band energy correlation.
+         * @param[in]    X       FFT vector X.
+         * @param[in]    P       FFT vector P.
+         * @param[out]   bandC   Vector with 22 elements populated with band energy
+         *                       correlation for the two input FFT vectors.
+         **/
+        void ComputeBandCorr(const vec1D32F& X, const vec1D32F& P, vec1D32F& bandC);
+
+        /**
+         * @brief        Performs pitch auto-correlation for a given vector for
+         *               given lag.
+         * @param[in]    x     Input vector.
+         * @param[out]   ac    Auto-correlation output vector.
+         * @param[in]    lag   Lag value.
+         * @param[in]    n     Number of elements to consider for correlation
+         *                     computation.
+         **/
+        void AutoCorr(const vec1D32F &x,
+                     vec1D32F &ac,
+                     size_t lag,
+                     size_t n);
+
+        /**
+         * @brief       Computes pitch cross-correlation.
+         * @param[in]   x          Input vector 1.
+         * @param[in]   y          Input vector 2.
+         * @param[out]  xCorr         Cross-correlation output vector.
+         * @param[in]   len        Number of elements to consider for correlation.
+         *                         computation.
+         * @param[in]   maxPitch   Maximum pitch.
+         **/
+        void PitchXCorr(
+            const vec1D32F& x,
+            const vec1D32F& y,
+            vec1D32F& xCorr,
+            size_t len,
+            size_t maxPitch);
+
+        /**
+         * @brief        Computes "Linear Predictor Coefficients".
+         * @param[in]    ac    Correlation vector.
+         * @param[in]    p     Number of elements of input vector to consider.
+         * @param[out]   lpc   Output coefficients vector.
+         **/
+        void LPC(const vec1D32F& ac, int32_t p, vec1D32F& lpc);
+
+        /**
+         * @brief        Custom FIR implementation.
+         * @param[in]    num   FIR coefficient vector.
+         * @param[in]    N     Number of elements.
+         * @param[out]   x     Vector to be be processed.
+         **/
+        void Fir5(const vec1D32F& num, uint32_t N, vec1D32F& x);
+
+        /**
+         * @brief           Down-sample the pitch buffer.
+         * @param[in,out]   pitchBuf     Pitch buffer.
+         * @param[in]       pitchBufSz   Buffer size.
+         **/
+        void PitchDownsample(vec1D32F& pitchBuf, size_t pitchBufSz);
+
+        /**
+         * @brief       Pitch search function.
+         * @param[in]   xLP        Shifted pitch buffer input.
+         * @param[in]   y          Pitch buffer input.
+         * @param[in]   len        Length to search for.
+         * @param[in]   maxPitch   Maximum pitch.
+         * @return      pitch index.
+         **/
+        int PitchSearch(vec1D32F& xLp, vec1D32F& y, uint32_t len, uint32_t maxPitch);
+
+        /**
+         * @brief       Finds the "best" pitch from the buffer.
+         * @param[in]   xCorr      Pitch correlation vector.
+         * @param[in]   y          Pitch buffer input.
+         * @param[in]   len        Length to search for.
+         * @param[in]   maxPitch   Maximum pitch.
+         * @return      pitch array (2 elements).
+         **/
+        arrHp FindBestPitch(vec1D32F& xCorr, vec1D32F& y, uint32_t len, uint32_t maxPitch);
+
+        /**
+         * @brief           Remove pitch period doubling errors.
+         * @param[in,out]   pitchBuf     Pitch buffer vector.
+         * @param[in]       maxPeriod    Maximum period.
+         * @param[in]       minPeriod    Minimum period.
+         * @param[in]       frameSize    Frame size.
+         * @param[in]       pitchIdx0_   Pitch index 0.
+         * @return          pitch index.
+         **/
+        int RemoveDoubling(
+                vec1D32F& pitchBuf,
+                uint32_t maxPeriod,
+                uint32_t minPeriod,
+                uint32_t frameSize,
+                size_t pitchIdx0_);
+
+        /**
+         * @brief       Computes pitch gain.
+         * @param[in]   xy   Single xy cross correlation value.
+         * @param[in]   xx   Single xx auto correlation value.
+         * @param[in]   yy   Single yy auto correlation value.
+         * @return      Calculated pitch gain.
+         **/
+        float ComputePitchGain(float xy, float xx, float yy);
+
+        /**
+         * @brief        Computes DCT vector from the given input.
+         * @param[in]    input    Input vector.
+         * @param[out]   output   Output vector with DCT coefficients.
+         **/
+        void DCT(vec1D32F& input, vec1D32F& output);
+
+        /**
+         * @brief        Perform inverse fourier transform on complex spectral vector.
+         * @param[out]   out      Output vector.
+         * @param[in]    fftXIn   Vector of floats arranged to represent complex numbers interleaved.
+         **/
+        void InverseTransform(vec1D32F& out, vec1D32F& fftXIn);
+
+        /**
+         * @brief       Perform pitch filtering.
+         * @param[in]   features   Object with pre-processing calculated frame features.
+         * @param[in]   g          Gain values.
+         **/
+        void PitchFilter(FrameFeatures& features, vec1D32F& g);
+
+        /**
+         * @brief        Interpolate the band gain values.
+         * @param[out]   g       Gain values.
+         * @param[in]    bandE   Vector with 22 elements populated with energy for
+         *                       each band.
+         **/
+        void InterpBandGain(vec1D32F& g, vec1D32F& bandE);
+
+        /**
+         * @brief        Create de-noised frame.
+         * @param[out]   outFrame   Output vector for storing the created audio frame.
+         * @param[in]    fftY       Gain adjusted complex spectral vector.
+         */
+        void FrameSynthesis(vec1D32F& outFrame, vec1D32F& fftY);
+
+    /* Private objects */
+    private:
+        FftInstance m_fftInstReal;  /* FFT instance for real numbers */
+        FftInstance m_fftInstCmplx; /* FFT instance for complex numbers */
+        vec1D32F m_halfWindow;      /* Window coefficients */
+        vec1D32F m_dctTable;        /* DCT table */
+        vec1D32F m_analysisMem;     /* Buffer used for frame analysis */
+        vec2D32F m_cepstralMem;     /* Cepstral coefficients */
+        size_t m_memId;             /* memory ID */
+        vec1D32F m_synthesisMem;    /* Synthesis mem (used by post-processing) */
+        vec1D32F m_pitchBuf;        /* Pitch buffer */
+        float m_lastGain;           /* Last gain calculated */
+        int m_lastPeriod;           /* Last period calculated */
+        arrHp m_memHpX;             /* HpX coefficients. */
+        vec1D32F m_lastGVec;        /* Last gain vector (used by post-processing) */
+
+        /* Constants */
+        const std::array <uint32_t, NB_BANDS> m_eband5ms {
+            0,  1,  2,  3,  4,  5,  6,  7,  8, 10,  12,
+            14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100};
+    };
+
+
+} /* namespace rnn */
+} /* namespace app */
+} /* namespace arm */
+
+#endif /* RNNOISE_FEATURE_PROCESSOR_HPP */
diff --git a/source/application/api/use_case/noise_reduction/include/RNNoiseModel.hpp b/source/application/api/use_case/noise_reduction/include/RNNoiseModel.hpp
new file mode 100644
index 0000000..3d2f23c
--- /dev/null
+++ b/source/application/api/use_case/noise_reduction/include/RNNoiseModel.hpp
@@ -0,0 +1,78 @@
+/*
+ * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef RNNOISE_MODEL_HPP
+#define RNNOISE_MODEL_HPP
+
+#include "Model.hpp"
+
+extern const uint32_t g_NumInputFeatures;
+extern const uint32_t g_FrameLength;
+extern const uint32_t g_FrameStride;
+
+namespace arm {
+namespace app {
+
+    class RNNoiseModel : public Model {
+    public:
+        /**
+         * @brief Runs inference for RNNoise model.
+         *
+         * Call CopyGruStates so GRU state outputs are copied to GRU state inputs before the inference run.
+         * Run ResetGruState() method to set states to zero before starting processing logically related data.
+         * @return True if inference succeeded, False - otherwise
+         */
+        bool RunInference() override;
+
+        /**
+         * @brief Sets GRU input states to zeros.
+         * Call this method before starting processing the new sequence of logically related data.
+         */
+        void ResetGruState();
+
+        /**
+        * @brief Copy current GRU output states to input states.
+        * Call this method before starting processing the next sequence of logically related data.
+         */
+        bool CopyGruStates();
+
+        /* Which index of model outputs does the main output (gains) come from. */
+        const size_t m_indexForModelOutput = 1;
+
+    protected:
+        /** @brief   Gets the reference to op resolver interface class. */
+        const tflite::MicroOpResolver& GetOpResolver() override;
+
+        /** @brief   Adds operations to the op resolver instance. */
+        bool EnlistOperations() override;
+
+        /*
+        Each inference after the first needs to copy 3 GRU states from a output index to input index (model dependent):
+        0 -> 3, 2 -> 2, 3 -> 1
+        */
+        const std::vector<std::pair<size_t, size_t>> m_gruStateMap = {{0,3}, {2, 2}, {3, 1}};
+    private:
+        /* Maximum number of individual operations that can be enlisted. */
+        static constexpr int ms_maxOpCnt = 15;
+
+        /* A mutable op resolver instance. */
+        tflite::MicroMutableOpResolver<ms_maxOpCnt> m_opResolver;
+    };
+
+} /* namespace app */
+} /* namespace arm */
+
+#endif /* RNNOISE_MODEL_HPP */
diff --git a/source/application/api/use_case/noise_reduction/include/RNNoiseProcessing.hpp b/source/application/api/use_case/noise_reduction/include/RNNoiseProcessing.hpp
new file mode 100644
index 0000000..15e62d9
--- /dev/null
+++ b/source/application/api/use_case/noise_reduction/include/RNNoiseProcessing.hpp
@@ -0,0 +1,113 @@
+/*
+ * Copyright (c) 2022 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef RNNOISE_PROCESSING_HPP
+#define RNNOISE_PROCESSING_HPP
+
+#include "BaseProcessing.hpp"
+#include "Model.hpp"
+#include "RNNoiseFeatureProcessor.hpp"
+
+namespace arm {
+namespace app {
+
+    /**
+     * @brief   Pre-processing class for Noise Reduction use case.
+     *          Implements methods declared by BasePreProcess and anything else needed
+     *          to populate input tensors ready for inference.
+     */
+    class RNNoisePreProcess : public BasePreProcess {
+
+    public:
+        /**
+         * @brief           Constructor
+         * @param[in]       inputTensor        Pointer to the TFLite Micro input Tensor.
+         * @param[in/out]   featureProcessor   RNNoise specific feature extractor object.
+         * @param[in/out]   frameFeatures      RNNoise specific features shared between pre & post-processing.
+         *
+         **/
+        explicit RNNoisePreProcess(TfLiteTensor* inputTensor,
+                                   std::shared_ptr<rnn::RNNoiseFeatureProcessor> featureProcessor,
+                                   std::shared_ptr<rnn::FrameFeatures> frameFeatures);
+
+        /**
+         * @brief       Should perform pre-processing of 'raw' input audio data and load it into
+         *              TFLite Micro input tensors ready for inference
+         * @param[in]   input      Pointer to the data that pre-processing will work on.
+         * @param[in]   inputSize  Size of the input data.
+         * @return      true if successful, false otherwise.
+         **/
+        bool DoPreProcess(const void* input, size_t inputSize) override;
+
+    private:
+        TfLiteTensor* m_inputTensor;                        /* Model input tensor. */
+        std::shared_ptr<rnn::RNNoiseFeatureProcessor> m_featureProcessor;   /* RNNoise feature processor shared between pre & post-processing. */
+        std::shared_ptr<rnn::FrameFeatures> m_frameFeatures;                /* RNNoise features shared between pre & post-processing. */
+        rnn::vec1D32F m_audioFrame;                         /* Audio frame cast to FP32 */
+
+        /**
+         * @brief            Quantize the given features and populate the input Tensor.
+         * @param[in]        inputFeatures   Vector of floating point features to quantize.
+         * @param[in]        quantScale      Quantization scale for the inputTensor.
+         * @param[in]        quantOffset     Quantization offset for the inputTensor.
+         * @param[in,out]    inputTensor     TFLite micro tensor to populate.
+         **/
+        static void QuantizeAndPopulateInput(rnn::vec1D32F& inputFeatures,
+                float quantScale, int quantOffset,
+                TfLiteTensor* inputTensor);
+    };
+
+    /**
+     * @brief   Post-processing class for Noise Reduction use case.
+     *          Implements methods declared by BasePostProcess and anything else needed
+     *          to populate result vector.
+     */
+    class RNNoisePostProcess : public BasePostProcess {
+
+    public:
+        /**
+         * @brief           Constructor
+         * @param[in]       outputTensor         Pointer to the TFLite Micro output Tensor.
+         * @param[out]      denoisedAudioFrame   Vector to store the final denoised audio frame.
+         * @param[in/out]   featureProcessor     RNNoise specific feature extractor object.
+         * @param[in/out]   frameFeatures        RNNoise specific features shared between pre & post-processing.
+         **/
+        RNNoisePostProcess(TfLiteTensor* outputTensor,
+                           std::vector<int16_t>& denoisedAudioFrame,
+                           std::shared_ptr<rnn::RNNoiseFeatureProcessor> featureProcessor,
+                           std::shared_ptr<rnn::FrameFeatures> frameFeatures);
+
+        /**
+         * @brief       Should perform post-processing of the result of inference then
+         *              populate result data for any later use.
+         * @return      true if successful, false otherwise.
+         **/
+        bool DoPostProcess() override;
+
+    private:
+        TfLiteTensor* m_outputTensor;                       /* Model output tensor. */
+        std::vector<int16_t>& m_denoisedAudioFrame;         /* Vector to store the final denoised frame. */
+        rnn::vec1D32F m_denoisedAudioFrameFloat;            /* Internal vector to store the final denoised frame (FP32). */
+        std::shared_ptr<rnn::RNNoiseFeatureProcessor> m_featureProcessor;   /* RNNoise feature processor shared between pre & post-processing. */
+        std::shared_ptr<rnn::FrameFeatures> m_frameFeatures;                /* RNNoise features shared between pre & post-processing. */
+        std::vector<float> m_modelOutputFloat;              /* Internal vector to store de-quantized model output. */
+
+    };
+
+} /* namespace app */
+} /* namespace arm */
+
+#endif /* RNNOISE_PROCESSING_HPP */
\ No newline at end of file
diff --git a/source/application/api/use_case/noise_reduction/src/RNNoiseFeatureProcessor.cc b/source/application/api/use_case/noise_reduction/src/RNNoiseFeatureProcessor.cc
new file mode 100644
index 0000000..036894c
--- /dev/null
+++ b/source/application/api/use_case/noise_reduction/src/RNNoiseFeatureProcessor.cc
@@ -0,0 +1,892 @@
+/*
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "RNNoiseFeatureProcessor.hpp"
+#include "log_macros.h"
+
+#include <algorithm>
+#include <cmath>
+#include <cstring>
+
+namespace arm {
+namespace app {
+namespace rnn {
+
+#define VERIFY(x)                                   \
+do {                                                \
+    if (!(x)) {                                     \
+        printf_err("Assert failed:" #x "\n");       \
+        exit(1);                                    \
+    }                                               \
+} while(0)
+
+RNNoiseFeatureProcessor::RNNoiseFeatureProcessor() :
+        m_halfWindow(FRAME_SIZE, 0),
+        m_dctTable(NB_BANDS * NB_BANDS),
+        m_analysisMem(FRAME_SIZE, 0),
+        m_cepstralMem(CEPS_MEM, vec1D32F(NB_BANDS, 0)),
+        m_memId{0},
+        m_synthesisMem(FRAME_SIZE, 0),
+        m_pitchBuf(PITCH_BUF_SIZE, 0),
+        m_lastGain{0.0},
+        m_lastPeriod{0},
+        m_memHpX{},
+        m_lastGVec(NB_BANDS, 0)
+{
+    constexpr uint32_t numFFt = 2 * FRAME_SIZE;
+    static_assert(numFFt != 0, "Num FFT can't be 0");
+
+    math::MathUtils::FftInitF32(numFFt, this->m_fftInstReal, FftType::real);
+    math::MathUtils::FftInitF32(numFFt, this->m_fftInstCmplx, FftType::complex);
+    this->InitTables();
+}
+
+void RNNoiseFeatureProcessor::PreprocessFrame(const float*   audioData,
+                                              const size_t   audioLen,
+                                              FrameFeatures& features)
+{
+    /* Note audioWindow is modified in place */
+    const arrHp aHp {-1.99599, 0.99600 };
+    const arrHp bHp {-2.00000, 1.00000 };
+
+    vec1D32F audioWindow{audioData, audioData + audioLen};
+
+    this->BiQuad(bHp, aHp, this->m_memHpX, audioWindow);
+    this->ComputeFrameFeatures(audioWindow, features);
+}
+
+void RNNoiseFeatureProcessor::PostProcessFrame(vec1D32F& modelOutput, FrameFeatures& features, vec1D32F& outFrame)
+{
+    std::vector<float> outputBands = modelOutput;
+    std::vector<float> gain(FREQ_SIZE, 0);
+
+    if (!features.m_silence) {
+        PitchFilter(features, outputBands);
+        for (size_t i = 0; i < NB_BANDS; i++) {
+            float alpha = .6f;
+            outputBands[i] = std::max(outputBands[i], alpha * m_lastGVec[i]);
+            m_lastGVec[i] = outputBands[i];
+        }
+        InterpBandGain(gain, outputBands);
+        for (size_t i = 0; i < FREQ_SIZE; i++) {
+            features.m_fftX[2 * i] *= gain[i];  /* Real. */
+            features.m_fftX[2 * i + 1] *= gain[i];  /*imaginary. */
+
+        }
+
+    }
+
+    FrameSynthesis(outFrame, features.m_fftX);
+}
+
+void RNNoiseFeatureProcessor::InitTables()
+{
+    constexpr float pi = M_PI;
+    constexpr float halfPi = M_PI / 2;
+    constexpr float halfPiOverFrameSz = halfPi/FRAME_SIZE;
+
+    for (uint32_t i = 0; i < FRAME_SIZE; i++) {
+        const float sinVal = math::MathUtils::SineF32(halfPiOverFrameSz * (i + 0.5f));
+        m_halfWindow[i] = math::MathUtils::SineF32(halfPi * sinVal * sinVal);
+    }
+
+    for (uint32_t i = 0; i < NB_BANDS; i++) {
+        for (uint32_t j = 0; j < NB_BANDS; j++) {
+            m_dctTable[i * NB_BANDS + j] = math::MathUtils::CosineF32((i + 0.5f) * j * pi / NB_BANDS);
+        }
+        m_dctTable[i * NB_BANDS] *= math::MathUtils::SqrtF32(0.5f);
+    }
+}
+
+void RNNoiseFeatureProcessor::BiQuad(
+        const arrHp& bHp,
+        const arrHp& aHp,
+        arrHp& memHpX,
+        vec1D32F& audioWindow)
+{
+    for (float& audioElement : audioWindow) {
+        const auto xi = audioElement;
+        const auto yi = audioElement + memHpX[0];
+        memHpX[0] = memHpX[1] + (bHp[0] * xi - aHp[0] * yi);
+        memHpX[1] = (bHp[1] * xi - aHp[1] * yi);
+        audioElement = yi;
+    }
+}
+
+void RNNoiseFeatureProcessor::ComputeFrameFeatures(vec1D32F& audioWindow,
+                                                   FrameFeatures& features)
+{
+    this->FrameAnalysis(audioWindow,
+                        features.m_fftX,
+                        features.m_Ex,
+                        this->m_analysisMem);
+
+    float energy = 0.0;
+
+    vec1D32F Ly(NB_BANDS, 0);
+    vec1D32F p(WINDOW_SIZE, 0);
+    vec1D32F pitchBuf(PITCH_BUF_SIZE >> 1, 0);
+
+    VERIFY(PITCH_BUF_SIZE >= this->m_pitchBuf.size());
+    std::copy_n(this->m_pitchBuf.begin() + FRAME_SIZE,
+                PITCH_BUF_SIZE - FRAME_SIZE,
+                this->m_pitchBuf.begin());
+
+    VERIFY(FRAME_SIZE <= audioWindow.size() && PITCH_BUF_SIZE > FRAME_SIZE);
+    std::copy_n(audioWindow.begin(),
+                FRAME_SIZE,
+                this->m_pitchBuf.begin() + PITCH_BUF_SIZE - FRAME_SIZE);
+
+    this->PitchDownsample(pitchBuf, PITCH_BUF_SIZE);
+
+    VERIFY(pitchBuf.size() > PITCH_MAX_PERIOD/2);
+    vec1D32F xLp(pitchBuf.size() - PITCH_MAX_PERIOD/2, 0);
+    std::copy_n(pitchBuf.begin() + PITCH_MAX_PERIOD/2, xLp.size(), xLp.begin());
+
+    int pitchIdx = this->PitchSearch(xLp, pitchBuf,
+            PITCH_FRAME_SIZE, (PITCH_MAX_PERIOD - (3*PITCH_MIN_PERIOD)));
+
+    pitchIdx = this->RemoveDoubling(
+                pitchBuf,
+                PITCH_MAX_PERIOD,
+                PITCH_MIN_PERIOD,
+                PITCH_FRAME_SIZE,
+                PITCH_MAX_PERIOD - pitchIdx);
+
+    size_t stIdx = PITCH_BUF_SIZE - WINDOW_SIZE - pitchIdx;
+    VERIFY((static_cast<int>(PITCH_BUF_SIZE) - static_cast<int>(WINDOW_SIZE) - pitchIdx) >= 0);
+    std::copy_n(this->m_pitchBuf.begin() + stIdx, WINDOW_SIZE, p.begin());
+
+    this->ApplyWindow(p);
+    this->ForwardTransform(p, features.m_fftP);
+    this->ComputeBandEnergy(features.m_fftP, features.m_Ep);
+    this->ComputeBandCorr(features.m_fftX, features.m_fftP, features.m_Exp);
+
+    for (uint32_t i = 0 ; i < NB_BANDS; ++i) {
+        features.m_Exp[i] /= math::MathUtils::SqrtF32(
+            0.001f + features.m_Ex[i] * features.m_Ep[i]);
+    }
+
+    vec1D32F dctVec(NB_BANDS, 0);
+    this->DCT(features.m_Exp, dctVec);
+
+    features.m_featuresVec = vec1D32F (NB_FEATURES, 0);
+    for (uint32_t i = 0; i < NB_DELTA_CEPS; ++i) {
+        features.m_featuresVec[NB_BANDS + 2*NB_DELTA_CEPS + i] = dctVec[i];
+    }
+
+    features.m_featuresVec[NB_BANDS + 2*NB_DELTA_CEPS] -= 1.3;
+    features.m_featuresVec[NB_BANDS + 2*NB_DELTA_CEPS + 1] -= 0.9;
+    features.m_featuresVec[NB_BANDS + 3*NB_DELTA_CEPS] = 0.01 * (static_cast<int>(pitchIdx) - 300);
+
+    float logMax = -2.f;
+    float follow = -2.f;
+    for (uint32_t i = 0; i < NB_BANDS; ++i) {
+        Ly[i] = log10f(1e-2f + features.m_Ex[i]);
+        Ly[i] = std::max<float>(logMax - 7, std::max<float>(follow - 1.5, Ly[i]));
+        logMax = std::max<float>(logMax, Ly[i]);
+        follow = std::max<float>(follow - 1.5, Ly[i]);
+        energy += features.m_Ex[i];
+    }
+
+    /* If there's no audio avoid messing up the state. */
+    features.m_silence = true;
+    if (energy < 0.04) {
+        return;
+    } else {
+        features.m_silence = false;
+    }
+
+    this->DCT(Ly, features.m_featuresVec);
+    features.m_featuresVec[0] -= 12.0;
+    features.m_featuresVec[1] -= 4.0;
+
+    VERIFY(CEPS_MEM > 2);
+    uint32_t stIdx1 = this->m_memId < 1 ? CEPS_MEM + this->m_memId - 1 : this->m_memId - 1;
+    uint32_t stIdx2 = this->m_memId < 2 ? CEPS_MEM + this->m_memId - 2 : this->m_memId - 2;
+    VERIFY(stIdx1 < this->m_cepstralMem.size());
+    VERIFY(stIdx2 < this->m_cepstralMem.size());
+    auto ceps1 = this->m_cepstralMem[stIdx1];
+    auto ceps2 = this->m_cepstralMem[stIdx2];
+
+    /* Ceps 0 */
+    for (uint32_t i = 0; i < NB_BANDS; ++i) {
+        this->m_cepstralMem[this->m_memId][i] = features.m_featuresVec[i];
+    }
+
+    for (uint32_t i = 0; i < NB_DELTA_CEPS; ++i) {
+        features.m_featuresVec[i] = this->m_cepstralMem[this->m_memId][i] + ceps1[i] + ceps2[i];
+        features.m_featuresVec[NB_BANDS + i] = this->m_cepstralMem[this->m_memId][i] - ceps2[i];
+        features.m_featuresVec[NB_BANDS + NB_DELTA_CEPS + i] =
+                this->m_cepstralMem[this->m_memId][i] - 2 * ceps1[i] + ceps2[i];
+    }
+
+    /* Spectral variability features. */
+    this->m_memId += 1;
+    if (this->m_memId == CEPS_MEM) {
+        this->m_memId = 0;
+    }
+
+    float specVariability = 0.f;
+
+    VERIFY(this->m_cepstralMem.size() >= CEPS_MEM);
+    for (size_t i = 0; i < CEPS_MEM; ++i) {
+        float minDist = 1e15;
+        for (size_t j = 0; j < CEPS_MEM; ++j) {
+            float dist = 0.f;
+            for (size_t k = 0; k < NB_BANDS; ++k) {
+                VERIFY(this->m_cepstralMem[i].size() >= NB_BANDS);
+                auto tmp = this->m_cepstralMem[i][k] - this->m_cepstralMem[j][k];
+                dist += tmp * tmp;
+            }
+
+            if (j != i) {
+                minDist = std::min<float>(minDist, dist);
+            }
+        }
+        specVariability += minDist;
+    }
+
+    VERIFY(features.m_featuresVec.size() >= NB_BANDS + 3 * NB_DELTA_CEPS + 1);
+    features.m_featuresVec[NB_BANDS + 3 * NB_DELTA_CEPS + 1] = specVariability / CEPS_MEM - 2.1;
+}
+
+void RNNoiseFeatureProcessor::FrameAnalysis(
+    const vec1D32F& audioWindow,
+    vec1D32F& fft,
+    vec1D32F& energy,
+    vec1D32F& analysisMem)
+{
+    vec1D32F x(WINDOW_SIZE, 0);
+
+    /* Move old audio down and populate end with latest audio window. */
+    VERIFY(x.size() >= FRAME_SIZE && analysisMem.size() >= FRAME_SIZE);
+    VERIFY(audioWindow.size() >= FRAME_SIZE);
+
+    std::copy_n(analysisMem.begin(), FRAME_SIZE, x.begin());
+    std::copy_n(audioWindow.begin(), x.size() - FRAME_SIZE, x.begin() + FRAME_SIZE);
+    std::copy_n(audioWindow.begin(), FRAME_SIZE, analysisMem.begin());
+
+    this->ApplyWindow(x);
+
+    /* Calculate FFT. */
+    ForwardTransform(x, fft);
+
+    /* Compute band energy. */
+    ComputeBandEnergy(fft, energy);
+}
+
+void RNNoiseFeatureProcessor::ApplyWindow(vec1D32F& x)
+{
+    if (WINDOW_SIZE != x.size()) {
+        printf_err("Invalid size for vector to be windowed\n");
+        return;
+    }
+
+    VERIFY(this->m_halfWindow.size() >= FRAME_SIZE);
+
+    /* Multiply input by sinusoidal function. */
+    for (size_t i = 0; i < FRAME_SIZE; i++) {
+        x[i] *= this->m_halfWindow[i];
+        x[WINDOW_SIZE - 1 - i] *= this->m_halfWindow[i];
+    }
+}
+
+void RNNoiseFeatureProcessor::ForwardTransform(
+    vec1D32F& x,
+    vec1D32F& fft)
+{
+    /* The input vector can be modified by the fft function. */
+    fft.reserve(x.size() + 2);
+    fft.resize(x.size() + 2, 0);
+    math::MathUtils::FftF32(x, fft, this->m_fftInstReal);
+
+    /* Normalise. */
+    for (auto& f : fft) {
+        f /= this->m_fftInstReal.m_fftLen;
+    }
+
+    /* Place the last freq element correctly */
+    fft[fft.size()-2] = fft[1];
+    fft[1] = 0;
+
+    /* NOTE: We don't truncate out FFT vector as it already contains only the
+     * first half of the FFT's. The conjugates are not present. */
+}
+
+void RNNoiseFeatureProcessor::ComputeBandEnergy(const vec1D32F& fftX, vec1D32F& bandE)
+{
+    bandE = vec1D32F(NB_BANDS, 0);
+
+    VERIFY(this->m_eband5ms.size() >= NB_BANDS);
+    for (uint32_t i = 0; i < NB_BANDS - 1; i++) {
+        const auto bandSize = (this->m_eband5ms[i + 1] - this->m_eband5ms[i])
+                              << FRAME_SIZE_SHIFT;
+
+        for (uint32_t j = 0; j < bandSize; j++) {
+            const auto frac = static_cast<float>(j) / bandSize;
+            const auto idx = (this->m_eband5ms[i] << FRAME_SIZE_SHIFT) + j;
+
+            auto tmp = fftX[2 * idx] * fftX[2 * idx]; /* Real part */
+            tmp += fftX[2 * idx + 1] * fftX[2 * idx + 1]; /* Imaginary part */
+
+            bandE[i] += (1 - frac) * tmp;
+            bandE[i + 1] += frac * tmp;
+        }
+    }
+    bandE[0] *= 2;
+    bandE[NB_BANDS - 1] *= 2;
+}
+
+void RNNoiseFeatureProcessor::ComputeBandCorr(const vec1D32F& X, const vec1D32F& P, vec1D32F& bandC)
+{
+    bandC = vec1D32F(NB_BANDS, 0);
+    VERIFY(this->m_eband5ms.size() >= NB_BANDS);
+
+    for (uint32_t i = 0; i < NB_BANDS - 1; i++) {
+        const auto bandSize = (this->m_eband5ms[i + 1] - this->m_eband5ms[i]) << FRAME_SIZE_SHIFT;
+
+        for (uint32_t j = 0; j < bandSize; j++) {
+            const auto frac = static_cast<float>(j) / bandSize;
+            const auto idx = (this->m_eband5ms[i] << FRAME_SIZE_SHIFT) + j;
+
+            auto tmp = X[2 * idx] * P[2 * idx]; /* Real part */
+            tmp += X[2 * idx + 1] * P[2 * idx + 1]; /* Imaginary part */
+
+            bandC[i] += (1 - frac) * tmp;
+            bandC[i + 1] += frac * tmp;
+        }
+    }
+    bandC[0] *= 2;
+    bandC[NB_BANDS - 1] *= 2;
+}
+
+void RNNoiseFeatureProcessor::DCT(vec1D32F& input, vec1D32F& output)
+{
+    VERIFY(this->m_dctTable.size() >= NB_BANDS * NB_BANDS);
+    for (uint32_t i = 0; i < NB_BANDS; ++i) {
+        float sum = 0;
+
+        for (uint32_t j = 0, k = 0; j < NB_BANDS; ++j, k += NB_BANDS) {
+            sum += input[j] * this->m_dctTable[k + i];
+        }
+        output[i] = sum * math::MathUtils::SqrtF32(2.0/22);
+    }
+}
+
+void RNNoiseFeatureProcessor::PitchDownsample(vec1D32F& pitchBuf, size_t pitchBufSz) {
+    for (size_t i = 1; i < (pitchBufSz >> 1); ++i) {
+        pitchBuf[i] = 0.5 * (
+                        0.5 * (this->m_pitchBuf[2 * i - 1] + this->m_pitchBuf[2 * i + 1])
+                            + this->m_pitchBuf[2 * i]);
+    }
+
+    pitchBuf[0] = 0.5*(0.5*(this->m_pitchBuf[1]) + this->m_pitchBuf[0]);
+
+    vec1D32F ac(5, 0);
+    size_t numLags = 4;
+
+    this->AutoCorr(pitchBuf, ac, numLags, pitchBufSz >> 1);
+
+    /* Noise floor -40db */
+    ac[0] *= 1.0001;
+
+    /* Lag windowing. */
+    for (size_t i = 1; i < numLags + 1; ++i) {
+        ac[i] -= ac[i] * (0.008 * i) * (0.008 * i);
+    }
+
+    vec1D32F lpc(numLags, 0);
+    this->LPC(ac, numLags, lpc);
+
+    float tmp = 1.0;
+    for (size_t i = 0; i < numLags; ++i) {
+        tmp = 0.9f * tmp;
+        lpc[i] = lpc[i] * tmp;
+    }
+
+    vec1D32F lpc2(numLags + 1, 0);
+    float c1 = 0.8;
+
+    /* Add a zero. */
+    lpc2[0] = lpc[0] + 0.8;
+    lpc2[1] = lpc[1] + (c1 * lpc[0]);
+    lpc2[2] = lpc[2] + (c1 * lpc[1]);
+    lpc2[3] = lpc[3] + (c1 * lpc[2]);
+    lpc2[4] = (c1 * lpc[3]);
+
+    this->Fir5(lpc2, pitchBufSz >> 1, pitchBuf);
+}
+
+int RNNoiseFeatureProcessor::PitchSearch(vec1D32F& xLp, vec1D32F& y, uint32_t len, uint32_t maxPitch) {
+    uint32_t lag = len + maxPitch;
+    vec1D32F xLp4(len >> 2, 0);
+    vec1D32F yLp4(lag >> 2, 0);
+    vec1D32F xCorr(maxPitch >> 1, 0);
+
+    /* Downsample by 2 again. */
+    for (size_t j = 0; j < (len >> 2); ++j) {
+        xLp4[j] = xLp[2*j];
+    }
+    for (size_t j = 0; j < (lag >> 2); ++j) {
+        yLp4[j] = y[2*j];
+    }
+
+    this->PitchXCorr(xLp4, yLp4, xCorr, len >> 2, maxPitch >> 2);
+
+    /* Coarse search with 4x decimation. */
+    arrHp bestPitch = this->FindBestPitch(xCorr, yLp4, len >> 2, maxPitch >> 2);
+
+    /* Finer search with 2x decimation. */
+    const int maxIdx = (maxPitch >> 1);
+    for (int i = 0; i < maxIdx; ++i) {
+        xCorr[i] = 0;
+        if (std::abs(i - 2*bestPitch[0]) > 2 and std::abs(i - 2*bestPitch[1]) > 2) {
+            continue;
+        }
+        float sum = 0;
+        for (size_t j = 0; j < len >> 1; ++j) {
+            sum += xLp[j] * y[i+j];
+        }
+
+        xCorr[i] = std::max(-1.0f, sum);
+    }
+
+    bestPitch = this->FindBestPitch(xCorr, y, len >> 1, maxPitch >> 1);
+
+    int offset;
+    /* Refine by pseudo-interpolation. */
+    if ( 0 < bestPitch[0] && bestPitch[0] < ((maxPitch >> 1) - 1)) {
+        float a = xCorr[bestPitch[0] - 1];
+        float b = xCorr[bestPitch[0]];
+        float c = xCorr[bestPitch[0] + 1];
+
+        if ( (c-a) > 0.7*(b-a) ) {
+            offset = 1;
+        } else if ( (a-c) > 0.7*(b-c) ) {
+            offset = -1;
+        } else {
+            offset = 0;
+        }
+    } else {
+        offset = 0;
+    }
+
+    return 2*bestPitch[0] - offset;
+}
+
+arrHp RNNoiseFeatureProcessor::FindBestPitch(vec1D32F& xCorr, vec1D32F& y, uint32_t len, uint32_t maxPitch)
+{
+    float Syy = 1;
+    arrHp bestNum {-1, -1};
+    arrHp bestDen {0, 0};
+    arrHp bestPitch {0, 1};
+
+    for (size_t j = 0; j < len; ++j) {
+        Syy += (y[j] * y[j]);
+    }
+
+    for (size_t i = 0; i < maxPitch; ++i ) {
+        if (xCorr[i] > 0) {
+            float xCorr16 = xCorr[i] * 1e-12f;  /* Avoid problems when squaring. */
+
+            float num = xCorr16 * xCorr16;
+            if (num*bestDen[1] > bestNum[1]*Syy) {
+                if (num*bestDen[0] > bestNum[0]*Syy) {
+                    bestNum[1] = bestNum[0];
+                    bestDen[1] = bestDen[0];
+                    bestPitch[1] = bestPitch[0];
+                    bestNum[0] = num;
+                    bestDen[0] = Syy;
+                    bestPitch[0] = i;
+                } else {
+                    bestNum[1] = num;
+                    bestDen[1] = Syy;
+                    bestPitch[1] = i;
+                }
+            }
+        }
+
+        Syy += (y[i+len]*y[i+len]) - (y[i]*y[i]);
+        Syy = std::max(1.0f, Syy);
+    }
+
+    return bestPitch;
+}
+
+int RNNoiseFeatureProcessor::RemoveDoubling(
+    vec1D32F& pitchBuf,
+    uint32_t maxPeriod,
+    uint32_t minPeriod,
+    uint32_t frameSize,
+    size_t pitchIdx0_)
+{
+    constexpr std::array<size_t, 16> secondCheck {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2};
+    uint32_t minPeriod0 = minPeriod;
+    float lastPeriod = static_cast<float>(this->m_lastPeriod)/2;
+    float lastGain = static_cast<float>(this->m_lastGain);
+
+    maxPeriod /= 2;
+    minPeriod /= 2;
+    pitchIdx0_ /= 2;
+    frameSize /= 2;
+    uint32_t xStart = maxPeriod;
+
+    if (pitchIdx0_ >= maxPeriod) {
+        pitchIdx0_ = maxPeriod - 1;
+    }
+
+    size_t pitchIdx  = pitchIdx0_;
+    const size_t pitchIdx0 = pitchIdx0_;
+
+    float xx = 0;
+    for ( size_t i = xStart; i < xStart+frameSize; ++i) {
+        xx += (pitchBuf[i] * pitchBuf[i]);
+    }
+
+    float xy = 0;
+    for ( size_t i = xStart; i < xStart+frameSize; ++i) {
+        xy += (pitchBuf[i] * pitchBuf[i-pitchIdx0]);
+    }
+
+    vec1D32F yyLookup (maxPeriod+1, 0);
+    yyLookup[0] = xx;
+    float yy = xx;
+
+    for ( size_t i = 1; i < yyLookup.size(); ++i) {
+        yy = yy + (pitchBuf[xStart-i] * pitchBuf[xStart-i]) -
+                (pitchBuf[xStart+frameSize-i] * pitchBuf[xStart+frameSize-i]);
+        yyLookup[i] = std::max(0.0f, yy);
+    }
+
+    yy = yyLookup[pitchIdx0];
+    float bestXy = xy;
+    float bestYy = yy;
+
+    float g = this->ComputePitchGain(xy, xx, yy);
+    float g0 = g;
+
+    /* Look for any pitch at pitchIndex/k. */
+    for ( size_t k = 2; k < 16; ++k) {
+        size_t pitchIdx1 = (2*pitchIdx0+k) / (2*k);
+        if (pitchIdx1 < minPeriod) {
+            break;
+        }
+
+        size_t pitchIdx1b;
+        /* Look for another strong correlation at T1b. */
+        if (k == 2) {
+            if ((pitchIdx1 + pitchIdx0) > maxPeriod) {
+                pitchIdx1b = pitchIdx0;
+            } else {
+                pitchIdx1b = pitchIdx0 + pitchIdx1;
+            }
+        } else {
+            pitchIdx1b = (2*(secondCheck[k])*pitchIdx0 + k) / (2*k);
+        }
+
+        xy = 0;
+        for ( size_t i = xStart; i < xStart+frameSize; ++i) {
+            xy += (pitchBuf[i] * pitchBuf[i-pitchIdx1]);
+        }
+
+        float xy2 = 0;
+        for ( size_t i = xStart; i < xStart+frameSize; ++i) {
+            xy2 += (pitchBuf[i] * pitchBuf[i-pitchIdx1b]);
+        }
+        xy = 0.5f * (xy + xy2);
+        VERIFY(pitchIdx1b < maxPeriod+1);
+        yy = 0.5f * (yyLookup[pitchIdx1] + yyLookup[pitchIdx1b]);
+
+        float g1 = this->ComputePitchGain(xy, xx, yy);
+
+        float cont;
+        if (std::abs(pitchIdx1-lastPeriod) <= 1) {
+            cont = lastGain;
+        } else if (std::abs(pitchIdx1-lastPeriod) <= 2 and 5*k*k < pitchIdx0) {
+            cont = 0.5f*lastGain;
+        } else {
+            cont = 0.0f;
+        }
+
+        float thresh = std::max(0.3, 0.7*g0-cont);
+
+        /* Bias against very high pitch (very short period) to avoid false-positives
+         * due to short-term correlation */
+        if (pitchIdx1 < 3*minPeriod) {
+            thresh = std::max(0.4, 0.85*g0-cont);
+        } else if (pitchIdx1 < 2*minPeriod) {
+            thresh = std::max(0.5, 0.9*g0-cont);
+        }
+        if (g1 > thresh) {
+            bestXy = xy;
+            bestYy = yy;
+            pitchIdx = pitchIdx1;
+            g = g1;
+        }
+    }
+
+    bestXy = std::max(0.0f, bestXy);
+    float pg;
+    if (bestYy <= bestXy) {
+        pg = 1.0;
+    } else {
+        pg = bestXy/(bestYy+1);
+    }
+
+    std::array<float, 3> xCorr {0};
+    for ( size_t k = 0; k < 3; ++k ) {
+        for ( size_t i = xStart; i < xStart+frameSize; ++i) {
+            xCorr[k] += (pitchBuf[i] * pitchBuf[i-(pitchIdx+k-1)]);
+        }
+    }
+
+    size_t offset;
+    if ((xCorr[2]-xCorr[0]) > 0.7*(xCorr[1]-xCorr[0])) {
+        offset = 1;
+    } else if ((xCorr[0]-xCorr[2]) > 0.7*(xCorr[1]-xCorr[2])) {
+        offset = -1;
+    } else {
+        offset = 0;
+    }
+
+    if (pg > g) {
+        pg = g;
+    }
+
+    pitchIdx0_ = 2*pitchIdx + offset;
+
+    if (pitchIdx0_ < minPeriod0) {
+        pitchIdx0_ = minPeriod0;
+    }
+
+    this->m_lastPeriod = pitchIdx0_;
+    this->m_lastGain = pg;
+
+    return this->m_lastPeriod;
+}
+
+float RNNoiseFeatureProcessor::ComputePitchGain(float xy, float xx, float yy)
+{
+    return xy / math::MathUtils::SqrtF32(1+xx*yy);
+}
+
+void RNNoiseFeatureProcessor::AutoCorr(
+    const vec1D32F& x,
+    vec1D32F& ac,
+    size_t lag,
+    size_t n)
+{
+    if (n < lag) {
+        printf_err("Invalid parameters for AutoCorr\n");
+        return;
+    }
+
+    auto fastN = n - lag;
+
+    /* Auto-correlation - can be done by PlatformMath functions */
+    this->PitchXCorr(x, x, ac, fastN, lag + 1);
+
+    /* Modify auto-correlation by summing with auto-correlation for different lags. */
+    for (size_t k = 0; k < lag + 1; k++) {
+        float d = 0;
+        for (size_t i = k + fastN; i < n; i++) {
+            d += x[i] * x[i - k];
+        }
+        ac[k] += d;
+    }
+}
+
+
+void RNNoiseFeatureProcessor::PitchXCorr(
+    const vec1D32F& x,
+    const vec1D32F& y,
+    vec1D32F& xCorr,
+    size_t len,
+    size_t maxPitch)
+{
+    for (size_t i = 0; i < maxPitch; i++) {
+        float sum = 0;
+        for (size_t j = 0; j < len; j++) {
+            sum += x[j] * y[i + j];
+        }
+        xCorr[i] = sum;
+    }
+}
+
+/* Linear predictor coefficients */
+void RNNoiseFeatureProcessor::LPC(
+    const vec1D32F& correlation,
+    int32_t p,
+    vec1D32F& lpc)
+{
+    auto error = correlation[0];
+
+    if (error != 0) {
+        for (int i = 0; i < p; i++) {
+
+            /* Sum up this iteration's reflection coefficient */
+            float rr = 0;
+            for (int j = 0; j < i; j++) {
+                rr += lpc[j] * correlation[i - j];
+            }
+
+            rr += correlation[i + 1];
+            auto r = -rr / error;
+
+            /* Update LP coefficients and total error */
+            lpc[i] = r;
+            for (int j = 0; j < ((i + 1) >> 1); j++) {
+                auto tmp1 = lpc[j];
+                auto tmp2 = lpc[i - 1 - j];
+                lpc[j] = tmp1 + (r * tmp2);
+                lpc[i - 1 - j] = tmp2 + (r * tmp1);
+            }
+
+            error = error - (r * r * error);
+
+            /* Bail out once we get 30dB gain */
+            if (error < (0.001 * correlation[0])) {
+                break;
+            }
+        }
+    }
+}
+
+void RNNoiseFeatureProcessor::Fir5(
+    const vec1D32F &num,
+    uint32_t N,
+    vec1D32F &x)
+{
+    auto num0 = num[0];
+    auto num1 = num[1];
+    auto num2 = num[2];
+    auto num3 = num[3];
+    auto num4 = num[4];
+    auto mem0 = 0;
+    auto mem1 = 0;
+    auto mem2 = 0;
+    auto mem3 = 0;
+    auto mem4 = 0;
+    for (uint32_t i = 0; i < N; i++)
+    {
+        auto sum_ = x[i] +  (num0 * mem0) + (num1 * mem1) +
+                    (num2 * mem2) + (num3 * mem3) + (num4 * mem4);
+        mem4 = mem3;
+        mem3 = mem2;
+        mem2 = mem1;
+        mem1 = mem0;
+        mem0 = x[i];
+        x[i] = sum_;
+    }
+}
+
+void RNNoiseFeatureProcessor::PitchFilter(FrameFeatures &features, vec1D32F &gain) {
+    std::vector<float> r(NB_BANDS, 0);
+    std::vector<float> rf(FREQ_SIZE, 0);
+    std::vector<float> newE(NB_BANDS);
+
+    for (size_t i = 0; i < NB_BANDS; i++) {
+        if (features.m_Exp[i] > gain[i]) {
+            r[i] = 1;
+        } else {
+
+
+            r[i] = std::pow(features.m_Exp[i], 2) * (1 - std::pow(gain[i], 2)) /
+                   (.001 + std::pow(gain[i], 2) * (1 - std::pow(features.m_Exp[i], 2)));
+        }
+
+
+        r[i] = math::MathUtils::SqrtF32(std::min(1.0f, std::max(0.0f, r[i])));
+        r[i] *= math::MathUtils::SqrtF32(features.m_Ex[i] / (1e-8f + features.m_Ep[i]));
+    }
+
+    InterpBandGain(rf, r);
+    for (size_t i = 0; i < FREQ_SIZE - 1; i++) {
+        features.m_fftX[2 * i] += rf[i] * features.m_fftP[2 * i];  /* Real. */
+        features.m_fftX[2 * i + 1] += rf[i] * features.m_fftP[2 * i + 1];  /* Imaginary. */
+
+    }
+    ComputeBandEnergy(features.m_fftX, newE);
+    std::vector<float> norm(NB_BANDS);
+    std::vector<float> normf(FRAME_SIZE, 0);
+    for (size_t i = 0; i < NB_BANDS; i++) {
+        norm[i] = math::MathUtils::SqrtF32(features.m_Ex[i] / (1e-8f + newE[i]));
+    }
+
+    InterpBandGain(normf, norm);
+    for (size_t i = 0; i < FREQ_SIZE - 1; i++) {
+        features.m_fftX[2 * i] *= normf[i];  /* Real. */
+        features.m_fftX[2 * i + 1] *= normf[i];  /* Imaginary. */
+
+    }
+}
+
+void RNNoiseFeatureProcessor::FrameSynthesis(vec1D32F& outFrame, vec1D32F& fftY) {
+    std::vector<float> x(WINDOW_SIZE, 0);
+    InverseTransform(x, fftY);
+    ApplyWindow(x);
+    for (size_t i = 0; i < FRAME_SIZE; i++) {
+        outFrame[i] = x[i] + m_synthesisMem[i];
+    }
+    memcpy((m_synthesisMem.data()), &x[FRAME_SIZE], FRAME_SIZE*sizeof(float));
+}
+
+void RNNoiseFeatureProcessor::InterpBandGain(vec1D32F& g, vec1D32F& bandE) {
+    for (size_t i = 0; i < NB_BANDS - 1; i++) {
+        int bandSize = (m_eband5ms[i + 1] - m_eband5ms[i]) << FRAME_SIZE_SHIFT;
+        for (int j = 0; j < bandSize; j++) {
+            float frac = static_cast<float>(j) / bandSize;
+            g[(m_eband5ms[i] << FRAME_SIZE_SHIFT) + j] = (1 - frac) * bandE[i] + frac * bandE[i + 1];
+        }
+    }
+}
+
+void RNNoiseFeatureProcessor::InverseTransform(vec1D32F& out, vec1D32F& fftXIn) {
+
+    std::vector<float> x(WINDOW_SIZE * 2);  /* This is complex. */
+    vec1D32F newFFT;  /* This is complex. */
+
+    size_t i;
+    for (i = 0; i < FREQ_SIZE * 2; i++) {
+        x[i] = fftXIn[i];
+    }
+    for (i = FREQ_SIZE; i < WINDOW_SIZE; i++) {
+        x[2 * i] = x[2 * (WINDOW_SIZE - i)];  /* Real. */
+        x[2 * i + 1] = -x[2 * (WINDOW_SIZE - i) + 1];  /* Imaginary. */
+    }
+
+    constexpr uint32_t numFFt = 2 * FRAME_SIZE;
+    static_assert(numFFt != 0, "numFFt cannot be 0!");
+
+    vec1D32F fftOut = vec1D32F(x.size(), 0);
+    math::MathUtils::FftF32(x,fftOut, m_fftInstCmplx);
+
+    /* Normalize. */
+    for (auto &f: fftOut) {
+        f /= numFFt;
+    }
+
+    out[0] = WINDOW_SIZE * fftOut[0];  /* Real. */
+    for (i = 1; i < WINDOW_SIZE; i++) {
+        out[i] = WINDOW_SIZE * fftOut[(WINDOW_SIZE * 2) - (2 * i)];  /* Real. */
+    }
+}
+
+
+} /* namespace rnn */
+} /* namespace app */
+} /* namspace arm */
diff --git a/source/application/api/use_case/noise_reduction/src/RNNoiseModel.cc b/source/application/api/use_case/noise_reduction/src/RNNoiseModel.cc
new file mode 100644
index 0000000..457cda9
--- /dev/null
+++ b/source/application/api/use_case/noise_reduction/src/RNNoiseModel.cc
@@ -0,0 +1,96 @@
+/*
+ * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "RNNoiseModel.hpp"
+#include "log_macros.h"
+
+const tflite::MicroOpResolver& arm::app::RNNoiseModel::GetOpResolver()
+{
+    return this->m_opResolver;
+}
+
+bool arm::app::RNNoiseModel::EnlistOperations()
+{
+    this->m_opResolver.AddUnpack();
+    this->m_opResolver.AddFullyConnected();
+    this->m_opResolver.AddSplit();
+    this->m_opResolver.AddSplitV();
+    this->m_opResolver.AddAdd();
+    this->m_opResolver.AddLogistic();
+    this->m_opResolver.AddMul();
+    this->m_opResolver.AddSub();
+    this->m_opResolver.AddTanh();
+    this->m_opResolver.AddPack();
+    this->m_opResolver.AddReshape();
+    this->m_opResolver.AddQuantize();
+    this->m_opResolver.AddConcatenation();
+    this->m_opResolver.AddRelu();
+
+    if (kTfLiteOk == this->m_opResolver.AddEthosU()) {
+        info("Added %s support to op resolver\n",
+            tflite::GetString_ETHOSU());
+    } else {
+        printf_err("Failed to add Arm NPU support to op resolver.");
+        return false;
+    }
+    return true;
+}
+
+bool arm::app::RNNoiseModel::RunInference()
+{
+    return Model::RunInference();
+}
+
+void arm::app::RNNoiseModel::ResetGruState()
+{
+    for (auto& stateMapping: this->m_gruStateMap) {
+        TfLiteTensor* inputGruStateTensor = this->GetInputTensor(stateMapping.second);
+        auto* inputGruState = tflite::GetTensorData<int8_t>(inputGruStateTensor);
+        /* Initial value of states is 0, but this is affected by quantization zero point. */
+        auto quantParams = arm::app::GetTensorQuantParams(inputGruStateTensor);
+        memset(inputGruState, quantParams.offset, inputGruStateTensor->bytes);
+    }
+}
+
+bool arm::app::RNNoiseModel::CopyGruStates()
+{
+    std::vector<std::pair<size_t, std::vector<int8_t>>> tempOutGruStates;
+    /* Saving output states before copying them to input states to avoid output states modification in the tensor.
+     * tflu shares input and output tensors memory, thus writing to input tensor can change output tensor values. */
+    for (auto& stateMapping: this->m_gruStateMap) {
+        TfLiteTensor* outputGruStateTensor = this->GetOutputTensor(stateMapping.first);
+        std::vector<int8_t> tempOutGruState(outputGruStateTensor->bytes);
+        auto* outGruState = tflite::GetTensorData<int8_t>(outputGruStateTensor);
+        memcpy(tempOutGruState.data(), outGruState, outputGruStateTensor->bytes);
+        /* Index of the input tensor and the data to copy. */
+        tempOutGruStates.emplace_back(stateMapping.second, std::move(tempOutGruState));
+    }
+    /* Updating input GRU states with saved GRU output states. */
+    for (auto& stateMapping: tempOutGruStates) {
+        auto outputGruStateTensorData = stateMapping.second;
+        TfLiteTensor* inputGruStateTensor = this->GetInputTensor(stateMapping.first);
+        if (outputGruStateTensorData.size() != inputGruStateTensor->bytes) {
+            printf_err("Unexpected number of bytes for GRU state mapping. Input = %zuz, output = %zuz.\n",
+                       inputGruStateTensor->bytes,
+                       outputGruStateTensorData.size());
+            return false;
+        }
+        auto* inputGruState = tflite::GetTensorData<int8_t>(inputGruStateTensor);
+        auto* outGruState = outputGruStateTensorData.data();
+        memcpy(inputGruState, outGruState, inputGruStateTensor->bytes);
+    }
+    return true;
+}
diff --git a/source/application/api/use_case/noise_reduction/src/RNNoiseProcessing.cc b/source/application/api/use_case/noise_reduction/src/RNNoiseProcessing.cc
new file mode 100644
index 0000000..f6a3ec4
--- /dev/null
+++ b/source/application/api/use_case/noise_reduction/src/RNNoiseProcessing.cc
@@ -0,0 +1,100 @@
+/*
+ * Copyright (c) 2022 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "RNNoiseProcessing.hpp"
+#include "log_macros.h"
+
+namespace arm {
+namespace app {
+
+    RNNoisePreProcess::RNNoisePreProcess(TfLiteTensor* inputTensor,
+            std::shared_ptr<rnn::RNNoiseFeatureProcessor> featureProcessor, std::shared_ptr<rnn::FrameFeatures> frameFeatures)
+    :   m_inputTensor{inputTensor},
+        m_featureProcessor{featureProcessor},
+        m_frameFeatures{frameFeatures}
+    {}
+
+    bool RNNoisePreProcess::DoPreProcess(const void* data, size_t inputSize)
+    {
+        if (data == nullptr) {
+            printf_err("Data pointer is null");
+            return false;
+        }
+
+        auto input = static_cast<const int16_t*>(data);
+        this->m_audioFrame = rnn::vec1D32F(input, input + inputSize);
+        m_featureProcessor->PreprocessFrame(this->m_audioFrame.data(), inputSize, *this->m_frameFeatures);
+
+        QuantizeAndPopulateInput(this->m_frameFeatures->m_featuresVec,
+                this->m_inputTensor->params.scale, this->m_inputTensor->params.zero_point,
+                this->m_inputTensor);
+
+        debug("Input tensor populated \n");
+
+        return true;
+    }
+
+    void RNNoisePreProcess::QuantizeAndPopulateInput(rnn::vec1D32F& inputFeatures,
+            const float quantScale, const int quantOffset,
+            TfLiteTensor* inputTensor)
+    {
+        const float minVal = std::numeric_limits<int8_t>::min();
+        const float maxVal = std::numeric_limits<int8_t>::max();
+
+        auto* inputTensorData = tflite::GetTensorData<int8_t>(inputTensor);
+
+        for (size_t i=0; i < inputFeatures.size(); ++i) {
+            float quantValue = ((inputFeatures[i] / quantScale) + quantOffset);
+            inputTensorData[i] = static_cast<int8_t>(std::min<float>(std::max<float>(quantValue, minVal), maxVal));
+        }
+    }
+
+    RNNoisePostProcess::RNNoisePostProcess(TfLiteTensor* outputTensor,
+            std::vector<int16_t>& denoisedAudioFrame,
+            std::shared_ptr<rnn::RNNoiseFeatureProcessor> featureProcessor,
+            std::shared_ptr<rnn::FrameFeatures> frameFeatures)
+    :   m_outputTensor{outputTensor},
+        m_denoisedAudioFrame{denoisedAudioFrame},
+        m_featureProcessor{featureProcessor},
+        m_frameFeatures{frameFeatures}
+        {
+            this->m_denoisedAudioFrameFloat.reserve(denoisedAudioFrame.size());
+            this->m_modelOutputFloat.resize(outputTensor->bytes);
+        }
+
+    bool RNNoisePostProcess::DoPostProcess()
+    {
+        const auto* outputData = tflite::GetTensorData<int8_t>(this->m_outputTensor);
+        auto outputQuantParams = GetTensorQuantParams(this->m_outputTensor);
+
+        for (size_t i = 0; i < this->m_outputTensor->bytes; ++i) {
+            this->m_modelOutputFloat[i] = (static_cast<float>(outputData[i]) - outputQuantParams.offset)
+                                  * outputQuantParams.scale;
+        }
+
+        this->m_featureProcessor->PostProcessFrame(this->m_modelOutputFloat,
+                *this->m_frameFeatures, this->m_denoisedAudioFrameFloat);
+
+        for (size_t i = 0; i < this->m_denoisedAudioFrame.size(); ++i) {
+            this->m_denoisedAudioFrame[i] = static_cast<int16_t>(
+                    std::roundf(this->m_denoisedAudioFrameFloat[i]));
+        }
+
+        return true;
+    }
+
+} /* namespace app */
+} /* namespace arm */
\ No newline at end of file