blob: 9115a8f351ed9b35ef29d0cc8ec7cfd2c4151e99 [file] [log] [blame]
Jim Flynne571d332019-04-15 14:34:17 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "InferenceTest.hpp"
8#include "DeepSpeechV1Database.hpp"
9
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010010#include <armnn/utility/Assert.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000011#include <armnn/utility/IgnoreUnused.hpp>
Colm Donelan9a5ce4a2020-10-29 11:39:14 +000012#include <armnnUtils/FloatingPointComparison.hpp>
Jim Flynne571d332019-04-15 14:34:17 +010013
14#include <vector>
15
16namespace
17{
18
19template<typename Model>
20class DeepSpeechV1TestCase : public InferenceModelTestCase<Model>
21{
22public:
23 DeepSpeechV1TestCase(Model& model,
24 unsigned int testCaseId,
25 const DeepSpeechV1TestCaseData& testCaseData)
26 : InferenceModelTestCase<Model>(model,
27 testCaseId,
28 { testCaseData.m_InputData.m_InputSeq,
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +010029 testCaseData.m_InputData.m_StateH,
30 testCaseData.m_InputData.m_StateC},
Jim Flynne571d332019-04-15 14:34:17 +010031 { k_OutputSize1, k_OutputSize2, k_OutputSize3 })
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +010032 , m_ExpectedOutputs({testCaseData.m_ExpectedOutputData.m_InputSeq, testCaseData.m_ExpectedOutputData.m_StateH,
33 testCaseData.m_ExpectedOutputData.m_StateC})
Jim Flynne571d332019-04-15 14:34:17 +010034 {}
35
36 TestCaseResult ProcessResult(const InferenceTestOptions& options) override
37 {
Jan Eilers8eb25602020-03-09 12:13:48 +000038 armnn::IgnoreUnused(options);
James Ward6d9f5c52020-09-28 11:56:35 +010039 const std::vector<float>& output1 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[0]); // logits
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010040 ARMNN_ASSERT(output1.size() == k_OutputSize1);
Jim Flynne571d332019-04-15 14:34:17 +010041
James Ward6d9f5c52020-09-28 11:56:35 +010042 const std::vector<float>& output2 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[1]); // new_state_c
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010043 ARMNN_ASSERT(output2.size() == k_OutputSize2);
Jim Flynne571d332019-04-15 14:34:17 +010044
James Ward6d9f5c52020-09-28 11:56:35 +010045 const std::vector<float>& output3 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[2]); // new_state_h
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010046 ARMNN_ASSERT(output3.size() == k_OutputSize3);
Jim Flynne571d332019-04-15 14:34:17 +010047
48 // Check each output to see whether it is the expected value
49 for (unsigned int j = 0u; j < output1.size(); j++)
50 {
Colm Donelan9a5ce4a2020-10-29 11:39:14 +000051 if(!armnnUtils::within_percentage_tolerance(output1[j], m_ExpectedOutputs.m_InputSeq[j]))
Jim Flynne571d332019-04-15 14:34:17 +010052 {
Derek Lamberti08446972019-11-26 16:38:31 +000053 ARMNN_LOG(error) << "InputSeq for Lstm " << this->GetTestCaseId() <<
Jim Flynne571d332019-04-15 14:34:17 +010054 " is incorrect at" << j;
55 return TestCaseResult::Failed;
56 }
57 }
58
59 for (unsigned int j = 0u; j < output2.size(); j++)
60 {
Colm Donelan9a5ce4a2020-10-29 11:39:14 +000061 if(!armnnUtils::within_percentage_tolerance(output2[j], m_ExpectedOutputs.m_StateH[j]))
Jim Flynne571d332019-04-15 14:34:17 +010062 {
Derek Lamberti08446972019-11-26 16:38:31 +000063 ARMNN_LOG(error) << "StateH for Lstm " << this->GetTestCaseId() <<
Jim Flynne571d332019-04-15 14:34:17 +010064 " is incorrect";
65 return TestCaseResult::Failed;
66 }
67 }
68
69 for (unsigned int j = 0u; j < output3.size(); j++)
70 {
Colm Donelan9a5ce4a2020-10-29 11:39:14 +000071 if(!armnnUtils::within_percentage_tolerance(output3[j], m_ExpectedOutputs.m_StateC[j]))
Jim Flynne571d332019-04-15 14:34:17 +010072 {
Derek Lamberti08446972019-11-26 16:38:31 +000073 ARMNN_LOG(error) << "StateC for Lstm " << this->GetTestCaseId() <<
Jim Flynne571d332019-04-15 14:34:17 +010074 " is incorrect";
75 return TestCaseResult::Failed;
76 }
77 }
78 return TestCaseResult::Ok;
79 }
80
81private:
82
83 static constexpr unsigned int k_OutputSize1 = 464u;
84 static constexpr unsigned int k_OutputSize2 = 2048u;
85 static constexpr unsigned int k_OutputSize3 = 2048u;
86
Jim Flynne571d332019-04-15 14:34:17 +010087 LstmInput m_ExpectedOutputs;
88};
89
90template <typename Model>
91class DeepSpeechV1TestCaseProvider : public IInferenceTestCaseProvider
92{
93public:
94 template <typename TConstructModelCallable>
95 explicit DeepSpeechV1TestCaseProvider(TConstructModelCallable constructModel)
96 : m_ConstructModel(constructModel)
97 {}
98
James Wardc89829f2020-10-12 14:17:36 +010099 virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required) override
Jim Flynne571d332019-04-15 14:34:17 +0100100 {
James Wardc89829f2020-10-12 14:17:36 +0100101 options
102 .allow_unrecognised_options()
103 .add_options()
104 ("s,input-seq-dir", "Path to directory containing test data for m_InputSeq",
105 cxxopts::value<std::string>(m_InputSeqDir))
106 ("h,prev-state-h-dir", "Path to directory containing test data for m_PrevStateH",
107 cxxopts::value<std::string>(m_PrevStateHDir))
108 ("c,prev-state-c-dir", "Path to directory containing test data for m_PrevStateC",
109 cxxopts::value<std::string>(m_PrevStateCDir))
110 ("l,logits-dir", "Path to directory containing test data for m_Logits",
111 cxxopts::value<std::string>(m_LogitsDir))
112 ("H,new-state-h-dir", "Path to directory containing test data for m_NewStateH",
113 cxxopts::value<std::string>(m_NewStateHDir))
114 ("C,new-state-c-dir", "Path to directory containing test data for m_NewStateC",
115 cxxopts::value<std::string>(m_NewStateCDir));
Jim Flynne571d332019-04-15 14:34:17 +0100116
James Wardc89829f2020-10-12 14:17:36 +0100117 required.insert(required.end(), {"input-seq-dir", "prev-state-h-dir", "prev-state-c-dir", "logits-dir",
118 "new-state-h-dir", "new-state-c-dir"});
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100119
James Wardc89829f2020-10-12 14:17:36 +0100120 Model::AddCommandLineOptions(options, m_ModelCommandLineOptions, required);
Jim Flynne571d332019-04-15 14:34:17 +0100121 }
122
Jim Flynnc2ebc632019-04-17 10:16:58 +0100123 virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override
Jim Flynne571d332019-04-15 14:34:17 +0100124 {
125 if (!ValidateDirectory(m_InputSeqDir))
126 {
127 return false;
128 }
129
130 if (!ValidateDirectory(m_PrevStateCDir))
131 {
132 return false;
133 }
134
135 if (!ValidateDirectory(m_PrevStateHDir))
136 {
137 return false;
138 }
139
140 if (!ValidateDirectory(m_LogitsDir))
141 {
142 return false;
143 }
144
145 if (!ValidateDirectory(m_NewStateCDir))
146 {
147 return false;
148 }
149
150 if (!ValidateDirectory(m_NewStateHDir))
151 {
152 return false;
153 }
154
Jim Flynnc2ebc632019-04-17 10:16:58 +0100155 m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
Jim Flynne571d332019-04-15 14:34:17 +0100156 if (!m_Model)
157 {
158 return false;
159 }
Narumol Prangnawarat04a8b052019-04-26 13:48:57 +0100160 m_Database = std::make_unique<DeepSpeechV1Database>(m_InputSeqDir.c_str(), m_PrevStateHDir.c_str(),
161 m_PrevStateCDir.c_str(), m_LogitsDir.c_str(),
162 m_NewStateHDir.c_str(), m_NewStateCDir.c_str());
Jim Flynne571d332019-04-15 14:34:17 +0100163 if (!m_Database)
164 {
165 return false;
166 }
167
168 return true;
169 }
170
171 std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override
172 {
173 std::unique_ptr<DeepSpeechV1TestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
174 if (!testCaseData)
175 {
176 return nullptr;
177 }
178
179 return std::make_unique<DeepSpeechV1TestCase<Model>>(*m_Model, testCaseId, *testCaseData);
180 }
181
182private:
183 typename Model::CommandLineOptions m_ModelCommandLineOptions;
Jim Flynnc2ebc632019-04-17 10:16:58 +0100184 std::function<std::unique_ptr<Model>(const InferenceTestOptions&,
185 typename Model::CommandLineOptions)> m_ConstructModel;
Jim Flynne571d332019-04-15 14:34:17 +0100186 std::unique_ptr<Model> m_Model;
187
188 std::string m_InputSeqDir;
189 std::string m_PrevStateCDir;
190 std::string m_PrevStateHDir;
191 std::string m_LogitsDir;
192 std::string m_NewStateCDir;
193 std::string m_NewStateHDir;
194
195 std::unique_ptr<DeepSpeechV1Database> m_Database;
196};
197
198} // anonymous namespace