MLECO-3079: Implement image classification API

All ML related work for image classification seperated out and accessed via new Runner
Further work to improve profiling integration to be done in follow up ticket: MLECO-3154

Signed-off-by: Richard Burton <richard.burton@arm.com>
Change-Id: I0fe0550c932241a2d335a560ecb7abc329c934e9
diff --git a/source/application/main/include/BaseProcessing.hpp b/source/application/main/include/BaseProcessing.hpp
new file mode 100644
index 0000000..c1c3255
--- /dev/null
+++ b/source/application/main/include/BaseProcessing.hpp
@@ -0,0 +1,73 @@
+/*
+ * Copyright (c) 2022 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef BASE_PROCESSING_HPP
+#define BASE_PROCESSING_HPP
+
+#include "Model.hpp"
+
+namespace arm {
+namespace app {
+
+    /**
+     * @brief   Base class exposing pre-processing API.
+     *          Use cases should provide their own PreProcessing class that inherits from this one.
+     *          All steps required to take raw input data and populate tensors ready for inference
+     *          should be handled.
+     */
+    class BasePreProcess {
+
+    public:
+        virtual ~BasePreProcess() = default;
+
+        /**
+         * @brief       Should perform pre-processing of 'raw' input data and load it into
+         *              TFLite Micro input tensors ready for inference
+         * @param[in]   input      Pointer to the data that pre-processing will work on.
+         * @param[in]   inputSize  Size of the input data.
+         * @return      true if successful, false otherwise.
+         **/
+        virtual bool DoPreProcess(const void* input, size_t inputSize) = 0;
+
+    protected:
+        Model* m_model = nullptr;
+    };
+
+    /**
+     * @brief   Base class exposing post-processing API.
+     *          Use cases should provide their own PostProcessing class that inherits from this one.
+     *          All steps required to take inference output and populate results vectors should be handled.
+     */
+    class BasePostProcess {
+
+    public:
+        virtual ~BasePostProcess() = default;
+
+        /**
+         * @brief       Should perform post-processing of the result of inference then populate
+         *              populate result data for any later use.
+         * @return      true if successful, false otherwise.
+         **/
+        virtual bool DoPostProcess() = 0;
+
+    protected:
+        Model* m_model = nullptr;
+    };
+
+} /* namespace app */
+} /* namespace arm */
+
+#endif /* BASE_PROCESSING_HPP */
\ No newline at end of file
diff --git a/source/application/main/include/UseCaseCommonUtils.hpp b/source/application/main/include/UseCaseCommonUtils.hpp
index 9b6d550..f79f6ed 100644
--- a/source/application/main/include/UseCaseCommonUtils.hpp
+++ b/source/application/main/include/UseCaseCommonUtils.hpp
@@ -24,6 +24,7 @@
 #include "UseCaseHandler.hpp"       /* Handlers for different user options. */
 #include "Classifier.hpp"           /* Classifier. */
 #include "InputFiles.hpp"
+#include "BaseProcessing.hpp"
 
 
 void DisplayCommonMenu();
@@ -107,6 +108,67 @@
      **/
     bool ListFilesHandler(ApplicationContext& ctx);
 
+    /**
+     * @brief   Use case runner class that will handle calling pre-processing,
+     *          inference and post-processing.
+     *          After constructing an instance of this class the user can call
+     *          PreProcess(), RunInference() and PostProcess() to perform inference.
+     */
+    class UseCaseRunner {
+
+    private:
+        BasePreProcess* m_preProcess;
+        BasePostProcess* m_postProcess;
+        Model* m_model;
+
+    public:
+        explicit UseCaseRunner(BasePreProcess* preprocess, BasePostProcess* postprocess, Model* model)
+        : m_preProcess{preprocess},
+          m_postProcess{postprocess},
+          m_model{model}
+          {};
+
+        /**
+         * @brief       Runs pre-processing as defined by PreProcess object within the runner.
+         *              Templated for the input data type.
+         * @param[in]   inputData    Pointer to the data that inference will be performed on.
+         * @param[in]   inputSize    Size of the input data that inference will be performed on.
+         * @return      true if successful, false otherwise.
+         **/
+        template<typename T>
+        bool PreProcess(T* inputData, size_t inputSize) {
+            if (!this->m_preProcess->DoPreProcess(inputData, inputSize)) {
+                printf_err("Pre-processing failed.");
+                return false;
+            }
+            return true;
+        }
+
+        /**
+         * @brief       Runs inference with the Model object within the runner.
+         * @return      true if successful, false otherwise.
+         **/
+        bool RunInference() {
+            if (!this->m_model->RunInference()) {
+                printf_err("Inference failed.");
+                return false;
+            }
+            return true;
+        }
+
+        /**
+         * @brief       Runs post-processing as defined by PostProcess object within the runner.
+         * @return      true if successful, false otherwise.
+         **/
+        bool PostProcess() {
+            if (!this->m_postProcess->DoPostProcess()) {
+                printf_err("Post-processing failed.");
+                return false;
+            }
+            return true;
+        }
+    };
+
 } /* namespace app */
 } /* namespace arm */
 
diff --git a/source/use_case/img_class/include/ImgClassProcessing.hpp b/source/use_case/img_class/include/ImgClassProcessing.hpp
new file mode 100644
index 0000000..5a59b5f
--- /dev/null
+++ b/source/use_case/img_class/include/ImgClassProcessing.hpp
@@ -0,0 +1,63 @@
+/*
+ * Copyright (c) 2022 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef IMG_CLASS_PROCESSING_HPP
+#define IMG_CLASS_PROCESSING_HPP
+
+#include "BaseProcessing.hpp"
+#include "Model.hpp"
+#include "Classifier.hpp"
+
+namespace arm {
+namespace app {
+
+    /**
+     * @brief   Pre-processing class for Image Classification use case.
+     *          Implements methods declared by BasePreProcess and anything else needed
+     *          to populate input tensors ready for inference.
+     */
+    class ImgClassPreProcess : public BasePreProcess {
+
+    public:
+        explicit ImgClassPreProcess(Model* model);
+
+        bool DoPreProcess(const void* input, size_t inputSize) override;
+    };
+
+    /**
+     * @brief   Post-processing class for Image Classification use case.
+     *          Implements methods declared by BasePostProcess and anything else needed
+     *          to populate result vector.
+     */
+    class ImgClassPostProcess : public BasePostProcess {
+
+    private:
+        Classifier& m_imgClassifier;
+        const std::vector<std::string>& m_labels;
+        std::vector<ClassificationResult>& m_results;
+
+    public:
+        ImgClassPostProcess(Classifier& classifier, Model* model,
+                            const std::vector<std::string>& labels,
+                            std::vector<ClassificationResult>& results);
+
+        bool DoPostProcess() override;
+    };
+
+} /* namespace app */
+} /* namespace arm */
+
+#endif /* IMG_CLASS_PROCESSING_HPP */
\ No newline at end of file
diff --git a/source/use_case/img_class/src/ImgClassProcessing.cc b/source/use_case/img_class/src/ImgClassProcessing.cc
new file mode 100644
index 0000000..e33e3c1
--- /dev/null
+++ b/source/use_case/img_class/src/ImgClassProcessing.cc
@@ -0,0 +1,66 @@
+/*
+ * Copyright (c) 2022 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "ImgClassProcessing.hpp"
+#include "ImageUtils.hpp"
+#include "log_macros.h"
+
+namespace arm {
+namespace app {
+
+    ImgClassPreProcess::ImgClassPreProcess(Model* model)
+    {
+        this->m_model = model;
+    }
+
+    bool ImgClassPreProcess::DoPreProcess(const void* data, size_t inputSize)
+    {
+        if (data == nullptr) {
+            printf_err("Data pointer is null");
+        }
+
+        auto input = static_cast<const uint8_t*>(data);
+        TfLiteTensor* inputTensor = this->m_model->GetInputTensor(0);
+
+        memcpy(inputTensor->data.data, input, inputSize);
+        debug("Input tensor populated \n");
+
+        if (this->m_model->IsDataSigned()) {
+            image::ConvertImgToInt8(inputTensor->data.data, inputTensor->bytes);
+        }
+
+        return true;
+    }
+
+    ImgClassPostProcess::ImgClassPostProcess(Classifier& classifier, Model* model,
+                                             const std::vector<std::string>& labels,
+                                             std::vector<ClassificationResult>& results)
+            :m_imgClassifier{classifier},
+             m_labels{labels},
+             m_results{results}
+    {
+        this->m_model = model;
+    }
+
+    bool ImgClassPostProcess::DoPostProcess()
+    {
+        return this->m_imgClassifier.GetClassificationResults(
+                this->m_model->GetOutputTensor(0), this->m_results,
+                this->m_labels, 5, false);
+    }
+
+} /* namespace app */
+} /* namespace arm */
\ No newline at end of file
diff --git a/source/use_case/img_class/src/UseCaseHandler.cc b/source/use_case/img_class/src/UseCaseHandler.cc
index 9061282..98e2b59 100644
--- a/source/use_case/img_class/src/UseCaseHandler.cc
+++ b/source/use_case/img_class/src/UseCaseHandler.cc
@@ -23,6 +23,7 @@
 #include "UseCaseCommonUtils.hpp"
 #include "hal.h"
 #include "log_macros.h"
+#include "ImgClassProcessing.hpp"
 
 #include <cinttypes>
 
@@ -31,20 +32,12 @@
 namespace arm {
 namespace app {
 
-    /**
-    * @brief           Helper function to load the current image into the input
-    *                  tensor.
-    * @param[in]       imIdx         Image index (from the pool of images available
-    *                                to the application).
-    * @param[out]      inputTensor   Pointer to the input tensor to be populated.
-    * @return          true if tensor is loaded, false otherwise.
-    **/
-    static bool LoadImageIntoTensor(uint32_t imIdx, TfLiteTensor* inputTensor);
-
-    /* Image inference classification handler. */
+    /* Image classification inference handler. */
     bool ClassifyImageHandler(ApplicationContext& ctx, uint32_t imgIndex, bool runAll)
     {
         auto& profiler = ctx.Get<Profiler&>("profiler");
+        auto& model = ctx.Get<Model&>("model");
+        auto initialImIdx = ctx.Get<uint32_t>("imgIndex");
 
         constexpr uint32_t dataPsnImgDownscaleFactor = 2;
         constexpr uint32_t dataPsnImgStartX = 10;
@@ -53,8 +46,6 @@
         constexpr uint32_t dataPsnTxtInfStartX = 150;
         constexpr uint32_t dataPsnTxtInfStartY = 40;
 
-        auto& model = ctx.Get<Model&>("model");
-
         /* If the request has a valid size, set the image index. */
         if (imgIndex < NUMBER_OF_FILES) {
             if (!SetAppCtxIfmIdx(ctx, imgIndex, "imgIndex")) {
@@ -66,11 +57,7 @@
             return false;
         }
 
-        auto curImIdx = ctx.Get<uint32_t>("imgIndex");
-
-        TfLiteTensor* outputTensor = model.GetOutputTensor(0);
         TfLiteTensor* inputTensor = model.GetInputTensor(0);
-
         if (!inputTensor->dims) {
             printf_err("Invalid input tensor dims\n");
             return false;
@@ -79,13 +66,20 @@
             return false;
         }
 
+        /* Get input shape for displaying the image. */
         TfLiteIntArray* inputShape = model.GetInputShape(0);
-
         const uint32_t nCols = inputShape->data[arm::app::MobileNetModel::ms_inputColsIdx];
         const uint32_t nRows = inputShape->data[arm::app::MobileNetModel::ms_inputRowsIdx];
         const uint32_t nChannels = inputShape->data[arm::app::MobileNetModel::ms_inputChannelsIdx];
 
+        /* Set up pre and post-processing. */
+        ImgClassPreProcess preprocess = ImgClassPreProcess(&model);
+
         std::vector<ClassificationResult> results;
+        ImgClassPostProcess postprocess = ImgClassPostProcess(ctx.Get<ImgClassClassifier&>("classifier"), &model,
+                ctx.Get<std::vector<std::string>&>("labels"), results);
+
+        UseCaseRunner runner = UseCaseRunner(&preprocess, &postprocess, &model);
 
         do {
             hal_lcd_clear(COLOR_BLACK);
@@ -93,29 +87,42 @@
             /* Strings for presentation/logging. */
             std::string str_inf{"Running inference... "};
 
-            /* Copy over the data. */
-            LoadImageIntoTensor(ctx.Get<uint32_t>("imgIndex"), inputTensor);
+            const uint8_t* imgSrc = get_img_array(ctx.Get<uint32_t>("imgIndex"));
+            if (nullptr == imgSrc) {
+                printf_err("Failed to get image index %" PRIu32 " (max: %u)\n", ctx.Get<uint32_t>("imgIndex"),
+                           NUMBER_OF_FILES - 1);
+                return false;
+            }
 
             /* Display this image on the LCD. */
             hal_lcd_display_image(
-                static_cast<uint8_t *>(inputTensor->data.data),
+                imgSrc,
                 nCols, nRows, nChannels,
                 dataPsnImgStartX, dataPsnImgStartY, dataPsnImgDownscaleFactor);
 
-            /* If the data is signed. */
-            if (model.IsDataSigned()) {
-                image::ConvertImgToInt8(inputTensor->data.data, inputTensor->bytes);
-            }
-
             /* Display message on the LCD - inference running. */
             hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
                                     dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
 
-            /* Run inference over this image. */
+            /* Select the image to run inference with. */
             info("Running inference on image %" PRIu32 " => %s\n", ctx.Get<uint32_t>("imgIndex"),
                 get_filename(ctx.Get<uint32_t>("imgIndex")));
 
-            if (!RunInference(model, profiler)) {
+            const size_t imgSz = inputTensor->bytes < IMAGE_DATA_SIZE ?
+                                  inputTensor->bytes : IMAGE_DATA_SIZE;
+
+            /* Run the pre-processing, inference and post-processing. */
+            if (!runner.PreProcess(imgSrc, imgSz)) {
+                return false;
+            }
+
+            profiler.StartProfiling("Inference");
+            if (!runner.RunInference()) {
+                return false;
+            }
+            profiler.StopProfiling();
+
+            if (!runner.PostProcess()) {
                 return false;
             }
 
@@ -124,15 +131,11 @@
             hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
                                     dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
 
-            auto& classifier = ctx.Get<ImgClassClassifier&>("classifier");
-            classifier.GetClassificationResults(outputTensor, results,
-                                                ctx.Get<std::vector <std::string>&>("labels"),
-                                                5, false);
-
             /* Add results to context for access outside handler. */
             ctx.Set<std::vector<ClassificationResult>>("results", results);
 
 #if VERIFY_TEST_OUTPUT
+            TfLiteTensor* outputTensor = model.GetOutputTensor(0);
             arm::app::DumpTensor(outputTensor);
 #endif /* VERIFY_TEST_OUTPUT */
 
@@ -144,27 +147,10 @@
 
             IncrementAppCtxIfmIdx(ctx,"imgIndex");
 
-        } while (runAll && ctx.Get<uint32_t>("imgIndex") != curImIdx);
+        } while (runAll && ctx.Get<uint32_t>("imgIndex") != initialImIdx);
 
         return true;
     }
 
-    static bool LoadImageIntoTensor(uint32_t imIdx, TfLiteTensor* inputTensor)
-    {
-        const size_t copySz = inputTensor->bytes < IMAGE_DATA_SIZE ?
-                              inputTensor->bytes : IMAGE_DATA_SIZE;
-        const uint8_t* imgSrc = get_img_array(imIdx);
-        if (nullptr == imgSrc) {
-            printf_err("Failed to get image index %" PRIu32 " (max: %u)\n", imIdx,
-                       NUMBER_OF_FILES - 1);
-            return false;
-        }
-
-        memcpy(inputTensor->data.data, imgSrc, copySz);
-        debug("Image %" PRIu32 " loaded\n", imIdx);
-        return true;
-    }
-
-
 } /* namespace app */
 } /* namespace arm */