IVGCVSW-2855 Create TfLite reference test for DeepSpeechV1

Change-Id: I4492a85c8337bf4ea0eb998c88b9cbfc932dc4e6
Signed-off-by: Ruomei Yan <ruomei.yan@arm.com>
Signed-off-by: Jim Flynn <jim.flynn@arm.com>
diff --git a/tests/DeepSpeechV1Database.hpp b/tests/DeepSpeechV1Database.hpp
new file mode 100644
index 0000000..4d2d591
--- /dev/null
+++ b/tests/DeepSpeechV1Database.hpp
@@ -0,0 +1,203 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#pragma once
+
+#include "LstmCommon.hpp"
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include <armnn/TypesUtils.hpp>
+#include <backendsCommon/test/QuantizeHelper.hpp>
+
+#include <boost/log/trivial.hpp>
+#include <boost/numeric/conversion/cast.hpp>
+
+#include <array>
+#include <string>
+
+#include "InferenceTestImage.hpp"
+
+namespace
+{
+
+template<typename T, typename TParseElementFunc>
+std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc, const char * chars = "\t ,:")
+{
+    std::vector<T> result;
+    // Processes line-by-line.
+    std::string line;
+    while (std::getline(stream, line))
+    {
+        std::vector<std::string> tokens;
+        try
+        {
+            // Coverity fix: boost::split() may throw an exception of type boost::bad_function_call.
+            boost::split(tokens, line, boost::algorithm::is_any_of(chars), boost::token_compress_on);
+        }
+        catch (const std::exception& e)
+        {
+            BOOST_LOG_TRIVIAL(error) << "An error occurred when splitting tokens: " << e.what();
+            continue;
+        }
+        for (const std::string& token : tokens)
+        {
+            if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
+            {
+                try
+                {
+                    result.push_back(parseElementFunc(token));
+                }
+                catch (const std::exception&)
+                {
+                    BOOST_LOG_TRIVIAL(error) << "'" << token << "' is not a valid number. It has been ignored.";
+                }
+            }
+        }
+    }
+
+    return result;
+}
+
+template<armnn::DataType NonQuantizedType>
+auto ParseDataArray(std::istream & stream);
+
+template<armnn::DataType QuantizedType>
+auto ParseDataArray(std::istream& stream,
+                    const float& quantizationScale,
+                    const int32_t& quantizationOffset);
+
+template<>
+auto ParseDataArray<armnn::DataType::Float32>(std::istream & stream)
+{
+    return ParseArrayImpl<float>(stream, [](const std::string& s) { return std::stof(s); });
+}
+
+template<>
+auto ParseDataArray<armnn::DataType::Signed32>(std::istream & stream)
+{
+    return ParseArrayImpl<int>(stream, [](const std::string & s) { return std::stoi(s); });
+}
+
+template<>
+auto ParseDataArray<armnn::DataType::QuantisedAsymm8>(std::istream& stream,
+                                                      const float& quantizationScale,
+                                                      const int32_t& quantizationOffset)
+{
+    return ParseArrayImpl<uint8_t>(stream,
+                                   [&quantizationScale, &quantizationOffset](const std::string & s)
+                                   {
+                                       return boost::numeric_cast<uint8_t>(
+                                               armnn::Quantize<u_int8_t>(std::stof(s),
+                                                                         quantizationScale,
+                                                                         quantizationOffset));
+                                   });
+}
+
+struct DeepSpeechV1TestCaseData
+{
+    DeepSpeechV1TestCaseData(
+        const LstmInput& inputData,
+        const LstmInput& expectedOutputData)
+        : m_InputData(inputData)
+        , m_ExpectedOutputData(expectedOutputData)
+    {}
+
+    LstmInput m_InputData;
+    LstmInput m_ExpectedOutputData;
+};
+
+class DeepSpeechV1Database
+{
+public:
+    explicit DeepSpeechV1Database(const std::string& inputSeqDir, const std::string& prevStateCDir,
+                                  const std::string& prevStateHDir, const std::string& logitsDir,
+                                  const std::string& newStateCDir, const std::string& newStateHDir);
+
+    std::unique_ptr<DeepSpeechV1TestCaseData> GetTestCaseData(unsigned int testCaseId);
+
+private:
+    std::string m_InputSeqDir;
+    std::string m_PrevStateCDir;
+    std::string m_PrevStateHDir;
+    std::string m_LogitsDir;
+    std::string m_NewStateCDir;
+    std::string m_NewStateHDir;
+};
+
+DeepSpeechV1Database::DeepSpeechV1Database(const std::string& inputSeqDir, const std::string& prevStateCDir,
+                                           const std::string& prevStateHDir, const std::string& logitsDir,
+                                           const std::string& newStateCDir, const std::string& newStateHDir)
+    : m_InputSeqDir(inputSeqDir)
+    , m_PrevStateCDir(prevStateCDir)
+    , m_PrevStateHDir(prevStateHDir)
+    , m_LogitsDir(logitsDir)
+    , m_NewStateCDir(newStateCDir)
+    , m_NewStateHDir(newStateHDir)
+{}
+
+std::unique_ptr<DeepSpeechV1TestCaseData> DeepSpeechV1Database::GetTestCaseData(unsigned int testCaseId)
+{
+    // Load test case input
+    const std::string inputSeqPath   = m_InputSeqDir + "input_node_0_flat.txt";
+    const std::string prevStateCPath = m_PrevStateCDir + "previous_state_c_0.txt";
+    const std::string prevStateHPath = m_PrevStateHDir + "previous_state_h_0.txt";
+
+    std::vector<float> inputSeqData;
+    std::vector<float> prevStateCData;
+    std::vector<float> prevStateHData;
+
+    std::ifstream inputSeqFile(inputSeqPath);
+    std::ifstream prevStateCTensorFile(prevStateCPath);
+    std::ifstream prevStateHTensorFile(prevStateHPath);
+
+    try
+    {
+        inputSeqData   = ParseDataArray<armnn::DataType::Float32>(inputSeqFile);
+        prevStateCData = ParseDataArray<armnn::DataType::Float32>(prevStateCTensorFile);
+        prevStateHData = ParseDataArray<armnn::DataType::Float32>(prevStateHTensorFile);
+    }
+    catch (const InferenceTestImageException& e)
+    {
+        BOOST_LOG_TRIVIAL(fatal) << "Failed to load image for test case " << testCaseId << ". Error: " << e.what();
+        return nullptr;
+    }
+
+    // Prepare test case expected output
+    const std::string logitsPath   = m_LogitsDir + "logits.txt";
+    const std::string newStateCPath = m_NewStateCDir + "new_state_c.txt";
+    const std::string newStateHPath = m_NewStateHDir + "new_state_h.txt";
+
+    std::vector<float> logitsData;
+    std::vector<float> expectedNewStateCData;
+    std::vector<float> expectedNewStateHData;
+
+    std::ifstream logitsTensorFile(logitsPath);
+    std::ifstream newStateCTensorFile(newStateCPath);
+    std::ifstream newStateHTensorFile(newStateHPath);
+
+    try
+    {
+        logitsData     = ParseDataArray<armnn::DataType::Float32>(logitsTensorFile);
+        expectedNewStateCData = ParseDataArray<armnn::DataType::Float32>(newStateCTensorFile);
+        expectedNewStateHData = ParseDataArray<armnn::DataType::Float32>(newStateHTensorFile);
+    }
+    catch (const InferenceTestImageException& e)
+    {
+        BOOST_LOG_TRIVIAL(fatal) << "Failed to load image for test case " << testCaseId << ". Error: " << e.what();
+        return nullptr;
+    }
+
+    // use the struct for representing input and output data
+    LstmInput inputDataSingleTest(inputSeqData, prevStateCData, prevStateHData);
+
+    LstmInput expectedOutputsSingleTest(logitsData, expectedNewStateCData, expectedNewStateHData);
+
+    return std::make_unique<DeepSpeechV1TestCaseData>(inputDataSingleTest, expectedOutputsSingleTest);
+}
+
+} // anonymous namespace
+
diff --git a/tests/DeepSpeechV1InferenceTest.hpp b/tests/DeepSpeechV1InferenceTest.hpp
new file mode 100755
index 0000000..24e7dac
--- /dev/null
+++ b/tests/DeepSpeechV1InferenceTest.hpp
@@ -0,0 +1,201 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#pragma once
+
+#include "InferenceTest.hpp"
+#include "DeepSpeechV1Database.hpp"
+
+#include <boost/assert.hpp>
+#include <boost/log/trivial.hpp>
+#include <boost/numeric/conversion/cast.hpp>
+#include <boost/test/tools/floating_point_comparison.hpp>
+
+#include <vector>
+
+namespace
+{
+
+template<typename Model>
+class DeepSpeechV1TestCase : public InferenceModelTestCase<Model>
+{
+public:
+    DeepSpeechV1TestCase(Model& model,
+                         unsigned int testCaseId,
+                         const DeepSpeechV1TestCaseData& testCaseData)
+        : InferenceModelTestCase<Model>(model,
+                                        testCaseId,
+                                        { testCaseData.m_InputData.m_InputSeq,
+                                          testCaseData.m_InputData.m_StateC,
+                                          testCaseData.m_InputData.m_StateH},
+                                        { k_OutputSize1, k_OutputSize2, k_OutputSize3 })
+        , m_FloatComparer(boost::math::fpc::percent_tolerance(1.0f))
+        , m_ExpectedOutputs({testCaseData.m_ExpectedOutputData.m_InputSeq, testCaseData.m_ExpectedOutputData.m_StateC,
+                             testCaseData.m_ExpectedOutputData.m_StateH})
+    {}
+
+    TestCaseResult ProcessResult(const InferenceTestOptions& options) override
+    {
+        const std::vector<float>& output1 = boost::get<std::vector<float>>(this->GetOutputs()[0]); // logits
+        BOOST_ASSERT(output1.size() == k_OutputSize1);
+
+        const std::vector<float>& output2 = boost::get<std::vector<float>>(this->GetOutputs()[1]); // new_state_c
+        BOOST_ASSERT(output2.size() == k_OutputSize2);
+
+        const std::vector<float>& output3 = boost::get<std::vector<float>>(this->GetOutputs()[2]); // new_state_h
+        BOOST_ASSERT(output3.size() == k_OutputSize3);
+
+        // Check each output to see whether it is the expected value
+        for (unsigned int j = 0u; j < output1.size(); j++)
+        {
+            if(!m_FloatComparer(output1[j], m_ExpectedOutputs.m_InputSeq[j]))
+            {
+                BOOST_LOG_TRIVIAL(error) << "InputSeq for Lstm " << this->GetTestCaseId() <<
+                                         " is incorrect at" << j;
+                return TestCaseResult::Failed;
+            }
+        }
+
+        for (unsigned int j = 0u; j < output2.size(); j++)
+        {
+            if(!m_FloatComparer(output2[j], m_ExpectedOutputs.m_StateC[j]))
+            {
+                BOOST_LOG_TRIVIAL(error) << "StateC for Lstm " << this->GetTestCaseId() <<
+                                         " is incorrect";
+                return TestCaseResult::Failed;
+            }
+        }
+
+        for (unsigned int j = 0u; j < output3.size(); j++)
+        {
+            if(!m_FloatComparer(output3[j], m_ExpectedOutputs.m_StateH[j]))
+            {
+                BOOST_LOG_TRIVIAL(error) << "StateH for Lstm " << this->GetTestCaseId() <<
+                                         " is incorrect";
+                return TestCaseResult::Failed;
+            }
+        }
+        return TestCaseResult::Ok;
+    }
+
+private:
+
+    static constexpr unsigned int k_OutputSize1 = 464u;
+    static constexpr unsigned int k_OutputSize2 = 2048u;
+    static constexpr unsigned int k_OutputSize3 = 2048u;
+
+    boost::math::fpc::close_at_tolerance<float> m_FloatComparer;
+    LstmInput m_ExpectedOutputs;
+};
+
+template <typename Model>
+class DeepSpeechV1TestCaseProvider : public IInferenceTestCaseProvider
+{
+public:
+    template <typename TConstructModelCallable>
+    explicit DeepSpeechV1TestCaseProvider(TConstructModelCallable constructModel)
+        : m_ConstructModel(constructModel)
+    {}
+
+    virtual void AddCommandLineOptions(boost::program_options::options_description& options) override
+    {
+        namespace po = boost::program_options;
+
+        options.add_options()
+                ("input-seq-dir,s", po::value<std::string>(&m_InputSeqDir)->required(),
+                 "Path to directory containing test data for m_InputSeq");
+        options.add_options()
+                ("prev-state-c-dir,c", po::value<std::string>(&m_PrevStateCDir)->required(),
+                 "Path to directory containing test data for m_PrevStateC");
+        options.add_options()
+                ("prev-state-h-dir,h", po::value<std::string>(&m_PrevStateHDir)->required(),
+                 "Path to directory containing test data for m_PrevStateH");
+        options.add_options()
+                ("logits-dir,l", po::value<std::string>(&m_LogitsDir)->required(),
+                 "Path to directory containing test data for m_Logits");
+        options.add_options()
+                ("new-state-c-dir,C", po::value<std::string>(&m_NewStateCDir)->required(),
+                 "Path to directory containing test data for m_NewStateC");
+        options.add_options()
+                ("new-state-h-dir,H", po::value<std::string>(&m_NewStateHDir)->required(),
+                 "Path to directory containing test data for m_NewStateH");
+
+        Model::AddCommandLineOptions(options, m_ModelCommandLineOptions);
+    }
+
+    virtual bool ProcessCommandLineOptions() override
+    {
+        if (!ValidateDirectory(m_InputSeqDir))
+        {
+            return false;
+        }
+
+        if (!ValidateDirectory(m_PrevStateCDir))
+        {
+            return false;
+        }
+
+        if (!ValidateDirectory(m_PrevStateHDir))
+        {
+            return false;
+        }
+
+        if (!ValidateDirectory(m_LogitsDir))
+        {
+            return false;
+        }
+
+        if (!ValidateDirectory(m_NewStateCDir))
+        {
+            return false;
+        }
+
+        if (!ValidateDirectory(m_NewStateHDir))
+        {
+            return false;
+        }
+
+        m_Model = m_ConstructModel(m_ModelCommandLineOptions);
+        if (!m_Model)
+        {
+            return false;
+        }
+        m_Database = std::make_unique<DeepSpeechV1Database>(m_InputSeqDir.c_str(), m_PrevStateCDir.c_str(),
+                                                            m_PrevStateHDir.c_str(), m_LogitsDir.c_str(),
+                                                            m_NewStateCDir.c_str(), m_NewStateHDir.c_str());
+        if (!m_Database)
+        {
+            return false;
+        }
+
+        return true;
+    }
+
+    std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override
+    {
+        std::unique_ptr<DeepSpeechV1TestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
+        if (!testCaseData)
+        {
+            return nullptr;
+        }
+
+        return std::make_unique<DeepSpeechV1TestCase<Model>>(*m_Model, testCaseId, *testCaseData);
+    }
+
+private:
+    typename Model::CommandLineOptions m_ModelCommandLineOptions;
+    std::function<std::unique_ptr<Model>(typename Model::CommandLineOptions)> m_ConstructModel;
+    std::unique_ptr<Model> m_Model;
+
+    std::string m_InputSeqDir;
+    std::string m_PrevStateCDir;
+    std::string m_PrevStateHDir;
+    std::string m_LogitsDir;
+    std::string m_NewStateCDir;
+    std::string m_NewStateHDir;
+
+    std::unique_ptr<DeepSpeechV1Database> m_Database;
+};
+
+} // anonymous namespace
diff --git a/tests/LstmCommon.hpp b/tests/LstmCommon.hpp
new file mode 100755
index 0000000..31c4d04
--- /dev/null
+++ b/tests/LstmCommon.hpp
@@ -0,0 +1,30 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#pragma once
+
+#include <string>
+#include <utility>
+
+namespace
+{
+
+struct LstmInput
+{
+    LstmInput(const std::vector<float>& inputSeq,
+              const std::vector<float>& stateC,
+              const std::vector<float>& stateH)
+            : m_InputSeq(inputSeq)
+            , m_StateC(stateC)
+            , m_StateH(stateH)
+    {}
+
+    std::vector<float>        m_InputSeq;
+    std::vector<float>        m_StateC;
+    std::vector<float>        m_StateH;
+};
+
+using LstmInputs = std::pair<std::string, std::vector<LstmInput>>;
+
+} // anonymous namespace
\ No newline at end of file