blob: 037c810122076948451a5494e3367e62102ad930 [file] [log] [blame]
Jim Flynne571d332019-04-15 14:34:17 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "LstmCommon.hpp"
8
9#include <memory>
10#include <string>
11#include <vector>
12
13#include <armnn/TypesUtils.hpp>
14#include <backendsCommon/test/QuantizeHelper.hpp>
15
16#include <boost/log/trivial.hpp>
17#include <boost/numeric/conversion/cast.hpp>
18
19#include <array>
20#include <string>
21
22#include "InferenceTestImage.hpp"
23
24namespace
25{
26
27template<typename T, typename TParseElementFunc>
28std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc, const char * chars = "\t ,:")
29{
30 std::vector<T> result;
31 // Processes line-by-line.
32 std::string line;
33 while (std::getline(stream, line))
34 {
35 std::vector<std::string> tokens;
36 try
37 {
38 // Coverity fix: boost::split() may throw an exception of type boost::bad_function_call.
39 boost::split(tokens, line, boost::algorithm::is_any_of(chars), boost::token_compress_on);
40 }
41 catch (const std::exception& e)
42 {
43 BOOST_LOG_TRIVIAL(error) << "An error occurred when splitting tokens: " << e.what();
44 continue;
45 }
46 for (const std::string& token : tokens)
47 {
48 if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
49 {
50 try
51 {
52 result.push_back(parseElementFunc(token));
53 }
54 catch (const std::exception&)
55 {
56 BOOST_LOG_TRIVIAL(error) << "'" << token << "' is not a valid number. It has been ignored.";
57 }
58 }
59 }
60 }
61
62 return result;
63}
64
65template<armnn::DataType NonQuantizedType>
66auto ParseDataArray(std::istream & stream);
67
68template<armnn::DataType QuantizedType>
69auto ParseDataArray(std::istream& stream,
70 const float& quantizationScale,
71 const int32_t& quantizationOffset);
72
Jim Flynnc2ebc632019-04-17 10:16:58 +010073// NOTE: declaring the template specialisations inline to prevent them
74// being flagged as unused functions when -Werror=unused-function is in effect
Jim Flynne571d332019-04-15 14:34:17 +010075template<>
Jim Flynnc2ebc632019-04-17 10:16:58 +010076inline auto ParseDataArray<armnn::DataType::Float32>(std::istream & stream)
Jim Flynne571d332019-04-15 14:34:17 +010077{
78 return ParseArrayImpl<float>(stream, [](const std::string& s) { return std::stof(s); });
79}
80
81template<>
Jim Flynnc2ebc632019-04-17 10:16:58 +010082inline auto ParseDataArray<armnn::DataType::Signed32>(std::istream & stream)
Jim Flynne571d332019-04-15 14:34:17 +010083{
84 return ParseArrayImpl<int>(stream, [](const std::string & s) { return std::stoi(s); });
85}
86
87template<>
Jim Flynnc2ebc632019-04-17 10:16:58 +010088inline auto ParseDataArray<armnn::DataType::QuantisedAsymm8>(std::istream& stream,
Jim Flynne571d332019-04-15 14:34:17 +010089 const float& quantizationScale,
90 const int32_t& quantizationOffset)
91{
92 return ParseArrayImpl<uint8_t>(stream,
93 [&quantizationScale, &quantizationOffset](const std::string & s)
94 {
95 return boost::numeric_cast<uint8_t>(
96 armnn::Quantize<u_int8_t>(std::stof(s),
97 quantizationScale,
98 quantizationOffset));
99 });
100}
101
102struct DeepSpeechV1TestCaseData
103{
104 DeepSpeechV1TestCaseData(
105 const LstmInput& inputData,
106 const LstmInput& expectedOutputData)
107 : m_InputData(inputData)
108 , m_ExpectedOutputData(expectedOutputData)
109 {}
110
111 LstmInput m_InputData;
112 LstmInput m_ExpectedOutputData;
113};
114
115class DeepSpeechV1Database
116{
117public:
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100118 explicit DeepSpeechV1Database(const std::string& inputSeqDir, const std::string& prevStateHDir,
119 const std::string& prevStateCDir, const std::string& logitsDir,
120 const std::string& newStateHDir, const std::string& newStateCDir);
Jim Flynne571d332019-04-15 14:34:17 +0100121
122 std::unique_ptr<DeepSpeechV1TestCaseData> GetTestCaseData(unsigned int testCaseId);
123
124private:
125 std::string m_InputSeqDir;
Jim Flynne571d332019-04-15 14:34:17 +0100126 std::string m_PrevStateHDir;
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100127 std::string m_PrevStateCDir;
Jim Flynne571d332019-04-15 14:34:17 +0100128 std::string m_LogitsDir;
Jim Flynne571d332019-04-15 14:34:17 +0100129 std::string m_NewStateHDir;
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100130 std::string m_NewStateCDir;
Jim Flynne571d332019-04-15 14:34:17 +0100131};
132
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100133DeepSpeechV1Database::DeepSpeechV1Database(const std::string& inputSeqDir, const std::string& prevStateHDir,
134 const std::string& prevStateCDir, const std::string& logitsDir,
135 const std::string& newStateHDir, const std::string& newStateCDir)
Jim Flynne571d332019-04-15 14:34:17 +0100136 : m_InputSeqDir(inputSeqDir)
Jim Flynne571d332019-04-15 14:34:17 +0100137 , m_PrevStateHDir(prevStateHDir)
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100138 , m_PrevStateCDir(prevStateCDir)
Jim Flynne571d332019-04-15 14:34:17 +0100139 , m_LogitsDir(logitsDir)
Jim Flynne571d332019-04-15 14:34:17 +0100140 , m_NewStateHDir(newStateHDir)
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100141 , m_NewStateCDir(newStateCDir)
Jim Flynne571d332019-04-15 14:34:17 +0100142{}
143
144std::unique_ptr<DeepSpeechV1TestCaseData> DeepSpeechV1Database::GetTestCaseData(unsigned int testCaseId)
145{
146 // Load test case input
147 const std::string inputSeqPath = m_InputSeqDir + "input_node_0_flat.txt";
148 const std::string prevStateCPath = m_PrevStateCDir + "previous_state_c_0.txt";
149 const std::string prevStateHPath = m_PrevStateHDir + "previous_state_h_0.txt";
150
151 std::vector<float> inputSeqData;
152 std::vector<float> prevStateCData;
153 std::vector<float> prevStateHData;
154
155 std::ifstream inputSeqFile(inputSeqPath);
156 std::ifstream prevStateCTensorFile(prevStateCPath);
157 std::ifstream prevStateHTensorFile(prevStateHPath);
158
159 try
160 {
161 inputSeqData = ParseDataArray<armnn::DataType::Float32>(inputSeqFile);
162 prevStateCData = ParseDataArray<armnn::DataType::Float32>(prevStateCTensorFile);
163 prevStateHData = ParseDataArray<armnn::DataType::Float32>(prevStateHTensorFile);
164 }
165 catch (const InferenceTestImageException& e)
166 {
167 BOOST_LOG_TRIVIAL(fatal) << "Failed to load image for test case " << testCaseId << ". Error: " << e.what();
168 return nullptr;
169 }
170
171 // Prepare test case expected output
172 const std::string logitsPath = m_LogitsDir + "logits.txt";
173 const std::string newStateCPath = m_NewStateCDir + "new_state_c.txt";
174 const std::string newStateHPath = m_NewStateHDir + "new_state_h.txt";
175
176 std::vector<float> logitsData;
177 std::vector<float> expectedNewStateCData;
178 std::vector<float> expectedNewStateHData;
179
180 std::ifstream logitsTensorFile(logitsPath);
181 std::ifstream newStateCTensorFile(newStateCPath);
182 std::ifstream newStateHTensorFile(newStateHPath);
183
184 try
185 {
186 logitsData = ParseDataArray<armnn::DataType::Float32>(logitsTensorFile);
187 expectedNewStateCData = ParseDataArray<armnn::DataType::Float32>(newStateCTensorFile);
188 expectedNewStateHData = ParseDataArray<armnn::DataType::Float32>(newStateHTensorFile);
189 }
190 catch (const InferenceTestImageException& e)
191 {
192 BOOST_LOG_TRIVIAL(fatal) << "Failed to load image for test case " << testCaseId << ". Error: " << e.what();
193 return nullptr;
194 }
195
196 // use the struct for representing input and output data
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100197 LstmInput inputDataSingleTest(inputSeqData, prevStateHData, prevStateCData);
Jim Flynne571d332019-04-15 14:34:17 +0100198
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100199 LstmInput expectedOutputsSingleTest(logitsData, expectedNewStateHData, expectedNewStateCData);
Jim Flynne571d332019-04-15 14:34:17 +0100200
201 return std::make_unique<DeepSpeechV1TestCaseData>(inputDataSingleTest, expectedOutputsSingleTest);
202}
203
204} // anonymous namespace
205