blob: 9d1321345aed5708985a87cf8906880124463489 [file] [log] [blame]
Jim Flynn64063552020-02-14 10:18:08 +00001//
2// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <IProfilingConnectionFactory.hpp>
9#include <IProfilingService.hpp>
10#include <ProfilingGuidGenerator.hpp>
11#include <ProfilingUtils.hpp>
12#include <SendCounterPacket.hpp>
13#include <SendThread.hpp>
14
15#include <armnn/Exceptions.hpp>
16#include <armnn/Optional.hpp>
17#include <armnn/Conversion.hpp>
18
19#include <boost/assert.hpp>
20#include <boost/core/ignore_unused.hpp>
21#include <boost/numeric/conversion/cast.hpp>
22
23#include <atomic>
24#include <condition_variable>
25#include <mutex>
26#include <thread>
27
28namespace armnn
29{
30
31namespace profiling
32{
33
34class MockProfilingConnection : public IProfilingConnection
35{
36public:
37 MockProfilingConnection()
38 : m_IsOpen(true)
39 , m_WrittenData()
40 , m_Packet()
41 {}
42
43 enum class PacketType
44 {
45 StreamMetaData,
46 ConnectionAcknowledge,
47 CounterDirectory,
48 ReqCounterDirectory,
49 PeriodicCounterSelection,
50 PerJobCounterSelection,
51 TimelineMessageDirectory,
52 PeriodicCounterCapture,
53 Unknown
54 };
55
56 bool IsOpen() const override
57 {
58 std::lock_guard<std::mutex> lock(m_Mutex);
59
60 return m_IsOpen;
61 }
62
63 void Close() override
64 {
65 std::lock_guard<std::mutex> lock(m_Mutex);
66
67 m_IsOpen = false;
68 }
69
70 bool WritePacket(const unsigned char* buffer, uint32_t length) override
71 {
72 if (buffer == nullptr || length == 0)
73 {
74 return false;
75 }
76
77 uint32_t header = ReadUint32(buffer, 0);
78
79 uint32_t packetFamily = (header >> 26);
80 uint32_t packetId = ((header >> 16) & 1023);
81
82 PacketType packetType;
83
84 switch (packetFamily)
85 {
86 case 0:
87 packetType = packetId < 6 ? PacketType(packetId) : PacketType::Unknown;
88 break;
89 case 1:
90 packetType = packetId == 0 ? PacketType::TimelineMessageDirectory : PacketType::Unknown;
91 break;
92 case 3:
93 packetType = packetId == 0 ? PacketType::PeriodicCounterCapture : PacketType::Unknown;
94 break;
95 default:
96 packetType = PacketType::Unknown;
97 }
98
99 std::lock_guard<std::mutex> lock(m_Mutex);
100
101 m_WrittenData.push_back({ packetType, length });
102 return true;
103 }
104
105 long CheckForPacket(const std::pair<PacketType, uint32_t> packetInfo)
106 {
107 std::lock_guard<std::mutex> lock(m_Mutex);
108
109 if(packetInfo.second != 0)
110 {
111 return std::count(m_WrittenData.begin(), m_WrittenData.end(), packetInfo);
112 }
113 else
114 {
115 return std::count_if(m_WrittenData.begin(), m_WrittenData.end(),
116 [&packetInfo](const std::pair<PacketType, uint32_t> pair) { return packetInfo.first == pair.first; });
117 }
118 }
119
120 bool WritePacket(Packet&& packet)
121 {
122 std::lock_guard<std::mutex> lock(m_Mutex);
123
124 m_Packet = std::move(packet);
125 return true;
126 }
127
128 Packet ReadPacket(uint32_t timeout) override
129 {
130 boost::ignore_unused(timeout);
131
132 // Simulate a delay in the reading process. The default timeout is way too long.
133 std::this_thread::sleep_for(std::chrono::milliseconds(5));
134 std::lock_guard<std::mutex> lock(m_Mutex);
135 return std::move(m_Packet);
136 }
137
138 unsigned long GetWrittenDataSize()
139 {
140 std::lock_guard<std::mutex> lock(m_Mutex);
141
142 return m_WrittenData.size();
143 }
144
145 void Clear()
146 {
147 std::lock_guard<std::mutex> lock(m_Mutex);
148
149 m_WrittenData.clear();
150 }
151
152private:
153 bool m_IsOpen;
154 std::vector<std::pair<PacketType, uint32_t>> m_WrittenData;
155 Packet m_Packet;
156 mutable std::mutex m_Mutex;
157};
158
159class MockProfilingConnectionFactory : public IProfilingConnectionFactory
160{
161public:
162 IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override
163 {
164 boost::ignore_unused(options);
165 return std::make_unique<MockProfilingConnection>();
166 }
167};
168
169class MockPacketBuffer : public IPacketBuffer
170{
171public:
172 MockPacketBuffer(unsigned int maxSize)
173 : m_MaxSize(maxSize)
174 , m_Size(0)
175 , m_Data(std::make_unique<unsigned char[]>(m_MaxSize))
176 {}
177
178 ~MockPacketBuffer() {}
179
180 const unsigned char* GetReadableData() const override { return m_Data.get(); }
181
182 unsigned int GetSize() const override { return m_Size; }
183
184 void MarkRead() override { m_Size = 0; }
185
186 void Commit(unsigned int size) override { m_Size = size; }
187
188 void Release() override { m_Size = 0; }
189
190 unsigned char* GetWritableData() override { return m_Data.get(); }
191
192private:
193 unsigned int m_MaxSize;
194 unsigned int m_Size;
195 std::unique_ptr<unsigned char[]> m_Data;
196};
197
198class MockBufferManager : public IBufferManager
199{
200public:
201 MockBufferManager(unsigned int size)
202 : m_BufferSize(size),
203 m_Buffer(std::make_unique<MockPacketBuffer>(size)) {}
204
205 ~MockBufferManager() {}
206
207 IPacketBufferPtr Reserve(unsigned int requestedSize, unsigned int& reservedSize) override
208 {
209 if (requestedSize > m_BufferSize)
210 {
211 reservedSize = m_BufferSize;
212 }
213 else
214 {
215 reservedSize = requestedSize;
216 }
217
218 return std::move(m_Buffer);
219 }
220
221 void Commit(IPacketBufferPtr& packetBuffer, unsigned int size, bool notifyConsumer = true) override
222 {
223 packetBuffer->Commit(size);
224 m_Buffer = std::move(packetBuffer);
225
226 if (notifyConsumer)
227 {
228 FlushReadList();
229 }
230 }
231
232 IPacketBufferPtr GetReadableBuffer() override
233 {
234 return std::move(m_Buffer);
235 }
236
237 void Release(IPacketBufferPtr& packetBuffer) override
238 {
239 packetBuffer->Release();
240 m_Buffer = std::move(packetBuffer);
241 }
242
243 void MarkRead(IPacketBufferPtr& packetBuffer) override
244 {
245 packetBuffer->MarkRead();
246 m_Buffer = std::move(packetBuffer);
247 }
248
249 void SetConsumer(IConsumer* consumer) override
250 {
251 if (consumer != nullptr)
252 {
253 m_Consumer = consumer;
254 }
255 }
256
257 void FlushReadList() override
258 {
259 // notify consumer that packet is ready to read
260 if (m_Consumer != nullptr)
261 {
262 m_Consumer->SetReadyToRead();
263 }
264 }
265
266private:
267 unsigned int m_BufferSize;
268 IPacketBufferPtr m_Buffer;
269 IConsumer* m_Consumer = nullptr;
270};
271
272class MockStreamCounterBuffer : public IBufferManager
273{
274public:
275 MockStreamCounterBuffer(unsigned int maxBufferSize = 4096)
276 : m_MaxBufferSize(maxBufferSize)
277 , m_BufferList()
278 , m_CommittedSize(0)
279 , m_ReadableSize(0)
280 , m_ReadSize(0)
281 {}
282 ~MockStreamCounterBuffer() {}
283
284 IPacketBufferPtr Reserve(unsigned int requestedSize, unsigned int& reservedSize) override
285 {
286 std::lock_guard<std::mutex> lock(m_Mutex);
287
288 reservedSize = 0;
289 if (requestedSize > m_MaxBufferSize)
290 {
291 throw armnn::InvalidArgumentException("The maximum buffer size that can be requested is [" +
292 std::to_string(m_MaxBufferSize) + "] bytes");
293 }
294 reservedSize = requestedSize;
295 return std::make_unique<MockPacketBuffer>(requestedSize);
296 }
297
298 void Commit(IPacketBufferPtr& packetBuffer, unsigned int size, bool notifyConsumer = true) override
299 {
300 std::lock_guard<std::mutex> lock(m_Mutex);
301
302 packetBuffer->Commit(size);
303 m_BufferList.push_back(std::move(packetBuffer));
304 m_CommittedSize += size;
305
306 if (notifyConsumer)
307 {
308 FlushReadList();
309 }
310 }
311
312 void Release(IPacketBufferPtr& packetBuffer) override
313 {
314 std::lock_guard<std::mutex> lock(m_Mutex);
315
316 packetBuffer->Release();
317 }
318
319 IPacketBufferPtr GetReadableBuffer() override
320 {
321 std::lock_guard<std::mutex> lock(m_Mutex);
322
323 if (m_BufferList.empty())
324 {
325 return nullptr;
326 }
327 IPacketBufferPtr buffer = std::move(m_BufferList.back());
328 m_BufferList.pop_back();
329 m_ReadableSize += buffer->GetSize();
330 return buffer;
331 }
332
333 void MarkRead(IPacketBufferPtr& packetBuffer) override
334 {
335 std::lock_guard<std::mutex> lock(m_Mutex);
336
337 m_ReadSize += packetBuffer->GetSize();
338 packetBuffer->MarkRead();
339 }
340
341 void SetConsumer(IConsumer* consumer) override
342 {
343 if (consumer != nullptr)
344 {
345 m_Consumer = consumer;
346 }
347 }
348
349 void FlushReadList() override
350 {
351 // notify consumer that packet is ready to read
352 if (m_Consumer != nullptr)
353 {
354 m_Consumer->SetReadyToRead();
355 }
356 }
357
358 unsigned int GetCommittedSize() const { return m_CommittedSize; }
359 unsigned int GetReadableSize() const { return m_ReadableSize; }
360 unsigned int GetReadSize() const { return m_ReadSize; }
361
362private:
363 // The maximum buffer size when creating a new buffer
364 unsigned int m_MaxBufferSize;
365
366 // A list of buffers
367 std::vector<IPacketBufferPtr> m_BufferList;
368
369 // The mutex to synchronize this mock's methods
370 std::mutex m_Mutex;
371
372 // The total size of the buffers that has been committed for reading
373 unsigned int m_CommittedSize;
374
375 // The total size of the buffers that can be read
376 unsigned int m_ReadableSize;
377
378 // The total size of the buffers that has already been read
379 unsigned int m_ReadSize;
380
381 // Consumer thread to notify packet is ready to read
382 IConsumer* m_Consumer = nullptr;
383};
384
385class MockSendCounterPacket : public ISendCounterPacket
386{
387public:
388 MockSendCounterPacket(IBufferManager& sendBuffer) : m_BufferManager(sendBuffer) {}
389
390 void SendStreamMetaDataPacket() override
391 {
392 std::string message("SendStreamMetaDataPacket");
393 unsigned int reserved = 0;
394 IPacketBufferPtr buffer = m_BufferManager.Reserve(1024, reserved);
395 memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
396 m_BufferManager.Commit(buffer, reserved, false);
397 }
398
399 void SendCounterDirectoryPacket(const ICounterDirectory& counterDirectory) override
400 {
401 boost::ignore_unused(counterDirectory);
402
403 std::string message("SendCounterDirectoryPacket");
404 unsigned int reserved = 0;
405 IPacketBufferPtr buffer = m_BufferManager.Reserve(1024, reserved);
406 memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
407 m_BufferManager.Commit(buffer, reserved);
408 }
409
410 void SendPeriodicCounterCapturePacket(uint64_t timestamp,
411 const std::vector<CounterValue>& values) override
412 {
413 boost::ignore_unused(timestamp, values);
414
415 std::string message("SendPeriodicCounterCapturePacket");
416 unsigned int reserved = 0;
417 IPacketBufferPtr buffer = m_BufferManager.Reserve(1024, reserved);
418 memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
419 m_BufferManager.Commit(buffer, reserved);
420 }
421
422 void SendPeriodicCounterSelectionPacket(uint32_t capturePeriod,
423 const std::vector<uint16_t>& selectedCounterIds) override
424 {
425 boost::ignore_unused(capturePeriod, selectedCounterIds);
426
427 std::string message("SendPeriodicCounterSelectionPacket");
428 unsigned int reserved = 0;
429 IPacketBufferPtr buffer = m_BufferManager.Reserve(1024, reserved);
430 memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
431 m_BufferManager.Commit(buffer, reserved);
432 }
433
434private:
435 IBufferManager& m_BufferManager;
436};
437
438class MockCounterDirectory : public ICounterDirectory
439{
440public:
441 MockCounterDirectory() = default;
442 ~MockCounterDirectory() = default;
443
444 // Register profiling objects
445 const Category* RegisterCategory(const std::string& categoryName,
446 const armnn::Optional<uint16_t>& deviceUid = armnn::EmptyOptional(),
447 const armnn::Optional<uint16_t>& counterSetUid = armnn::EmptyOptional())
448 {
449 // Get the device UID
450 uint16_t deviceUidValue = deviceUid.has_value() ? deviceUid.value() : 0;
451
452 // Get the counter set UID
453 uint16_t counterSetUidValue = counterSetUid.has_value() ? counterSetUid.value() : 0;
454
455 // Create the category
456 CategoryPtr category = std::make_unique<Category>(categoryName, deviceUidValue, counterSetUidValue);
457 BOOST_ASSERT(category);
458
459 // Get the raw category pointer
460 const Category* categoryPtr = category.get();
461 BOOST_ASSERT(categoryPtr);
462
463 // Register the category
464 m_Categories.insert(std::move(category));
465
466 return categoryPtr;
467 }
468
469 const Device* RegisterDevice(const std::string& deviceName,
470 uint16_t cores = 0,
471 const armnn::Optional<std::string>& parentCategoryName = armnn::EmptyOptional())
472 {
473 // Get the device UID
474 uint16_t deviceUid = GetNextUid();
475
476 // Create the device
477 DevicePtr device = std::make_unique<Device>(deviceUid, deviceName, cores);
478 BOOST_ASSERT(device);
479
480 // Get the raw device pointer
481 const Device* devicePtr = device.get();
482 BOOST_ASSERT(devicePtr);
483
484 // Register the device
485 m_Devices.insert(std::make_pair(deviceUid, std::move(device)));
486
487 // Connect the counter set to the parent category, if required
488 if (parentCategoryName.has_value())
489 {
490 // Set the counter set UID in the parent category
491 Category* parentCategory = const_cast<Category*>(GetCategory(parentCategoryName.value()));
492 BOOST_ASSERT(parentCategory);
493 parentCategory->m_DeviceUid = deviceUid;
494 }
495
496 return devicePtr;
497 }
498
499 const CounterSet* RegisterCounterSet(
500 const std::string& counterSetName,
501 uint16_t count = 0,
502 const armnn::Optional<std::string>& parentCategoryName = armnn::EmptyOptional())
503 {
504 // Get the counter set UID
505 uint16_t counterSetUid = GetNextUid();
506
507 // Create the counter set
508 CounterSetPtr counterSet = std::make_unique<CounterSet>(counterSetUid, counterSetName, count);
509 BOOST_ASSERT(counterSet);
510
511 // Get the raw counter set pointer
512 const CounterSet* counterSetPtr = counterSet.get();
513 BOOST_ASSERT(counterSetPtr);
514
515 // Register the counter set
516 m_CounterSets.insert(std::make_pair(counterSetUid, std::move(counterSet)));
517
518 // Connect the counter set to the parent category, if required
519 if (parentCategoryName.has_value())
520 {
521 // Set the counter set UID in the parent category
522 Category* parentCategory = const_cast<Category*>(GetCategory(parentCategoryName.value()));
523 BOOST_ASSERT(parentCategory);
524 parentCategory->m_CounterSetUid = counterSetUid;
525 }
526
527 return counterSetPtr;
528 }
529
530 const Counter* RegisterCounter(const BackendId& backendId,
531 const uint16_t uid,
532 const std::string& parentCategoryName,
533 uint16_t counterClass,
534 uint16_t interpolation,
535 double multiplier,
536 const std::string& name,
537 const std::string& description,
538 const armnn::Optional<std::string>& units = armnn::EmptyOptional(),
539 const armnn::Optional<uint16_t>& numberOfCores = armnn::EmptyOptional(),
540 const armnn::Optional<uint16_t>& deviceUid = armnn::EmptyOptional(),
541 const armnn::Optional<uint16_t>& counterSetUid = armnn::EmptyOptional())
542 {
543 boost::ignore_unused(backendId);
544
545 // Get the number of cores from the argument only
546 uint16_t deviceCores = numberOfCores.has_value() ? numberOfCores.value() : 0;
547
548 // Get the device UID
549 uint16_t deviceUidValue = deviceUid.has_value() ? deviceUid.value() : 0;
550
551 // Get the counter set UID
552 uint16_t counterSetUidValue = counterSetUid.has_value() ? counterSetUid.value() : 0;
553
554 // Get the counter UIDs and calculate the max counter UID
555 std::vector<uint16_t> counterUids = GetNextCounterUids(uid, deviceCores);
556 BOOST_ASSERT(!counterUids.empty());
557 uint16_t maxCounterUid = deviceCores <= 1 ? counterUids.front() : counterUids.back();
558
559 // Get the counter units
560 const std::string unitsValue = units.has_value() ? units.value() : "";
561
562 // Create the counter
563 CounterPtr counter = std::make_shared<Counter>(armnn::profiling::BACKEND_ID,
564 counterUids.front(),
565 maxCounterUid,
566 counterClass,
567 interpolation,
568 multiplier,
569 name,
570 description,
571 unitsValue,
572 deviceUidValue,
573 counterSetUidValue);
574 BOOST_ASSERT(counter);
575
576 // Get the raw counter pointer
577 const Counter* counterPtr = counter.get();
578 BOOST_ASSERT(counterPtr);
579
580 // Process multiple counters if necessary
581 for (uint16_t counterUid : counterUids)
582 {
583 // Connect the counter to the parent category
584 Category* parentCategory = const_cast<Category*>(GetCategory(parentCategoryName));
585 BOOST_ASSERT(parentCategory);
586 parentCategory->m_Counters.push_back(counterUid);
587
588 // Register the counter
589 m_Counters.insert(std::make_pair(counterUid, counter));
590 }
591
592 return counterPtr;
593 }
594
595 // Getters for counts
596 uint16_t GetCategoryCount() const override { return boost::numeric_cast<uint16_t>(m_Categories.size()); }
597 uint16_t GetDeviceCount() const override { return boost::numeric_cast<uint16_t>(m_Devices.size()); }
598 uint16_t GetCounterSetCount() const override { return boost::numeric_cast<uint16_t>(m_CounterSets.size()); }
599 uint16_t GetCounterCount() const override { return boost::numeric_cast<uint16_t>(m_Counters.size()); }
600
601 // Getters for collections
602 const Categories& GetCategories() const override { return m_Categories; }
603 const Devices& GetDevices() const override { return m_Devices; }
604 const CounterSets& GetCounterSets() const override { return m_CounterSets; }
605 const Counters& GetCounters() const override { return m_Counters; }
606
607 // Getters for profiling objects
608 const Category* GetCategory(const std::string& name) const override
609 {
610 auto it = std::find_if(m_Categories.begin(), m_Categories.end(), [&name](const CategoryPtr& category)
611 {
612 BOOST_ASSERT(category);
613
614 return category->m_Name == name;
615 });
616
617 if (it == m_Categories.end())
618 {
619 return nullptr;
620 }
621
622 return it->get();
623 }
624
625 const Device* GetDevice(uint16_t uid) const override
626 {
627 boost::ignore_unused(uid);
628 return nullptr; // Not used by the unit tests
629 }
630
631 const CounterSet* GetCounterSet(uint16_t uid) const override
632 {
633 boost::ignore_unused(uid);
634 return nullptr; // Not used by the unit tests
635 }
636
637 const Counter* GetCounter(uint16_t uid) const override
638 {
639 boost::ignore_unused(uid);
640 return nullptr; // Not used by the unit tests
641 }
642
643private:
644 Categories m_Categories;
645 Devices m_Devices;
646 CounterSets m_CounterSets;
647 Counters m_Counters;
648};
649
650} // namespace profiling
651
652} // namespace armnn