blob: eeb641e878c7e9cb058dfac6ef23068e9d08f549 [file] [log] [blame]
Jim Flynn64063552020-02-14 10:18:08 +00001//
2// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Sadik Armagancab588a2020-02-17 11:33:31 +00008#include <Holder.hpp>
Jim Flynn64063552020-02-14 10:18:08 +00009#include <IProfilingConnectionFactory.hpp>
Sadik Armagan3184c902020-03-18 10:57:30 +000010#include <ProfilingService.hpp>
Jim Flynn64063552020-02-14 10:18:08 +000011#include <ProfilingGuidGenerator.hpp>
12#include <ProfilingUtils.hpp>
13#include <SendCounterPacket.hpp>
14#include <SendThread.hpp>
15
16#include <armnn/Exceptions.hpp>
17#include <armnn/Optional.hpp>
18#include <armnn/Conversion.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000019#include <armnn/utility/IgnoreUnused.hpp>
Jim Flynn64063552020-02-14 10:18:08 +000020
21#include <boost/assert.hpp>
Jim Flynn64063552020-02-14 10:18:08 +000022#include <boost/numeric/conversion/cast.hpp>
23
24#include <atomic>
25#include <condition_variable>
26#include <mutex>
27#include <thread>
28
29namespace armnn
30{
31
32namespace profiling
33{
34
35class MockProfilingConnection : public IProfilingConnection
36{
37public:
38 MockProfilingConnection()
39 : m_IsOpen(true)
40 , m_WrittenData()
41 , m_Packet()
42 {}
43
44 enum class PacketType
45 {
46 StreamMetaData,
47 ConnectionAcknowledge,
48 CounterDirectory,
49 ReqCounterDirectory,
50 PeriodicCounterSelection,
51 PerJobCounterSelection,
52 TimelineMessageDirectory,
53 PeriodicCounterCapture,
54 Unknown
55 };
56
57 bool IsOpen() const override
58 {
59 std::lock_guard<std::mutex> lock(m_Mutex);
60
61 return m_IsOpen;
62 }
63
64 void Close() override
65 {
66 std::lock_guard<std::mutex> lock(m_Mutex);
67
68 m_IsOpen = false;
69 }
70
71 bool WritePacket(const unsigned char* buffer, uint32_t length) override
72 {
73 if (buffer == nullptr || length == 0)
74 {
75 return false;
76 }
77
78 uint32_t header = ReadUint32(buffer, 0);
79
80 uint32_t packetFamily = (header >> 26);
81 uint32_t packetId = ((header >> 16) & 1023);
82
83 PacketType packetType;
84
85 switch (packetFamily)
86 {
87 case 0:
88 packetType = packetId < 6 ? PacketType(packetId) : PacketType::Unknown;
89 break;
90 case 1:
91 packetType = packetId == 0 ? PacketType::TimelineMessageDirectory : PacketType::Unknown;
92 break;
93 case 3:
94 packetType = packetId == 0 ? PacketType::PeriodicCounterCapture : PacketType::Unknown;
95 break;
96 default:
97 packetType = PacketType::Unknown;
98 }
99
100 std::lock_guard<std::mutex> lock(m_Mutex);
101
102 m_WrittenData.push_back({ packetType, length });
103 return true;
104 }
105
106 long CheckForPacket(const std::pair<PacketType, uint32_t> packetInfo)
107 {
108 std::lock_guard<std::mutex> lock(m_Mutex);
109
110 if(packetInfo.second != 0)
111 {
112 return std::count(m_WrittenData.begin(), m_WrittenData.end(), packetInfo);
113 }
114 else
115 {
116 return std::count_if(m_WrittenData.begin(), m_WrittenData.end(),
117 [&packetInfo](const std::pair<PacketType, uint32_t> pair) { return packetInfo.first == pair.first; });
118 }
119 }
120
121 bool WritePacket(Packet&& packet)
122 {
123 std::lock_guard<std::mutex> lock(m_Mutex);
124
125 m_Packet = std::move(packet);
126 return true;
127 }
128
129 Packet ReadPacket(uint32_t timeout) override
130 {
Jan Eilers8eb25602020-03-09 12:13:48 +0000131 IgnoreUnused(timeout);
Jim Flynn64063552020-02-14 10:18:08 +0000132
133 // Simulate a delay in the reading process. The default timeout is way too long.
134 std::this_thread::sleep_for(std::chrono::milliseconds(5));
135 std::lock_guard<std::mutex> lock(m_Mutex);
136 return std::move(m_Packet);
137 }
138
139 unsigned long GetWrittenDataSize()
140 {
141 std::lock_guard<std::mutex> lock(m_Mutex);
142
143 return m_WrittenData.size();
144 }
145
146 void Clear()
147 {
148 std::lock_guard<std::mutex> lock(m_Mutex);
149
150 m_WrittenData.clear();
151 }
152
153private:
154 bool m_IsOpen;
155 std::vector<std::pair<PacketType, uint32_t>> m_WrittenData;
156 Packet m_Packet;
157 mutable std::mutex m_Mutex;
158};
159
160class MockProfilingConnectionFactory : public IProfilingConnectionFactory
161{
162public:
163 IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override
164 {
Jan Eilers8eb25602020-03-09 12:13:48 +0000165 IgnoreUnused(options);
Jim Flynn64063552020-02-14 10:18:08 +0000166 return std::make_unique<MockProfilingConnection>();
167 }
168};
169
170class MockPacketBuffer : public IPacketBuffer
171{
172public:
173 MockPacketBuffer(unsigned int maxSize)
174 : m_MaxSize(maxSize)
175 , m_Size(0)
176 , m_Data(std::make_unique<unsigned char[]>(m_MaxSize))
177 {}
178
179 ~MockPacketBuffer() {}
180
181 const unsigned char* GetReadableData() const override { return m_Data.get(); }
182
183 unsigned int GetSize() const override { return m_Size; }
184
185 void MarkRead() override { m_Size = 0; }
186
187 void Commit(unsigned int size) override { m_Size = size; }
188
189 void Release() override { m_Size = 0; }
190
191 unsigned char* GetWritableData() override { return m_Data.get(); }
192
193private:
194 unsigned int m_MaxSize;
195 unsigned int m_Size;
196 std::unique_ptr<unsigned char[]> m_Data;
197};
198
199class MockBufferManager : public IBufferManager
200{
201public:
202 MockBufferManager(unsigned int size)
203 : m_BufferSize(size),
204 m_Buffer(std::make_unique<MockPacketBuffer>(size)) {}
205
206 ~MockBufferManager() {}
207
208 IPacketBufferPtr Reserve(unsigned int requestedSize, unsigned int& reservedSize) override
209 {
210 if (requestedSize > m_BufferSize)
211 {
212 reservedSize = m_BufferSize;
213 }
214 else
215 {
216 reservedSize = requestedSize;
217 }
218
219 return std::move(m_Buffer);
220 }
221
222 void Commit(IPacketBufferPtr& packetBuffer, unsigned int size, bool notifyConsumer = true) override
223 {
224 packetBuffer->Commit(size);
225 m_Buffer = std::move(packetBuffer);
226
227 if (notifyConsumer)
228 {
229 FlushReadList();
230 }
231 }
232
233 IPacketBufferPtr GetReadableBuffer() override
234 {
235 return std::move(m_Buffer);
236 }
237
238 void Release(IPacketBufferPtr& packetBuffer) override
239 {
240 packetBuffer->Release();
241 m_Buffer = std::move(packetBuffer);
242 }
243
244 void MarkRead(IPacketBufferPtr& packetBuffer) override
245 {
246 packetBuffer->MarkRead();
247 m_Buffer = std::move(packetBuffer);
248 }
249
250 void SetConsumer(IConsumer* consumer) override
251 {
252 if (consumer != nullptr)
253 {
254 m_Consumer = consumer;
255 }
256 }
257
258 void FlushReadList() override
259 {
260 // notify consumer that packet is ready to read
261 if (m_Consumer != nullptr)
262 {
263 m_Consumer->SetReadyToRead();
264 }
265 }
266
267private:
268 unsigned int m_BufferSize;
269 IPacketBufferPtr m_Buffer;
270 IConsumer* m_Consumer = nullptr;
271};
272
273class MockStreamCounterBuffer : public IBufferManager
274{
275public:
276 MockStreamCounterBuffer(unsigned int maxBufferSize = 4096)
277 : m_MaxBufferSize(maxBufferSize)
278 , m_BufferList()
279 , m_CommittedSize(0)
280 , m_ReadableSize(0)
281 , m_ReadSize(0)
282 {}
283 ~MockStreamCounterBuffer() {}
284
285 IPacketBufferPtr Reserve(unsigned int requestedSize, unsigned int& reservedSize) override
286 {
287 std::lock_guard<std::mutex> lock(m_Mutex);
288
289 reservedSize = 0;
290 if (requestedSize > m_MaxBufferSize)
291 {
292 throw armnn::InvalidArgumentException("The maximum buffer size that can be requested is [" +
293 std::to_string(m_MaxBufferSize) + "] bytes");
294 }
295 reservedSize = requestedSize;
296 return std::make_unique<MockPacketBuffer>(requestedSize);
297 }
298
299 void Commit(IPacketBufferPtr& packetBuffer, unsigned int size, bool notifyConsumer = true) override
300 {
301 std::lock_guard<std::mutex> lock(m_Mutex);
302
303 packetBuffer->Commit(size);
304 m_BufferList.push_back(std::move(packetBuffer));
305 m_CommittedSize += size;
306
307 if (notifyConsumer)
308 {
309 FlushReadList();
310 }
311 }
312
313 void Release(IPacketBufferPtr& packetBuffer) override
314 {
315 std::lock_guard<std::mutex> lock(m_Mutex);
316
317 packetBuffer->Release();
318 }
319
320 IPacketBufferPtr GetReadableBuffer() override
321 {
322 std::lock_guard<std::mutex> lock(m_Mutex);
323
324 if (m_BufferList.empty())
325 {
326 return nullptr;
327 }
328 IPacketBufferPtr buffer = std::move(m_BufferList.back());
329 m_BufferList.pop_back();
330 m_ReadableSize += buffer->GetSize();
331 return buffer;
332 }
333
334 void MarkRead(IPacketBufferPtr& packetBuffer) override
335 {
336 std::lock_guard<std::mutex> lock(m_Mutex);
337
338 m_ReadSize += packetBuffer->GetSize();
339 packetBuffer->MarkRead();
340 }
341
342 void SetConsumer(IConsumer* consumer) override
343 {
344 if (consumer != nullptr)
345 {
346 m_Consumer = consumer;
347 }
348 }
349
350 void FlushReadList() override
351 {
352 // notify consumer that packet is ready to read
353 if (m_Consumer != nullptr)
354 {
355 m_Consumer->SetReadyToRead();
356 }
357 }
358
359 unsigned int GetCommittedSize() const { return m_CommittedSize; }
360 unsigned int GetReadableSize() const { return m_ReadableSize; }
361 unsigned int GetReadSize() const { return m_ReadSize; }
362
363private:
364 // The maximum buffer size when creating a new buffer
365 unsigned int m_MaxBufferSize;
366
367 // A list of buffers
368 std::vector<IPacketBufferPtr> m_BufferList;
369
370 // The mutex to synchronize this mock's methods
371 std::mutex m_Mutex;
372
373 // The total size of the buffers that has been committed for reading
374 unsigned int m_CommittedSize;
375
376 // The total size of the buffers that can be read
377 unsigned int m_ReadableSize;
378
379 // The total size of the buffers that has already been read
380 unsigned int m_ReadSize;
381
382 // Consumer thread to notify packet is ready to read
383 IConsumer* m_Consumer = nullptr;
384};
385
386class MockSendCounterPacket : public ISendCounterPacket
387{
388public:
389 MockSendCounterPacket(IBufferManager& sendBuffer) : m_BufferManager(sendBuffer) {}
390
391 void SendStreamMetaDataPacket() override
392 {
393 std::string message("SendStreamMetaDataPacket");
394 unsigned int reserved = 0;
395 IPacketBufferPtr buffer = m_BufferManager.Reserve(1024, reserved);
396 memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
397 m_BufferManager.Commit(buffer, reserved, false);
398 }
399
400 void SendCounterDirectoryPacket(const ICounterDirectory& counterDirectory) override
401 {
Jan Eilers8eb25602020-03-09 12:13:48 +0000402 IgnoreUnused(counterDirectory);
Jim Flynn64063552020-02-14 10:18:08 +0000403
404 std::string message("SendCounterDirectoryPacket");
405 unsigned int reserved = 0;
406 IPacketBufferPtr buffer = m_BufferManager.Reserve(1024, reserved);
407 memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
408 m_BufferManager.Commit(buffer, reserved);
409 }
410
411 void SendPeriodicCounterCapturePacket(uint64_t timestamp,
412 const std::vector<CounterValue>& values) override
413 {
Jan Eilers8eb25602020-03-09 12:13:48 +0000414 IgnoreUnused(timestamp, values);
Jim Flynn64063552020-02-14 10:18:08 +0000415
416 std::string message("SendPeriodicCounterCapturePacket");
417 unsigned int reserved = 0;
418 IPacketBufferPtr buffer = m_BufferManager.Reserve(1024, reserved);
419 memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
420 m_BufferManager.Commit(buffer, reserved);
421 }
422
423 void SendPeriodicCounterSelectionPacket(uint32_t capturePeriod,
424 const std::vector<uint16_t>& selectedCounterIds) override
425 {
Jan Eilers8eb25602020-03-09 12:13:48 +0000426 IgnoreUnused(capturePeriod, selectedCounterIds);
Jim Flynn64063552020-02-14 10:18:08 +0000427
428 std::string message("SendPeriodicCounterSelectionPacket");
429 unsigned int reserved = 0;
430 IPacketBufferPtr buffer = m_BufferManager.Reserve(1024, reserved);
431 memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
432 m_BufferManager.Commit(buffer, reserved);
433 }
434
435private:
436 IBufferManager& m_BufferManager;
437};
438
439class MockCounterDirectory : public ICounterDirectory
440{
441public:
442 MockCounterDirectory() = default;
443 ~MockCounterDirectory() = default;
444
445 // Register profiling objects
Sadik Armagan4c998992020-02-25 12:44:44 +0000446 const Category* RegisterCategory(const std::string& categoryName)
Jim Flynn64063552020-02-14 10:18:08 +0000447 {
Jim Flynn64063552020-02-14 10:18:08 +0000448 // Create the category
Sadik Armagan4c998992020-02-25 12:44:44 +0000449 CategoryPtr category = std::make_unique<Category>(categoryName);
Jim Flynn64063552020-02-14 10:18:08 +0000450 BOOST_ASSERT(category);
451
452 // Get the raw category pointer
453 const Category* categoryPtr = category.get();
454 BOOST_ASSERT(categoryPtr);
455
456 // Register the category
457 m_Categories.insert(std::move(category));
458
459 return categoryPtr;
460 }
461
462 const Device* RegisterDevice(const std::string& deviceName,
Sadik Armagan4c998992020-02-25 12:44:44 +0000463 uint16_t cores = 0)
Jim Flynn64063552020-02-14 10:18:08 +0000464 {
465 // Get the device UID
466 uint16_t deviceUid = GetNextUid();
467
468 // Create the device
469 DevicePtr device = std::make_unique<Device>(deviceUid, deviceName, cores);
470 BOOST_ASSERT(device);
471
472 // Get the raw device pointer
473 const Device* devicePtr = device.get();
474 BOOST_ASSERT(devicePtr);
475
476 // Register the device
477 m_Devices.insert(std::make_pair(deviceUid, std::move(device)));
478
Jim Flynn64063552020-02-14 10:18:08 +0000479 return devicePtr;
480 }
481
482 const CounterSet* RegisterCounterSet(
483 const std::string& counterSetName,
Sadik Armagan4c998992020-02-25 12:44:44 +0000484 uint16_t count = 0)
Jim Flynn64063552020-02-14 10:18:08 +0000485 {
486 // Get the counter set UID
487 uint16_t counterSetUid = GetNextUid();
488
489 // Create the counter set
490 CounterSetPtr counterSet = std::make_unique<CounterSet>(counterSetUid, counterSetName, count);
491 BOOST_ASSERT(counterSet);
492
493 // Get the raw counter set pointer
494 const CounterSet* counterSetPtr = counterSet.get();
495 BOOST_ASSERT(counterSetPtr);
496
497 // Register the counter set
498 m_CounterSets.insert(std::make_pair(counterSetUid, std::move(counterSet)));
499
Jim Flynn64063552020-02-14 10:18:08 +0000500 return counterSetPtr;
501 }
502
503 const Counter* RegisterCounter(const BackendId& backendId,
504 const uint16_t uid,
505 const std::string& parentCategoryName,
506 uint16_t counterClass,
507 uint16_t interpolation,
508 double multiplier,
509 const std::string& name,
510 const std::string& description,
511 const armnn::Optional<std::string>& units = armnn::EmptyOptional(),
512 const armnn::Optional<uint16_t>& numberOfCores = armnn::EmptyOptional(),
513 const armnn::Optional<uint16_t>& deviceUid = armnn::EmptyOptional(),
514 const armnn::Optional<uint16_t>& counterSetUid = armnn::EmptyOptional())
515 {
Jan Eilers8eb25602020-03-09 12:13:48 +0000516 IgnoreUnused(backendId);
Jim Flynn64063552020-02-14 10:18:08 +0000517
518 // Get the number of cores from the argument only
519 uint16_t deviceCores = numberOfCores.has_value() ? numberOfCores.value() : 0;
520
521 // Get the device UID
522 uint16_t deviceUidValue = deviceUid.has_value() ? deviceUid.value() : 0;
523
524 // Get the counter set UID
525 uint16_t counterSetUidValue = counterSetUid.has_value() ? counterSetUid.value() : 0;
526
527 // Get the counter UIDs and calculate the max counter UID
528 std::vector<uint16_t> counterUids = GetNextCounterUids(uid, deviceCores);
529 BOOST_ASSERT(!counterUids.empty());
530 uint16_t maxCounterUid = deviceCores <= 1 ? counterUids.front() : counterUids.back();
531
532 // Get the counter units
533 const std::string unitsValue = units.has_value() ? units.value() : "";
534
535 // Create the counter
536 CounterPtr counter = std::make_shared<Counter>(armnn::profiling::BACKEND_ID,
537 counterUids.front(),
538 maxCounterUid,
539 counterClass,
540 interpolation,
541 multiplier,
542 name,
543 description,
544 unitsValue,
545 deviceUidValue,
546 counterSetUidValue);
547 BOOST_ASSERT(counter);
548
549 // Get the raw counter pointer
550 const Counter* counterPtr = counter.get();
551 BOOST_ASSERT(counterPtr);
552
553 // Process multiple counters if necessary
554 for (uint16_t counterUid : counterUids)
555 {
556 // Connect the counter to the parent category
557 Category* parentCategory = const_cast<Category*>(GetCategory(parentCategoryName));
558 BOOST_ASSERT(parentCategory);
559 parentCategory->m_Counters.push_back(counterUid);
560
561 // Register the counter
562 m_Counters.insert(std::make_pair(counterUid, counter));
563 }
564
565 return counterPtr;
566 }
567
568 // Getters for counts
569 uint16_t GetCategoryCount() const override { return boost::numeric_cast<uint16_t>(m_Categories.size()); }
570 uint16_t GetDeviceCount() const override { return boost::numeric_cast<uint16_t>(m_Devices.size()); }
571 uint16_t GetCounterSetCount() const override { return boost::numeric_cast<uint16_t>(m_CounterSets.size()); }
572 uint16_t GetCounterCount() const override { return boost::numeric_cast<uint16_t>(m_Counters.size()); }
573
574 // Getters for collections
575 const Categories& GetCategories() const override { return m_Categories; }
576 const Devices& GetDevices() const override { return m_Devices; }
577 const CounterSets& GetCounterSets() const override { return m_CounterSets; }
578 const Counters& GetCounters() const override { return m_Counters; }
579
580 // Getters for profiling objects
581 const Category* GetCategory(const std::string& name) const override
582 {
583 auto it = std::find_if(m_Categories.begin(), m_Categories.end(), [&name](const CategoryPtr& category)
584 {
585 BOOST_ASSERT(category);
586
587 return category->m_Name == name;
588 });
589
590 if (it == m_Categories.end())
591 {
592 return nullptr;
593 }
594
595 return it->get();
596 }
597
598 const Device* GetDevice(uint16_t uid) const override
599 {
Jan Eilers8eb25602020-03-09 12:13:48 +0000600 IgnoreUnused(uid);
Jim Flynn64063552020-02-14 10:18:08 +0000601 return nullptr; // Not used by the unit tests
602 }
603
604 const CounterSet* GetCounterSet(uint16_t uid) const override
605 {
Jan Eilers8eb25602020-03-09 12:13:48 +0000606 IgnoreUnused(uid);
Jim Flynn64063552020-02-14 10:18:08 +0000607 return nullptr; // Not used by the unit tests
608 }
609
610 const Counter* GetCounter(uint16_t uid) const override
611 {
Jan Eilers8eb25602020-03-09 12:13:48 +0000612 IgnoreUnused(uid);
Jim Flynn64063552020-02-14 10:18:08 +0000613 return nullptr; // Not used by the unit tests
614 }
615
616private:
617 Categories m_Categories;
618 Devices m_Devices;
619 CounterSets m_CounterSets;
620 Counters m_Counters;
621};
622
Sadik Armagan3184c902020-03-18 10:57:30 +0000623class MockProfilingService : public ProfilingService
Sadik Armagancab588a2020-02-17 11:33:31 +0000624{
625public:
626 MockProfilingService(MockBufferManager& mockBufferManager,
627 bool isProfilingEnabled,
628 const CaptureData& captureData) :
629 m_SendCounterPacket(mockBufferManager),
630 m_IsProfilingEnabled(isProfilingEnabled),
631 m_CaptureData(captureData) {}
632
633 /// Return the next random Guid in the sequence
634 ProfilingDynamicGuid NextGuid() override
635 {
636 return m_GuidGenerator.NextGuid();
637 }
638
639 /// Create a ProfilingStaticGuid based on a hash of the string
640 ProfilingStaticGuid GenerateStaticId(const std::string& str) override
641 {
642 return m_GuidGenerator.GenerateStaticId(str);
643 }
644
645 std::unique_ptr<ISendTimelinePacket> GetSendTimelinePacket() const override
646 {
647 return nullptr;
648 }
649
650 const ICounterMappings& GetCounterMappings() const override
651 {
652 return m_CounterMapping;
653 }
654
655 ISendCounterPacket& GetSendCounterPacket() override
656 {
657 return m_SendCounterPacket;
658 }
659
660 bool IsProfilingEnabled() const override
661 {
662 return m_IsProfilingEnabled;
663 }
664
665 CaptureData GetCaptureData() override
666 {
667 CaptureData copy(m_CaptureData);
668 return copy;
669 }
670
671 void RegisterMapping(uint16_t globalCounterId,
672 uint16_t backendCounterId,
Sadik Armagan3184c902020-03-18 10:57:30 +0000673 const armnn::BackendId& backendId)
Sadik Armagancab588a2020-02-17 11:33:31 +0000674 {
675 m_CounterMapping.RegisterMapping(globalCounterId, backendCounterId, backendId);
676 }
677
Sadik Armagan3184c902020-03-18 10:57:30 +0000678 void Reset()
Sadik Armagancab588a2020-02-17 11:33:31 +0000679 {
680 m_CounterMapping.Reset();
681 }
682
683private:
684 ProfilingGuidGenerator m_GuidGenerator;
685 CounterIdMap m_CounterMapping;
686 SendCounterPacket m_SendCounterPacket;
687 bool m_IsProfilingEnabled;
688 CaptureData m_CaptureData;
689};
690
Jim Flynn64063552020-02-14 10:18:08 +0000691} // namespace profiling
692
693} // namespace armnn