blob: 49941e5669de974fcb08ff6357dbde3037de7a25 [file] [log] [blame]
surmeh0176660052018-03-29 16:33:54 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5
6#define LOG_TAG "ArmnnDriverUtilsTests"
7//#define BOOST_TEST_MODULE armnn_driver_utils_tests
8#include <boost/test/unit_test.hpp>
9#include <log/log.h>
10
11#include "../ArmnnDriver.hpp"
12#include "../Utils.hpp"
13
14#include <fstream>
15#include <iomanip>
16#include <boost/format.hpp>
17#include <armnn/INetwork.hpp>
18
19BOOST_AUTO_TEST_SUITE(UtilsTests)
20
21using namespace armnn_driver;
22using namespace android::nn;
23using namespace android;
24
25// The following are helpers for writing unit tests for the driver.
26namespace
27{
28
29struct ExportNetworkGraphFixture
30{
31public:
32 // Setup: set the output dump directory and an empty dummy model (as only its memory address is used).
33 // Defaulting the output dump directory to "/sdcard" because it should exist and be writable in all deployments.
34 ExportNetworkGraphFixture()
35 : ExportNetworkGraphFixture("/sdcard")
36 {}
37 ExportNetworkGraphFixture(const std::string& requestInputsAndOutputsDumpDir)
38 : m_RequestInputsAndOutputsDumpDir(requestInputsAndOutputsDumpDir)
39 , m_Model({})
40 , m_FileName()
41 , m_FileStream()
42 {
43 // Get the memory address of the model and convert it to a hex string (of at least a '0' character).
44 size_t modelAddress = uintptr_t(&m_Model);
45 std::stringstream ss;
46 ss << std::uppercase << std::hex << std::setfill('0') << std::setw(1) << modelAddress;
47 std::string modelAddressHexString = ss.str();
48
49 // Set the name of the output .dot file.
50 m_FileName = boost::str(boost::format("%1%/networkgraph_%2%.dot")
51 % m_RequestInputsAndOutputsDumpDir
52 % modelAddressHexString);
53 }
54
55 // Teardown: delete the dump file regardless of the outcome of the tests.
56 ~ExportNetworkGraphFixture()
57 {
58 // Close the file stream.
59 m_FileStream.close();
60
61 // Ignore any error (such as file not found).
62 remove(m_FileName.c_str());
63 }
64
65 bool FileExists()
66 {
67 // Close any file opened in a previous session.
68 if (m_FileStream.is_open())
69 {
70 m_FileStream.close();
71 }
72
73 // Open the file.
74 m_FileStream.open(m_FileName, std::ifstream::in);
75
76 // Check that the file is open.
77 if (!m_FileStream.is_open())
78 {
79 return false;
80 }
81
82 // Check that the stream is readable.
83 return m_FileStream.good();
84 }
85
86 std::string GetFileContent()
87 {
88 // Check that the stream is readable.
89 if (!m_FileStream.good())
90 {
91 return "";
92 }
93
94 // Get all the contents of the file.
95 return std::string((std::istreambuf_iterator<char>(m_FileStream)),
96 (std::istreambuf_iterator<char>()));
97 }
98
99 std::string m_RequestInputsAndOutputsDumpDir;
100 Model m_Model;
101
102private:
103 std::string m_FileName;
104 std::ifstream m_FileStream;
105};
106
107class MockOptimizedNetwork final : public armnn::IOptimizedNetwork
108{
109public:
110 MockOptimizedNetwork(const std::string& mockSerializedContent)
111 : m_MockSerializedContent(mockSerializedContent)
112 {}
113 ~MockOptimizedNetwork() {}
114
115 armnn::Status PrintGraph() override { return armnn::Status::Failure; }
116 armnn::Status SerializeToDot(std::ostream& stream) const override
117 {
118 stream << m_MockSerializedContent;
119
120 return stream.good() ? armnn::Status::Success : armnn::Status::Failure;
121 }
122
123 void UpdateMockSerializedContent(const std::string& mockSerializedContent)
124 {
125 this->m_MockSerializedContent = mockSerializedContent;
126 }
127
128private:
129 std::string m_MockSerializedContent;
130};
131
132} // namespace
133
134BOOST_AUTO_TEST_CASE(ExportToEmptyDirectory)
135{
136 // Set the fixture for this test.
137 ExportNetworkGraphFixture fixture("");
138
139 // Set a mock content for the optimized network.
140 std::string mockSerializedContent = "This is a mock serialized content.";
141
142 // Set a mock optimized network.
143 MockOptimizedNetwork mockOptimizedNetwork(mockSerializedContent);
144
145 // Export the mock optimized network.
146 armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork,
147 fixture.m_RequestInputsAndOutputsDumpDir,
148 fixture.m_Model);
149
150 // Check that the output file does not exist.
151 BOOST_TEST(!fixture.FileExists());
152}
153
154BOOST_AUTO_TEST_CASE(ExportNetwork)
155{
156 // Set the fixture for this test.
157 ExportNetworkGraphFixture fixture;
158
159 // Set a mock content for the optimized network.
160 std::string mockSerializedContent = "This is a mock serialized content.";
161
162 // Set a mock optimized network.
163 MockOptimizedNetwork mockOptimizedNetwork(mockSerializedContent);
164
165 // Export the mock optimized network.
166 armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork,
167 fixture.m_RequestInputsAndOutputsDumpDir,
168 fixture.m_Model);
169
170 // Check that the output file exists and that it has the correct name.
171 BOOST_TEST(fixture.FileExists());
172
173 // Check that the content of the output file matches the mock content.
174 BOOST_TEST(fixture.GetFileContent() == mockSerializedContent);
175}
176
177BOOST_AUTO_TEST_CASE(ExportNetworkOverwriteFile)
178{
179 // Set the fixture for this test.
180 ExportNetworkGraphFixture fixture;
181
182 // Set a mock content for the optimized network.
183 std::string mockSerializedContent = "This is a mock serialized content.";
184
185 // Set a mock optimized network.
186 MockOptimizedNetwork mockOptimizedNetwork(mockSerializedContent);
187
188 // Export the mock optimized network.
189 armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork,
190 fixture.m_RequestInputsAndOutputsDumpDir,
191 fixture.m_Model);
192
193 // Check that the output file exists and that it has the correct name.
194 BOOST_TEST(fixture.FileExists());
195
196 // Check that the content of the output file matches the mock content.
197 BOOST_TEST(fixture.GetFileContent() == mockSerializedContent);
198
199 // Update the mock serialized content of the network.
200 mockSerializedContent = "This is ANOTHER mock serialized content!";
201 mockOptimizedNetwork.UpdateMockSerializedContent(mockSerializedContent);
202
203 // Export the mock optimized network.
204 armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork,
205 fixture.m_RequestInputsAndOutputsDumpDir,
206 fixture.m_Model);
207
208 // Check that the output file still exists and that it has the correct name.
209 BOOST_TEST(fixture.FileExists());
210
211 // Check that the content of the output file matches the mock content.
212 BOOST_TEST(fixture.GetFileContent() == mockSerializedContent);
213}
214
215BOOST_AUTO_TEST_CASE(ExportMultipleNetworks)
216{
217 // Set the fixtures for this test.
218 ExportNetworkGraphFixture fixture1;
219 ExportNetworkGraphFixture fixture2;
220 ExportNetworkGraphFixture fixture3;
221
222 // Set a mock content for the optimized network.
223 std::string mockSerializedContent = "This is a mock serialized content.";
224
225 // Set a mock optimized network.
226 MockOptimizedNetwork mockOptimizedNetwork(mockSerializedContent);
227
228 // Export the mock optimized network.
229 armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork,
230 fixture1.m_RequestInputsAndOutputsDumpDir,
231 fixture1.m_Model);
232
233 // Check that the output file exists and that it has the correct name.
234 BOOST_TEST(fixture1.FileExists());
235
236 // Check that the content of the output file matches the mock content.
237 BOOST_TEST(fixture1.GetFileContent() == mockSerializedContent);
238
239 // Export the mock optimized network.
240 armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork,
241 fixture2.m_RequestInputsAndOutputsDumpDir,
242 fixture2.m_Model);
243
244 // Check that the output file exists and that it has the correct name.
245 BOOST_TEST(fixture2.FileExists());
246
247 // Check that the content of the output file matches the mock content.
248 BOOST_TEST(fixture2.GetFileContent() == mockSerializedContent);
249
250 // Export the mock optimized network.
251 armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork,
252 fixture3.m_RequestInputsAndOutputsDumpDir,
253 fixture3.m_Model);
254 // Check that the output file exists and that it has the correct name.
255 BOOST_TEST(fixture3.FileExists());
256
257 // Check that the content of the output file matches the mock content.
258 BOOST_TEST(fixture3.GetFileContent() == mockSerializedContent);
259}
260
261BOOST_AUTO_TEST_SUITE_END()