blob: fa7dda0d768d2f86dea631bfb057b02787dce7f1 [file] [log] [blame]
Jim Flynne571d332019-04-15 14:34:17 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "LstmCommon.hpp"
8
9#include <memory>
10#include <string>
11#include <vector>
12
13#include <armnn/TypesUtils.hpp>
Jim Flynne571d332019-04-15 14:34:17 +010014
Jim Flynne571d332019-04-15 14:34:17 +010015#include <boost/numeric/conversion/cast.hpp>
16
17#include <array>
18#include <string>
19
20#include "InferenceTestImage.hpp"
21
22namespace
23{
24
25template<typename T, typename TParseElementFunc>
26std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc, const char * chars = "\t ,:")
27{
28 std::vector<T> result;
29 // Processes line-by-line.
30 std::string line;
31 while (std::getline(stream, line))
32 {
David Monahana8837bf2020-04-16 10:01:56 +010033 std::vector<std::string> tokens = armnn::stringUtils::StringTokenizer(line, chars);
Jim Flynne571d332019-04-15 14:34:17 +010034 for (const std::string& token : tokens)
35 {
36 if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
37 {
38 try
39 {
40 result.push_back(parseElementFunc(token));
41 }
42 catch (const std::exception&)
43 {
Derek Lamberti08446972019-11-26 16:38:31 +000044 ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored.";
Jim Flynne571d332019-04-15 14:34:17 +010045 }
46 }
47 }
48 }
49
50 return result;
51}
52
53template<armnn::DataType NonQuantizedType>
54auto ParseDataArray(std::istream & stream);
55
56template<armnn::DataType QuantizedType>
57auto ParseDataArray(std::istream& stream,
58 const float& quantizationScale,
59 const int32_t& quantizationOffset);
60
Jim Flynnc2ebc632019-04-17 10:16:58 +010061// NOTE: declaring the template specialisations inline to prevent them
62// being flagged as unused functions when -Werror=unused-function is in effect
Jim Flynne571d332019-04-15 14:34:17 +010063template<>
Jim Flynnc2ebc632019-04-17 10:16:58 +010064inline auto ParseDataArray<armnn::DataType::Float32>(std::istream & stream)
Jim Flynne571d332019-04-15 14:34:17 +010065{
66 return ParseArrayImpl<float>(stream, [](const std::string& s) { return std::stof(s); });
67}
68
69template<>
Jim Flynnc2ebc632019-04-17 10:16:58 +010070inline auto ParseDataArray<armnn::DataType::Signed32>(std::istream & stream)
Jim Flynne571d332019-04-15 14:34:17 +010071{
72 return ParseArrayImpl<int>(stream, [](const std::string & s) { return std::stoi(s); });
73}
74
75template<>
Derek Lambertif90c56d2020-01-10 17:14:08 +000076inline auto ParseDataArray<armnn::DataType::QAsymmU8>(std::istream& stream,
Jim Flynne571d332019-04-15 14:34:17 +010077 const float& quantizationScale,
78 const int32_t& quantizationOffset)
79{
80 return ParseArrayImpl<uint8_t>(stream,
81 [&quantizationScale, &quantizationOffset](const std::string & s)
82 {
83 return boost::numeric_cast<uint8_t>(
Finn Williamsbadcc3f2020-05-22 14:28:15 +010084 armnn::Quantize<uint8_t>(std::stof(s),
Jim Flynne571d332019-04-15 14:34:17 +010085 quantizationScale,
86 quantizationOffset));
87 });
88}
89
90struct DeepSpeechV1TestCaseData
91{
92 DeepSpeechV1TestCaseData(
93 const LstmInput& inputData,
94 const LstmInput& expectedOutputData)
95 : m_InputData(inputData)
96 , m_ExpectedOutputData(expectedOutputData)
97 {}
98
99 LstmInput m_InputData;
100 LstmInput m_ExpectedOutputData;
101};
102
103class DeepSpeechV1Database
104{
105public:
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100106 explicit DeepSpeechV1Database(const std::string& inputSeqDir, const std::string& prevStateHDir,
107 const std::string& prevStateCDir, const std::string& logitsDir,
108 const std::string& newStateHDir, const std::string& newStateCDir);
Jim Flynne571d332019-04-15 14:34:17 +0100109
110 std::unique_ptr<DeepSpeechV1TestCaseData> GetTestCaseData(unsigned int testCaseId);
111
112private:
113 std::string m_InputSeqDir;
Jim Flynne571d332019-04-15 14:34:17 +0100114 std::string m_PrevStateHDir;
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100115 std::string m_PrevStateCDir;
Jim Flynne571d332019-04-15 14:34:17 +0100116 std::string m_LogitsDir;
Jim Flynne571d332019-04-15 14:34:17 +0100117 std::string m_NewStateHDir;
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100118 std::string m_NewStateCDir;
Jim Flynne571d332019-04-15 14:34:17 +0100119};
120
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100121DeepSpeechV1Database::DeepSpeechV1Database(const std::string& inputSeqDir, const std::string& prevStateHDir,
122 const std::string& prevStateCDir, const std::string& logitsDir,
123 const std::string& newStateHDir, const std::string& newStateCDir)
Jim Flynne571d332019-04-15 14:34:17 +0100124 : m_InputSeqDir(inputSeqDir)
Jim Flynne571d332019-04-15 14:34:17 +0100125 , m_PrevStateHDir(prevStateHDir)
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100126 , m_PrevStateCDir(prevStateCDir)
Jim Flynne571d332019-04-15 14:34:17 +0100127 , m_LogitsDir(logitsDir)
Jim Flynne571d332019-04-15 14:34:17 +0100128 , m_NewStateHDir(newStateHDir)
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100129 , m_NewStateCDir(newStateCDir)
Jim Flynne571d332019-04-15 14:34:17 +0100130{}
131
132std::unique_ptr<DeepSpeechV1TestCaseData> DeepSpeechV1Database::GetTestCaseData(unsigned int testCaseId)
133{
134 // Load test case input
135 const std::string inputSeqPath = m_InputSeqDir + "input_node_0_flat.txt";
136 const std::string prevStateCPath = m_PrevStateCDir + "previous_state_c_0.txt";
137 const std::string prevStateHPath = m_PrevStateHDir + "previous_state_h_0.txt";
138
139 std::vector<float> inputSeqData;
140 std::vector<float> prevStateCData;
141 std::vector<float> prevStateHData;
142
143 std::ifstream inputSeqFile(inputSeqPath);
144 std::ifstream prevStateCTensorFile(prevStateCPath);
145 std::ifstream prevStateHTensorFile(prevStateHPath);
146
147 try
148 {
149 inputSeqData = ParseDataArray<armnn::DataType::Float32>(inputSeqFile);
150 prevStateCData = ParseDataArray<armnn::DataType::Float32>(prevStateCTensorFile);
151 prevStateHData = ParseDataArray<armnn::DataType::Float32>(prevStateHTensorFile);
152 }
153 catch (const InferenceTestImageException& e)
154 {
Derek Lamberti08446972019-11-26 16:38:31 +0000155 ARMNN_LOG(fatal) << "Failed to load image for test case " << testCaseId << ". Error: " << e.what();
Jim Flynne571d332019-04-15 14:34:17 +0100156 return nullptr;
157 }
158
159 // Prepare test case expected output
160 const std::string logitsPath = m_LogitsDir + "logits.txt";
161 const std::string newStateCPath = m_NewStateCDir + "new_state_c.txt";
162 const std::string newStateHPath = m_NewStateHDir + "new_state_h.txt";
163
164 std::vector<float> logitsData;
165 std::vector<float> expectedNewStateCData;
166 std::vector<float> expectedNewStateHData;
167
168 std::ifstream logitsTensorFile(logitsPath);
169 std::ifstream newStateCTensorFile(newStateCPath);
170 std::ifstream newStateHTensorFile(newStateHPath);
171
172 try
173 {
174 logitsData = ParseDataArray<armnn::DataType::Float32>(logitsTensorFile);
175 expectedNewStateCData = ParseDataArray<armnn::DataType::Float32>(newStateCTensorFile);
176 expectedNewStateHData = ParseDataArray<armnn::DataType::Float32>(newStateHTensorFile);
177 }
178 catch (const InferenceTestImageException& e)
179 {
Derek Lamberti08446972019-11-26 16:38:31 +0000180 ARMNN_LOG(fatal) << "Failed to load image for test case " << testCaseId << ". Error: " << e.what();
Jim Flynne571d332019-04-15 14:34:17 +0100181 return nullptr;
182 }
183
184 // use the struct for representing input and output data
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100185 LstmInput inputDataSingleTest(inputSeqData, prevStateHData, prevStateCData);
Jim Flynne571d332019-04-15 14:34:17 +0100186
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100187 LstmInput expectedOutputsSingleTest(logitsData, expectedNewStateHData, expectedNewStateCData);
Jim Flynne571d332019-04-15 14:34:17 +0100188
189 return std::make_unique<DeepSpeechV1TestCaseData>(inputDataSingleTest, expectedOutputsSingleTest);
190}
191
192} // anonymous namespace