MLECO-2873: Object detection usecase follow-up

Change-Id: Ic14e93a50fb7b3f3cfd9497bac1280794cc0fc15
Signed-off-by: Isabella Gottardi <isabella.gottardi@arm.com>
diff --git a/source/use_case/object_detection/src/UseCaseHandler.cc b/source/use_case/object_detection/src/UseCaseHandler.cc
index 45df4f8..ce3ef06 100644
--- a/source/use_case/object_detection/src/UseCaseHandler.cc
+++ b/source/use_case/object_detection/src/UseCaseHandler.cc
@@ -18,19 +18,23 @@
 #include "InputFiles.hpp"
 #include "YoloFastestModel.hpp"
 #include "UseCaseCommonUtils.hpp"
-#include "DetectionUseCaseUtils.hpp"
 #include "DetectorPostProcessing.hpp"
 #include "hal.h"
 
 #include <inttypes.h>
 
-
-/* used for presentation, original images are read-only"*/
-static uint8_t g_image_buffer[INPUT_IMAGE_WIDTH*INPUT_IMAGE_HEIGHT*FORMAT_MULTIPLY_FACTOR] IFM_BUF_ATTRIBUTE = {}; 
-
 namespace arm {
 namespace app {
 
+    /**
+     * @brief           Presents inference results along using the data presentation
+     *                  object.
+     * @param[in]       platform           Reference to the hal platform object.
+     * @param[in]       results            Vector of detection results to be displayed.
+     * @return          true if successful, false otherwise.
+     **/
+    static bool PresentInferenceResult(hal_platform& platform,
+                                    const std::vector<arm::app::object_detection::DetectionResult>& results);
 
     /* Object detection classification handler. */
     bool ObjectDetectionHandler(ApplicationContext& ctx, uint32_t imgIndex, bool runAll)
@@ -48,7 +52,7 @@
         platform.data_psn->clear(COLOR_BLACK);
 
         auto& model = ctx.Get<Model&>("model");
-        
+
         /* If the request has a valid size, set the image index. */
         if (imgIndex < NUMBER_OF_FILES) {
             if (!SetAppCtxIfmIdx(ctx, imgIndex, "imgIndex")) {
@@ -76,9 +80,10 @@
 
         const uint32_t nCols = inputShape->data[arm::app::YoloFastestModel::ms_inputColsIdx];
         const uint32_t nRows = inputShape->data[arm::app::YoloFastestModel::ms_inputRowsIdx];
-        const uint32_t nPresentationChannels = FORMAT_MULTIPLY_FACTOR;
+        const uint32_t nPresentationChannels = channelsImageDisplayed;
 
-        std::vector<DetectionResult> results;
+        /* Get pre/post-processing objects. */
+        auto& postp = ctx.Get<object_detection::DetectorPostprocessing&>("postprocess");
 
         do {
             /* Strings for presentation/logging. */
@@ -86,19 +91,23 @@
 
             const uint8_t* curr_image = get_img_array(ctx.Get<uint32_t>("imgIndex"));
 
-            /* Copy over the data  and convert to gryscale */
-#if DISPLAY_RGB_IMAGE
-            memcpy(g_image_buffer,curr_image, INPUT_IMAGE_WIDTH*INPUT_IMAGE_HEIGHT*FORMAT_MULTIPLY_FACTOR);
-#else 
-            RgbToGrayscale(curr_image,g_image_buffer,INPUT_IMAGE_WIDTH,INPUT_IMAGE_HEIGHT);
-#endif /*DISPLAY_RGB_IMAGE*/
-            
-            RgbToGrayscale(curr_image,inputTensor->data.uint8,INPUT_IMAGE_WIDTH,INPUT_IMAGE_HEIGHT);
+            /* Copy over the data and convert to grayscale */
+            auto* dstPtr = static_cast<uint8_t*>(inputTensor->data.uint8);
+            const size_t copySz = inputTensor->bytes < IMAGE_DATA_SIZE ?
+                                inputTensor->bytes : IMAGE_DATA_SIZE;
 
+            /* Copy of the image used for presentation, original images are read-only */
+            std::vector<uint8_t> g_image_buffer(nCols*nRows*channelsImageDisplayed);
+            if (nPresentationChannels == 3) {
+                memcpy(g_image_buffer.data(),curr_image, nCols * nRows * channelsImageDisplayed);
+            } else {
+                image::RgbToGrayscale(curr_image, g_image_buffer.data(), nCols * nRows);
+            }
+            image::RgbToGrayscale(curr_image, dstPtr, copySz);
 
             /* Display this image on the LCD. */
             platform.data_psn->present_data_image(
-                g_image_buffer,
+                g_image_buffer.data(),
                 nCols, nRows, nPresentationChannels,
                 dataPsnImgStartX, dataPsnImgStartY, dataPsnImgDownscaleFactor);
 
@@ -125,27 +134,27 @@
                                     dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
 
             /* Detector post-processing*/
-            TfLiteTensor* output_arr[2] = {nullptr,nullptr};
-            output_arr[0] = model.GetOutputTensor(0);
-            output_arr[1] = model.GetOutputTensor(1);
-            RunPostProcessing(g_image_buffer,output_arr,results);
+            std::vector<object_detection::DetectionResult> results;
+            TfLiteTensor* modelOutput0 = model.GetOutputTensor(0);
+            TfLiteTensor* modelOutput1 = model.GetOutputTensor(1);
+            postp.RunPostProcessing(
+                g_image_buffer.data(),
+                nRows,
+                nCols,
+                modelOutput0,
+                modelOutput1,
+                results);
 
             platform.data_psn->present_data_image(
-                g_image_buffer,
+                g_image_buffer.data(),
                 nCols, nRows, nPresentationChannels,
                 dataPsnImgStartX, dataPsnImgStartY, dataPsnImgDownscaleFactor);
 
-            /*Detector post-processing*/
-
-
-            /* Add results to context for access outside handler. */
-            ctx.Set<std::vector<DetectionResult>>("results", results);
-
 #if VERIFY_TEST_OUTPUT
             arm::app::DumpTensor(outputTensor);
 #endif /* VERIFY_TEST_OUTPUT */
 
-            if (!image::PresentInferenceResult(platform, results)) {
+            if (!PresentInferenceResult(platform, results)) {
                 return false;
             }
 
@@ -158,5 +167,24 @@
         return true;
     }
 
+
+    static bool PresentInferenceResult(hal_platform& platform,
+                                       const std::vector<arm::app::object_detection::DetectionResult>& results)
+    {
+        platform.data_psn->set_text_color(COLOR_GREEN);
+
+        /* If profiling is enabled, and the time is valid. */
+        info("Final results:\n");
+        info("Total number of inferences: 1\n");
+
+        for (uint32_t i = 0; i < results.size(); ++i) {
+            info("%" PRIu32 ") (%f) -> %s {x=%d,y=%d,w=%d,h=%d}\n", i,
+                results[i].m_normalisedVal, "Detection box:",
+                results[i].m_x0, results[i].m_y0, results[i].m_w, results[i].m_h );
+        }
+
+        return true;
+    }
+
 } /* namespace app */
 } /* namespace arm */