Network info

Add message for fetching meta data about built in network models.

Change-Id: I757094c20848d4cb018db68b0455297bb03be463
diff --git a/applications/message_handler/CMakeLists.txt b/applications/message_handler/CMakeLists.txt
index 72d930f..27a4815 100644
--- a/applications/message_handler/CMakeLists.txt
+++ b/applications/message_handler/CMakeLists.txt
@@ -1,5 +1,5 @@
 #
-# Copyright (c) 2020-2022 Arm Limited. All rights reserved.
+# Copyright (c) 2020-2022 Arm Limited.
 #
 # SPDX-License-Identifier: Apache-2.0
 #
diff --git a/applications/message_handler/main.cpp b/applications/message_handler/main.cpp
index 9b36f84..8a36325 100644
--- a/applications/message_handler/main.cpp
+++ b/applications/message_handler/main.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2019-2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2019-2022 Arm Limited.
  *
  * SPDX-License-Identifier: Apache-2.0
  *
@@ -182,7 +182,7 @@
     outputQueue    = xQueueCreate(10, sizeof(OutputMessage));
 
     // Task for handling incoming messages from the remote host
-    ret = xTaskCreate(inputMessageTask, "inputMessageTask", 512, nullptr, 2, nullptr);
+    ret = xTaskCreate(inputMessageTask, "inputMessageTask", 1024, nullptr, 2, nullptr);
     if (ret != pdPASS) {
         printf("Failed to create 'inputMessageTask'\n");
         return ret;
diff --git a/applications/message_handler/message_handler.cpp b/applications/message_handler/message_handler.cpp
index e530712..4b77389 100644
--- a/applications/message_handler/message_handler.cpp
+++ b/applications/message_handler/message_handler.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2020-2022 Arm Limited. All rights reserved.
+ * Copyright (c) 2020-2022 Arm Limited.
  *
  * SPDX-License-Identifier: Apache-2.0
  *
@@ -19,6 +19,7 @@
 #include "message_handler.hpp"
 
 #include "cmsis_compiler.h"
+#include "tensorflow/lite/schema/schema_generated.h"
 
 #ifdef ETHOSU
 #include <ethosu_driver.h>
@@ -31,6 +32,7 @@
 
 #include <cstring>
 #include <inttypes.h>
+#include <vector>
 
 #define XSTRINGIFY(src) #src
 #define STRINGIFY(src)  XSTRINGIFY(src)
@@ -79,6 +81,36 @@
  ****************************************************************************/
 
 namespace {
+
+template <typename T, typename U>
+class Array {
+public:
+    Array() = delete;
+    Array(T *const data, U &size, size_t capacity) : _data{data}, _size{size}, _capacity{capacity} {}
+
+    auto size() const {
+        return _size;
+    }
+
+    auto capacity() const {
+        return _capacity;
+    }
+
+    void push_back(const T &data) {
+        _data[_size++] = data;
+    }
+
+private:
+    T *const _data;
+    U &_size;
+    const size_t _capacity{};
+};
+
+template <typename T, typename U>
+Array<T, U> makeArray(T *const data, U &size, size_t capacity) {
+    return Array<T, U>{data, size, capacity};
+}
+
 bool getNetwork(const ethosu_core_buffer &buffer, void *&data, size_t &size) {
     data = reinterpret_cast<void *>(buffer.ptr);
     size = buffer.size;
@@ -134,6 +166,119 @@
         return true;
     }
 }
+
+bool getShapeSize(const flatbuffers::Vector<int32_t> *shape, size_t &size) {
+    size = 1;
+
+    if (shape == nullptr) {
+        printf("Warning: nullptr shape size.\n");
+        return true;
+    }
+
+    if (shape->Length() == 0) {
+        printf("Warning: shape zero length.\n");
+        return true;
+    }
+
+    for (auto it = shape->begin(); it != shape->end(); ++it) {
+        size *= *it;
+    }
+
+    return false;
+}
+
+bool getTensorTypeSize(const enum tflite::TensorType type, size_t &size) {
+    switch (type) {
+    case tflite::TensorType::TensorType_UINT8:
+    case tflite::TensorType::TensorType_INT8:
+        size = 1;
+        break;
+    case tflite::TensorType::TensorType_INT16:
+        size = 2;
+        break;
+    case tflite::TensorType::TensorType_INT32:
+    case tflite::TensorType::TensorType_FLOAT32:
+        size = 4;
+        break;
+    default:
+        printf("Warning: Unsupported tensor type\n");
+        return true;
+    }
+
+    return false;
+}
+
+template <typename T>
+bool getSubGraphDims(const tflite::SubGraph *subgraph, const flatbuffers::Vector<int32_t> *tensorMap, T &dims) {
+    if (subgraph == nullptr || tensorMap == nullptr) {
+        printf("Warning: nullptr subgraph or tensormap.\n");
+        return true;
+    }
+
+    if ((dims.capacity() - dims.size()) < tensorMap->size()) {
+        printf("Warning: tensormap size is larger than dimension capacity.\n");
+        return true;
+    }
+
+    for (auto index = tensorMap->begin(); index != tensorMap->end(); ++index) {
+        auto tensor = subgraph->tensors()->Get(*index);
+        size_t size;
+        size_t tensorSize;
+
+        bool failed = getShapeSize(tensor->shape(), size);
+        if (failed) {
+            return true;
+        }
+
+        failed = getTensorTypeSize(tensor->type(), tensorSize);
+        if (failed) {
+            return true;
+        }
+
+        size *= tensorSize;
+
+        if (size > 0) {
+            dims.push_back(size);
+        }
+    }
+
+    return false;
+}
+
+template <typename T, typename U, size_t S>
+bool parseModel(const ethosu_core_network_buffer &buffer, char (&description)[S], T &&ifmDims, U &&ofmDims) {
+    void *data;
+    size_t size;
+    bool failed = getNetwork(buffer, data, size);
+    if (failed) {
+        return true;
+    }
+
+    // Create model handle
+    const tflite::Model *model = tflite::GetModel(reinterpret_cast<const void *>(data));
+    if (model->subgraphs() == nullptr) {
+        printf("Warning: nullptr subgraph\n");
+        return true;
+    }
+
+    strncpy(description, model->description()->c_str(), sizeof(description));
+
+    // Get input dimensions for first subgraph
+    auto *subgraph = *model->subgraphs()->begin();
+    failed         = getSubGraphDims(subgraph, subgraph->inputs(), ifmDims);
+    if (failed) {
+        return true;
+    }
+
+    // Get output dimensions for last subgraph
+    subgraph = *model->subgraphs()->rbegin();
+    failed   = getSubGraphDims(subgraph, subgraph->outputs(), ofmDims);
+    if (failed) {
+        return true;
+    }
+
+    return false;
+}
 }; // namespace
 
 IncomingMessageHandler::IncomingMessageHandler(ethosu_core_queue &_messageQueue,
@@ -281,6 +426,32 @@
         xQueueSend(inferenceQueue, &inference, portMAX_DELAY);
         break;
     }
+    case ETHOSU_CORE_MSG_NETWORK_INFO_REQ: {
+        ethosu_core_network_info_req req;
+
+        if (!messageQueue.read(req)) {
+            queueErrorAndResetQueue(ETHOSU_CORE_MSG_ERR_INVALID_PAYLOAD, "NetworkInfoReq. Failed to read payload");
+            break;
+        }
+
+        printf("Msg: NetworkInfoReq. user_arg=0x%" PRIx64 "\n", req.user_arg);
+
+        OutputMessage message(ETHOSU_CORE_MSG_NETWORK_INFO_RSP);
+        ethosu_core_network_info_rsp &rsp = message.data.networkInfo;
+        rsp.user_arg                      = req.user_arg;
+        rsp.ifm_count                     = 0;
+        rsp.ofm_count                     = 0;
+
+        bool failed = parseModel(req.network,
+                                 rsp.desc,
+                                 makeArray(rsp.ifm_size, rsp.ifm_count, ETHOSU_CORE_BUFFER_MAX),
+                                 makeArray(rsp.ofm_size, rsp.ofm_count, ETHOSU_CORE_BUFFER_MAX));
+
+        rsp.status = failed ? ETHOSU_CORE_STATUS_ERROR : ETHOSU_CORE_STATUS_OK;
+
+        xQueueSend(outputQueue, &message, portMAX_DELAY);
+        break;
+    }
     default: {
         char errMsg[128];
 
@@ -428,6 +599,9 @@
         case ETHOSU_CORE_MSG_ERR:
             sendErrorRsp(message.data.error);
             break;
+        case ETHOSU_CORE_MSG_NETWORK_INFO_RSP:
+            sendNetworkInfoRsp(message.data.networkInfo);
+            break;
         default:
             printf("Dropping unknown outcome of type %d\n", message.type);
             break;
@@ -476,6 +650,14 @@
     }
 }
 
+void OutgoingMessageHandler::sendNetworkInfoRsp(EthosU::ethosu_core_network_info_rsp &networkInfo) {
+    if (!messageQueue.write(ETHOSU_CORE_MSG_NETWORK_INFO_RSP, networkInfo)) {
+        printf("ERROR: Msg: Failed to write network info response. No mailbox message sent\n");
+    } else {
+        mailbox.sendMessage();
+    }
+}
+
 void OutgoingMessageHandler::sendErrorRsp(ethosu_core_msg_err &error) {
     printf("ERROR: Msg: \"%s\"\n", error.msg);
 
diff --git a/applications/message_handler/message_handler.hpp b/applications/message_handler/message_handler.hpp
index ee063de..90b1cd2 100644
--- a/applications/message_handler/message_handler.hpp
+++ b/applications/message_handler/message_handler.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2020-2022 Arm Limited. All rights reserved.
+ * Copyright (c) 2020-2022 Arm Limited.
  *
  * SPDX-License-Identifier: Apache-2.0
  *
@@ -81,6 +81,7 @@
     EthosU::ethosu_core_msg_type type;
     union {
         EthosU::ethosu_core_inference_rsp inference;
+        EthosU::ethosu_core_network_info_rsp networkInfo;
         EthosU::ethosu_core_msg_err error;
         uint64_t userArg;
     } data;
@@ -99,6 +100,7 @@
     void sendVersionRsp();
     void sendCapabilitiesRsp(uint64_t userArg);
     void sendInferenceRsp(EthosU::ethosu_core_inference_rsp &inference);
+    void sendNetworkInfoRsp(EthosU::ethosu_core_network_info_rsp &networkInfo);
     void readCapabilties(EthosU::ethosu_core_msg_capabilities_rsp &rsp);
 
     MessageQueue::QueueImpl messageQueue;
diff --git a/applications/message_handler/message_queue.cpp b/applications/message_handler/message_queue.cpp
index e896349..c3890fe 100644
--- a/applications/message_handler/message_queue.cpp
+++ b/applications/message_handler/message_queue.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2020-2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2020-2022 Arm Limited.
  *
  * SPDX-License-Identifier: Apache-2.0
  *
diff --git a/applications/message_handler/message_queue.hpp b/applications/message_handler/message_queue.hpp
index 7c59e75..4140c62 100644
--- a/applications/message_handler/message_queue.hpp
+++ b/applications/message_handler/message_queue.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2020-2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2020-2022 Arm Limited.
  *
  * SPDX-License-Identifier: Apache-2.0
  *
diff --git a/applications/message_handler/model_template.hpp b/applications/message_handler/model_template.hpp
index 06636b2..353d7d3 100644
--- a/applications/message_handler/model_template.hpp
+++ b/applications/message_handler/model_template.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2022 Arm Limited. All rights reserved.
+ * Copyright (c) 2022 Arm Limited.
  *
  * SPDX-License-Identifier: Apache-2.0
  *