blob: d218433d93c60a94317a975e3aec2f1009a143a1 [file] [log] [blame]
Ferran Balaguer1b941722019-08-28 16:57:18 +01001//
2// Copyright © 2019 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "PeriodicCounterSelectionCommandHandler.hpp"
7#include "ProfilingUtils.hpp"
8
Colm Donelan02705242019-11-14 14:19:07 +00009#include <armnn/Types.hpp>
Ferran Balaguer1b941722019-08-28 16:57:18 +010010#include <boost/numeric/conversion/cast.hpp>
Matteo Martincighe8485382019-10-10 14:08:21 +010011#include <boost/format.hpp>
12
13#include <vector>
Ferran Balaguer1b941722019-08-28 16:57:18 +010014
15namespace armnn
16{
17
18namespace profiling
19{
20
Ferran Balaguer1b941722019-08-28 16:57:18 +010021void PeriodicCounterSelectionCommandHandler::ParseData(const Packet& packet, CaptureData& captureData)
22{
23 std::vector<uint16_t> counterIds;
Matteo Martincighe8485382019-10-10 14:08:21 +010024 uint32_t sizeOfUint32 = boost::numeric_cast<uint32_t>(sizeof(uint32_t));
25 uint32_t sizeOfUint16 = boost::numeric_cast<uint32_t>(sizeof(uint16_t));
Ferran Balaguer1b941722019-08-28 16:57:18 +010026 uint32_t offset = 0;
27
Matteo Martincighe8485382019-10-10 14:08:21 +010028 if (packet.GetLength() < 4)
Ferran Balaguer1b941722019-08-28 16:57:18 +010029 {
Matteo Martincighe8485382019-10-10 14:08:21 +010030 // Insufficient packet size
31 return;
32 }
33
34 // Parse the capture period
35 uint32_t capturePeriod = ReadUint32(packet.GetData(), offset);
36
37 // Set the capture period
38 captureData.SetCapturePeriod(capturePeriod);
39
40 // Parse the counter ids
41 unsigned int counters = (packet.GetLength() - 4) / 2;
42 if (counters > 0)
43 {
44 counterIds.reserve(counters);
45 offset += sizeOfUint32;
46 for (unsigned int i = 0; i < counters; ++i)
Ferran Balaguer1b941722019-08-28 16:57:18 +010047 {
Matteo Martincighe8485382019-10-10 14:08:21 +010048 // Parse the counter id
49 uint16_t counterId = ReadUint16(packet.GetData(), offset);
50 counterIds.emplace_back(counterId);
51 offset += sizeOfUint16;
Ferran Balaguer1b941722019-08-28 16:57:18 +010052 }
53 }
Matteo Martincighe8485382019-10-10 14:08:21 +010054
55 // Set the counter ids
56 captureData.SetCounterIds(counterIds);
Ferran Balaguer1b941722019-08-28 16:57:18 +010057}
58
59void PeriodicCounterSelectionCommandHandler::operator()(const Packet& packet)
60{
Matteo Martincighe8485382019-10-10 14:08:21 +010061 ProfilingState currentState = m_StateMachine.GetCurrentState();
62 switch (currentState)
63 {
64 case ProfilingState::Uninitialised:
65 case ProfilingState::NotConnected:
66 case ProfilingState::WaitingForAck:
67 throw RuntimeException(boost::str(boost::format("Periodic Counter Selection Command Handler invoked while in "
68 "an wrong state: %1%")
69 % GetProfilingStateName(currentState)));
70 case ProfilingState::Active:
71 {
72 // Process the packet
73 if (!(packet.GetPacketFamily() == 0u && packet.GetPacketId() == 4u))
74 {
75 throw armnn::InvalidArgumentException(boost::str(boost::format("Expected Packet family = 0, id = 4 but "
76 "received family = %1%, id = %2%")
77 % packet.GetPacketFamily()
78 % packet.GetPacketId()));
79 }
Ferran Balaguer1b941722019-08-28 16:57:18 +010080
Matteo Martincighe8485382019-10-10 14:08:21 +010081 // Parse the packet to get the capture period and counter UIDs
82 CaptureData captureData;
83 ParseData(packet, captureData);
Ferran Balaguer1b941722019-08-28 16:57:18 +010084
Matteo Martincighe8485382019-10-10 14:08:21 +010085 // Get the capture data
Colm Donelan02705242019-11-14 14:19:07 +000086 uint32_t capturePeriod = captureData.GetCapturePeriod();
87 // Validate that the capture period is within the acceptable range.
88 if (capturePeriod > 0 && capturePeriod < LOWEST_CAPTURE_PERIOD)
89 {
90 capturePeriod = LOWEST_CAPTURE_PERIOD;
91 }
Matteo Martincighe8485382019-10-10 14:08:21 +010092 const std::vector<uint16_t>& counterIds = captureData.GetCounterIds();
Ferran Balaguer1b941722019-08-28 16:57:18 +010093
Matteo Martincighe8485382019-10-10 14:08:21 +010094 // Check whether the selected counter UIDs are valid
95 std::vector<uint16_t> validCounterIds;
96 for (uint16_t counterId : counterIds)
97 {
98 // Check whether the counter is registered
99 if (!m_ReadCounterValues.IsCounterRegistered(counterId))
100 {
101 // Invalid counter UID, ignore it and continue
102 continue;
103 }
Matteo Martincighe8485382019-10-10 14:08:21 +0100104 // The counter is valid
Finn Williams032bc742020-02-12 11:02:34 +0000105 validCounterIds.emplace_back(counterId);
Matteo Martincighe8485382019-10-10 14:08:21 +0100106 }
Ferran Balaguer1b941722019-08-28 16:57:18 +0100107
Finn Williams032bc742020-02-12 11:02:34 +0000108 std::sort(validCounterIds.begin(), validCounterIds.end());
109
110 auto backendIdStart = std::find_if(validCounterIds.begin(), validCounterIds.end(), [&](uint16_t& counterId)
111 {
112 return counterId > m_MaxArmCounterId;
113 });
114
115 std::set<armnn::BackendId> activeBackends;
116 std::set<uint16_t> backendCounterIds = std::set<uint16_t>(backendIdStart, validCounterIds.end());
117
118 if (m_BackendCounterMap.size() != 0)
119 {
120 std::set<uint16_t> newCounterIds;
121 std::set<uint16_t> unusedCounterIds;
122
123 // Get any backend counter ids that is in backendCounterIds but not in m_PrevBackendCounterIds
124 std::set_difference(backendCounterIds.begin(), backendCounterIds.end(),
125 m_PrevBackendCounterIds.begin(), m_PrevBackendCounterIds.end(),
126 std::inserter(newCounterIds, newCounterIds.begin()));
127
128 // Get any backend counter ids that is in m_PrevBackendCounterIds but not in backendCounterIds
129 std::set_difference(m_PrevBackendCounterIds.begin(), m_PrevBackendCounterIds.end(),
130 backendCounterIds.begin(), backendCounterIds.end(),
131 std::inserter(unusedCounterIds, unusedCounterIds.begin()));
132
133 activeBackends = ProcessBackendCounterIds(capturePeriod, newCounterIds, unusedCounterIds);
134 }
135 else
136 {
137 activeBackends = ProcessBackendCounterIds(capturePeriod, backendCounterIds, {});
138 }
139
140 // save the new backend counter ids for next time
141 m_PrevBackendCounterIds = backendCounterIds;
142
143
144 // Set the capture data with only the valid armnn counter UIDs
145 m_CaptureDataHolder.SetCaptureData(capturePeriod, {validCounterIds.begin(), backendIdStart}, activeBackends);
Matteo Martincighe8485382019-10-10 14:08:21 +0100146
147 // Echo back the Periodic Counter Selection packet to the Counter Stream Buffer
148 m_SendCounterPacket.SendPeriodicCounterSelectionPacket(capturePeriod, validCounterIds);
149
Finn Williamsf4d59a62019-10-14 15:55:18 +0100150 if (capturePeriod == 0 || validCounterIds.empty())
151 {
152 // No data capture stop the thread
153 m_PeriodicCounterCapture.Stop();
154 }
155 else
156 {
157 // Start the Period Counter Capture thread (if not running already)
158 m_PeriodicCounterCapture.Start();
159 }
Matteo Martincighe8485382019-10-10 14:08:21 +0100160
161 break;
162 }
163 default:
164 throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1%")
165 % static_cast<int>(currentState)));
166 }
Ferran Balaguer1b941722019-08-28 16:57:18 +0100167}
168
Finn Williams032bc742020-02-12 11:02:34 +0000169std::set<armnn::BackendId> PeriodicCounterSelectionCommandHandler::ProcessBackendCounterIds(
170 const u_int32_t capturePeriod,
171 std::set<uint16_t> newCounterIds,
172 std::set<uint16_t> unusedCounterIds)
173{
174 std::set<armnn::BackendId> changedBackends;
175 std::set<armnn::BackendId> activeBackends = m_CaptureDataHolder.GetCaptureData().GetActiveBackends();
176
177 for (uint16_t counterId : newCounterIds)
178 {
179 auto backendId = m_CounterIdMap.GetBackendId(counterId);
180 m_BackendCounterMap[backendId.second].emplace_back(backendId.first);
181 changedBackends.insert(backendId.second);
182 }
183 // Add any new backends to active backends
184 activeBackends.insert(changedBackends.begin(), changedBackends.end());
185
186 for (uint16_t counterId : unusedCounterIds)
187 {
188 auto backendId = m_CounterIdMap.GetBackendId(counterId);
189 std::vector<uint16_t>& backendCounters = m_BackendCounterMap[backendId.second];
190
191 backendCounters.erase(std::remove(backendCounters.begin(), backendCounters.end(), backendId.first));
192
193 if(backendCounters.size() == 0)
194 {
195 // If a backend has no counters associated with it we remove it from active backends and
196 // send a capture period of zero with an empty vector, this will deactivate all the backends counters
197 activeBackends.erase(backendId.second);
198 ActivateBackedCounters(backendId.second, 0, {});
199 }
200 else
201 {
202 changedBackends.insert(backendId.second);
203 }
204 }
205
206 // If the capture period remains the same we only need to update the backends who's counters have changed
207 if(capturePeriod == m_PrevCapturePeriod)
208 {
209 for (auto backend : changedBackends)
210 {
211 ActivateBackedCounters(backend, capturePeriod, m_BackendCounterMap[backend]);
212 }
213 }
214 // Otherwise update all the backends with the new capture period and any new/unused counters
215 else
216 {
217 for (auto backend : m_BackendCounterMap)
218 {
219 ActivateBackedCounters(backend.first, capturePeriod, backend.second);
220 }
221 if(capturePeriod == 0)
222 {
223 activeBackends = {};
224 }
225 m_PrevCapturePeriod = capturePeriod;
226 }
227
228 return activeBackends;
229}
230
Ferran Balaguer1b941722019-08-28 16:57:18 +0100231} // namespace profiling
232
Matteo Martincighe8485382019-10-10 14:08:21 +0100233} // namespace armnn