blob: 3195d2bb1421f5de668e403005a384a01f6d53c7 [file] [log] [blame]
Jim Flynne571d332019-04-15 14:34:17 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "InferenceTest.hpp"
8#include "DeepSpeechV1Database.hpp"
9
10#include <boost/assert.hpp>
11#include <boost/log/trivial.hpp>
12#include <boost/numeric/conversion/cast.hpp>
13#include <boost/test/tools/floating_point_comparison.hpp>
14
15#include <vector>
16
17namespace
18{
19
20template<typename Model>
21class DeepSpeechV1TestCase : public InferenceModelTestCase<Model>
22{
23public:
24 DeepSpeechV1TestCase(Model& model,
25 unsigned int testCaseId,
26 const DeepSpeechV1TestCaseData& testCaseData)
27 : InferenceModelTestCase<Model>(model,
28 testCaseId,
29 { testCaseData.m_InputData.m_InputSeq,
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +010030 testCaseData.m_InputData.m_StateH,
31 testCaseData.m_InputData.m_StateC},
Jim Flynne571d332019-04-15 14:34:17 +010032 { k_OutputSize1, k_OutputSize2, k_OutputSize3 })
33 , m_FloatComparer(boost::math::fpc::percent_tolerance(1.0f))
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +010034 , m_ExpectedOutputs({testCaseData.m_ExpectedOutputData.m_InputSeq, testCaseData.m_ExpectedOutputData.m_StateH,
35 testCaseData.m_ExpectedOutputData.m_StateC})
Jim Flynne571d332019-04-15 14:34:17 +010036 {}
37
38 TestCaseResult ProcessResult(const InferenceTestOptions& options) override
39 {
40 const std::vector<float>& output1 = boost::get<std::vector<float>>(this->GetOutputs()[0]); // logits
41 BOOST_ASSERT(output1.size() == k_OutputSize1);
42
43 const std::vector<float>& output2 = boost::get<std::vector<float>>(this->GetOutputs()[1]); // new_state_c
44 BOOST_ASSERT(output2.size() == k_OutputSize2);
45
46 const std::vector<float>& output3 = boost::get<std::vector<float>>(this->GetOutputs()[2]); // new_state_h
47 BOOST_ASSERT(output3.size() == k_OutputSize3);
48
49 // Check each output to see whether it is the expected value
50 for (unsigned int j = 0u; j < output1.size(); j++)
51 {
52 if(!m_FloatComparer(output1[j], m_ExpectedOutputs.m_InputSeq[j]))
53 {
54 BOOST_LOG_TRIVIAL(error) << "InputSeq for Lstm " << this->GetTestCaseId() <<
55 " is incorrect at" << j;
56 return TestCaseResult::Failed;
57 }
58 }
59
60 for (unsigned int j = 0u; j < output2.size(); j++)
61 {
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +010062 if(!m_FloatComparer(output2[j], m_ExpectedOutputs.m_StateH[j]))
Jim Flynne571d332019-04-15 14:34:17 +010063 {
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +010064 BOOST_LOG_TRIVIAL(error) << "StateH for Lstm " << this->GetTestCaseId() <<
Jim Flynne571d332019-04-15 14:34:17 +010065 " is incorrect";
66 return TestCaseResult::Failed;
67 }
68 }
69
70 for (unsigned int j = 0u; j < output3.size(); j++)
71 {
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +010072 if(!m_FloatComparer(output3[j], m_ExpectedOutputs.m_StateC[j]))
Jim Flynne571d332019-04-15 14:34:17 +010073 {
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +010074 BOOST_LOG_TRIVIAL(error) << "StateC for Lstm " << this->GetTestCaseId() <<
Jim Flynne571d332019-04-15 14:34:17 +010075 " is incorrect";
76 return TestCaseResult::Failed;
77 }
78 }
79 return TestCaseResult::Ok;
80 }
81
82private:
83
84 static constexpr unsigned int k_OutputSize1 = 464u;
85 static constexpr unsigned int k_OutputSize2 = 2048u;
86 static constexpr unsigned int k_OutputSize3 = 2048u;
87
88 boost::math::fpc::close_at_tolerance<float> m_FloatComparer;
89 LstmInput m_ExpectedOutputs;
90};
91
92template <typename Model>
93class DeepSpeechV1TestCaseProvider : public IInferenceTestCaseProvider
94{
95public:
96 template <typename TConstructModelCallable>
97 explicit DeepSpeechV1TestCaseProvider(TConstructModelCallable constructModel)
98 : m_ConstructModel(constructModel)
99 {}
100
101 virtual void AddCommandLineOptions(boost::program_options::options_description& options) override
102 {
103 namespace po = boost::program_options;
104
105 options.add_options()
106 ("input-seq-dir,s", po::value<std::string>(&m_InputSeqDir)->required(),
107 "Path to directory containing test data for m_InputSeq");
108 options.add_options()
Jim Flynne571d332019-04-15 14:34:17 +0100109 ("prev-state-h-dir,h", po::value<std::string>(&m_PrevStateHDir)->required(),
110 "Path to directory containing test data for m_PrevStateH");
111 options.add_options()
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100112 ("prev-state-c-dir,c", po::value<std::string>(&m_PrevStateCDir)->required(),
113 "Path to directory containing test data for m_PrevStateC");
114 options.add_options()
Jim Flynne571d332019-04-15 14:34:17 +0100115 ("logits-dir,l", po::value<std::string>(&m_LogitsDir)->required(),
116 "Path to directory containing test data for m_Logits");
117 options.add_options()
Jim Flynne571d332019-04-15 14:34:17 +0100118 ("new-state-h-dir,H", po::value<std::string>(&m_NewStateHDir)->required(),
119 "Path to directory containing test data for m_NewStateH");
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100120 options.add_options()
121 ("new-state-c-dir,C", po::value<std::string>(&m_NewStateCDir)->required(),
122 "Path to directory containing test data for m_NewStateC");
123
Jim Flynne571d332019-04-15 14:34:17 +0100124
125 Model::AddCommandLineOptions(options, m_ModelCommandLineOptions);
126 }
127
Jim Flynnc2ebc632019-04-17 10:16:58 +0100128 virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override
Jim Flynne571d332019-04-15 14:34:17 +0100129 {
130 if (!ValidateDirectory(m_InputSeqDir))
131 {
132 return false;
133 }
134
135 if (!ValidateDirectory(m_PrevStateCDir))
136 {
137 return false;
138 }
139
140 if (!ValidateDirectory(m_PrevStateHDir))
141 {
142 return false;
143 }
144
145 if (!ValidateDirectory(m_LogitsDir))
146 {
147 return false;
148 }
149
150 if (!ValidateDirectory(m_NewStateCDir))
151 {
152 return false;
153 }
154
155 if (!ValidateDirectory(m_NewStateHDir))
156 {
157 return false;
158 }
159
Jim Flynnc2ebc632019-04-17 10:16:58 +0100160 m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
Jim Flynne571d332019-04-15 14:34:17 +0100161 if (!m_Model)
162 {
163 return false;
164 }
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100165 m_Database = std::make_unique<DeepSpeechV1Database>(m_InputSeqDir.c_str(), m_PrevStateHDir.c_str(),
166 m_PrevStateCDir.c_str(), m_LogitsDir.c_str(),
167 m_NewStateHDir.c_str(), m_NewStateCDir.c_str());
Jim Flynne571d332019-04-15 14:34:17 +0100168 if (!m_Database)
169 {
170 return false;
171 }
172
173 return true;
174 }
175
176 std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override
177 {
178 std::unique_ptr<DeepSpeechV1TestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
179 if (!testCaseData)
180 {
181 return nullptr;
182 }
183
184 return std::make_unique<DeepSpeechV1TestCase<Model>>(*m_Model, testCaseId, *testCaseData);
185 }
186
187private:
188 typename Model::CommandLineOptions m_ModelCommandLineOptions;
Jim Flynnc2ebc632019-04-17 10:16:58 +0100189 std::function<std::unique_ptr<Model>(const InferenceTestOptions&,
190 typename Model::CommandLineOptions)> m_ConstructModel;
Jim Flynne571d332019-04-15 14:34:17 +0100191 std::unique_ptr<Model> m_Model;
192
193 std::string m_InputSeqDir;
194 std::string m_PrevStateCDir;
195 std::string m_PrevStateHDir;
196 std::string m_LogitsDir;
197 std::string m_NewStateCDir;
198 std::string m_NewStateHDir;
199
200 std::unique_ptr<DeepSpeechV1Database> m_Database;
201};
202
203} // anonymous namespace