MLECO-3076: Add use case API for object detection

* Removed unused prototype for box drawing

Signed-off-by: Richard Burton <richard.burton@arm.com>
Change-Id: I1b03b88e710a5efb1ff8e107859d2245b1fead26
diff --git a/source/use_case/object_detection/include/DetectorPostProcessing.hpp b/source/use_case/object_detection/include/DetectorPostProcessing.hpp
index cdb14f5..b3ddb2c 100644
--- a/source/use_case/object_detection/include/DetectorPostProcessing.hpp
+++ b/source/use_case/object_detection/include/DetectorPostProcessing.hpp
@@ -21,11 +21,13 @@
 #include "ImageUtils.hpp"
 #include "DetectionResult.hpp"
 #include "YoloFastestModel.hpp"
+#include "BaseProcessing.hpp"
 
 #include <forward_list>
 
 namespace arm {
 namespace app {
+
 namespace object_detection {
 
     struct Branch {
@@ -46,42 +48,55 @@
         int topN;
     };
 
+} /* namespace object_detection */
+
     /**
-     * @brief   Helper class to manage tensor post-processing for "object_detection"
-     *          output.
+     * @brief   Post-processing class for Object Detection use case.
+     *          Implements methods declared by BasePostProcess and anything else needed
+     *          to populate result vector.
      */
-    class DetectorPostprocessing {
+    class DetectorPostProcess : public BasePostProcess {
     public:
         /**
-         * @brief       Constructor.
-         * @param[in]   threshold     Post-processing threshold.
-         * @param[in]   nms           Non-maximum Suppression threshold.
-         * @param[in]   numClasses    Number of classes.
-         * @param[in]   topN          Top N for each class.
+         * @brief        Constructor.
+         * @param[in]    outputTensor0   Pointer to the TFLite Micro output Tensor at index 0.
+         * @param[in]    outputTensor1   Pointer to the TFLite Micro output Tensor at index 1.
+         * @param[out]   results         Vector of detected results.
+         * @param[in]    inputImgRows    Number of rows in the input image.
+         * @param[in]    inputImgCols    Number of columns in the input image.
+         * @param[in]    threshold       Post-processing threshold.
+         * @param[in]    nms             Non-maximum Suppression threshold.
+         * @param[in]    numClasses      Number of classes.
+         * @param[in]    topN            Top N for each class.
          **/
-        explicit DetectorPostprocessing(float threshold = 0.5f,
-                                        float nms = 0.45f,
-                                        int numClasses = 1,
-                                        int topN = 0);
+        explicit DetectorPostProcess(TfLiteTensor* outputTensor0,
+                                     TfLiteTensor* outputTensor1,
+                                     std::vector<object_detection::DetectionResult>& results,
+                                     int inputImgRows,
+                                     int inputImgCols,
+                                     float threshold = 0.5f,
+                                     float nms = 0.45f,
+                                     int numClasses = 1,
+                                     int topN = 0);
 
         /**
-         * @brief       Post processing part of YOLO object detection CNN.
-         * @param[in]   imgRows      Number of rows in the input image.
-         * @param[in]   imgCols      Number of columns in the input image.
-         * @param[in]   modelOutput  Output tensors after CNN invoked.
-         * @param[out]  resultsOut   Vector of detected results.
+         * @brief    Should perform YOLO post-processing of the result of inference then
+         *           populate Detection result data for any later use.
+         * @return   true if successful, false otherwise.
          **/
-        void RunPostProcessing(uint32_t imgRows,
-                               uint32_t imgCols,
-                               TfLiteTensor* modelOutput0,
-                               TfLiteTensor* modelOutput1,
-                               std::vector<DetectionResult>& resultsOut);
+        bool DoPostProcess() override;
 
     private:
-        float m_threshold;  /* Post-processing threshold */
-        float m_nms;        /* NMS threshold */
-        int   m_numClasses; /* Number of classes */
-        int   m_topN;       /* TopN */
+        TfLiteTensor* m_outputTensor0;     /* Output tensor index 0 */
+        TfLiteTensor* m_outputTensor1;     /* Output tensor index 1 */
+        std::vector<object_detection::DetectionResult>& m_results;  /* Single inference results. */
+        int m_inputImgRows;                /* Number of rows for model input. */
+        int m_inputImgCols;                /* Number of cols for model input. */
+        float m_threshold;                 /* Post-processing threshold. */
+        float m_nms;                       /* NMS threshold. */
+        int   m_numClasses;                /* Number of classes. */
+        int   m_topN;                      /* TopN. */
+        object_detection::Network m_net;   /* YOLO network object. */
 
         /**
          * @brief       Insert the given Detection in the list.
@@ -98,32 +113,13 @@
          * @param[in]    threshold     Detections threshold.
          * @param[out]   detections    Detection boxes.
          **/
-        void GetNetworkBoxes(Network& net,
+        void GetNetworkBoxes(object_detection::Network& net,
                              int imageWidth,
                              int imageHeight,
                              float threshold,
                              std::forward_list<image::Detection>& detections);
-
-        /**
-         * @brief       Draw on the given image a bounding box starting at (boxX, boxY).
-         * @param[in/out]   imgIn    Image.
-         * @param[in]       imWidth    Image width.
-         * @param[in]       imHeight   Image height.
-         * @param[in]       boxX       Axis X starting point.
-         * @param[in]       boxY       Axis Y starting point.
-         * @param[in]       boxWidth   Box width.
-         * @param[in]       boxHeight  Box height.
-         **/
-        void DrawBoxOnImage(uint8_t* imgIn,
-                            int imWidth,
-                            int imHeight,
-                            int boxX,
-                            int boxY,
-                            int boxWidth,
-                            int boxHeight);
     };
 
-} /* namespace object_detection */
 } /* namespace app */
 } /* namespace arm */
 
diff --git a/source/use_case/object_detection/include/DetectorPreProcessing.hpp b/source/use_case/object_detection/include/DetectorPreProcessing.hpp
new file mode 100644
index 0000000..4936048
--- /dev/null
+++ b/source/use_case/object_detection/include/DetectorPreProcessing.hpp
@@ -0,0 +1,60 @@
+/*
+ * Copyright (c) 2022 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef DETECTOR_PRE_PROCESSING_HPP
+#define DETECTOR_PRE_PROCESSING_HPP
+
+#include "BaseProcessing.hpp"
+#include "Classifier.hpp"
+
+namespace arm {
+namespace app {
+
+    /**
+     * @brief   Pre-processing class for Object detection use case.
+     *          Implements methods declared by BasePreProcess and anything else needed
+     *          to populate input tensors ready for inference.
+     */
+    class DetectorPreProcess : public BasePreProcess {
+
+    public:
+        /**
+         * @brief       Constructor
+         * @param[in]   inputTensor     Pointer to the TFLite Micro input Tensor.
+         * @param[in]   rgb2Gray        Convert image from 3 channel RGB to 1 channel grayscale.
+         * @param[in]   convertToInt8   Convert the image from uint8 to int8 range.
+         **/
+        explicit DetectorPreProcess(TfLiteTensor* inputTensor, bool rgb2Gray, bool convertToInt8);
+
+        /**
+         * @brief       Should perform pre-processing of 'raw' input image data and load it into
+         *              TFLite Micro input tensor ready for inference
+         * @param[in]   input      Pointer to the data that pre-processing will work on.
+         * @param[in]   inputSize  Size of the input data.
+         * @return      true if successful, false otherwise.
+         **/
+        bool DoPreProcess(const void* input, size_t inputSize) override;
+
+    private:
+        TfLiteTensor* m_inputTensor;
+        bool m_rgb2Gray;
+        bool m_convertToInt8;
+    };
+
+} /* namespace app */
+} /* namespace arm */
+
+#endif /* DETECTOR_PRE_PROCESSING_HPP */
\ No newline at end of file
diff --git a/source/use_case/object_detection/src/DetectorPostProcessing.cc b/source/use_case/object_detection/src/DetectorPostProcessing.cc
index a890c9e..fb1606a 100644
--- a/source/use_case/object_detection/src/DetectorPostProcessing.cc
+++ b/source/use_case/object_detection/src/DetectorPostProcessing.cc
@@ -21,64 +21,73 @@
 
 namespace arm {
 namespace app {
-namespace object_detection {
 
-DetectorPostprocessing::DetectorPostprocessing(
-    const float threshold,
-    const float nms,
-    int numClasses,
-    int topN)
-    :   m_threshold(threshold),
-        m_nms(nms),
-        m_numClasses(numClasses),
-        m_topN(topN)
-{}
-
-void DetectorPostprocessing::RunPostProcessing(
-    uint32_t imgRows,
-    uint32_t imgCols,
-    TfLiteTensor* modelOutput0,
-    TfLiteTensor* modelOutput1,
-    std::vector<DetectionResult>& resultsOut)
+    DetectorPostProcess::DetectorPostProcess(
+        TfLiteTensor* modelOutput0,
+        TfLiteTensor* modelOutput1,
+        std::vector<object_detection::DetectionResult>& results,
+        int inputImgRows,
+        int inputImgCols,
+        const float threshold,
+        const float nms,
+        int numClasses,
+        int topN)
+        :   m_outputTensor0{modelOutput0},
+            m_outputTensor1{modelOutput1},
+            m_results{results},
+            m_inputImgRows{inputImgRows},
+            m_inputImgCols{inputImgCols},
+            m_threshold(threshold),
+            m_nms(nms),
+            m_numClasses(numClasses),
+            m_topN(topN)
 {
-    /* init postprocessing */
-    Network net {
-        .inputWidth = static_cast<int>(imgCols),
-        .inputHeight = static_cast<int>(imgRows),
-        .numClasses = m_numClasses,
+    /* Init PostProcessing */
+    this->m_net =
+    object_detection::Network {
+        .inputWidth = inputImgCols,
+        .inputHeight = inputImgRows,
+        .numClasses = numClasses,
         .branches = {
-            Branch {
-                .resolution = static_cast<int>(imgCols/32),
-                .numBox = 3,
-                .anchor = anchor1,
-                .modelOutput = modelOutput0->data.int8,
-                .scale = ((TfLiteAffineQuantization*)(modelOutput0->quantization.params))->scale->data[0],
-                .zeroPoint = ((TfLiteAffineQuantization*)(modelOutput0->quantization.params))->zero_point->data[0],
-                .size = modelOutput0->bytes
+            object_detection::Branch {
+                        .resolution = inputImgCols/32,
+                        .numBox = 3,
+                        .anchor = anchor1,
+                        .modelOutput = this->m_outputTensor0->data.int8,
+                        .scale = (static_cast<TfLiteAffineQuantization*>(
+                                this->m_outputTensor0->quantization.params))->scale->data[0],
+                        .zeroPoint = (static_cast<TfLiteAffineQuantization*>(
+                                this->m_outputTensor0->quantization.params))->zero_point->data[0],
+                        .size = this->m_outputTensor0->bytes
             },
-            Branch {
-                .resolution = static_cast<int>(imgCols/16),
-                .numBox = 3,
-                .anchor = anchor2,
-                .modelOutput = modelOutput1->data.int8,
-                .scale = ((TfLiteAffineQuantization*)(modelOutput1->quantization.params))->scale->data[0],
-                .zeroPoint = ((TfLiteAffineQuantization*)(modelOutput1->quantization.params))->zero_point->data[0],
-                .size = modelOutput1->bytes
+            object_detection::Branch {
+                    .resolution = inputImgCols/16,
+                    .numBox = 3,
+                    .anchor = anchor2,
+                    .modelOutput = this->m_outputTensor1->data.int8,
+                    .scale = (static_cast<TfLiteAffineQuantization*>(
+                            this->m_outputTensor1->quantization.params))->scale->data[0],
+                    .zeroPoint = (static_cast<TfLiteAffineQuantization*>(
+                            this->m_outputTensor1->quantization.params))->zero_point->data[0],
+                    .size = this->m_outputTensor1->bytes
             }
         },
         .topN = m_topN
     };
     /* End init */
+}
 
+bool DetectorPostProcess::DoPostProcess()
+{
     /* Start postprocessing */
     int originalImageWidth = originalImageSize;
     int originalImageHeight = originalImageSize;
 
     std::forward_list<image::Detection> detections;
-    GetNetworkBoxes(net, originalImageWidth, originalImageHeight, m_threshold, detections);
+    GetNetworkBoxes(this->m_net, originalImageWidth, originalImageHeight, m_threshold, detections);
 
     /* Do nms */
-    CalculateNMS(detections, net.numClasses, m_nms);
+    CalculateNMS(detections, this->m_net.numClasses, m_nms);
 
     for (auto& it: detections) {
         float xMin = it.bbox.x - it.bbox.w / 2.0f;
@@ -104,24 +113,24 @@
         float boxWidth = xMax - xMin;
         float boxHeight = yMax - yMin;
 
-        for (int j = 0; j < net.numClasses; ++j) {
+        for (int j = 0; j < this->m_net.numClasses; ++j) {
             if (it.prob[j] > 0) {
 
-                DetectionResult tmpResult = {};
+                object_detection::DetectionResult tmpResult = {};
                 tmpResult.m_normalisedVal = it.prob[j];
                 tmpResult.m_x0 = boxX;
                 tmpResult.m_y0 = boxY;
                 tmpResult.m_w = boxWidth;
                 tmpResult.m_h = boxHeight;
 
-                resultsOut.push_back(tmpResult);
+                this->m_results.push_back(tmpResult);
             }
         }
     }
+    return true;
 }
 
-
-void DetectorPostprocessing::InsertTopNDetections(std::forward_list<image::Detection>& detections, image::Detection& det)
+void DetectorPostProcess::InsertTopNDetections(std::forward_list<image::Detection>& detections, image::Detection& det)
 {
     std::forward_list<image::Detection>::iterator it;
     std::forward_list<image::Detection>::iterator last_it;
@@ -136,7 +145,12 @@
     }
 }
 
-void DetectorPostprocessing::GetNetworkBoxes(Network& net, int imageWidth, int imageHeight, float threshold, std::forward_list<image::Detection>& detections)
+void DetectorPostProcess::GetNetworkBoxes(
+        object_detection::Network& net,
+        int imageWidth,
+        int imageHeight,
+        float threshold,
+        std::forward_list<image::Detection>& detections)
 {
     int numClasses = net.numClasses;
     int num = 0;
@@ -169,10 +183,14 @@
                         int bbox_h_offset = bbox_x_offset + 3;
                         int bbox_scores_offset = bbox_x_offset + 5;
 
-                        det.bbox.x = (static_cast<float>(net.branches[i].modelOutput[bbox_x_offset]) - net.branches[i].zeroPoint) * net.branches[i].scale;
-                        det.bbox.y = (static_cast<float>(net.branches[i].modelOutput[bbox_y_offset]) - net.branches[i].zeroPoint) * net.branches[i].scale;
-                        det.bbox.w = (static_cast<float>(net.branches[i].modelOutput[bbox_w_offset]) - net.branches[i].zeroPoint) * net.branches[i].scale;
-                        det.bbox.h = (static_cast<float>(net.branches[i].modelOutput[bbox_h_offset]) - net.branches[i].zeroPoint) * net.branches[i].scale;
+                        det.bbox.x = (static_cast<float>(net.branches[i].modelOutput[bbox_x_offset])
+                                - net.branches[i].zeroPoint) * net.branches[i].scale;
+                        det.bbox.y = (static_cast<float>(net.branches[i].modelOutput[bbox_y_offset])
+                                - net.branches[i].zeroPoint) * net.branches[i].scale;
+                        det.bbox.w = (static_cast<float>(net.branches[i].modelOutput[bbox_w_offset])
+                                - net.branches[i].zeroPoint) * net.branches[i].scale;
+                        det.bbox.h = (static_cast<float>(net.branches[i].modelOutput[bbox_h_offset])
+                                - net.branches[i].zeroPoint) * net.branches[i].scale;
 
                         float bbox_x, bbox_y;
 
@@ -218,6 +236,5 @@
         num -=1;
 }
 
-} /* namespace object_detection */
 } /* namespace app */
 } /* namespace arm */
diff --git a/source/use_case/object_detection/src/DetectorPreProcessing.cc b/source/use_case/object_detection/src/DetectorPreProcessing.cc
new file mode 100644
index 0000000..7212046
--- /dev/null
+++ b/source/use_case/object_detection/src/DetectorPreProcessing.cc
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2022 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "DetectorPreProcessing.hpp"
+#include "ImageUtils.hpp"
+#include "log_macros.h"
+
+namespace arm {
+namespace app {
+
+    DetectorPreProcess::DetectorPreProcess(TfLiteTensor* inputTensor, bool rgb2Gray, bool convertToInt8)
+    :   m_inputTensor{inputTensor},
+        m_rgb2Gray{rgb2Gray},
+        m_convertToInt8{convertToInt8}
+    {}
+
+    bool DetectorPreProcess::DoPreProcess(const void* data, size_t inputSize) {
+        if (data == nullptr) {
+            printf_err("Data pointer is null");
+        }
+
+        auto input = static_cast<const uint8_t*>(data);
+
+        if (this->m_rgb2Gray) {
+            image::RgbToGrayscale(input, this->m_inputTensor->data.uint8, this->m_inputTensor->bytes);
+        } else {
+            std::memcpy(this->m_inputTensor->data.data, input, inputSize);
+        }
+        debug("Input tensor populated \n");
+
+        if (this->m_convertToInt8) {
+            image::ConvertImgToInt8(this->m_inputTensor->data.data, this->m_inputTensor->bytes);
+        }
+
+        return true;
+    }
+
+} /* namespace app */
+} /* namespace arm */
\ No newline at end of file
diff --git a/source/use_case/object_detection/src/MainLoop.cc b/source/use_case/object_detection/src/MainLoop.cc
index acfc195..4291164 100644
--- a/source/use_case/object_detection/src/MainLoop.cc
+++ b/source/use_case/object_detection/src/MainLoop.cc
@@ -19,7 +19,6 @@
 #include "YoloFastestModel.hpp"       /* Model class for running inference. */
 #include "UseCaseHandler.hpp"         /* Handlers for different user options. */
 #include "UseCaseCommonUtils.hpp"     /* Utils functions. */
-#include "DetectorPostProcessing.hpp" /* Post-processing class. */
 #include "log_macros.h"
 
 static void DisplayDetectionMenu()
@@ -53,9 +52,6 @@
     caseContext.Set<arm::app::Profiler&>("profiler", profiler);
     caseContext.Set<arm::app::Model&>("model", model);
     caseContext.Set<uint32_t>("imgIndex", 0);
-    arm::app::object_detection::DetectorPostprocessing postp;
-    caseContext.Set<arm::app::object_detection::DetectorPostprocessing&>("postprocess", postp);
-
 
     /* Loop. */
     bool executionSuccessful = true;
diff --git a/source/use_case/object_detection/src/UseCaseHandler.cc b/source/use_case/object_detection/src/UseCaseHandler.cc
index f3b317e..332d199 100644
--- a/source/use_case/object_detection/src/UseCaseHandler.cc
+++ b/source/use_case/object_detection/src/UseCaseHandler.cc
@@ -19,6 +19,7 @@
 #include "YoloFastestModel.hpp"
 #include "UseCaseCommonUtils.hpp"
 #include "DetectorPostProcessing.hpp"
+#include "DetectorPreProcessing.hpp"
 #include "hal.h"
 #include "log_macros.h"
 
@@ -33,7 +34,7 @@
      * @param[in]       results            Vector of detection results to be displayed.
      * @return          true if successful, false otherwise.
      **/
-    static bool PresentInferenceResult(const std::vector<arm::app::object_detection::DetectionResult>& results);
+    static bool PresentInferenceResult(const std::vector<object_detection::DetectionResult>& results);
 
     /**
      * @brief           Draw boxes directly on the LCD for all detected objects.
@@ -43,12 +44,12 @@
      * @param[in]       imgDownscaleFactor How much image has been downscaled on LCD.
      **/
     static void DrawDetectionBoxes(
-            const std::vector<arm::app::object_detection::DetectionResult>& results,
+           const std::vector<object_detection::DetectionResult>& results,
            uint32_t imgStartX,
            uint32_t imgStartY,
            uint32_t imgDownscaleFactor);
 
-    /* Object detection classification handler. */
+    /* Object detection inference handler. */
     bool ObjectDetectionHandler(ApplicationContext& ctx, uint32_t imgIndex, bool runAll)
     {
         auto& profiler = ctx.Get<Profiler&>("profiler");
@@ -75,9 +76,11 @@
             return false;
         }
 
-        auto curImIdx = ctx.Get<uint32_t>("imgIndex");
+        auto initialImgIdx = ctx.Get<uint32_t>("imgIndex");
 
         TfLiteTensor* inputTensor = model.GetInputTensor(0);
+        TfLiteTensor* outputTensor0 = model.GetOutputTensor(0);
+        TfLiteTensor* outputTensor1 = model.GetOutputTensor(1);
 
         if (!inputTensor->dims) {
             printf_err("Invalid input tensor dims\n");
@@ -89,71 +92,66 @@
 
         TfLiteIntArray* inputShape = model.GetInputShape(0);
 
-        const uint32_t nCols = inputShape->data[arm::app::YoloFastestModel::ms_inputColsIdx];
-        const uint32_t nRows = inputShape->data[arm::app::YoloFastestModel::ms_inputRowsIdx];
+        const int inputImgCols = inputShape->data[YoloFastestModel::ms_inputColsIdx];
+        const int inputImgRows = inputShape->data[YoloFastestModel::ms_inputRowsIdx];
 
-        /* Get pre/post-processing objects. */
-        auto& postp = ctx.Get<object_detection::DetectorPostprocessing&>("postprocess");
+        /* Set up pre and post-processing. */
+        DetectorPreProcess preProcess = DetectorPreProcess(inputTensor, true, model.IsDataSigned());
 
+        std::vector<object_detection::DetectionResult> results;
+        DetectorPostProcess postProcess = DetectorPostProcess(outputTensor0, outputTensor1,
+                results, inputImgRows, inputImgCols);
         do {
             /* Strings for presentation/logging. */
             std::string str_inf{"Running inference... "};
 
-            const uint8_t* curr_image = get_img_array(ctx.Get<uint32_t>("imgIndex"));
+            const uint8_t* currImage = get_img_array(ctx.Get<uint32_t>("imgIndex"));
 
-            /* Copy over the data and convert to grayscale */
-            auto* dstPtr = static_cast<uint8_t*>(inputTensor->data.uint8);
+            auto dstPtr = static_cast<uint8_t*>(inputTensor->data.uint8);
             const size_t copySz = inputTensor->bytes < IMAGE_DATA_SIZE ?
                                 inputTensor->bytes : IMAGE_DATA_SIZE;
 
-            /* Convert to gray scale and populate input tensor. */
-            image::RgbToGrayscale(curr_image, dstPtr, copySz);
+            /* Run the pre-processing, inference and post-processing. */
+            if (!preProcess.DoPreProcess(currImage, copySz)) {
+                printf_err("Pre-processing failed.");
+                return false;
+            }
 
             /* Display image on the LCD. */
             hal_lcd_display_image(
-                (channelsImageDisplayed == 3) ? curr_image : dstPtr,
-                nCols, nRows, channelsImageDisplayed,
+                (channelsImageDisplayed == 3) ? currImage : dstPtr,
+                inputImgCols, inputImgRows, channelsImageDisplayed,
                 dataPsnImgStartX, dataPsnImgStartY, dataPsnImgDownscaleFactor);
 
-            /* If the data is signed. */
-            if (model.IsDataSigned()) {
-                image::ConvertImgToInt8(inputTensor->data.data, inputTensor->bytes);
-            }
-
             /* Display message on the LCD - inference running. */
             hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
-                                    dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
+                    dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
 
             /* Run inference over this image. */
             info("Running inference on image %" PRIu32 " => %s\n", ctx.Get<uint32_t>("imgIndex"),
                 get_filename(ctx.Get<uint32_t>("imgIndex")));
 
             if (!RunInference(model, profiler)) {
+                printf_err("Inference failed.");
+                return false;
+            }
+
+            if (!postProcess.DoPostProcess()) {
+                printf_err("Post-processing failed.");
                 return false;
             }
 
             /* Erase. */
             str_inf = std::string(str_inf.size(), ' ');
             hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
-                                    dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
-
-            /* Detector post-processing*/
-            std::vector<object_detection::DetectionResult> results;
-            TfLiteTensor* modelOutput0 = model.GetOutputTensor(0);
-            TfLiteTensor* modelOutput1 = model.GetOutputTensor(1);
-            postp.RunPostProcessing(
-                nRows,
-                nCols,
-                modelOutput0,
-                modelOutput1,
-                results);
+                                 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
 
             /* Draw boxes. */
             DrawDetectionBoxes(results, dataPsnImgStartX, dataPsnImgStartY, dataPsnImgDownscaleFactor);
 
 #if VERIFY_TEST_OUTPUT
-            arm::app::DumpTensor(modelOutput0);
-            arm::app::DumpTensor(modelOutput1);
+            DumpTensor(modelOutput0);
+            DumpTensor(modelOutput1);
 #endif /* VERIFY_TEST_OUTPUT */
 
             if (!PresentInferenceResult(results)) {
@@ -164,12 +162,12 @@
 
             IncrementAppCtxIfmIdx(ctx,"imgIndex");
 
-        } while (runAll && ctx.Get<uint32_t>("imgIndex") != curImIdx);
+        } while (runAll && ctx.Get<uint32_t>("imgIndex") != initialImgIdx);
 
         return true;
     }
 
-    static bool PresentInferenceResult(const std::vector<arm::app::object_detection::DetectionResult>& results)
+    static bool PresentInferenceResult(const std::vector<object_detection::DetectionResult>& results)
     {
         hal_lcd_set_text_color(COLOR_GREEN);
 
@@ -186,7 +184,7 @@
         return true;
     }
 
-    static void DrawDetectionBoxes(const std::vector<arm::app::object_detection::DetectionResult>& results,
+    static void DrawDetectionBoxes(const std::vector<object_detection::DetectionResult>& results,
                                    uint32_t imgStartX,
                                    uint32_t imgStartY,
                                    uint32_t imgDownscaleFactor)
diff --git a/tests/use_case/object_detection/InferenceTestYoloFastest.cc b/tests/use_case/object_detection/InferenceTestYoloFastest.cc
index 8ef012d..2c035e7 100644
--- a/tests/use_case/object_detection/InferenceTestYoloFastest.cc
+++ b/tests/use_case/object_detection/InferenceTestYoloFastest.cc
@@ -94,13 +94,8 @@
         REQUIRE(tflite::GetTensorData<T>(output_arr[i]));
     }
 
-    arm::app::object_detection::DetectorPostprocessing postp;
-    postp.RunPostProcessing(
-        nRows,
-        nCols,
-        output_arr[0],
-        output_arr[1],
-        results);
+    arm::app::DetectorPostProcess postp{output_arr[0], output_arr[1], results, nRows, nCols};
+    postp.DoPostProcess();
 
     std::vector<std::vector<arm::app::object_detection::DetectionResult>> expected_results;
     GetExpectedResults(expected_results);
diff --git a/tests/use_case/object_detection/ObjectDetectionUCTest.cc b/tests/use_case/object_detection/ObjectDetectionUCTest.cc
index a7e4f33..023b893 100644
--- a/tests/use_case/object_detection/ObjectDetectionUCTest.cc
+++ b/tests/use_case/object_detection/ObjectDetectionUCTest.cc
@@ -58,8 +58,6 @@
     caseContext.Set<arm::app::Profiler&>("profiler", profiler);
     caseContext.Set<arm::app::Model&>("model", model);
     caseContext.Set<uint32_t>("imgIndex", 0);
-    arm::app::object_detection::DetectorPostprocessing postp;
-    caseContext.Set<arm::app::object_detection::DetectorPostprocessing&>("postprocess", postp);
 
     REQUIRE(arm::app::ObjectDetectionHandler(caseContext, 0, false));
 }
@@ -83,8 +81,6 @@
     caseContext.Set<arm::app::Profiler&>("profiler", profiler);
     caseContext.Set<arm::app::Model&>("model", model);
     caseContext.Set<uint32_t>("imgIndex", 0);
-    arm::app::object_detection::DetectorPostprocessing postp;
-    caseContext.Set<arm::app::object_detection::DetectorPostprocessing&>("postprocess", postp);
 
     REQUIRE(arm::app::ObjectDetectionHandler(caseContext, 0, true));
 }