blob: 3616816ae24b3d63168740c1823cc0ff2a64157e [file] [log] [blame]
Ferran Balaguer1b941722019-08-28 16:57:18 +01001//
2// Copyright © 2019 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Matteo Martincigh24e8f922019-09-19 11:57:46 +01008#include <SendCounterPacket.hpp>
9#include <ProfilingUtils.hpp>
Ferran Balaguer1b941722019-08-28 16:57:18 +010010
11#include <armnn/Exceptions.hpp>
Matteo Martincigh24e8f922019-09-19 11:57:46 +010012#include <armnn/Optional.hpp>
13#include <armnn/Conversion.hpp>
Ferran Balaguer1b941722019-08-28 16:57:18 +010014
Matteo Martincigh24e8f922019-09-19 11:57:46 +010015#include <boost/numeric/conversion/cast.hpp>
Ferran Balaguer1b941722019-08-28 16:57:18 +010016
Matteo Martincigh24e8f922019-09-19 11:57:46 +010017namespace armnn
18{
Ferran Balaguer1b941722019-08-28 16:57:18 +010019
Matteo Martincigh24e8f922019-09-19 11:57:46 +010020namespace profiling
21{
22
23class MockProfilingConnection : public IProfilingConnection
24{
25public:
26 MockProfilingConnection()
27 : m_IsOpen(true)
28 {}
29
30 bool IsOpen() override { return m_IsOpen; }
31
32 void Close() override { m_IsOpen = false; }
33
34 bool WritePacket(const unsigned char* buffer, uint32_t length) override
35 {
36 return buffer != nullptr && length > 0;
37 }
38
39 Packet ReadPacket(uint32_t timeout) override { return Packet(); }
40
41private:
42 bool m_IsOpen;
43};
Ferran Balaguer1b941722019-08-28 16:57:18 +010044
45class MockBuffer : public IBufferWrapper
46{
47public:
48 MockBuffer(unsigned int size)
Matteo Martincigh24e8f922019-09-19 11:57:46 +010049 : m_BufferSize(size)
50 , m_Buffer(std::make_unique<unsigned char[]>(size))
51 {}
Ferran Balaguer1b941722019-08-28 16:57:18 +010052
53 unsigned char* Reserve(unsigned int requestedSize, unsigned int& reservedSize) override
54 {
55 if (requestedSize > m_BufferSize)
56 {
57 reservedSize = m_BufferSize;
58 }
59 else
60 {
61 reservedSize = requestedSize;
62 }
63
64 return m_Buffer.get();
65 }
66
67 void Commit(unsigned int size) override {}
68
69 const unsigned char* GetReadBuffer(unsigned int& size) override
70 {
71 size = static_cast<unsigned int>(strlen(reinterpret_cast<const char*>(m_Buffer.get())) + 1);
72 return m_Buffer.get();
73 }
74
Matteo Martincigh24e8f922019-09-19 11:57:46 +010075 void Release(unsigned int size) override {}
Ferran Balaguer1b941722019-08-28 16:57:18 +010076
77private:
78 unsigned int m_BufferSize;
79 std::unique_ptr<unsigned char[]> m_Buffer;
80};
81
Matteo Martincigh24e8f922019-09-19 11:57:46 +010082class MockStreamCounterBuffer : public IBufferWrapper
83{
84public:
85 MockStreamCounterBuffer(unsigned int size)
86 : m_Buffer(size, 0)
87 , m_CommittedSize(0)
88 , m_ReadSize(0)
89 {}
90
91 unsigned char* Reserve(unsigned int requestedSize, unsigned int& reservedSize) override
92 {
93 std::unique_lock<std::mutex>(m_Mutex);
94
95 // Get the buffer size and the available size in the buffer past the committed size
96 size_t bufferSize = m_Buffer.size();
97 size_t availableSize = bufferSize - m_CommittedSize;
98
99 // Check whether the buffer needs to be resized
100 if (requestedSize > availableSize)
101 {
102 // Resize the buffer
103 size_t newSize = m_CommittedSize + requestedSize;
104 m_Buffer.resize(newSize, 0);
105 }
106
107 // Set the reserved size
108 reservedSize = requestedSize;
109
110 // Get a pointer to the beginning of the part of buffer available for writing
111 return m_Buffer.data() + m_CommittedSize;
112 }
113
114 void Commit(unsigned int size) override
115 {
116 std::unique_lock<std::mutex>(m_Mutex);
117
118 // Update the committed size
119 m_CommittedSize += size;
120 }
121
122 const unsigned char* GetReadBuffer(unsigned int& size) override
123 {
124 std::unique_lock<std::mutex>(m_Mutex);
125
126 // Get the size available for reading
127 size = boost::numeric_cast<unsigned int>(m_CommittedSize - m_ReadSize);
128
129 // Get a pointer to the beginning of the part of buffer available for reading
130 const unsigned char* readBuffer = m_Buffer.data() + m_ReadSize;
131
132 // Update the read size
133 m_ReadSize = m_CommittedSize;
134
135 return readBuffer;
136 }
137
138 void Release(unsigned int size) override
139 {
140 std::unique_lock<std::mutex>(m_Mutex);
141
142 if (size == 0)
143 {
144 // Nothing to release
145 return;
146 }
147
148 // Get the buffer size
149 size_t bufferSize = m_Buffer.size();
150
151 // Remove the last "size" bytes from the buffer
152 if (size < bufferSize)
153 {
154 // Resize the buffer
155 size_t newSize = bufferSize - size;
156 m_Buffer.resize(newSize);
157 }
158 else
159 {
160 // Clear the whole buffer
161 m_Buffer.clear();
162 }
163 }
164
165 size_t GetBufferSize() const { return m_Buffer.size(); }
166 size_t GetCommittedSize() const { return m_CommittedSize; }
167 size_t GetReadSize() const { return m_ReadSize; }
168 const unsigned char* GetBuffer() const { return m_Buffer.data(); }
169
170private:
171 // This mock uses an ever-expanding vector to simulate a counter stream buffer
172 std::vector<unsigned char> m_Buffer;
173
174 // The size of the buffer that has been committed for reading
175 size_t m_CommittedSize;
176
177 // The size of the buffer that has already been read
178 size_t m_ReadSize;
179
180 // This mock buffer provides basic synchronization
181 std::mutex m_Mutex;
182};
183
Ferran Balaguer1b941722019-08-28 16:57:18 +0100184class MockSendCounterPacket : public ISendCounterPacket
185{
186public:
187 MockSendCounterPacket(IBufferWrapper& sendBuffer) : m_Buffer(sendBuffer) {}
188
189 void SendStreamMetaDataPacket() override
190 {
191 std::string message("SendStreamMetaDataPacket");
192 unsigned int reserved = 0;
193 unsigned char* buffer = m_Buffer.Reserve(1024, reserved);
194 memcpy(buffer, message.c_str(), static_cast<unsigned int>(message.size()) + 1);
195 }
196
Matteo Martincigh42f9d9e2019-09-05 12:02:04 +0100197 void SendCounterDirectoryPacket(const ICounterDirectory& counterDirectory) override
Ferran Balaguer1b941722019-08-28 16:57:18 +0100198 {
199 std::string message("SendCounterDirectoryPacket");
200 unsigned int reserved = 0;
201 unsigned char* buffer = m_Buffer.Reserve(1024, reserved);
202 memcpy(buffer, message.c_str(), static_cast<unsigned int>(message.size()) + 1);
203 }
204
205 void SendPeriodicCounterCapturePacket(uint64_t timestamp,
206 const std::vector<std::pair<uint16_t, uint32_t>>& values) override
207 {
208 std::string message("SendPeriodicCounterCapturePacket");
209 unsigned int reserved = 0;
210 unsigned char* buffer = m_Buffer.Reserve(1024, reserved);
211 memcpy(buffer, message.c_str(), static_cast<unsigned int>(message.size()) + 1);
212 }
213
214 void SendPeriodicCounterSelectionPacket(uint32_t capturePeriod,
215 const std::vector<uint16_t>& selectedCounterIds) override
216 {
217 std::string message("SendPeriodicCounterSelectionPacket");
218 unsigned int reserved = 0;
219 unsigned char* buffer = m_Buffer.Reserve(1024, reserved);
220 memcpy(buffer, message.c_str(), static_cast<unsigned int>(message.size()) + 1);
221 m_Buffer.Commit(reserved);
222 }
223
Matteo Martincigh24e8f922019-09-19 11:57:46 +0100224 void SetReadyToRead() override {}
Ferran Balaguer1b941722019-08-28 16:57:18 +0100225
226private:
227 IBufferWrapper& m_Buffer;
228};
Matteo Martincigh42f9d9e2019-09-05 12:02:04 +0100229
230class MockCounterDirectory : public ICounterDirectory
231{
232public:
233 MockCounterDirectory() = default;
234 ~MockCounterDirectory() = default;
235
236 // Register profiling objects
237 const Category* RegisterCategory(const std::string& categoryName,
238 const armnn::Optional<uint16_t>& deviceUid = armnn::EmptyOptional(),
239 const armnn::Optional<uint16_t>& counterSetUid = armnn::EmptyOptional())
240 {
241 // Get the device UID
242 uint16_t deviceUidValue = deviceUid.has_value() ? deviceUid.value() : 0;
243
244 // Get the counter set UID
245 uint16_t counterSetUidValue = counterSetUid.has_value() ? counterSetUid.value() : 0;
246
247 // Create the category
248 CategoryPtr category = std::make_unique<Category>(categoryName, deviceUidValue, counterSetUidValue);
249 BOOST_ASSERT(category);
250
251 // Get the raw category pointer
252 const Category* categoryPtr = category.get();
253 BOOST_ASSERT(categoryPtr);
254
255 // Register the category
256 m_Categories.insert(std::move(category));
257
258 return categoryPtr;
259 }
260
261 const Device* RegisterDevice(const std::string& deviceName,
262 uint16_t cores = 0,
263 const armnn::Optional<std::string>& parentCategoryName = armnn::EmptyOptional())
264 {
265 // Get the device UID
266 uint16_t deviceUid = GetNextUid();
267
268 // Create the device
269 DevicePtr device = std::make_unique<Device>(deviceUid, deviceName, cores);
270 BOOST_ASSERT(device);
271
272 // Get the raw device pointer
273 const Device* devicePtr = device.get();
274 BOOST_ASSERT(devicePtr);
275
276 // Register the device
277 m_Devices.insert(std::make_pair(deviceUid, std::move(device)));
278
279 // Connect the counter set to the parent category, if required
280 if (parentCategoryName.has_value())
281 {
282 // Set the counter set UID in the parent category
283 Category* parentCategory = const_cast<Category*>(GetCategory(parentCategoryName.value()));
284 BOOST_ASSERT(parentCategory);
285 parentCategory->m_DeviceUid = deviceUid;
286 }
287
288 return devicePtr;
289 }
290
291 const CounterSet* RegisterCounterSet(
292 const std::string& counterSetName,
293 uint16_t count = 0,
294 const armnn::Optional<std::string>& parentCategoryName = armnn::EmptyOptional())
295 {
296 // Get the counter set UID
297 uint16_t counterSetUid = GetNextUid();
298
299 // Create the counter set
300 CounterSetPtr counterSet = std::make_unique<CounterSet>(counterSetUid, counterSetName, count);
301 BOOST_ASSERT(counterSet);
302
303 // Get the raw counter set pointer
304 const CounterSet* counterSetPtr = counterSet.get();
305 BOOST_ASSERT(counterSetPtr);
306
307 // Register the counter set
308 m_CounterSets.insert(std::make_pair(counterSetUid, std::move(counterSet)));
309
310 // Connect the counter set to the parent category, if required
311 if (parentCategoryName.has_value())
312 {
313 // Set the counter set UID in the parent category
314 Category* parentCategory = const_cast<Category*>(GetCategory(parentCategoryName.value()));
315 BOOST_ASSERT(parentCategory);
316 parentCategory->m_CounterSetUid = counterSetUid;
317 }
318
319 return counterSetPtr;
320 }
321
322 const Counter* RegisterCounter(const std::string& parentCategoryName,
323 uint16_t counterClass,
324 uint16_t interpolation,
325 double multiplier,
326 const std::string& name,
327 const std::string& description,
328 const armnn::Optional<std::string>& units = armnn::EmptyOptional(),
329 const armnn::Optional<uint16_t>& numberOfCores = armnn::EmptyOptional(),
330 const armnn::Optional<uint16_t>& deviceUid = armnn::EmptyOptional(),
331 const armnn::Optional<uint16_t>& counterSetUid = armnn::EmptyOptional())
332 {
333 // Get the number of cores from the argument only
334 uint16_t deviceCores = numberOfCores.has_value() ? numberOfCores.value() : 0;
335
336 // Get the device UID
337 uint16_t deviceUidValue = deviceUid.has_value() ? deviceUid.value() : 0;
338
339 // Get the counter set UID
340 uint16_t counterSetUidValue = counterSetUid.has_value() ? counterSetUid.value() : 0;
341
342 // Get the counter UIDs and calculate the max counter UID
343 std::vector<uint16_t> counterUids = GetNextCounterUids(deviceCores);
344 BOOST_ASSERT(!counterUids.empty());
345 uint16_t maxCounterUid = deviceCores <= 1 ? counterUids.front() : counterUids.back();
346
347 // Get the counter units
348 const std::string unitsValue = units.has_value() ? units.value() : "";
349
350 // Create the counter
351 CounterPtr counter = std::make_shared<Counter>(counterUids.front(),
352 maxCounterUid,
353 counterClass,
354 interpolation,
355 multiplier,
356 name,
357 description,
358 unitsValue,
359 deviceUidValue,
360 counterSetUidValue);
361 BOOST_ASSERT(counter);
362
363 // Get the raw counter pointer
364 const Counter* counterPtr = counter.get();
365 BOOST_ASSERT(counterPtr);
366
367 // Process multiple counters if necessary
368 for (uint16_t counterUid : counterUids)
369 {
370 // Connect the counter to the parent category
371 Category* parentCategory = const_cast<Category*>(GetCategory(parentCategoryName));
372 BOOST_ASSERT(parentCategory);
373 parentCategory->m_Counters.push_back(counterUid);
374
375 // Register the counter
376 m_Counters.insert(std::make_pair(counterUid, counter));
377 }
378
379 return counterPtr;
380 }
381
382 // Getters for counts
383 uint16_t GetCategoryCount() const override { return boost::numeric_cast<uint16_t>(m_Categories.size()); }
384 uint16_t GetDeviceCount() const override { return boost::numeric_cast<uint16_t>(m_Devices.size()); }
385 uint16_t GetCounterSetCount() const override { return boost::numeric_cast<uint16_t>(m_CounterSets.size()); }
386 uint16_t GetCounterCount() const override { return boost::numeric_cast<uint16_t>(m_Counters.size()); }
387
388 // Getters for collections
389 const Categories& GetCategories() const override { return m_Categories; }
390 const Devices& GetDevices() const override { return m_Devices; }
391 const CounterSets& GetCounterSets() const override { return m_CounterSets; }
392 const Counters& GetCounters() const override { return m_Counters; }
393
394 // Getters for profiling objects
395 const Category* GetCategory(const std::string& name) const override
396 {
397 auto it = std::find_if(m_Categories.begin(), m_Categories.end(), [&name](const CategoryPtr& category)
398 {
399 BOOST_ASSERT(category);
400
401 return category->m_Name == name;
402 });
403
404 if (it == m_Categories.end())
405 {
406 return nullptr;
407 }
408
409 return it->get();
410 }
411
412 const Device* GetDevice(uint16_t uid) const override
413 {
414 return nullptr; // Not used by the unit tests
415 }
416
417 const CounterSet* GetCounterSet(uint16_t uid) const override
418 {
419 return nullptr; // Not used by the unit tests
420 }
421
422 const Counter* GetCounter(uint16_t uid) const override
423 {
424 return nullptr; // Not used by the unit tests
425 }
426
427private:
428 Categories m_Categories;
429 Devices m_Devices;
430 CounterSets m_CounterSets;
431 Counters m_Counters;
432};
433
434class SendCounterPacketTest : public SendCounterPacket
435{
436public:
Matteo Martincigh24e8f922019-09-19 11:57:46 +0100437 SendCounterPacketTest(IProfilingConnection& profilingconnection, IBufferWrapper& buffer)
438 : SendCounterPacket(profilingconnection, buffer)
Matteo Martincigh42f9d9e2019-09-05 12:02:04 +0100439 {}
440
441 bool CreateDeviceRecordTest(const DevicePtr& device,
442 DeviceRecord& deviceRecord,
443 std::string& errorMessage)
444 {
445 return CreateDeviceRecord(device, deviceRecord, errorMessage);
446 }
447
448 bool CreateCounterSetRecordTest(const CounterSetPtr& counterSet,
449 CounterSetRecord& counterSetRecord,
450 std::string& errorMessage)
451 {
452 return CreateCounterSetRecord(counterSet, counterSetRecord, errorMessage);
453 }
454
455 bool CreateEventRecordTest(const CounterPtr& counter,
456 EventRecord& eventRecord,
457 std::string& errorMessage)
458 {
459 return CreateEventRecord(counter, eventRecord, errorMessage);
460 }
461
462 bool CreateCategoryRecordTest(const CategoryPtr& category,
463 const Counters& counters,
464 CategoryRecord& categoryRecord,
465 std::string& errorMessage)
466 {
467 return CreateCategoryRecord(category, counters, categoryRecord, errorMessage);
468 }
469};
Matteo Martincigh24e8f922019-09-19 11:57:46 +0100470
471} // namespace profiling
472
473} // namespace armnn