GitHub #709 Provide a CreateNetworkFromBinary method for the ONNX parser

 * Added CreateNetworkFromBinary to the ONNX parser

Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: I5ca72ee49c7b098f9fb4aaf55a8bc077230cb30e
diff --git a/include/armnnOnnxParser/IOnnxParser.hpp b/include/armnnOnnxParser/IOnnxParser.hpp
index ba7fc83..89c22c0 100644
--- a/include/armnnOnnxParser/IOnnxParser.hpp
+++ b/include/armnnOnnxParser/IOnnxParser.hpp
@@ -1,5 +1,5 @@
 //
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017,2022 Arm Ltd. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 #pragma once
@@ -27,6 +27,13 @@
     static IOnnxParserPtr Create();
     static void Destroy(IOnnxParser* parser);
 
+    /// Create the network from a protobuf binary vector
+    armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent);
+
+    /// Create the network from a protobuf binary vector, with inputShapes specified
+    armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent,
+                                               const std::map<std::string, armnn::TensorShape>& inputShapes);
+
     /// Create the network from a protobuf binary file on disk
     armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile);
 
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index 63fb603..552d4e4 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -1,5 +1,5 @@
 //
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 #include "OnnxParser.hpp"
@@ -50,6 +50,17 @@
     return pOnnxParserImpl->CreateNetworkFromBinaryFile(graphFile);
 }
 
+armnn::INetworkPtr IOnnxParser::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent)
+{
+    return pOnnxParserImpl->CreateNetworkFromBinary(binaryContent);
+}
+
+armnn::INetworkPtr IOnnxParser::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent,
+                                                        const std::map<std::string, armnn::TensorShape>& inputShapes)
+{
+    return pOnnxParserImpl->CreateNetworkFromBinary(binaryContent, inputShapes);
+}
+
 armnn::INetworkPtr IOnnxParser::CreateNetworkFromTextFile(const char* graphFile)
 {
     return pOnnxParserImpl->CreateNetworkFromTextFile(graphFile);
@@ -731,6 +742,44 @@
     return CreateNetworkFromModel(*modelProto);
 }
 
+INetworkPtr OnnxParserImpl::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent)
+{
+    ResetParser();
+    ModelPtr modelProto = LoadModelFromBinary(binaryContent);
+    return CreateNetworkFromModel(*modelProto);
+}
+
+INetworkPtr OnnxParserImpl::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent,
+                                                    const std::map<std::string, armnn::TensorShape>& inputShapes)
+{
+    ResetParser();
+    m_InputShapes = inputShapes;
+    ModelPtr modelProto = LoadModelFromBinary(binaryContent);
+    return CreateNetworkFromModel(*modelProto);
+}
+
+ModelPtr OnnxParserImpl::LoadModelFromBinary(const std::vector<uint8_t>& binaryContent)
+{
+    if (binaryContent.size() == 0)
+    {
+        throw ParseException(fmt::format("Missing binary content", CHECK_LOCATION().AsString()));
+    }
+    // Parse the file into a message
+    ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
+
+    google::protobuf::io::CodedInputStream codedStream(binaryContent.data(), static_cast<int>(binaryContent.size()));
+    codedStream.SetTotalBytesLimit(INT_MAX);
+    bool success = modelProto.get()->ParseFromCodedStream(&codedStream);
+
+    if (!success)
+    {
+        std::stringstream error;
+        error << "Failed to parse graph";
+        throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString()));
+    }
+    return modelProto;
+}
+
 ModelPtr OnnxParserImpl::LoadModelFromBinaryFile(const char* graphFile)
 {
     FILE* fd = fopen(graphFile, "rb");
diff --git a/src/armnnOnnxParser/OnnxParser.hpp b/src/armnnOnnxParser/OnnxParser.hpp
index bb94472..c9f321a 100644
--- a/src/armnnOnnxParser/OnnxParser.hpp
+++ b/src/armnnOnnxParser/OnnxParser.hpp
@@ -1,5 +1,5 @@
 //
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 #pragma once
@@ -38,6 +38,13 @@
     armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile,
                                                    const std::map<std::string, armnn::TensorShape>& inputShapes);
 
+    /// Create the network from a protobuf binary
+    armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent);
+
+    /// Create the network from a protobuf binary, with inputShapes specified
+    armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent,
+                                               const std::map<std::string, armnn::TensorShape>& inputShapes);
+
     /// Create the network from a protobuf text file on disk
     armnn::INetworkPtr CreateNetworkFromTextFile(const char* graphFile);
 
@@ -64,6 +71,7 @@
     OnnxParserImpl();
     ~OnnxParserImpl() = default;
 
+    static ModelPtr LoadModelFromBinary(const std::vector<uint8_t>& binaryContent);
     static ModelPtr LoadModelFromBinaryFile(const char * fileName);
     static ModelPtr LoadModelFromTextFile(const char * fileName);
     static ModelPtr LoadModelFromString(const std::string& inputString);