MLECO-3079: Implement image classification API

All ML related work for image classification seperated out and accessed via new Runner
Further work to improve profiling integration to be done in follow up ticket: MLECO-3154

Signed-off-by: Richard Burton <richard.burton@arm.com>
Change-Id: I0fe0550c932241a2d335a560ecb7abc329c934e9
diff --git a/source/use_case/img_class/src/UseCaseHandler.cc b/source/use_case/img_class/src/UseCaseHandler.cc
index 9061282..98e2b59 100644
--- a/source/use_case/img_class/src/UseCaseHandler.cc
+++ b/source/use_case/img_class/src/UseCaseHandler.cc
@@ -23,6 +23,7 @@
 #include "UseCaseCommonUtils.hpp"
 #include "hal.h"
 #include "log_macros.h"
+#include "ImgClassProcessing.hpp"
 
 #include <cinttypes>
 
@@ -31,20 +32,12 @@
 namespace arm {
 namespace app {
 
-    /**
-    * @brief           Helper function to load the current image into the input
-    *                  tensor.
-    * @param[in]       imIdx         Image index (from the pool of images available
-    *                                to the application).
-    * @param[out]      inputTensor   Pointer to the input tensor to be populated.
-    * @return          true if tensor is loaded, false otherwise.
-    **/
-    static bool LoadImageIntoTensor(uint32_t imIdx, TfLiteTensor* inputTensor);
-
-    /* Image inference classification handler. */
+    /* Image classification inference handler. */
     bool ClassifyImageHandler(ApplicationContext& ctx, uint32_t imgIndex, bool runAll)
     {
         auto& profiler = ctx.Get<Profiler&>("profiler");
+        auto& model = ctx.Get<Model&>("model");
+        auto initialImIdx = ctx.Get<uint32_t>("imgIndex");
 
         constexpr uint32_t dataPsnImgDownscaleFactor = 2;
         constexpr uint32_t dataPsnImgStartX = 10;
@@ -53,8 +46,6 @@
         constexpr uint32_t dataPsnTxtInfStartX = 150;
         constexpr uint32_t dataPsnTxtInfStartY = 40;
 
-        auto& model = ctx.Get<Model&>("model");
-
         /* If the request has a valid size, set the image index. */
         if (imgIndex < NUMBER_OF_FILES) {
             if (!SetAppCtxIfmIdx(ctx, imgIndex, "imgIndex")) {
@@ -66,11 +57,7 @@
             return false;
         }
 
-        auto curImIdx = ctx.Get<uint32_t>("imgIndex");
-
-        TfLiteTensor* outputTensor = model.GetOutputTensor(0);
         TfLiteTensor* inputTensor = model.GetInputTensor(0);
-
         if (!inputTensor->dims) {
             printf_err("Invalid input tensor dims\n");
             return false;
@@ -79,13 +66,20 @@
             return false;
         }
 
+        /* Get input shape for displaying the image. */
         TfLiteIntArray* inputShape = model.GetInputShape(0);
-
         const uint32_t nCols = inputShape->data[arm::app::MobileNetModel::ms_inputColsIdx];
         const uint32_t nRows = inputShape->data[arm::app::MobileNetModel::ms_inputRowsIdx];
         const uint32_t nChannels = inputShape->data[arm::app::MobileNetModel::ms_inputChannelsIdx];
 
+        /* Set up pre and post-processing. */
+        ImgClassPreProcess preprocess = ImgClassPreProcess(&model);
+
         std::vector<ClassificationResult> results;
+        ImgClassPostProcess postprocess = ImgClassPostProcess(ctx.Get<ImgClassClassifier&>("classifier"), &model,
+                ctx.Get<std::vector<std::string>&>("labels"), results);
+
+        UseCaseRunner runner = UseCaseRunner(&preprocess, &postprocess, &model);
 
         do {
             hal_lcd_clear(COLOR_BLACK);
@@ -93,29 +87,42 @@
             /* Strings for presentation/logging. */
             std::string str_inf{"Running inference... "};
 
-            /* Copy over the data. */
-            LoadImageIntoTensor(ctx.Get<uint32_t>("imgIndex"), inputTensor);
+            const uint8_t* imgSrc = get_img_array(ctx.Get<uint32_t>("imgIndex"));
+            if (nullptr == imgSrc) {
+                printf_err("Failed to get image index %" PRIu32 " (max: %u)\n", ctx.Get<uint32_t>("imgIndex"),
+                           NUMBER_OF_FILES - 1);
+                return false;
+            }
 
             /* Display this image on the LCD. */
             hal_lcd_display_image(
-                static_cast<uint8_t *>(inputTensor->data.data),
+                imgSrc,
                 nCols, nRows, nChannels,
                 dataPsnImgStartX, dataPsnImgStartY, dataPsnImgDownscaleFactor);
 
-            /* If the data is signed. */
-            if (model.IsDataSigned()) {
-                image::ConvertImgToInt8(inputTensor->data.data, inputTensor->bytes);
-            }
-
             /* Display message on the LCD - inference running. */
             hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
                                     dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
 
-            /* Run inference over this image. */
+            /* Select the image to run inference with. */
             info("Running inference on image %" PRIu32 " => %s\n", ctx.Get<uint32_t>("imgIndex"),
                 get_filename(ctx.Get<uint32_t>("imgIndex")));
 
-            if (!RunInference(model, profiler)) {
+            const size_t imgSz = inputTensor->bytes < IMAGE_DATA_SIZE ?
+                                  inputTensor->bytes : IMAGE_DATA_SIZE;
+
+            /* Run the pre-processing, inference and post-processing. */
+            if (!runner.PreProcess(imgSrc, imgSz)) {
+                return false;
+            }
+
+            profiler.StartProfiling("Inference");
+            if (!runner.RunInference()) {
+                return false;
+            }
+            profiler.StopProfiling();
+
+            if (!runner.PostProcess()) {
                 return false;
             }
 
@@ -124,15 +131,11 @@
             hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
                                     dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
 
-            auto& classifier = ctx.Get<ImgClassClassifier&>("classifier");
-            classifier.GetClassificationResults(outputTensor, results,
-                                                ctx.Get<std::vector <std::string>&>("labels"),
-                                                5, false);
-
             /* Add results to context for access outside handler. */
             ctx.Set<std::vector<ClassificationResult>>("results", results);
 
 #if VERIFY_TEST_OUTPUT
+            TfLiteTensor* outputTensor = model.GetOutputTensor(0);
             arm::app::DumpTensor(outputTensor);
 #endif /* VERIFY_TEST_OUTPUT */
 
@@ -144,27 +147,10 @@
 
             IncrementAppCtxIfmIdx(ctx,"imgIndex");
 
-        } while (runAll && ctx.Get<uint32_t>("imgIndex") != curImIdx);
+        } while (runAll && ctx.Get<uint32_t>("imgIndex") != initialImIdx);
 
         return true;
     }
 
-    static bool LoadImageIntoTensor(uint32_t imIdx, TfLiteTensor* inputTensor)
-    {
-        const size_t copySz = inputTensor->bytes < IMAGE_DATA_SIZE ?
-                              inputTensor->bytes : IMAGE_DATA_SIZE;
-        const uint8_t* imgSrc = get_img_array(imIdx);
-        if (nullptr == imgSrc) {
-            printf_err("Failed to get image index %" PRIu32 " (max: %u)\n", imIdx,
-                       NUMBER_OF_FILES - 1);
-            return false;
-        }
-
-        memcpy(inputTensor->data.data, imgSrc, copySz);
-        debug("Image %" PRIu32 " loaded\n", imIdx);
-        return true;
-    }
-
-
 } /* namespace app */
 } /* namespace arm */