blob: 3782a0f7e40edfe36684f3d6a5caf3dc4b168eb3 [file] [log] [blame]
Jim Flynn64063552020-02-14 10:18:08 +00001//
2// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Sadik Armagancab588a2020-02-17 11:33:31 +00008#include <Holder.hpp>
Jim Flynn64063552020-02-14 10:18:08 +00009#include <IProfilingConnectionFactory.hpp>
10#include <IProfilingService.hpp>
11#include <ProfilingGuidGenerator.hpp>
12#include <ProfilingUtils.hpp>
13#include <SendCounterPacket.hpp>
14#include <SendThread.hpp>
15
16#include <armnn/Exceptions.hpp>
17#include <armnn/Optional.hpp>
18#include <armnn/Conversion.hpp>
19
20#include <boost/assert.hpp>
21#include <boost/core/ignore_unused.hpp>
22#include <boost/numeric/conversion/cast.hpp>
23
24#include <atomic>
25#include <condition_variable>
26#include <mutex>
27#include <thread>
28
29namespace armnn
30{
31
32namespace profiling
33{
34
35class MockProfilingConnection : public IProfilingConnection
36{
37public:
38 MockProfilingConnection()
39 : m_IsOpen(true)
40 , m_WrittenData()
41 , m_Packet()
42 {}
43
44 enum class PacketType
45 {
46 StreamMetaData,
47 ConnectionAcknowledge,
48 CounterDirectory,
49 ReqCounterDirectory,
50 PeriodicCounterSelection,
51 PerJobCounterSelection,
52 TimelineMessageDirectory,
53 PeriodicCounterCapture,
54 Unknown
55 };
56
57 bool IsOpen() const override
58 {
59 std::lock_guard<std::mutex> lock(m_Mutex);
60
61 return m_IsOpen;
62 }
63
64 void Close() override
65 {
66 std::lock_guard<std::mutex> lock(m_Mutex);
67
68 m_IsOpen = false;
69 }
70
71 bool WritePacket(const unsigned char* buffer, uint32_t length) override
72 {
73 if (buffer == nullptr || length == 0)
74 {
75 return false;
76 }
77
78 uint32_t header = ReadUint32(buffer, 0);
79
80 uint32_t packetFamily = (header >> 26);
81 uint32_t packetId = ((header >> 16) & 1023);
82
83 PacketType packetType;
84
85 switch (packetFamily)
86 {
87 case 0:
88 packetType = packetId < 6 ? PacketType(packetId) : PacketType::Unknown;
89 break;
90 case 1:
91 packetType = packetId == 0 ? PacketType::TimelineMessageDirectory : PacketType::Unknown;
92 break;
93 case 3:
94 packetType = packetId == 0 ? PacketType::PeriodicCounterCapture : PacketType::Unknown;
95 break;
96 default:
97 packetType = PacketType::Unknown;
98 }
99
100 std::lock_guard<std::mutex> lock(m_Mutex);
101
102 m_WrittenData.push_back({ packetType, length });
103 return true;
104 }
105
106 long CheckForPacket(const std::pair<PacketType, uint32_t> packetInfo)
107 {
108 std::lock_guard<std::mutex> lock(m_Mutex);
109
110 if(packetInfo.second != 0)
111 {
112 return std::count(m_WrittenData.begin(), m_WrittenData.end(), packetInfo);
113 }
114 else
115 {
116 return std::count_if(m_WrittenData.begin(), m_WrittenData.end(),
117 [&packetInfo](const std::pair<PacketType, uint32_t> pair) { return packetInfo.first == pair.first; });
118 }
119 }
120
121 bool WritePacket(Packet&& packet)
122 {
123 std::lock_guard<std::mutex> lock(m_Mutex);
124
125 m_Packet = std::move(packet);
126 return true;
127 }
128
129 Packet ReadPacket(uint32_t timeout) override
130 {
131 boost::ignore_unused(timeout);
132
133 // Simulate a delay in the reading process. The default timeout is way too long.
134 std::this_thread::sleep_for(std::chrono::milliseconds(5));
135 std::lock_guard<std::mutex> lock(m_Mutex);
136 return std::move(m_Packet);
137 }
138
139 unsigned long GetWrittenDataSize()
140 {
141 std::lock_guard<std::mutex> lock(m_Mutex);
142
143 return m_WrittenData.size();
144 }
145
146 void Clear()
147 {
148 std::lock_guard<std::mutex> lock(m_Mutex);
149
150 m_WrittenData.clear();
151 }
152
153private:
154 bool m_IsOpen;
155 std::vector<std::pair<PacketType, uint32_t>> m_WrittenData;
156 Packet m_Packet;
157 mutable std::mutex m_Mutex;
158};
159
160class MockProfilingConnectionFactory : public IProfilingConnectionFactory
161{
162public:
163 IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override
164 {
165 boost::ignore_unused(options);
166 return std::make_unique<MockProfilingConnection>();
167 }
168};
169
170class MockPacketBuffer : public IPacketBuffer
171{
172public:
173 MockPacketBuffer(unsigned int maxSize)
174 : m_MaxSize(maxSize)
175 , m_Size(0)
176 , m_Data(std::make_unique<unsigned char[]>(m_MaxSize))
177 {}
178
179 ~MockPacketBuffer() {}
180
181 const unsigned char* GetReadableData() const override { return m_Data.get(); }
182
183 unsigned int GetSize() const override { return m_Size; }
184
185 void MarkRead() override { m_Size = 0; }
186
187 void Commit(unsigned int size) override { m_Size = size; }
188
189 void Release() override { m_Size = 0; }
190
191 unsigned char* GetWritableData() override { return m_Data.get(); }
192
193private:
194 unsigned int m_MaxSize;
195 unsigned int m_Size;
196 std::unique_ptr<unsigned char[]> m_Data;
197};
198
199class MockBufferManager : public IBufferManager
200{
201public:
202 MockBufferManager(unsigned int size)
203 : m_BufferSize(size),
204 m_Buffer(std::make_unique<MockPacketBuffer>(size)) {}
205
206 ~MockBufferManager() {}
207
208 IPacketBufferPtr Reserve(unsigned int requestedSize, unsigned int& reservedSize) override
209 {
210 if (requestedSize > m_BufferSize)
211 {
212 reservedSize = m_BufferSize;
213 }
214 else
215 {
216 reservedSize = requestedSize;
217 }
218
219 return std::move(m_Buffer);
220 }
221
222 void Commit(IPacketBufferPtr& packetBuffer, unsigned int size, bool notifyConsumer = true) override
223 {
224 packetBuffer->Commit(size);
225 m_Buffer = std::move(packetBuffer);
226
227 if (notifyConsumer)
228 {
229 FlushReadList();
230 }
231 }
232
233 IPacketBufferPtr GetReadableBuffer() override
234 {
235 return std::move(m_Buffer);
236 }
237
238 void Release(IPacketBufferPtr& packetBuffer) override
239 {
240 packetBuffer->Release();
241 m_Buffer = std::move(packetBuffer);
242 }
243
244 void MarkRead(IPacketBufferPtr& packetBuffer) override
245 {
246 packetBuffer->MarkRead();
247 m_Buffer = std::move(packetBuffer);
248 }
249
250 void SetConsumer(IConsumer* consumer) override
251 {
252 if (consumer != nullptr)
253 {
254 m_Consumer = consumer;
255 }
256 }
257
258 void FlushReadList() override
259 {
260 // notify consumer that packet is ready to read
261 if (m_Consumer != nullptr)
262 {
263 m_Consumer->SetReadyToRead();
264 }
265 }
266
267private:
268 unsigned int m_BufferSize;
269 IPacketBufferPtr m_Buffer;
270 IConsumer* m_Consumer = nullptr;
271};
272
273class MockStreamCounterBuffer : public IBufferManager
274{
275public:
276 MockStreamCounterBuffer(unsigned int maxBufferSize = 4096)
277 : m_MaxBufferSize(maxBufferSize)
278 , m_BufferList()
279 , m_CommittedSize(0)
280 , m_ReadableSize(0)
281 , m_ReadSize(0)
282 {}
283 ~MockStreamCounterBuffer() {}
284
285 IPacketBufferPtr Reserve(unsigned int requestedSize, unsigned int& reservedSize) override
286 {
287 std::lock_guard<std::mutex> lock(m_Mutex);
288
289 reservedSize = 0;
290 if (requestedSize > m_MaxBufferSize)
291 {
292 throw armnn::InvalidArgumentException("The maximum buffer size that can be requested is [" +
293 std::to_string(m_MaxBufferSize) + "] bytes");
294 }
295 reservedSize = requestedSize;
296 return std::make_unique<MockPacketBuffer>(requestedSize);
297 }
298
299 void Commit(IPacketBufferPtr& packetBuffer, unsigned int size, bool notifyConsumer = true) override
300 {
301 std::lock_guard<std::mutex> lock(m_Mutex);
302
303 packetBuffer->Commit(size);
304 m_BufferList.push_back(std::move(packetBuffer));
305 m_CommittedSize += size;
306
307 if (notifyConsumer)
308 {
309 FlushReadList();
310 }
311 }
312
313 void Release(IPacketBufferPtr& packetBuffer) override
314 {
315 std::lock_guard<std::mutex> lock(m_Mutex);
316
317 packetBuffer->Release();
318 }
319
320 IPacketBufferPtr GetReadableBuffer() override
321 {
322 std::lock_guard<std::mutex> lock(m_Mutex);
323
324 if (m_BufferList.empty())
325 {
326 return nullptr;
327 }
328 IPacketBufferPtr buffer = std::move(m_BufferList.back());
329 m_BufferList.pop_back();
330 m_ReadableSize += buffer->GetSize();
331 return buffer;
332 }
333
334 void MarkRead(IPacketBufferPtr& packetBuffer) override
335 {
336 std::lock_guard<std::mutex> lock(m_Mutex);
337
338 m_ReadSize += packetBuffer->GetSize();
339 packetBuffer->MarkRead();
340 }
341
342 void SetConsumer(IConsumer* consumer) override
343 {
344 if (consumer != nullptr)
345 {
346 m_Consumer = consumer;
347 }
348 }
349
350 void FlushReadList() override
351 {
352 // notify consumer that packet is ready to read
353 if (m_Consumer != nullptr)
354 {
355 m_Consumer->SetReadyToRead();
356 }
357 }
358
359 unsigned int GetCommittedSize() const { return m_CommittedSize; }
360 unsigned int GetReadableSize() const { return m_ReadableSize; }
361 unsigned int GetReadSize() const { return m_ReadSize; }
362
363private:
364 // The maximum buffer size when creating a new buffer
365 unsigned int m_MaxBufferSize;
366
367 // A list of buffers
368 std::vector<IPacketBufferPtr> m_BufferList;
369
370 // The mutex to synchronize this mock's methods
371 std::mutex m_Mutex;
372
373 // The total size of the buffers that has been committed for reading
374 unsigned int m_CommittedSize;
375
376 // The total size of the buffers that can be read
377 unsigned int m_ReadableSize;
378
379 // The total size of the buffers that has already been read
380 unsigned int m_ReadSize;
381
382 // Consumer thread to notify packet is ready to read
383 IConsumer* m_Consumer = nullptr;
384};
385
386class MockSendCounterPacket : public ISendCounterPacket
387{
388public:
389 MockSendCounterPacket(IBufferManager& sendBuffer) : m_BufferManager(sendBuffer) {}
390
391 void SendStreamMetaDataPacket() override
392 {
393 std::string message("SendStreamMetaDataPacket");
394 unsigned int reserved = 0;
395 IPacketBufferPtr buffer = m_BufferManager.Reserve(1024, reserved);
396 memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
397 m_BufferManager.Commit(buffer, reserved, false);
398 }
399
400 void SendCounterDirectoryPacket(const ICounterDirectory& counterDirectory) override
401 {
402 boost::ignore_unused(counterDirectory);
403
404 std::string message("SendCounterDirectoryPacket");
405 unsigned int reserved = 0;
406 IPacketBufferPtr buffer = m_BufferManager.Reserve(1024, reserved);
407 memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
408 m_BufferManager.Commit(buffer, reserved);
409 }
410
411 void SendPeriodicCounterCapturePacket(uint64_t timestamp,
412 const std::vector<CounterValue>& values) override
413 {
414 boost::ignore_unused(timestamp, values);
415
416 std::string message("SendPeriodicCounterCapturePacket");
417 unsigned int reserved = 0;
418 IPacketBufferPtr buffer = m_BufferManager.Reserve(1024, reserved);
419 memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
420 m_BufferManager.Commit(buffer, reserved);
421 }
422
423 void SendPeriodicCounterSelectionPacket(uint32_t capturePeriod,
424 const std::vector<uint16_t>& selectedCounterIds) override
425 {
426 boost::ignore_unused(capturePeriod, selectedCounterIds);
427
428 std::string message("SendPeriodicCounterSelectionPacket");
429 unsigned int reserved = 0;
430 IPacketBufferPtr buffer = m_BufferManager.Reserve(1024, reserved);
431 memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
432 m_BufferManager.Commit(buffer, reserved);
433 }
434
435private:
436 IBufferManager& m_BufferManager;
437};
438
439class MockCounterDirectory : public ICounterDirectory
440{
441public:
442 MockCounterDirectory() = default;
443 ~MockCounterDirectory() = default;
444
445 // Register profiling objects
446 const Category* RegisterCategory(const std::string& categoryName,
447 const armnn::Optional<uint16_t>& deviceUid = armnn::EmptyOptional(),
448 const armnn::Optional<uint16_t>& counterSetUid = armnn::EmptyOptional())
449 {
450 // Get the device UID
451 uint16_t deviceUidValue = deviceUid.has_value() ? deviceUid.value() : 0;
452
453 // Get the counter set UID
454 uint16_t counterSetUidValue = counterSetUid.has_value() ? counterSetUid.value() : 0;
455
456 // Create the category
457 CategoryPtr category = std::make_unique<Category>(categoryName, deviceUidValue, counterSetUidValue);
458 BOOST_ASSERT(category);
459
460 // Get the raw category pointer
461 const Category* categoryPtr = category.get();
462 BOOST_ASSERT(categoryPtr);
463
464 // Register the category
465 m_Categories.insert(std::move(category));
466
467 return categoryPtr;
468 }
469
470 const Device* RegisterDevice(const std::string& deviceName,
471 uint16_t cores = 0,
472 const armnn::Optional<std::string>& parentCategoryName = armnn::EmptyOptional())
473 {
474 // Get the device UID
475 uint16_t deviceUid = GetNextUid();
476
477 // Create the device
478 DevicePtr device = std::make_unique<Device>(deviceUid, deviceName, cores);
479 BOOST_ASSERT(device);
480
481 // Get the raw device pointer
482 const Device* devicePtr = device.get();
483 BOOST_ASSERT(devicePtr);
484
485 // Register the device
486 m_Devices.insert(std::make_pair(deviceUid, std::move(device)));
487
488 // Connect the counter set to the parent category, if required
489 if (parentCategoryName.has_value())
490 {
491 // Set the counter set UID in the parent category
492 Category* parentCategory = const_cast<Category*>(GetCategory(parentCategoryName.value()));
493 BOOST_ASSERT(parentCategory);
494 parentCategory->m_DeviceUid = deviceUid;
495 }
496
497 return devicePtr;
498 }
499
500 const CounterSet* RegisterCounterSet(
501 const std::string& counterSetName,
502 uint16_t count = 0,
503 const armnn::Optional<std::string>& parentCategoryName = armnn::EmptyOptional())
504 {
505 // Get the counter set UID
506 uint16_t counterSetUid = GetNextUid();
507
508 // Create the counter set
509 CounterSetPtr counterSet = std::make_unique<CounterSet>(counterSetUid, counterSetName, count);
510 BOOST_ASSERT(counterSet);
511
512 // Get the raw counter set pointer
513 const CounterSet* counterSetPtr = counterSet.get();
514 BOOST_ASSERT(counterSetPtr);
515
516 // Register the counter set
517 m_CounterSets.insert(std::make_pair(counterSetUid, std::move(counterSet)));
518
519 // Connect the counter set to the parent category, if required
520 if (parentCategoryName.has_value())
521 {
522 // Set the counter set UID in the parent category
523 Category* parentCategory = const_cast<Category*>(GetCategory(parentCategoryName.value()));
524 BOOST_ASSERT(parentCategory);
525 parentCategory->m_CounterSetUid = counterSetUid;
526 }
527
528 return counterSetPtr;
529 }
530
531 const Counter* RegisterCounter(const BackendId& backendId,
532 const uint16_t uid,
533 const std::string& parentCategoryName,
534 uint16_t counterClass,
535 uint16_t interpolation,
536 double multiplier,
537 const std::string& name,
538 const std::string& description,
539 const armnn::Optional<std::string>& units = armnn::EmptyOptional(),
540 const armnn::Optional<uint16_t>& numberOfCores = armnn::EmptyOptional(),
541 const armnn::Optional<uint16_t>& deviceUid = armnn::EmptyOptional(),
542 const armnn::Optional<uint16_t>& counterSetUid = armnn::EmptyOptional())
543 {
544 boost::ignore_unused(backendId);
545
546 // Get the number of cores from the argument only
547 uint16_t deviceCores = numberOfCores.has_value() ? numberOfCores.value() : 0;
548
549 // Get the device UID
550 uint16_t deviceUidValue = deviceUid.has_value() ? deviceUid.value() : 0;
551
552 // Get the counter set UID
553 uint16_t counterSetUidValue = counterSetUid.has_value() ? counterSetUid.value() : 0;
554
555 // Get the counter UIDs and calculate the max counter UID
556 std::vector<uint16_t> counterUids = GetNextCounterUids(uid, deviceCores);
557 BOOST_ASSERT(!counterUids.empty());
558 uint16_t maxCounterUid = deviceCores <= 1 ? counterUids.front() : counterUids.back();
559
560 // Get the counter units
561 const std::string unitsValue = units.has_value() ? units.value() : "";
562
563 // Create the counter
564 CounterPtr counter = std::make_shared<Counter>(armnn::profiling::BACKEND_ID,
565 counterUids.front(),
566 maxCounterUid,
567 counterClass,
568 interpolation,
569 multiplier,
570 name,
571 description,
572 unitsValue,
573 deviceUidValue,
574 counterSetUidValue);
575 BOOST_ASSERT(counter);
576
577 // Get the raw counter pointer
578 const Counter* counterPtr = counter.get();
579 BOOST_ASSERT(counterPtr);
580
581 // Process multiple counters if necessary
582 for (uint16_t counterUid : counterUids)
583 {
584 // Connect the counter to the parent category
585 Category* parentCategory = const_cast<Category*>(GetCategory(parentCategoryName));
586 BOOST_ASSERT(parentCategory);
587 parentCategory->m_Counters.push_back(counterUid);
588
589 // Register the counter
590 m_Counters.insert(std::make_pair(counterUid, counter));
591 }
592
593 return counterPtr;
594 }
595
596 // Getters for counts
597 uint16_t GetCategoryCount() const override { return boost::numeric_cast<uint16_t>(m_Categories.size()); }
598 uint16_t GetDeviceCount() const override { return boost::numeric_cast<uint16_t>(m_Devices.size()); }
599 uint16_t GetCounterSetCount() const override { return boost::numeric_cast<uint16_t>(m_CounterSets.size()); }
600 uint16_t GetCounterCount() const override { return boost::numeric_cast<uint16_t>(m_Counters.size()); }
601
602 // Getters for collections
603 const Categories& GetCategories() const override { return m_Categories; }
604 const Devices& GetDevices() const override { return m_Devices; }
605 const CounterSets& GetCounterSets() const override { return m_CounterSets; }
606 const Counters& GetCounters() const override { return m_Counters; }
607
608 // Getters for profiling objects
609 const Category* GetCategory(const std::string& name) const override
610 {
611 auto it = std::find_if(m_Categories.begin(), m_Categories.end(), [&name](const CategoryPtr& category)
612 {
613 BOOST_ASSERT(category);
614
615 return category->m_Name == name;
616 });
617
618 if (it == m_Categories.end())
619 {
620 return nullptr;
621 }
622
623 return it->get();
624 }
625
626 const Device* GetDevice(uint16_t uid) const override
627 {
628 boost::ignore_unused(uid);
629 return nullptr; // Not used by the unit tests
630 }
631
632 const CounterSet* GetCounterSet(uint16_t uid) const override
633 {
634 boost::ignore_unused(uid);
635 return nullptr; // Not used by the unit tests
636 }
637
638 const Counter* GetCounter(uint16_t uid) const override
639 {
640 boost::ignore_unused(uid);
641 return nullptr; // Not used by the unit tests
642 }
643
644private:
645 Categories m_Categories;
646 Devices m_Devices;
647 CounterSets m_CounterSets;
648 Counters m_Counters;
649};
650
Sadik Armagancab588a2020-02-17 11:33:31 +0000651class MockProfilingService : public IProfilingService, public IRegisterCounterMapping
652{
653public:
654 MockProfilingService(MockBufferManager& mockBufferManager,
655 bool isProfilingEnabled,
656 const CaptureData& captureData) :
657 m_SendCounterPacket(mockBufferManager),
658 m_IsProfilingEnabled(isProfilingEnabled),
659 m_CaptureData(captureData) {}
660
661 /// Return the next random Guid in the sequence
662 ProfilingDynamicGuid NextGuid() override
663 {
664 return m_GuidGenerator.NextGuid();
665 }
666
667 /// Create a ProfilingStaticGuid based on a hash of the string
668 ProfilingStaticGuid GenerateStaticId(const std::string& str) override
669 {
670 return m_GuidGenerator.GenerateStaticId(str);
671 }
672
673 std::unique_ptr<ISendTimelinePacket> GetSendTimelinePacket() const override
674 {
675 return nullptr;
676 }
677
678 const ICounterMappings& GetCounterMappings() const override
679 {
680 return m_CounterMapping;
681 }
682
683 ISendCounterPacket& GetSendCounterPacket() override
684 {
685 return m_SendCounterPacket;
686 }
687
688 bool IsProfilingEnabled() const override
689 {
690 return m_IsProfilingEnabled;
691 }
692
693 CaptureData GetCaptureData() override
694 {
695 CaptureData copy(m_CaptureData);
696 return copy;
697 }
698
699 void RegisterMapping(uint16_t globalCounterId,
700 uint16_t backendCounterId,
701 const armnn::BackendId& backendId) override
702 {
703 m_CounterMapping.RegisterMapping(globalCounterId, backendCounterId, backendId);
704 }
705
706 void Reset() override
707 {
708 m_CounterMapping.Reset();
709 }
710
711private:
712 ProfilingGuidGenerator m_GuidGenerator;
713 CounterIdMap m_CounterMapping;
714 SendCounterPacket m_SendCounterPacket;
715 bool m_IsProfilingEnabled;
716 CaptureData m_CaptureData;
717};
718
Jim Flynn64063552020-02-14 10:18:08 +0000719} // namespace profiling
720
721} // namespace armnn