blob: 6ac1ebb08bdbcc211179bbc47a37e58df0bb246f [file] [log] [blame]
surmeh0176660052018-03-29 16:33:54 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beck93e48982018-09-05 13:05:09 +01003// SPDX-License-Identifier: MIT
surmeh0176660052018-03-29 16:33:54 +01004//
5
surmeh0149b9e102018-05-17 14:11:25 +01006#include "DriverTestHelpers.hpp"
surmeh0176660052018-03-29 16:33:54 +01007#include <boost/test/unit_test.hpp>
8#include <log/log.h>
9
surmeh0176660052018-03-29 16:33:54 +010010#include "../Utils.hpp"
11
12#include <fstream>
13#include <iomanip>
14#include <boost/format.hpp>
15#include <armnn/INetwork.hpp>
16
17BOOST_AUTO_TEST_SUITE(UtilsTests)
18
surmeh0176660052018-03-29 16:33:54 +010019using namespace android;
telsoa01ce3e84a2018-08-31 09:31:35 +010020using namespace android::nn;
21using namespace android::hardware;
22using namespace armnn_driver;
surmeh0176660052018-03-29 16:33:54 +010023
24// The following are helpers for writing unit tests for the driver.
25namespace
26{
27
28struct ExportNetworkGraphFixture
29{
30public:
31 // Setup: set the output dump directory and an empty dummy model (as only its memory address is used).
telsoa01ce3e84a2018-08-31 09:31:35 +010032 // Defaulting the output dump directory to "/data" because it should exist and be writable in all deployments.
surmeh0176660052018-03-29 16:33:54 +010033 ExportNetworkGraphFixture()
telsoa01ce3e84a2018-08-31 09:31:35 +010034 : ExportNetworkGraphFixture("/data")
surmeh0176660052018-03-29 16:33:54 +010035 {}
36 ExportNetworkGraphFixture(const std::string& requestInputsAndOutputsDumpDir)
37 : m_RequestInputsAndOutputsDumpDir(requestInputsAndOutputsDumpDir)
38 , m_Model({})
39 , m_FileName()
40 , m_FileStream()
41 {
42 // Get the memory address of the model and convert it to a hex string (of at least a '0' character).
43 size_t modelAddress = uintptr_t(&m_Model);
44 std::stringstream ss;
45 ss << std::uppercase << std::hex << std::setfill('0') << std::setw(1) << modelAddress;
46 std::string modelAddressHexString = ss.str();
47
48 // Set the name of the output .dot file.
49 m_FileName = boost::str(boost::format("%1%/networkgraph_%2%.dot")
50 % m_RequestInputsAndOutputsDumpDir
51 % modelAddressHexString);
52 }
53
54 // Teardown: delete the dump file regardless of the outcome of the tests.
55 ~ExportNetworkGraphFixture()
56 {
57 // Close the file stream.
58 m_FileStream.close();
59
60 // Ignore any error (such as file not found).
surmeh0149b9e102018-05-17 14:11:25 +010061 (void)remove(m_FileName.c_str());
surmeh0176660052018-03-29 16:33:54 +010062 }
63
64 bool FileExists()
65 {
66 // Close any file opened in a previous session.
67 if (m_FileStream.is_open())
68 {
69 m_FileStream.close();
70 }
71
72 // Open the file.
73 m_FileStream.open(m_FileName, std::ifstream::in);
74
75 // Check that the file is open.
76 if (!m_FileStream.is_open())
77 {
78 return false;
79 }
80
81 // Check that the stream is readable.
82 return m_FileStream.good();
83 }
84
85 std::string GetFileContent()
86 {
87 // Check that the stream is readable.
88 if (!m_FileStream.good())
89 {
90 return "";
91 }
92
93 // Get all the contents of the file.
94 return std::string((std::istreambuf_iterator<char>(m_FileStream)),
95 (std::istreambuf_iterator<char>()));
96 }
97
98 std::string m_RequestInputsAndOutputsDumpDir;
Matteo Martincigh8b287c22018-09-07 09:25:10 +010099 V1_0::Model m_Model;
surmeh0176660052018-03-29 16:33:54 +0100100
101private:
102 std::string m_FileName;
103 std::ifstream m_FileStream;
104};
105
106class MockOptimizedNetwork final : public armnn::IOptimizedNetwork
107{
108public:
109 MockOptimizedNetwork(const std::string& mockSerializedContent)
110 : m_MockSerializedContent(mockSerializedContent)
111 {}
112 ~MockOptimizedNetwork() {}
113
114 armnn::Status PrintGraph() override { return armnn::Status::Failure; }
115 armnn::Status SerializeToDot(std::ostream& stream) const override
116 {
117 stream << m_MockSerializedContent;
118
119 return stream.good() ? armnn::Status::Success : armnn::Status::Failure;
120 }
121
122 void UpdateMockSerializedContent(const std::string& mockSerializedContent)
123 {
124 this->m_MockSerializedContent = mockSerializedContent;
125 }
126
127private:
128 std::string m_MockSerializedContent;
129};
130
131} // namespace
132
133BOOST_AUTO_TEST_CASE(ExportToEmptyDirectory)
134{
135 // Set the fixture for this test.
136 ExportNetworkGraphFixture fixture("");
137
138 // Set a mock content for the optimized network.
139 std::string mockSerializedContent = "This is a mock serialized content.";
140
141 // Set a mock optimized network.
142 MockOptimizedNetwork mockOptimizedNetwork(mockSerializedContent);
143
144 // Export the mock optimized network.
145 armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork,
146 fixture.m_RequestInputsAndOutputsDumpDir,
147 fixture.m_Model);
148
149 // Check that the output file does not exist.
150 BOOST_TEST(!fixture.FileExists());
151}
152
153BOOST_AUTO_TEST_CASE(ExportNetwork)
154{
155 // Set the fixture for this test.
156 ExportNetworkGraphFixture fixture;
157
158 // Set a mock content for the optimized network.
159 std::string mockSerializedContent = "This is a mock serialized content.";
160
161 // Set a mock optimized network.
162 MockOptimizedNetwork mockOptimizedNetwork(mockSerializedContent);
163
164 // Export the mock optimized network.
165 armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork,
166 fixture.m_RequestInputsAndOutputsDumpDir,
167 fixture.m_Model);
168
169 // Check that the output file exists and that it has the correct name.
170 BOOST_TEST(fixture.FileExists());
171
172 // Check that the content of the output file matches the mock content.
173 BOOST_TEST(fixture.GetFileContent() == mockSerializedContent);
174}
175
176BOOST_AUTO_TEST_CASE(ExportNetworkOverwriteFile)
177{
178 // Set the fixture for this test.
179 ExportNetworkGraphFixture fixture;
180
181 // Set a mock content for the optimized network.
182 std::string mockSerializedContent = "This is a mock serialized content.";
183
184 // Set a mock optimized network.
185 MockOptimizedNetwork mockOptimizedNetwork(mockSerializedContent);
186
187 // Export the mock optimized network.
188 armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork,
189 fixture.m_RequestInputsAndOutputsDumpDir,
190 fixture.m_Model);
191
192 // Check that the output file exists and that it has the correct name.
193 BOOST_TEST(fixture.FileExists());
194
195 // Check that the content of the output file matches the mock content.
196 BOOST_TEST(fixture.GetFileContent() == mockSerializedContent);
197
198 // Update the mock serialized content of the network.
199 mockSerializedContent = "This is ANOTHER mock serialized content!";
200 mockOptimizedNetwork.UpdateMockSerializedContent(mockSerializedContent);
201
202 // Export the mock optimized network.
203 armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork,
204 fixture.m_RequestInputsAndOutputsDumpDir,
205 fixture.m_Model);
206
207 // Check that the output file still exists and that it has the correct name.
208 BOOST_TEST(fixture.FileExists());
209
210 // Check that the content of the output file matches the mock content.
211 BOOST_TEST(fixture.GetFileContent() == mockSerializedContent);
212}
213
214BOOST_AUTO_TEST_CASE(ExportMultipleNetworks)
215{
216 // Set the fixtures for this test.
217 ExportNetworkGraphFixture fixture1;
218 ExportNetworkGraphFixture fixture2;
219 ExportNetworkGraphFixture fixture3;
220
221 // Set a mock content for the optimized network.
222 std::string mockSerializedContent = "This is a mock serialized content.";
223
224 // Set a mock optimized network.
225 MockOptimizedNetwork mockOptimizedNetwork(mockSerializedContent);
226
227 // Export the mock optimized network.
228 armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork,
229 fixture1.m_RequestInputsAndOutputsDumpDir,
230 fixture1.m_Model);
231
232 // Check that the output file exists and that it has the correct name.
233 BOOST_TEST(fixture1.FileExists());
234
235 // Check that the content of the output file matches the mock content.
236 BOOST_TEST(fixture1.GetFileContent() == mockSerializedContent);
237
238 // Export the mock optimized network.
239 armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork,
240 fixture2.m_RequestInputsAndOutputsDumpDir,
241 fixture2.m_Model);
242
243 // Check that the output file exists and that it has the correct name.
244 BOOST_TEST(fixture2.FileExists());
245
246 // Check that the content of the output file matches the mock content.
247 BOOST_TEST(fixture2.GetFileContent() == mockSerializedContent);
248
249 // Export the mock optimized network.
250 armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork,
251 fixture3.m_RequestInputsAndOutputsDumpDir,
252 fixture3.m_Model);
253 // Check that the output file exists and that it has the correct name.
254 BOOST_TEST(fixture3.FileExists());
255
256 // Check that the content of the output file matches the mock content.
257 BOOST_TEST(fixture3.GetFileContent() == mockSerializedContent);
258}
259
260BOOST_AUTO_TEST_SUITE_END()