Fix MemoryOptimizerStrategyLibrary search

Signed-off-by: Finn Williams <finn.williams@arm.com>
Change-Id: I4ca8d9196abd0e116d420a36c780e39edbca0eb3
diff --git a/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyFactory.hpp b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyFactory.hpp
index aff0995..7b04f44 100644
--- a/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyFactory.hpp
+++ b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyFactory.hpp
@@ -12,17 +12,19 @@
 namespace armnn
 {
 
-class MemoryOptimizerStrategyFactory
+struct IMemoryOptimizerStrategyFactory
 {
-public:
-    MemoryOptimizerStrategyFactory() {}
+    virtual ~IMemoryOptimizerStrategyFactory() = default;
+    virtual std::unique_ptr<IMemoryOptimizerStrategy> CreateMemoryOptimizerStrategy() = 0;
+};
 
-    template <typename T>
-    std::unique_ptr<IMemoryOptimizerStrategy> CreateMemoryOptimizerStrategy()
+template <typename T>
+struct StrategyFactory : public IMemoryOptimizerStrategyFactory
+{
+    std::unique_ptr<IMemoryOptimizerStrategy> CreateMemoryOptimizerStrategy() override
     {
         return std::make_unique<T>();
     }
-
 };
 
 } // namespace armnn
\ No newline at end of file
diff --git a/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyLibrary.hpp b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyLibrary.hpp
index 5fa1515..9814405 100644
--- a/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyLibrary.hpp
+++ b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyLibrary.hpp
@@ -6,49 +6,52 @@
 
 #include <armnn/backends/IMemoryOptimizerStrategy.hpp>
 #include "MemoryOptimizerStrategyFactory.hpp"
-#include <algorithm>
 
 #include "strategies/ConstantMemoryStrategy.hpp"
 #include "strategies/StrategyValidator.hpp"
 #include "strategies/SingleAxisPriorityList.hpp"
 
-namespace
-{
-// Default Memory Optimizer Strategies
-static const std::vector<std::string> memoryOptimizationStrategies(
-{
-    "ConstantMemoryStrategy",
-    "SingleAxisPriorityList"
-    "StrategyValidator"
-});
+#include <map>
 
-#define CREATE_MEMORY_OPTIMIZER_STRATEGY(strategyName, memoryOptimizerStrategy)                                  \
-{                                                                                                                \
-    MemoryOptimizerStrategyFactory memoryOptimizerStrategyFactory;                                               \
-    memoryOptimizerStrategy = memoryOptimizerStrategyFactory.CreateMemoryOptimizerStrategy<strategyName>();      \
-}                                                                                                                \
-
-} // anonymous namespace
 namespace armnn
 {
-    std::unique_ptr<IMemoryOptimizerStrategy> GetMemoryOptimizerStrategy(const std::string& strategyName)
-    {
-        auto doesStrategyExist = std::find(memoryOptimizationStrategies.begin(),
-                                           memoryOptimizationStrategies.end(),
-                                           strategyName) != memoryOptimizationStrategies.end();
-        if (doesStrategyExist)
-        {
-            std::unique_ptr<IMemoryOptimizerStrategy> memoryOptimizerStrategy = nullptr;
-            CREATE_MEMORY_OPTIMIZER_STRATEGY(armnn::ConstantMemoryStrategy,
-                                             memoryOptimizerStrategy);
-            return  memoryOptimizerStrategy;
-        }
-        return nullptr;
-    }
+namespace
+{
 
+static std::map<std::string, std::unique_ptr<IMemoryOptimizerStrategyFactory>>& GetStrategyFactories()
+{
+    static std::map<std::string, std::unique_ptr<IMemoryOptimizerStrategyFactory>> strategies;
 
-    const std::vector<std::string>& GetMemoryOptimizerStrategyNames()
+    if (strategies.size() == 0)
     {
-        return memoryOptimizationStrategies;
+        strategies["ConstantMemoryStrategy"] = std::make_unique<StrategyFactory<ConstantMemoryStrategy>>();
+        strategies["SingleAxisPriorityList"] = std::make_unique<StrategyFactory<SingleAxisPriorityList>>();
+        strategies["StrategyValidator"]      = std::make_unique<StrategyFactory<StrategyValidator>>();
     }
+    return strategies;
+}
+
+} // anonymous namespace
+
+std::unique_ptr<IMemoryOptimizerStrategy> GetMemoryOptimizerStrategy(const std::string& strategyName)
+{
+     const auto& strategyFactoryMap = GetStrategyFactories();
+     auto strategyFactory = strategyFactoryMap.find(strategyName);
+     if (strategyFactory != GetStrategyFactories().end())
+     {
+         return  strategyFactory->second->CreateMemoryOptimizerStrategy();
+     }
+    return nullptr;
+}
+
+const std::vector<std::string> GetMemoryOptimizerStrategyNames()
+{
+    const auto& strategyFactoryMap = GetStrategyFactories();
+    std::vector<std::string> strategyNames;
+    for (const auto& strategyFactory : strategyFactoryMap)
+    {
+        strategyNames.emplace_back(strategyFactory.first);
+    }
+    return strategyNames;
+}
 } // namespace armnn
\ No newline at end of file
diff --git a/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/test/CMakeLists.txt b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/test/CMakeLists.txt
index 3068b60..a82f718 100644
--- a/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/test/CMakeLists.txt
+++ b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/test/CMakeLists.txt
@@ -7,6 +7,7 @@
             ConstMemoryStrategyTests.cpp
             ValidatorStrategyTests.cpp
             SingleAxisPriorityListTests.cpp
+            MemoryOptimizerStrategyLibraryTests.cpp
 )
 
 add_library(armnnMemoryOptimizationStrategiesUnitTests OBJECT ${armnnMemoryOptimizationStrategiesUnitTests_sources})
diff --git a/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/test/MemoryOptimizerStrategyLibraryTests.cpp b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/test/MemoryOptimizerStrategyLibraryTests.cpp
new file mode 100644
index 0000000..482bc7d
--- /dev/null
+++ b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/test/MemoryOptimizerStrategyLibraryTests.cpp
@@ -0,0 +1,28 @@
+//
+// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include <backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyLibrary.hpp>
+
+#include <doctest/doctest.h>
+
+using namespace armnn;
+
+TEST_SUITE("StrategyLibraryTestSuite")
+{
+
+TEST_CASE("StrategyLibraryTest")
+{
+    std::vector<std::string> strategyNames = GetMemoryOptimizerStrategyNames();
+    CHECK(strategyNames.size() != 0);
+    for (const auto& strategyName: strategyNames)
+    {
+        auto strategy = GetMemoryOptimizerStrategy(strategyName);
+        CHECK(strategy);
+        CHECK(strategy->GetName() == strategyName);
+    }
+}
+
+}
+