blob: 07b55d2ab823416786decc9a7fd4103341c3d2c9 [file] [log] [blame]
Jim Flynne571d332019-04-15 14:34:17 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "InferenceTest.hpp"
8#include "DeepSpeechV1Database.hpp"
9
Jan Eilers8eb25602020-03-09 12:13:48 +000010#include <armnn/utility/IgnoreUnused.hpp>
11
Jim Flynne571d332019-04-15 14:34:17 +010012#include <boost/assert.hpp>
Jim Flynne571d332019-04-15 14:34:17 +010013#include <boost/numeric/conversion/cast.hpp>
14#include <boost/test/tools/floating_point_comparison.hpp>
15
16#include <vector>
17
18namespace
19{
20
21template<typename Model>
22class DeepSpeechV1TestCase : public InferenceModelTestCase<Model>
23{
24public:
25 DeepSpeechV1TestCase(Model& model,
26 unsigned int testCaseId,
27 const DeepSpeechV1TestCaseData& testCaseData)
28 : InferenceModelTestCase<Model>(model,
29 testCaseId,
30 { testCaseData.m_InputData.m_InputSeq,
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +010031 testCaseData.m_InputData.m_StateH,
32 testCaseData.m_InputData.m_StateC},
Jim Flynne571d332019-04-15 14:34:17 +010033 { k_OutputSize1, k_OutputSize2, k_OutputSize3 })
34 , m_FloatComparer(boost::math::fpc::percent_tolerance(1.0f))
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +010035 , m_ExpectedOutputs({testCaseData.m_ExpectedOutputData.m_InputSeq, testCaseData.m_ExpectedOutputData.m_StateH,
36 testCaseData.m_ExpectedOutputData.m_StateC})
Jim Flynne571d332019-04-15 14:34:17 +010037 {}
38
39 TestCaseResult ProcessResult(const InferenceTestOptions& options) override
40 {
Jan Eilers8eb25602020-03-09 12:13:48 +000041 armnn::IgnoreUnused(options);
Jim Flynne571d332019-04-15 14:34:17 +010042 const std::vector<float>& output1 = boost::get<std::vector<float>>(this->GetOutputs()[0]); // logits
43 BOOST_ASSERT(output1.size() == k_OutputSize1);
44
45 const std::vector<float>& output2 = boost::get<std::vector<float>>(this->GetOutputs()[1]); // new_state_c
46 BOOST_ASSERT(output2.size() == k_OutputSize2);
47
48 const std::vector<float>& output3 = boost::get<std::vector<float>>(this->GetOutputs()[2]); // new_state_h
49 BOOST_ASSERT(output3.size() == k_OutputSize3);
50
51 // Check each output to see whether it is the expected value
52 for (unsigned int j = 0u; j < output1.size(); j++)
53 {
54 if(!m_FloatComparer(output1[j], m_ExpectedOutputs.m_InputSeq[j]))
55 {
Derek Lamberti08446972019-11-26 16:38:31 +000056 ARMNN_LOG(error) << "InputSeq for Lstm " << this->GetTestCaseId() <<
Jim Flynne571d332019-04-15 14:34:17 +010057 " is incorrect at" << j;
58 return TestCaseResult::Failed;
59 }
60 }
61
62 for (unsigned int j = 0u; j < output2.size(); j++)
63 {
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +010064 if(!m_FloatComparer(output2[j], m_ExpectedOutputs.m_StateH[j]))
Jim Flynne571d332019-04-15 14:34:17 +010065 {
Derek Lamberti08446972019-11-26 16:38:31 +000066 ARMNN_LOG(error) << "StateH for Lstm " << this->GetTestCaseId() <<
Jim Flynne571d332019-04-15 14:34:17 +010067 " is incorrect";
68 return TestCaseResult::Failed;
69 }
70 }
71
72 for (unsigned int j = 0u; j < output3.size(); j++)
73 {
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +010074 if(!m_FloatComparer(output3[j], m_ExpectedOutputs.m_StateC[j]))
Jim Flynne571d332019-04-15 14:34:17 +010075 {
Derek Lamberti08446972019-11-26 16:38:31 +000076 ARMNN_LOG(error) << "StateC for Lstm " << this->GetTestCaseId() <<
Jim Flynne571d332019-04-15 14:34:17 +010077 " is incorrect";
78 return TestCaseResult::Failed;
79 }
80 }
81 return TestCaseResult::Ok;
82 }
83
84private:
85
86 static constexpr unsigned int k_OutputSize1 = 464u;
87 static constexpr unsigned int k_OutputSize2 = 2048u;
88 static constexpr unsigned int k_OutputSize3 = 2048u;
89
90 boost::math::fpc::close_at_tolerance<float> m_FloatComparer;
91 LstmInput m_ExpectedOutputs;
92};
93
94template <typename Model>
95class DeepSpeechV1TestCaseProvider : public IInferenceTestCaseProvider
96{
97public:
98 template <typename TConstructModelCallable>
99 explicit DeepSpeechV1TestCaseProvider(TConstructModelCallable constructModel)
100 : m_ConstructModel(constructModel)
101 {}
102
103 virtual void AddCommandLineOptions(boost::program_options::options_description& options) override
104 {
105 namespace po = boost::program_options;
106
107 options.add_options()
108 ("input-seq-dir,s", po::value<std::string>(&m_InputSeqDir)->required(),
109 "Path to directory containing test data for m_InputSeq");
110 options.add_options()
Jim Flynne571d332019-04-15 14:34:17 +0100111 ("prev-state-h-dir,h", po::value<std::string>(&m_PrevStateHDir)->required(),
112 "Path to directory containing test data for m_PrevStateH");
113 options.add_options()
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100114 ("prev-state-c-dir,c", po::value<std::string>(&m_PrevStateCDir)->required(),
115 "Path to directory containing test data for m_PrevStateC");
116 options.add_options()
Jim Flynne571d332019-04-15 14:34:17 +0100117 ("logits-dir,l", po::value<std::string>(&m_LogitsDir)->required(),
118 "Path to directory containing test data for m_Logits");
119 options.add_options()
Jim Flynne571d332019-04-15 14:34:17 +0100120 ("new-state-h-dir,H", po::value<std::string>(&m_NewStateHDir)->required(),
121 "Path to directory containing test data for m_NewStateH");
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100122 options.add_options()
123 ("new-state-c-dir,C", po::value<std::string>(&m_NewStateCDir)->required(),
124 "Path to directory containing test data for m_NewStateC");
125
Jim Flynne571d332019-04-15 14:34:17 +0100126
127 Model::AddCommandLineOptions(options, m_ModelCommandLineOptions);
128 }
129
Jim Flynnc2ebc632019-04-17 10:16:58 +0100130 virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override
Jim Flynne571d332019-04-15 14:34:17 +0100131 {
132 if (!ValidateDirectory(m_InputSeqDir))
133 {
134 return false;
135 }
136
137 if (!ValidateDirectory(m_PrevStateCDir))
138 {
139 return false;
140 }
141
142 if (!ValidateDirectory(m_PrevStateHDir))
143 {
144 return false;
145 }
146
147 if (!ValidateDirectory(m_LogitsDir))
148 {
149 return false;
150 }
151
152 if (!ValidateDirectory(m_NewStateCDir))
153 {
154 return false;
155 }
156
157 if (!ValidateDirectory(m_NewStateHDir))
158 {
159 return false;
160 }
161
Jim Flynnc2ebc632019-04-17 10:16:58 +0100162 m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
Jim Flynne571d332019-04-15 14:34:17 +0100163 if (!m_Model)
164 {
165 return false;
166 }
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100167 m_Database = std::make_unique<DeepSpeechV1Database>(m_InputSeqDir.c_str(), m_PrevStateHDir.c_str(),
168 m_PrevStateCDir.c_str(), m_LogitsDir.c_str(),
169 m_NewStateHDir.c_str(), m_NewStateCDir.c_str());
Jim Flynne571d332019-04-15 14:34:17 +0100170 if (!m_Database)
171 {
172 return false;
173 }
174
175 return true;
176 }
177
178 std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override
179 {
180 std::unique_ptr<DeepSpeechV1TestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
181 if (!testCaseData)
182 {
183 return nullptr;
184 }
185
186 return std::make_unique<DeepSpeechV1TestCase<Model>>(*m_Model, testCaseId, *testCaseData);
187 }
188
189private:
190 typename Model::CommandLineOptions m_ModelCommandLineOptions;
Jim Flynnc2ebc632019-04-17 10:16:58 +0100191 std::function<std::unique_ptr<Model>(const InferenceTestOptions&,
192 typename Model::CommandLineOptions)> m_ConstructModel;
Jim Flynne571d332019-04-15 14:34:17 +0100193 std::unique_ptr<Model> m_Model;
194
195 std::string m_InputSeqDir;
196 std::string m_PrevStateCDir;
197 std::string m_PrevStateHDir;
198 std::string m_LogitsDir;
199 std::string m_NewStateCDir;
200 std::string m_NewStateHDir;
201
202 std::unique_ptr<DeepSpeechV1Database> m_Database;
203};
204
205} // anonymous namespace