blob: a690e3fecec614de4835bcdae5c2945572084c0b [file] [log] [blame]
Jim Flynne571d332019-04-15 14:34:17 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "LstmCommon.hpp"
8
9#include <memory>
10#include <string>
11#include <vector>
12
13#include <armnn/TypesUtils.hpp>
Jim Flynne571d332019-04-15 14:34:17 +010014
Jim Flynne571d332019-04-15 14:34:17 +010015#include <boost/numeric/conversion/cast.hpp>
16
17#include <array>
18#include <string>
19
20#include "InferenceTestImage.hpp"
21
22namespace
23{
24
25template<typename T, typename TParseElementFunc>
26std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc, const char * chars = "\t ,:")
27{
28 std::vector<T> result;
29 // Processes line-by-line.
30 std::string line;
31 while (std::getline(stream, line))
32 {
33 std::vector<std::string> tokens;
34 try
35 {
36 // Coverity fix: boost::split() may throw an exception of type boost::bad_function_call.
37 boost::split(tokens, line, boost::algorithm::is_any_of(chars), boost::token_compress_on);
38 }
39 catch (const std::exception& e)
40 {
Derek Lamberti08446972019-11-26 16:38:31 +000041 ARMNN_LOG(error) << "An error occurred when splitting tokens: " << e.what();
Jim Flynne571d332019-04-15 14:34:17 +010042 continue;
43 }
44 for (const std::string& token : tokens)
45 {
46 if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
47 {
48 try
49 {
50 result.push_back(parseElementFunc(token));
51 }
52 catch (const std::exception&)
53 {
Derek Lamberti08446972019-11-26 16:38:31 +000054 ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored.";
Jim Flynne571d332019-04-15 14:34:17 +010055 }
56 }
57 }
58 }
59
60 return result;
61}
62
63template<armnn::DataType NonQuantizedType>
64auto ParseDataArray(std::istream & stream);
65
66template<armnn::DataType QuantizedType>
67auto ParseDataArray(std::istream& stream,
68 const float& quantizationScale,
69 const int32_t& quantizationOffset);
70
Jim Flynnc2ebc632019-04-17 10:16:58 +010071// NOTE: declaring the template specialisations inline to prevent them
72// being flagged as unused functions when -Werror=unused-function is in effect
Jim Flynne571d332019-04-15 14:34:17 +010073template<>
Jim Flynnc2ebc632019-04-17 10:16:58 +010074inline auto ParseDataArray<armnn::DataType::Float32>(std::istream & stream)
Jim Flynne571d332019-04-15 14:34:17 +010075{
76 return ParseArrayImpl<float>(stream, [](const std::string& s) { return std::stof(s); });
77}
78
79template<>
Jim Flynnc2ebc632019-04-17 10:16:58 +010080inline auto ParseDataArray<armnn::DataType::Signed32>(std::istream & stream)
Jim Flynne571d332019-04-15 14:34:17 +010081{
82 return ParseArrayImpl<int>(stream, [](const std::string & s) { return std::stoi(s); });
83}
84
85template<>
Jim Flynnc2ebc632019-04-17 10:16:58 +010086inline auto ParseDataArray<armnn::DataType::QuantisedAsymm8>(std::istream& stream,
Jim Flynne571d332019-04-15 14:34:17 +010087 const float& quantizationScale,
88 const int32_t& quantizationOffset)
89{
90 return ParseArrayImpl<uint8_t>(stream,
91 [&quantizationScale, &quantizationOffset](const std::string & s)
92 {
93 return boost::numeric_cast<uint8_t>(
94 armnn::Quantize<u_int8_t>(std::stof(s),
95 quantizationScale,
96 quantizationOffset));
97 });
98}
99
100struct DeepSpeechV1TestCaseData
101{
102 DeepSpeechV1TestCaseData(
103 const LstmInput& inputData,
104 const LstmInput& expectedOutputData)
105 : m_InputData(inputData)
106 , m_ExpectedOutputData(expectedOutputData)
107 {}
108
109 LstmInput m_InputData;
110 LstmInput m_ExpectedOutputData;
111};
112
113class DeepSpeechV1Database
114{
115public:
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100116 explicit DeepSpeechV1Database(const std::string& inputSeqDir, const std::string& prevStateHDir,
117 const std::string& prevStateCDir, const std::string& logitsDir,
118 const std::string& newStateHDir, const std::string& newStateCDir);
Jim Flynne571d332019-04-15 14:34:17 +0100119
120 std::unique_ptr<DeepSpeechV1TestCaseData> GetTestCaseData(unsigned int testCaseId);
121
122private:
123 std::string m_InputSeqDir;
Jim Flynne571d332019-04-15 14:34:17 +0100124 std::string m_PrevStateHDir;
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100125 std::string m_PrevStateCDir;
Jim Flynne571d332019-04-15 14:34:17 +0100126 std::string m_LogitsDir;
Jim Flynne571d332019-04-15 14:34:17 +0100127 std::string m_NewStateHDir;
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100128 std::string m_NewStateCDir;
Jim Flynne571d332019-04-15 14:34:17 +0100129};
130
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100131DeepSpeechV1Database::DeepSpeechV1Database(const std::string& inputSeqDir, const std::string& prevStateHDir,
132 const std::string& prevStateCDir, const std::string& logitsDir,
133 const std::string& newStateHDir, const std::string& newStateCDir)
Jim Flynne571d332019-04-15 14:34:17 +0100134 : m_InputSeqDir(inputSeqDir)
Jim Flynne571d332019-04-15 14:34:17 +0100135 , m_PrevStateHDir(prevStateHDir)
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100136 , m_PrevStateCDir(prevStateCDir)
Jim Flynne571d332019-04-15 14:34:17 +0100137 , m_LogitsDir(logitsDir)
Jim Flynne571d332019-04-15 14:34:17 +0100138 , m_NewStateHDir(newStateHDir)
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100139 , m_NewStateCDir(newStateCDir)
Jim Flynne571d332019-04-15 14:34:17 +0100140{}
141
142std::unique_ptr<DeepSpeechV1TestCaseData> DeepSpeechV1Database::GetTestCaseData(unsigned int testCaseId)
143{
144 // Load test case input
145 const std::string inputSeqPath = m_InputSeqDir + "input_node_0_flat.txt";
146 const std::string prevStateCPath = m_PrevStateCDir + "previous_state_c_0.txt";
147 const std::string prevStateHPath = m_PrevStateHDir + "previous_state_h_0.txt";
148
149 std::vector<float> inputSeqData;
150 std::vector<float> prevStateCData;
151 std::vector<float> prevStateHData;
152
153 std::ifstream inputSeqFile(inputSeqPath);
154 std::ifstream prevStateCTensorFile(prevStateCPath);
155 std::ifstream prevStateHTensorFile(prevStateHPath);
156
157 try
158 {
159 inputSeqData = ParseDataArray<armnn::DataType::Float32>(inputSeqFile);
160 prevStateCData = ParseDataArray<armnn::DataType::Float32>(prevStateCTensorFile);
161 prevStateHData = ParseDataArray<armnn::DataType::Float32>(prevStateHTensorFile);
162 }
163 catch (const InferenceTestImageException& e)
164 {
Derek Lamberti08446972019-11-26 16:38:31 +0000165 ARMNN_LOG(fatal) << "Failed to load image for test case " << testCaseId << ". Error: " << e.what();
Jim Flynne571d332019-04-15 14:34:17 +0100166 return nullptr;
167 }
168
169 // Prepare test case expected output
170 const std::string logitsPath = m_LogitsDir + "logits.txt";
171 const std::string newStateCPath = m_NewStateCDir + "new_state_c.txt";
172 const std::string newStateHPath = m_NewStateHDir + "new_state_h.txt";
173
174 std::vector<float> logitsData;
175 std::vector<float> expectedNewStateCData;
176 std::vector<float> expectedNewStateHData;
177
178 std::ifstream logitsTensorFile(logitsPath);
179 std::ifstream newStateCTensorFile(newStateCPath);
180 std::ifstream newStateHTensorFile(newStateHPath);
181
182 try
183 {
184 logitsData = ParseDataArray<armnn::DataType::Float32>(logitsTensorFile);
185 expectedNewStateCData = ParseDataArray<armnn::DataType::Float32>(newStateCTensorFile);
186 expectedNewStateHData = ParseDataArray<armnn::DataType::Float32>(newStateHTensorFile);
187 }
188 catch (const InferenceTestImageException& e)
189 {
Derek Lamberti08446972019-11-26 16:38:31 +0000190 ARMNN_LOG(fatal) << "Failed to load image for test case " << testCaseId << ". Error: " << e.what();
Jim Flynne571d332019-04-15 14:34:17 +0100191 return nullptr;
192 }
193
194 // use the struct for representing input and output data
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100195 LstmInput inputDataSingleTest(inputSeqData, prevStateHData, prevStateCData);
Jim Flynne571d332019-04-15 14:34:17 +0100196
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100197 LstmInput expectedOutputsSingleTest(logitsData, expectedNewStateHData, expectedNewStateCData);
Jim Flynne571d332019-04-15 14:34:17 +0100198
199 return std::make_unique<DeepSpeechV1TestCaseData>(inputDataSingleTest, expectedOutputsSingleTest);
200}
201
202} // anonymous namespace