Load dynamic backends for YoloV3

* Optional cmd option to dump optimized model to dot
* Optional cmd option to specify dynamic backends path
* input is now optional and must exist if given
* comparison files now optional and must exist if given

Change-Id: I1499c9eb715be3cacdba2c227e1a93dd997f355d
Signed-off-by: Derek Lamberti <derek.lamberti@arm.com>
diff --git a/tests/TfLiteYoloV3Big-Armnn/TfLiteYoloV3Big-Armnn.cpp b/tests/TfLiteYoloV3Big-Armnn/TfLiteYoloV3Big-Armnn.cpp
index 2d373cd..b896d26 100644
--- a/tests/TfLiteYoloV3Big-Armnn/TfLiteYoloV3Big-Armnn.cpp
+++ b/tests/TfLiteYoloV3Big-Armnn/TfLiteYoloV3Big-Armnn.cpp
@@ -97,12 +97,20 @@
     return outputTensors;
 }
 
+#define S_BOOL(name) enum class name {False=0, True=1};
+
+S_BOOL(ImportMemory)
+S_BOOL(DumpToDot)
+S_BOOL(ExpectFile)
+S_BOOL(OptionalArg)
+
 int LoadModel(const char* filename,
               ITfLiteParser& parser,
               IRuntime& runtime,
               NetworkId& networkId,
               const std::vector<BackendId>& backendPreferences,
-              bool enableImport = false)
+              ImportMemory enableImport,
+              DumpToDot dumpToDot)
 {
     std::ifstream stream(filename, std::ios::in | std::ios::binary);
     if (!stream.is_open())
@@ -120,7 +128,7 @@
 
     // Optimize backbone model
     OptimizerOptions options;
-    options.m_ImportEnabled = enableImport;
+    options.m_ImportEnabled = enableImport != ImportMemory::False;
     auto optimizedModel = Optimize(*model, backendPreferences, runtime.GetDeviceSpec(), options);
     if (!optimizedModel)
     {
@@ -128,10 +136,18 @@
         return OPTIMIZE_NETWORK_ERROR;
     }
 
+    if (dumpToDot != DumpToDot::False)
+    {
+        std::stringstream ss;
+        ss << filename << ".dot";
+        std::ofstream dotStream(ss.str().c_str(), std::ofstream::out);
+        optimizedModel->SerializeToDot(dotStream);
+        dotStream.close();
+    }
     // Load model into runtime
     {
         std::string errorMessage;
-        INetworkProperties modelProps(enableImport, enableImport);
+        INetworkProperties modelProps(options.m_ImportEnabled, options.m_ImportEnabled);
         Status status = runtime.LoadNetwork(networkId, std::move(optimizedModel), errorMessage, modelProps);
         if (status != Status::Success)
         {
@@ -145,6 +161,10 @@
 
 std::vector<float> LoadImage(const char* filename)
 {
+    if (strlen(filename) == 0)
+    {
+        return std::vector<float>(1920*10180*3, 0.0f);
+    }
     struct Memory
     {
         ~Memory() {stbi_image_free(m_Data);}
@@ -185,14 +205,14 @@
 }
 
 
-bool ValidateFilePath(std::string& file)
+bool ValidateFilePath(std::string& file, ExpectFile expectFile)
 {
     if (!ghc::filesystem::exists(file))
     {
         std::cerr << "Given file path " << file << " does not exist" << std::endl;
         return false;
     }
-    if (!ghc::filesystem::is_regular_file(file))
+    if (!ghc::filesystem::is_regular_file(file) && expectFile == ExpectFile::True)
     {
         std::cerr << "Given file path " << file << " is not a regular file" << std::endl;
         return false;
@@ -330,7 +350,15 @@
                  "of yoloV3big e.g. 'CpuAcc,CpuRef' -> CpuAcc will be tried "
                  "first before falling back to CpuRef. NOTE: Backends are passed "
                  "as comma separated list without whitespaces.",
-                 cxxopts::value<std::vector<std::string>>()->default_value("CpuAcc,CpuRef"));
+                 cxxopts::value<std::vector<std::string>>()->default_value("CpuAcc,CpuRef"))
+
+                ("M, model-to-dot",
+                 "Dump the optimized model to a dot file for debugging/analysis",
+                 cxxopts::value<bool>()->default_value("false"))
+
+                ("Y, dynamic-backends-path",
+                 "Define a path from which to load any dynamic backends.",
+                 cxxopts::value<std::string>());
 
         auto result = options.parse(ac, av);
 
@@ -340,17 +368,23 @@
             exit(EXIT_SUCCESS);
         }
 
-        backboneDir = GetPathArgument(result, "backbone-path");
-        comparisonFiles = GetPathArgument(result["comparison-files"].as<std::vector<std::string>>());
-        detectorDir = GetPathArgument(result, "detector-path");
-        imageDir    = GetPathArgument(result, "image-path");
 
+        backboneDir = GetPathArgument(result, "backbone-path", ExpectFile::True, OptionalArg::False);
 
+        comparisonFiles = GetPathArgument(result["comparison-files"].as<std::vector<std::string>>(), OptionalArg::True);
+
+        detectorDir = GetPathArgument(result, "detector-path", ExpectFile::True, OptionalArg::False);
+
+        imageDir    = GetPathArgument(result, "image-path", ExpectFile::True, OptionalArg::True);
+
+        dynamicBackendPath = GetPathArgument(result, "dynamic-backends-path", ExpectFile::False, OptionalArg::True);
 
         prefBackendsBackbone = GetBackendIDs(result["preferred-backends-backbone"].as<std::vector<std::string>>());
         LogBackendsInfo(prefBackendsBackbone, "Backbone");
         prefBackendsDetector = GetBackendIDs(result["preferred-backends-detector"].as<std::vector<std::string>>());
         LogBackendsInfo(prefBackendsDetector, "detector");
+
+        dumpToDot = result["model-to-dot"].as<bool>() ? DumpToDot::True : DumpToDot::False;
     }
 
     /// Takes a vector of backend strings and returns a vector of backendIDs
@@ -367,27 +401,41 @@
     /// Verifies if the program argument with the name argName contains a valid file path.
     /// Returns the valid file path string if given argument is associated a valid file path.
     /// Otherwise throws an exception.
-    std::string GetPathArgument(cxxopts::ParseResult& result, std::string&& argName)
+    std::string GetPathArgument(cxxopts::ParseResult& result,
+                                std::string&& argName,
+                                ExpectFile expectFile,
+                                OptionalArg isOptionalArg)
     {
         if (result.count(argName))
         {
-            std::string fileDir = result[argName].as<std::string>();
-            if (!ValidateFilePath(fileDir))
+            std::string path = result[argName].as<std::string>();
+            if (!ValidateFilePath(path, expectFile))
             {
-                throw cxxopts::option_syntax_exception("Argument given to backbone-path is not a valid file path");
+                std::stringstream ss;
+                ss << "Argument given to" << argName << "is not a valid file path";
+                throw cxxopts::option_syntax_exception(ss.str().c_str());
             }
-            return fileDir;
+            return path;
         }
         else
         {
+            if (isOptionalArg == OptionalArg::True)
+            {
+                return "";
+            }
+
             throw cxxopts::missing_argument_exception(argName);
         }
     }
 
     /// Assigns vector of strings to struct member variable
-    std::vector<std::string> GetPathArgument(const std::vector<std::string>& pathStrings)
+    std::vector<std::string> GetPathArgument(const std::vector<std::string>& pathStrings, OptionalArg isOptional)
     {
         if (pathStrings.size() < 5){
+            if (isOptional == OptionalArg::True)
+            {
+                return std::vector<std::string>();
+            }
             throw cxxopts::option_syntax_exception("Comparison files requires 5 file paths.");
         }
 
@@ -395,7 +443,7 @@
         for (auto& path : pathStrings)
         {
             filePaths.push_back(path);
-            if (!ValidateFilePath(filePaths.back()))
+            if (!ValidateFilePath(filePaths.back(), ExpectFile::True))
             {
                 throw cxxopts::option_syntax_exception("Argument given to Comparison Files is not a valid file path");
             }
@@ -420,11 +468,14 @@
     std::vector<std::string> comparisonFiles;
     std::string detectorDir;
     std::string imageDir;
+    std::string dynamicBackendPath;
 
     std::vector<BackendId> prefBackendsBackbone;
     std::vector<BackendId> prefBackendsDetector;
 
     cxxopts::Options options;
+
+    DumpToDot dumpToDot;
 };
 
 int main(int argc, char* argv[])
@@ -438,6 +489,13 @@
 
     // Create runtime
     IRuntime::CreationOptions runtimeOptions; // default
+
+    if (!progArgs.dynamicBackendPath.empty())
+    {
+        std::cout << "Loading backends from" << progArgs.dynamicBackendPath << "\n";
+        runtimeOptions.m_DynamicBackendsPath = progArgs.dynamicBackendPath;
+    }
+
     auto runtime = IRuntime::Create(runtimeOptions);
     if (!runtime)
     {
@@ -452,7 +510,14 @@
     // Load backbone model
     ARMNN_LOG(info) << "Loading backbone...";
     NetworkId backboneId;
-    CHECK_OK(LoadModel(progArgs.backboneDir.c_str(), *parser, *runtime, backboneId, progArgs.prefBackendsBackbone));
+    const DumpToDot dumpToDot = progArgs.dumpToDot;
+    CHECK_OK(LoadModel(progArgs.backboneDir.c_str(),
+                       *parser,
+                       *runtime,
+                       backboneId,
+                       progArgs.prefBackendsBackbone,
+                       ImportMemory::False,
+                       dumpToDot));
     auto inputId = parser->GetNetworkInputBindingInfo(0, "inputs");
     auto bbOut0Id = parser->GetNetworkOutputBindingInfo(0, "input_to_detector_1");
     auto bbOut1Id = parser->GetNetworkOutputBindingInfo(0, "input_to_detector_2");
@@ -460,11 +525,17 @@
     auto backboneProfile = runtime->GetProfiler(backboneId);
     backboneProfile->EnableProfiling(true);
 
+
     // Load detector model
     ARMNN_LOG(info) << "Loading detector...";
     NetworkId detectorId;
-    CHECK_OK(LoadModel(
-        progArgs.detectorDir.c_str(), *parser, *runtime, detectorId, progArgs.prefBackendsDetector, true));
+    CHECK_OK(LoadModel(progArgs.detectorDir.c_str(),
+                       *parser,
+                       *runtime,
+                       detectorId,
+                       progArgs.prefBackendsDetector,
+                       ImportMemory::True,
+                       dumpToDot));
     auto detectIn0Id = parser->GetNetworkInputBindingInfo(0, "input_to_detector_1");
     auto detectIn1Id = parser->GetNetworkInputBindingInfo(0, "input_to_detector_2");
     auto detectIn2Id = parser->GetNetworkInputBindingInfo(0, "input_to_detector_3");
@@ -574,9 +645,15 @@
     nmsProfileStream << "}" << "\n";
     nmsProfileStream.close();
 
-    CheckAccuracy(&intermediateMem0, &intermediateMem1,
-                  &intermediateMem2, &intermediateMem3,
-                  filtered_boxes, progArgs.comparisonFiles);
+    if (progArgs.comparisonFiles.size() > 0)
+    {
+        CheckAccuracy(&intermediateMem0,
+                      &intermediateMem1,
+                      &intermediateMem2,
+                      &intermediateMem3,
+                      filtered_boxes,
+                      progArgs.comparisonFiles);
+    }
 
     ARMNN_LOG(info) << "Run completed";
     return 0;