Add support for cancelling enqueued inferences

If an enqueued inference is cancelled it is simply removed from the
queue.  Both a cancel inference response with status done and a
inference response with status cancelled are sent back.

Also override the new operators to call in the FreeRTOS allocator
instead of malloc/free.

Change-Id: I243e678aa6b996084c9b9be1d1b00ffcecc75bc9
diff --git a/applications/message_handler/main.cpp b/applications/message_handler/main.cpp
index dde5dc5..b527840 100644
--- a/applications/message_handler/main.cpp
+++ b/applications/message_handler/main.cpp
@@ -71,10 +71,6 @@
 
 namespace {
 
-SemaphoreHandle_t messageNotify;
-QueueHandle_t inferenceInputQueue;
-QueueHandle_t inferenceOutputQueue;
-
 // Mailbox driver
 #ifdef MHU_V2
 Mailbox::MHUv2 mailbox(MHU_TX_BASE_ADDRESS, MHU_RX_BASE_ADDRESS); // txBase, rxBase
@@ -87,6 +83,26 @@
 } // namespace
 
 /****************************************************************************
+ * Override new operators to call in FreeRTOS allocator
+ ****************************************************************************/
+
+void *operator new(size_t size) {
+    return pvPortMalloc(size);
+}
+
+void *operator new[](size_t size) {
+    return pvPortMalloc(size);
+}
+
+void operator delete(void *ptr) {
+    vPortFree(ptr);
+}
+
+void operator delete[](void *ptr) {
+    vPortFree(ptr);
+}
+
+/****************************************************************************
  * Mutex & Semaphore
  ****************************************************************************/
 
@@ -150,6 +166,24 @@
  * Application
  ****************************************************************************/
 
+struct TaskParams {
+    TaskParams() :
+        messageNotify(xSemaphoreCreateBinary()),
+        inferenceInputQueue(std::make_shared<Queue<ethosu_core_inference_req>>()),
+        inferenceOutputQueue(xQueueCreate(10, sizeof(ethosu_core_inference_rsp))) {}
+
+    SemaphoreHandle_t messageNotify;
+    // Used to pass inference requests to the inference runner task
+    std::shared_ptr<Queue<ethosu_core_inference_req>> inferenceInputQueue;
+    // Queue for message responses to the remote host
+    QueueHandle_t inferenceOutputQueue;
+};
+
+struct InferenceTaskParams {
+    TaskParams *taskParams;
+    uint8_t *arena;
+};
+
 namespace {
 
 #ifdef MHU_IRQ
@@ -160,21 +194,27 @@
 
 void inferenceTask(void *pvParameters) {
     printf("Starting inference task\n");
+    InferenceTaskParams *params = reinterpret_cast<InferenceTaskParams *>(pvParameters);
 
-    uint8_t *arena = reinterpret_cast<uint8_t *>(pvParameters);
-    InferenceHandler process(arena, arenaSize, inferenceInputQueue, inferenceOutputQueue, messageNotify);
+    InferenceHandler process(params->arena,
+                             arenaSize,
+                             params->taskParams->inferenceInputQueue,
+                             params->taskParams->inferenceOutputQueue,
+                             params->taskParams->messageNotify);
+
     process.run();
 }
 
-void messageTask(void *) {
-    printf("Starting input message task\n");
+void messageTask(void *pvParameters) {
+    printf("Starting message task\n");
+    TaskParams *params = reinterpret_cast<TaskParams *>(pvParameters);
 
     IncomingMessageHandler process(*inputMessageQueue.toQueue(),
                                    *outputMessageQueue.toQueue(),
                                    mailbox,
-                                   inferenceInputQueue,
-                                   inferenceOutputQueue,
-                                   messageNotify);
+                                   params->inferenceInputQueue,
+                                   params->inferenceOutputQueue,
+                                   params->messageNotify);
 
 #ifdef MHU_IRQ
     // Register mailbox interrupt handler
@@ -196,21 +236,22 @@
         return 1;
     }
 
-    // Create message queues for inter process communication
-    messageNotify        = xSemaphoreCreateBinary();
-    inferenceInputQueue  = xQueueCreate(10, sizeof(ethosu_core_inference_req));
-    inferenceOutputQueue = xQueueCreate(10, sizeof(ethosu_core_inference_rsp));
+    TaskParams taskParams;
 
-    // Task for handling incoming messages from the remote host
-    ret = xTaskCreate(messageTask, "messageTask", 1024, nullptr, 2, nullptr);
+    // Task for handling incoming /outgoing messages from the remote host
+    ret = xTaskCreate(messageTask, "messageTask", 1024, &taskParams, 2, nullptr);
     if (ret != pdPASS) {
         printf("Failed to create 'messageTask'\n");
         return ret;
     }
 
+    InferenceTaskParams infParams[NUM_PARALLEL_TASKS];
+
     // One inference task for each NPU
     for (size_t n = 0; n < NUM_PARALLEL_TASKS; n++) {
-        ret = xTaskCreate(inferenceTask, "inferenceTask", 8 * 1024, &tensorArena[n], 3, nullptr);
+        infParams[n].taskParams = &taskParams;
+        infParams[n].arena      = reinterpret_cast<uint8_t *>(&tensorArena[n]);
+        ret                     = xTaskCreate(inferenceTask, "inferenceTask", 8 * 1024, &infParams[n], 3, nullptr);
         if (ret != pdPASS) {
             printf("Failed to create 'inferenceTask%d'\n", n);
             return ret;
diff --git a/applications/message_handler/message_handler.cpp b/applications/message_handler/message_handler.cpp
index f06d056..f109dc8 100644
--- a/applications/message_handler/message_handler.cpp
+++ b/applications/message_handler/message_handler.cpp
@@ -138,12 +138,13 @@
 
 }; // namespace
 
-IncomingMessageHandler::IncomingMessageHandler(EthosU::ethosu_core_queue &_inputMessageQueue,
-                                               EthosU::ethosu_core_queue &_outputMessageQueue,
-                                               Mailbox::Mailbox &_mailbox,
-                                               QueueHandle_t _inferenceInputQueue,
-                                               QueueHandle_t _inferenceOutputQueue,
-                                               SemaphoreHandle_t _messageNotify) :
+IncomingMessageHandler::IncomingMessageHandler(
+    EthosU::ethosu_core_queue &_inputMessageQueue,
+    EthosU::ethosu_core_queue &_outputMessageQueue,
+    Mailbox::Mailbox &_mailbox,
+    std::shared_ptr<Queue<EthosU::ethosu_core_inference_req>> _inferenceInputQueue,
+    QueueHandle_t _inferenceOutputQueue,
+    SemaphoreHandle_t _messageNotify) :
     inputMessageQueue(_inputMessageQueue),
     outputMessageQueue(_outputMessageQueue), mailbox(_mailbox), inferenceInputQueue(_inferenceInputQueue),
     inferenceOutputQueue(_inferenceOutputQueue), messageNotify(_messageNotify) {
@@ -166,7 +167,7 @@
         return;
     }
     IncomingMessageHandler *_this = reinterpret_cast<IncomingMessageHandler *>(userArg);
-    xSemaphoreGive(_this->messageNotify);
+    xSemaphoreGiveFromISR(_this->messageNotify, nullptr);
 }
 
 void IncomingMessageHandler::sendErrorAndResetQueue(ethosu_core_msg_err_type type, const char *message) {
@@ -287,7 +288,7 @@
         }
         printf("]\n");
 
-        if (pdTRUE != xQueueSend(inferenceInputQueue, &req, 0)) {
+        if (!inferenceInputQueue->push(req)) {
             printf("Msg: Inference queue full. Rejecting inference user_arg=0x%" PRIx64 "\n", req.user_arg);
             sendFailedInferenceRsp(req.user_arg, ETHOSU_CORE_STATUS_REJECTED);
         }
@@ -303,7 +304,15 @@
                req.user_arg,
                req.inference_handle);
 
-        sendCancelInferenceRsp(req.user_arg, ETHOSU_CORE_STATUS_ERROR);
+        bool found =
+            inferenceInputQueue->erase([req](auto &inf_req) { return inf_req.user_arg == req.inference_handle; });
+
+        // NOTE: send an inference response with status ABORTED if the inference has been droped from the queue
+        if (found) {
+            sendFailedInferenceRsp(req.inference_handle, ETHOSU_CORE_STATUS_ABORTED);
+        }
+
+        sendCancelInferenceRsp(req.user_arg, found ? ETHOSU_CORE_STATUS_OK : ETHOSU_CORE_STATUS_ERROR);
         break;
     }
     case ETHOSU_CORE_MSG_NETWORK_INFO_REQ: {
@@ -450,22 +459,20 @@
  * InferenceHandler
  ****************************************************************************/
 
-InferenceHandler::InferenceHandler(uint8_t *_tensorArena,
-                                   size_t _arenaSize,
-                                   QueueHandle_t _inferenceInputQueue,
+InferenceHandler::InferenceHandler(uint8_t *tensorArena,
+                                   size_t arenaSize,
+                                   std::shared_ptr<Queue<EthosU::ethosu_core_inference_req>> _inferenceInputQueue,
                                    QueueHandle_t _inferenceOutputQueue,
                                    SemaphoreHandle_t _messageNotify) :
     inferenceInputQueue(_inferenceInputQueue),
-    inferenceOutputQueue(_inferenceOutputQueue), messageNotify(_messageNotify), inference(_tensorArena, _arenaSize) {}
+    inferenceOutputQueue(_inferenceOutputQueue), messageNotify(_messageNotify), inference(tensorArena, arenaSize) {}
 
 void InferenceHandler::run() {
     ethosu_core_inference_req req;
     ethosu_core_inference_rsp rsp;
 
     while (true) {
-        if (pdTRUE != xQueueReceive(inferenceInputQueue, &req, portMAX_DELAY)) {
-            continue;
-        }
+        inferenceInputQueue->pop(req);
 
         runInference(req, rsp);
 
diff --git a/applications/message_handler/message_handler.hpp b/applications/message_handler/message_handler.hpp
index fa79205..dd05059 100644
--- a/applications/message_handler/message_handler.hpp
+++ b/applications/message_handler/message_handler.hpp
@@ -24,6 +24,7 @@
 #include "semphr.h"
 
 #include "message_queue.hpp"
+#include <ethosu_core_interface.h>
 #if defined(ETHOSU)
 #include <ethosu_driver.h>
 #endif
@@ -31,20 +32,97 @@
 #include <inference_process.hpp>
 #include <mailbox.hpp>
 
+#include <algorithm>
 #include <cstddef>
 #include <cstdio>
+#include <list>
 #include <vector>
 
 namespace MessageHandler {
 
+template <typename T, size_t capacity = 10>
+class Queue {
+public:
+    using Predicate = std::function<bool(const T &data)>;
+
+    Queue() {
+        mutex = xSemaphoreCreateMutex();
+        size  = xSemaphoreCreateCounting(capacity, 0u);
+
+        if (mutex == nullptr || size == nullptr) {
+            printf("Error: failed to allocate memory for inference queue\n");
+        }
+    }
+
+    ~Queue() {
+        vSemaphoreDelete(mutex);
+        vSemaphoreDelete(size);
+    }
+
+    bool push(const T &data) {
+        xSemaphoreTake(mutex, portMAX_DELAY);
+        if (list.size() >= capacity) {
+            xSemaphoreGive(mutex);
+            return false;
+        }
+
+        list.push_back(data);
+        xSemaphoreGive(mutex);
+
+        // increase number of available inferences to pop
+        xSemaphoreGive(size);
+        return true;
+    }
+
+    void pop(T &data) {
+        // decrease the number of available inferences to pop
+        xSemaphoreTake(size, portMAX_DELAY);
+
+        xSemaphoreTake(mutex, portMAX_DELAY);
+        data = list.front();
+        list.pop_front();
+        xSemaphoreGive(mutex);
+    }
+
+    bool erase(Predicate pred) {
+        // let's optimistically assume we are removing an inference, so decrease pop
+        if (pdFALSE == xSemaphoreTake(size, 0)) {
+            // if there are no inferences return immediately
+            return false;
+        }
+
+        xSemaphoreTake(mutex, portMAX_DELAY);
+        auto found  = std::find_if(list.begin(), list.end(), pred);
+        bool erased = found != list.end();
+        if (erased) {
+            list.erase(found);
+        }
+        xSemaphoreGive(mutex);
+
+        if (!erased) {
+            // no inference erased, so let's put the size count back
+            xSemaphoreGive(size);
+        }
+
+        return erased;
+    }
+
+private:
+    std::list<T> list;
+
+    SemaphoreHandle_t mutex;
+    SemaphoreHandle_t size;
+};
+
 class IncomingMessageHandler {
 public:
     IncomingMessageHandler(EthosU::ethosu_core_queue &inputMessageQueue,
                            EthosU::ethosu_core_queue &outputMessageQueue,
                            Mailbox::Mailbox &mailbox,
-                           QueueHandle_t inferenceInputQueue,
+                           std::shared_ptr<Queue<EthosU::ethosu_core_inference_req>> inferenceInputQueue,
                            QueueHandle_t inferenceOutputQueue,
                            SemaphoreHandle_t messageNotify);
+
     void run();
 
 private:
@@ -66,7 +144,7 @@
     MessageQueue::QueueImpl outputMessageQueue;
     Mailbox::Mailbox &mailbox;
     InferenceProcess::InferenceParser parser;
-    QueueHandle_t inferenceInputQueue;
+    std::shared_ptr<Queue<EthosU::ethosu_core_inference_req>> inferenceInputQueue;
     QueueHandle_t inferenceOutputQueue;
     SemaphoreHandle_t messageNotify;
     EthosU::ethosu_core_msg_capabilities_rsp capabilities;
@@ -76,7 +154,7 @@
 public:
     InferenceHandler(uint8_t *tensorArena,
                      size_t arenaSize,
-                     QueueHandle_t inferenceInputQueue,
+                     std::shared_ptr<Queue<EthosU::ethosu_core_inference_req>> inferenceInputQueue,
                      QueueHandle_t inferenceOutputQueue,
                      SemaphoreHandle_t messageNotify);
 
@@ -90,7 +168,7 @@
     friend void ::ethosu_inference_begin(struct ethosu_driver *drv, void *userArg);
     friend void ::ethosu_inference_end(struct ethosu_driver *drv, void *userArg);
 #endif
-    QueueHandle_t inferenceInputQueue;
+    std::shared_ptr<Queue<EthosU::ethosu_core_inference_req>> inferenceInputQueue;
     QueueHandle_t inferenceOutputQueue;
     SemaphoreHandle_t messageNotify;
     InferenceProcess::InferenceProcess inference;