blob: 5e721c3cb94f12cbd8677d2dd1337ae540ff8855 [file] [log] [blame]
Jim Flynne571d332019-04-15 14:34:17 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "LstmCommon.hpp"
8
9#include <memory>
10#include <string>
11#include <vector>
12
13#include <armnn/TypesUtils.hpp>
Matthew Sloyan80c6b142020-09-08 12:00:32 +010014#include <armnn/utility/NumericCast.hpp>
Jim Flynne571d332019-04-15 14:34:17 +010015
16#include <array>
17#include <string>
18
19#include "InferenceTestImage.hpp"
20
21namespace
22{
23
24template<typename T, typename TParseElementFunc>
25std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc, const char * chars = "\t ,:")
26{
27 std::vector<T> result;
28 // Processes line-by-line.
29 std::string line;
30 while (std::getline(stream, line))
31 {
David Monahana8837bf2020-04-16 10:01:56 +010032 std::vector<std::string> tokens = armnn::stringUtils::StringTokenizer(line, chars);
Jim Flynne571d332019-04-15 14:34:17 +010033 for (const std::string& token : tokens)
34 {
35 if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
36 {
37 try
38 {
39 result.push_back(parseElementFunc(token));
40 }
41 catch (const std::exception&)
42 {
Derek Lamberti08446972019-11-26 16:38:31 +000043 ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored.";
Jim Flynne571d332019-04-15 14:34:17 +010044 }
45 }
46 }
47 }
48
49 return result;
50}
51
52template<armnn::DataType NonQuantizedType>
53auto ParseDataArray(std::istream & stream);
54
55template<armnn::DataType QuantizedType>
56auto ParseDataArray(std::istream& stream,
57 const float& quantizationScale,
58 const int32_t& quantizationOffset);
59
Jim Flynnc2ebc632019-04-17 10:16:58 +010060// NOTE: declaring the template specialisations inline to prevent them
61// being flagged as unused functions when -Werror=unused-function is in effect
Jim Flynne571d332019-04-15 14:34:17 +010062template<>
Jim Flynnc2ebc632019-04-17 10:16:58 +010063inline auto ParseDataArray<armnn::DataType::Float32>(std::istream & stream)
Jim Flynne571d332019-04-15 14:34:17 +010064{
65 return ParseArrayImpl<float>(stream, [](const std::string& s) { return std::stof(s); });
66}
67
68template<>
Jim Flynnc2ebc632019-04-17 10:16:58 +010069inline auto ParseDataArray<armnn::DataType::Signed32>(std::istream & stream)
Jim Flynne571d332019-04-15 14:34:17 +010070{
71 return ParseArrayImpl<int>(stream, [](const std::string & s) { return std::stoi(s); });
72}
73
74template<>
Derek Lambertif90c56d2020-01-10 17:14:08 +000075inline auto ParseDataArray<armnn::DataType::QAsymmU8>(std::istream& stream,
Jim Flynne571d332019-04-15 14:34:17 +010076 const float& quantizationScale,
77 const int32_t& quantizationOffset)
78{
79 return ParseArrayImpl<uint8_t>(stream,
80 [&quantizationScale, &quantizationOffset](const std::string & s)
81 {
Matthew Sloyan80c6b142020-09-08 12:00:32 +010082 return armnn::numeric_cast<uint8_t>(
Finn Williamsbadcc3f2020-05-22 14:28:15 +010083 armnn::Quantize<uint8_t>(std::stof(s),
Jim Flynne571d332019-04-15 14:34:17 +010084 quantizationScale,
85 quantizationOffset));
86 });
87}
88
89struct DeepSpeechV1TestCaseData
90{
91 DeepSpeechV1TestCaseData(
92 const LstmInput& inputData,
93 const LstmInput& expectedOutputData)
94 : m_InputData(inputData)
95 , m_ExpectedOutputData(expectedOutputData)
96 {}
97
98 LstmInput m_InputData;
99 LstmInput m_ExpectedOutputData;
100};
101
102class DeepSpeechV1Database
103{
104public:
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100105 explicit DeepSpeechV1Database(const std::string& inputSeqDir, const std::string& prevStateHDir,
106 const std::string& prevStateCDir, const std::string& logitsDir,
107 const std::string& newStateHDir, const std::string& newStateCDir);
Jim Flynne571d332019-04-15 14:34:17 +0100108
109 std::unique_ptr<DeepSpeechV1TestCaseData> GetTestCaseData(unsigned int testCaseId);
110
111private:
112 std::string m_InputSeqDir;
Jim Flynne571d332019-04-15 14:34:17 +0100113 std::string m_PrevStateHDir;
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100114 std::string m_PrevStateCDir;
Jim Flynne571d332019-04-15 14:34:17 +0100115 std::string m_LogitsDir;
Jim Flynne571d332019-04-15 14:34:17 +0100116 std::string m_NewStateHDir;
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100117 std::string m_NewStateCDir;
Jim Flynne571d332019-04-15 14:34:17 +0100118};
119
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100120DeepSpeechV1Database::DeepSpeechV1Database(const std::string& inputSeqDir, const std::string& prevStateHDir,
121 const std::string& prevStateCDir, const std::string& logitsDir,
122 const std::string& newStateHDir, const std::string& newStateCDir)
Jim Flynne571d332019-04-15 14:34:17 +0100123 : m_InputSeqDir(inputSeqDir)
Jim Flynne571d332019-04-15 14:34:17 +0100124 , m_PrevStateHDir(prevStateHDir)
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100125 , m_PrevStateCDir(prevStateCDir)
Jim Flynne571d332019-04-15 14:34:17 +0100126 , m_LogitsDir(logitsDir)
Jim Flynne571d332019-04-15 14:34:17 +0100127 , m_NewStateHDir(newStateHDir)
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100128 , m_NewStateCDir(newStateCDir)
Jim Flynne571d332019-04-15 14:34:17 +0100129{}
130
131std::unique_ptr<DeepSpeechV1TestCaseData> DeepSpeechV1Database::GetTestCaseData(unsigned int testCaseId)
132{
133 // Load test case input
134 const std::string inputSeqPath = m_InputSeqDir + "input_node_0_flat.txt";
135 const std::string prevStateCPath = m_PrevStateCDir + "previous_state_c_0.txt";
136 const std::string prevStateHPath = m_PrevStateHDir + "previous_state_h_0.txt";
137
138 std::vector<float> inputSeqData;
139 std::vector<float> prevStateCData;
140 std::vector<float> prevStateHData;
141
142 std::ifstream inputSeqFile(inputSeqPath);
143 std::ifstream prevStateCTensorFile(prevStateCPath);
144 std::ifstream prevStateHTensorFile(prevStateHPath);
145
146 try
147 {
148 inputSeqData = ParseDataArray<armnn::DataType::Float32>(inputSeqFile);
149 prevStateCData = ParseDataArray<armnn::DataType::Float32>(prevStateCTensorFile);
150 prevStateHData = ParseDataArray<armnn::DataType::Float32>(prevStateHTensorFile);
151 }
152 catch (const InferenceTestImageException& e)
153 {
Derek Lamberti08446972019-11-26 16:38:31 +0000154 ARMNN_LOG(fatal) << "Failed to load image for test case " << testCaseId << ". Error: " << e.what();
Jim Flynne571d332019-04-15 14:34:17 +0100155 return nullptr;
156 }
157
158 // Prepare test case expected output
159 const std::string logitsPath = m_LogitsDir + "logits.txt";
160 const std::string newStateCPath = m_NewStateCDir + "new_state_c.txt";
161 const std::string newStateHPath = m_NewStateHDir + "new_state_h.txt";
162
163 std::vector<float> logitsData;
164 std::vector<float> expectedNewStateCData;
165 std::vector<float> expectedNewStateHData;
166
167 std::ifstream logitsTensorFile(logitsPath);
168 std::ifstream newStateCTensorFile(newStateCPath);
169 std::ifstream newStateHTensorFile(newStateHPath);
170
171 try
172 {
173 logitsData = ParseDataArray<armnn::DataType::Float32>(logitsTensorFile);
174 expectedNewStateCData = ParseDataArray<armnn::DataType::Float32>(newStateCTensorFile);
175 expectedNewStateHData = ParseDataArray<armnn::DataType::Float32>(newStateHTensorFile);
176 }
177 catch (const InferenceTestImageException& e)
178 {
Derek Lamberti08446972019-11-26 16:38:31 +0000179 ARMNN_LOG(fatal) << "Failed to load image for test case " << testCaseId << ". Error: " << e.what();
Jim Flynne571d332019-04-15 14:34:17 +0100180 return nullptr;
181 }
182
183 // use the struct for representing input and output data
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100184 LstmInput inputDataSingleTest(inputSeqData, prevStateHData, prevStateCData);
Jim Flynne571d332019-04-15 14:34:17 +0100185
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100186 LstmInput expectedOutputsSingleTest(logitsData, expectedNewStateHData, expectedNewStateCData);
Jim Flynne571d332019-04-15 14:34:17 +0100187
188 return std::make_unique<DeepSpeechV1TestCaseData>(inputDataSingleTest, expectedOutputsSingleTest);
189}
190
191} // anonymous namespace