blob: c46fa5799f6018c842f10ab3fca95ae171d8c8ce [file] [log] [blame]
Jim Flynne571d332019-04-15 14:34:17 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "InferenceTest.hpp"
8#include "DeepSpeechV1Database.hpp"
9
10#include <boost/assert.hpp>
Jim Flynne571d332019-04-15 14:34:17 +010011#include <boost/numeric/conversion/cast.hpp>
12#include <boost/test/tools/floating_point_comparison.hpp>
13
14#include <vector>
15
16namespace
17{
18
19template<typename Model>
20class DeepSpeechV1TestCase : public InferenceModelTestCase<Model>
21{
22public:
23 DeepSpeechV1TestCase(Model& model,
24 unsigned int testCaseId,
25 const DeepSpeechV1TestCaseData& testCaseData)
26 : InferenceModelTestCase<Model>(model,
27 testCaseId,
28 { testCaseData.m_InputData.m_InputSeq,
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +010029 testCaseData.m_InputData.m_StateH,
30 testCaseData.m_InputData.m_StateC},
Jim Flynne571d332019-04-15 14:34:17 +010031 { k_OutputSize1, k_OutputSize2, k_OutputSize3 })
32 , m_FloatComparer(boost::math::fpc::percent_tolerance(1.0f))
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +010033 , m_ExpectedOutputs({testCaseData.m_ExpectedOutputData.m_InputSeq, testCaseData.m_ExpectedOutputData.m_StateH,
34 testCaseData.m_ExpectedOutputData.m_StateC})
Jim Flynne571d332019-04-15 14:34:17 +010035 {}
36
37 TestCaseResult ProcessResult(const InferenceTestOptions& options) override
38 {
39 const std::vector<float>& output1 = boost::get<std::vector<float>>(this->GetOutputs()[0]); // logits
40 BOOST_ASSERT(output1.size() == k_OutputSize1);
41
42 const std::vector<float>& output2 = boost::get<std::vector<float>>(this->GetOutputs()[1]); // new_state_c
43 BOOST_ASSERT(output2.size() == k_OutputSize2);
44
45 const std::vector<float>& output3 = boost::get<std::vector<float>>(this->GetOutputs()[2]); // new_state_h
46 BOOST_ASSERT(output3.size() == k_OutputSize3);
47
48 // Check each output to see whether it is the expected value
49 for (unsigned int j = 0u; j < output1.size(); j++)
50 {
51 if(!m_FloatComparer(output1[j], m_ExpectedOutputs.m_InputSeq[j]))
52 {
Derek Lamberti08446972019-11-26 16:38:31 +000053 ARMNN_LOG(error) << "InputSeq for Lstm " << this->GetTestCaseId() <<
Jim Flynne571d332019-04-15 14:34:17 +010054 " is incorrect at" << j;
55 return TestCaseResult::Failed;
56 }
57 }
58
59 for (unsigned int j = 0u; j < output2.size(); j++)
60 {
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +010061 if(!m_FloatComparer(output2[j], m_ExpectedOutputs.m_StateH[j]))
Jim Flynne571d332019-04-15 14:34:17 +010062 {
Derek Lamberti08446972019-11-26 16:38:31 +000063 ARMNN_LOG(error) << "StateH for Lstm " << this->GetTestCaseId() <<
Jim Flynne571d332019-04-15 14:34:17 +010064 " is incorrect";
65 return TestCaseResult::Failed;
66 }
67 }
68
69 for (unsigned int j = 0u; j < output3.size(); j++)
70 {
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +010071 if(!m_FloatComparer(output3[j], m_ExpectedOutputs.m_StateC[j]))
Jim Flynne571d332019-04-15 14:34:17 +010072 {
Derek Lamberti08446972019-11-26 16:38:31 +000073 ARMNN_LOG(error) << "StateC for Lstm " << this->GetTestCaseId() <<
Jim Flynne571d332019-04-15 14:34:17 +010074 " is incorrect";
75 return TestCaseResult::Failed;
76 }
77 }
78 return TestCaseResult::Ok;
79 }
80
81private:
82
83 static constexpr unsigned int k_OutputSize1 = 464u;
84 static constexpr unsigned int k_OutputSize2 = 2048u;
85 static constexpr unsigned int k_OutputSize3 = 2048u;
86
87 boost::math::fpc::close_at_tolerance<float> m_FloatComparer;
88 LstmInput m_ExpectedOutputs;
89};
90
91template <typename Model>
92class DeepSpeechV1TestCaseProvider : public IInferenceTestCaseProvider
93{
94public:
95 template <typename TConstructModelCallable>
96 explicit DeepSpeechV1TestCaseProvider(TConstructModelCallable constructModel)
97 : m_ConstructModel(constructModel)
98 {}
99
100 virtual void AddCommandLineOptions(boost::program_options::options_description& options) override
101 {
102 namespace po = boost::program_options;
103
104 options.add_options()
105 ("input-seq-dir,s", po::value<std::string>(&m_InputSeqDir)->required(),
106 "Path to directory containing test data for m_InputSeq");
107 options.add_options()
Jim Flynne571d332019-04-15 14:34:17 +0100108 ("prev-state-h-dir,h", po::value<std::string>(&m_PrevStateHDir)->required(),
109 "Path to directory containing test data for m_PrevStateH");
110 options.add_options()
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100111 ("prev-state-c-dir,c", po::value<std::string>(&m_PrevStateCDir)->required(),
112 "Path to directory containing test data for m_PrevStateC");
113 options.add_options()
Jim Flynne571d332019-04-15 14:34:17 +0100114 ("logits-dir,l", po::value<std::string>(&m_LogitsDir)->required(),
115 "Path to directory containing test data for m_Logits");
116 options.add_options()
Jim Flynne571d332019-04-15 14:34:17 +0100117 ("new-state-h-dir,H", po::value<std::string>(&m_NewStateHDir)->required(),
118 "Path to directory containing test data for m_NewStateH");
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100119 options.add_options()
120 ("new-state-c-dir,C", po::value<std::string>(&m_NewStateCDir)->required(),
121 "Path to directory containing test data for m_NewStateC");
122
Jim Flynne571d332019-04-15 14:34:17 +0100123
124 Model::AddCommandLineOptions(options, m_ModelCommandLineOptions);
125 }
126
Jim Flynnc2ebc632019-04-17 10:16:58 +0100127 virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override
Jim Flynne571d332019-04-15 14:34:17 +0100128 {
129 if (!ValidateDirectory(m_InputSeqDir))
130 {
131 return false;
132 }
133
134 if (!ValidateDirectory(m_PrevStateCDir))
135 {
136 return false;
137 }
138
139 if (!ValidateDirectory(m_PrevStateHDir))
140 {
141 return false;
142 }
143
144 if (!ValidateDirectory(m_LogitsDir))
145 {
146 return false;
147 }
148
149 if (!ValidateDirectory(m_NewStateCDir))
150 {
151 return false;
152 }
153
154 if (!ValidateDirectory(m_NewStateHDir))
155 {
156 return false;
157 }
158
Jim Flynnc2ebc632019-04-17 10:16:58 +0100159 m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
Jim Flynne571d332019-04-15 14:34:17 +0100160 if (!m_Model)
161 {
162 return false;
163 }
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100164 m_Database = std::make_unique<DeepSpeechV1Database>(m_InputSeqDir.c_str(), m_PrevStateHDir.c_str(),
165 m_PrevStateCDir.c_str(), m_LogitsDir.c_str(),
166 m_NewStateHDir.c_str(), m_NewStateCDir.c_str());
Jim Flynne571d332019-04-15 14:34:17 +0100167 if (!m_Database)
168 {
169 return false;
170 }
171
172 return true;
173 }
174
175 std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override
176 {
177 std::unique_ptr<DeepSpeechV1TestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
178 if (!testCaseData)
179 {
180 return nullptr;
181 }
182
183 return std::make_unique<DeepSpeechV1TestCase<Model>>(*m_Model, testCaseId, *testCaseData);
184 }
185
186private:
187 typename Model::CommandLineOptions m_ModelCommandLineOptions;
Jim Flynnc2ebc632019-04-17 10:16:58 +0100188 std::function<std::unique_ptr<Model>(const InferenceTestOptions&,
189 typename Model::CommandLineOptions)> m_ConstructModel;
Jim Flynne571d332019-04-15 14:34:17 +0100190 std::unique_ptr<Model> m_Model;
191
192 std::string m_InputSeqDir;
193 std::string m_PrevStateCDir;
194 std::string m_PrevStateHDir;
195 std::string m_LogitsDir;
196 std::string m_NewStateCDir;
197 std::string m_NewStateHDir;
198
199 std::unique_ptr<DeepSpeechV1Database> m_Database;
200};
201
202} // anonymous namespace