blob: ac799cb45af2bd4d23db0e5acc8beded28a9e19e [file] [log] [blame]
Jim Flynne571d332019-04-15 14:34:17 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "InferenceTest.hpp"
8#include "DeepSpeechV1Database.hpp"
9
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010010#include <armnn/utility/Assert.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000011#include <armnn/utility/IgnoreUnused.hpp>
12
Jim Flynne571d332019-04-15 14:34:17 +010013#include <boost/test/tools/floating_point_comparison.hpp>
14
15#include <vector>
16
17namespace
18{
19
20template<typename Model>
21class DeepSpeechV1TestCase : public InferenceModelTestCase<Model>
22{
23public:
24 DeepSpeechV1TestCase(Model& model,
25 unsigned int testCaseId,
26 const DeepSpeechV1TestCaseData& testCaseData)
27 : InferenceModelTestCase<Model>(model,
28 testCaseId,
29 { testCaseData.m_InputData.m_InputSeq,
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +010030 testCaseData.m_InputData.m_StateH,
31 testCaseData.m_InputData.m_StateC},
Jim Flynne571d332019-04-15 14:34:17 +010032 { k_OutputSize1, k_OutputSize2, k_OutputSize3 })
33 , m_FloatComparer(boost::math::fpc::percent_tolerance(1.0f))
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +010034 , m_ExpectedOutputs({testCaseData.m_ExpectedOutputData.m_InputSeq, testCaseData.m_ExpectedOutputData.m_StateH,
35 testCaseData.m_ExpectedOutputData.m_StateC})
Jim Flynne571d332019-04-15 14:34:17 +010036 {}
37
38 TestCaseResult ProcessResult(const InferenceTestOptions& options) override
39 {
Jan Eilers8eb25602020-03-09 12:13:48 +000040 armnn::IgnoreUnused(options);
James Ward6d9f5c52020-09-28 11:56:35 +010041 const std::vector<float>& output1 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[0]); // logits
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010042 ARMNN_ASSERT(output1.size() == k_OutputSize1);
Jim Flynne571d332019-04-15 14:34:17 +010043
James Ward6d9f5c52020-09-28 11:56:35 +010044 const std::vector<float>& output2 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[1]); // new_state_c
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010045 ARMNN_ASSERT(output2.size() == k_OutputSize2);
Jim Flynne571d332019-04-15 14:34:17 +010046
James Ward6d9f5c52020-09-28 11:56:35 +010047 const std::vector<float>& output3 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[2]); // new_state_h
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010048 ARMNN_ASSERT(output3.size() == k_OutputSize3);
Jim Flynne571d332019-04-15 14:34:17 +010049
50 // Check each output to see whether it is the expected value
51 for (unsigned int j = 0u; j < output1.size(); j++)
52 {
53 if(!m_FloatComparer(output1[j], m_ExpectedOutputs.m_InputSeq[j]))
54 {
Derek Lamberti08446972019-11-26 16:38:31 +000055 ARMNN_LOG(error) << "InputSeq for Lstm " << this->GetTestCaseId() <<
Jim Flynne571d332019-04-15 14:34:17 +010056 " is incorrect at" << j;
57 return TestCaseResult::Failed;
58 }
59 }
60
61 for (unsigned int j = 0u; j < output2.size(); j++)
62 {
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +010063 if(!m_FloatComparer(output2[j], m_ExpectedOutputs.m_StateH[j]))
Jim Flynne571d332019-04-15 14:34:17 +010064 {
Derek Lamberti08446972019-11-26 16:38:31 +000065 ARMNN_LOG(error) << "StateH for Lstm " << this->GetTestCaseId() <<
Jim Flynne571d332019-04-15 14:34:17 +010066 " is incorrect";
67 return TestCaseResult::Failed;
68 }
69 }
70
71 for (unsigned int j = 0u; j < output3.size(); j++)
72 {
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +010073 if(!m_FloatComparer(output3[j], m_ExpectedOutputs.m_StateC[j]))
Jim Flynne571d332019-04-15 14:34:17 +010074 {
Derek Lamberti08446972019-11-26 16:38:31 +000075 ARMNN_LOG(error) << "StateC for Lstm " << this->GetTestCaseId() <<
Jim Flynne571d332019-04-15 14:34:17 +010076 " is incorrect";
77 return TestCaseResult::Failed;
78 }
79 }
80 return TestCaseResult::Ok;
81 }
82
83private:
84
85 static constexpr unsigned int k_OutputSize1 = 464u;
86 static constexpr unsigned int k_OutputSize2 = 2048u;
87 static constexpr unsigned int k_OutputSize3 = 2048u;
88
89 boost::math::fpc::close_at_tolerance<float> m_FloatComparer;
90 LstmInput m_ExpectedOutputs;
91};
92
93template <typename Model>
94class DeepSpeechV1TestCaseProvider : public IInferenceTestCaseProvider
95{
96public:
97 template <typename TConstructModelCallable>
98 explicit DeepSpeechV1TestCaseProvider(TConstructModelCallable constructModel)
99 : m_ConstructModel(constructModel)
100 {}
101
James Wardc89829f2020-10-12 14:17:36 +0100102 virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required) override
Jim Flynne571d332019-04-15 14:34:17 +0100103 {
James Wardc89829f2020-10-12 14:17:36 +0100104 options
105 .allow_unrecognised_options()
106 .add_options()
107 ("s,input-seq-dir", "Path to directory containing test data for m_InputSeq",
108 cxxopts::value<std::string>(m_InputSeqDir))
109 ("h,prev-state-h-dir", "Path to directory containing test data for m_PrevStateH",
110 cxxopts::value<std::string>(m_PrevStateHDir))
111 ("c,prev-state-c-dir", "Path to directory containing test data for m_PrevStateC",
112 cxxopts::value<std::string>(m_PrevStateCDir))
113 ("l,logits-dir", "Path to directory containing test data for m_Logits",
114 cxxopts::value<std::string>(m_LogitsDir))
115 ("H,new-state-h-dir", "Path to directory containing test data for m_NewStateH",
116 cxxopts::value<std::string>(m_NewStateHDir))
117 ("C,new-state-c-dir", "Path to directory containing test data for m_NewStateC",
118 cxxopts::value<std::string>(m_NewStateCDir));
Jim Flynne571d332019-04-15 14:34:17 +0100119
James Wardc89829f2020-10-12 14:17:36 +0100120 required.insert(required.end(), {"input-seq-dir", "prev-state-h-dir", "prev-state-c-dir", "logits-dir",
121 "new-state-h-dir", "new-state-c-dir"});
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100122
James Wardc89829f2020-10-12 14:17:36 +0100123 Model::AddCommandLineOptions(options, m_ModelCommandLineOptions, required);
Jim Flynne571d332019-04-15 14:34:17 +0100124 }
125
Jim Flynnc2ebc632019-04-17 10:16:58 +0100126 virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override
Jim Flynne571d332019-04-15 14:34:17 +0100127 {
128 if (!ValidateDirectory(m_InputSeqDir))
129 {
130 return false;
131 }
132
133 if (!ValidateDirectory(m_PrevStateCDir))
134 {
135 return false;
136 }
137
138 if (!ValidateDirectory(m_PrevStateHDir))
139 {
140 return false;
141 }
142
143 if (!ValidateDirectory(m_LogitsDir))
144 {
145 return false;
146 }
147
148 if (!ValidateDirectory(m_NewStateCDir))
149 {
150 return false;
151 }
152
153 if (!ValidateDirectory(m_NewStateHDir))
154 {
155 return false;
156 }
157
Jim Flynnc2ebc632019-04-17 10:16:58 +0100158 m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
Jim Flynne571d332019-04-15 14:34:17 +0100159 if (!m_Model)
160 {
161 return false;
162 }
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100163 m_Database = std::make_unique<DeepSpeechV1Database>(m_InputSeqDir.c_str(), m_PrevStateHDir.c_str(),
164 m_PrevStateCDir.c_str(), m_LogitsDir.c_str(),
165 m_NewStateHDir.c_str(), m_NewStateCDir.c_str());
Jim Flynne571d332019-04-15 14:34:17 +0100166 if (!m_Database)
167 {
168 return false;
169 }
170
171 return true;
172 }
173
174 std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override
175 {
176 std::unique_ptr<DeepSpeechV1TestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
177 if (!testCaseData)
178 {
179 return nullptr;
180 }
181
182 return std::make_unique<DeepSpeechV1TestCase<Model>>(*m_Model, testCaseId, *testCaseData);
183 }
184
185private:
186 typename Model::CommandLineOptions m_ModelCommandLineOptions;
Jim Flynnc2ebc632019-04-17 10:16:58 +0100187 std::function<std::unique_ptr<Model>(const InferenceTestOptions&,
188 typename Model::CommandLineOptions)> m_ConstructModel;
Jim Flynne571d332019-04-15 14:34:17 +0100189 std::unique_ptr<Model> m_Model;
190
191 std::string m_InputSeqDir;
192 std::string m_PrevStateCDir;
193 std::string m_PrevStateHDir;
194 std::string m_LogitsDir;
195 std::string m_NewStateCDir;
196 std::string m_NewStateHDir;
197
198 std::unique_ptr<DeepSpeechV1Database> m_Database;
199};
200
201} // anonymous namespace