blob: 208fb80865e48361bccb9d080a21489cb722eaab [file] [log] [blame]
Matteo Martincighd0613b52019-10-09 16:47:04 +01001//
2// Copyright © 2019 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "SendCounterPacketTests.hpp"
9
Derek Lamberti08446972019-11-26 16:38:31 +000010#include <armnn/Logging.hpp>
11
Matteo Martincighd0613b52019-10-09 16:47:04 +010012#include <CommandHandlerFunctor.hpp>
13#include <IProfilingConnection.hpp>
Matteo Martincighd0613b52019-10-09 16:47:04 +010014#include <ProfilingService.hpp>
15
Narumol Prangnawarat85ad78c2019-11-18 15:34:23 +000016#include <boost/polymorphic_cast.hpp>
Matteo Martincighd0613b52019-10-09 16:47:04 +010017#include <boost/test/unit_test.hpp>
18
19#include <chrono>
Matteo Martincighd0613b52019-10-09 16:47:04 +010020#include <thread>
21
22namespace armnn
23{
24
25namespace profiling
26{
27
28struct LogLevelSwapper
29{
30public:
31 LogLevelSwapper(armnn::LogSeverity severity)
32 {
33 // Set the new log level
34 armnn::ConfigureLogging(true, true, severity);
35 }
36 ~LogLevelSwapper()
37 {
38 // The default log level for unit tests is "Fatal"
39 armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal);
40 }
41};
42
Matteo Martincighd0613b52019-10-09 16:47:04 +010043struct StreamRedirector
44{
45public:
46 StreamRedirector(std::ostream& stream, std::streambuf* newStreamBuffer)
47 : m_Stream(stream)
48 , m_BackupBuffer(m_Stream.rdbuf(newStreamBuffer))
49 {}
Colm Donelan2ba48d22019-11-29 09:10:59 +000050
51 ~StreamRedirector() { CancelRedirect(); }
52
53 void CancelRedirect()
54 {
55 // Only cancel the redirect once.
56 if (m_BackupBuffer != nullptr )
57 {
58 m_Stream.rdbuf(m_BackupBuffer);
59 m_BackupBuffer = nullptr;
60 }
61 }
Matteo Martincighd0613b52019-10-09 16:47:04 +010062
63private:
64 std::ostream& m_Stream;
65 std::streambuf* m_BackupBuffer;
66};
67
68class TestProfilingConnectionBase : public IProfilingConnection
69{
70public:
71 TestProfilingConnectionBase() = default;
72 ~TestProfilingConnectionBase() = default;
73
74 bool IsOpen() const override { return true; }
75
76 void Close() override {}
77
Derek Lamberti1dd75b32019-12-10 21:23:23 +000078 bool WritePacket(const unsigned char* buffer, uint32_t length) override
79 {
80 boost::ignore_unused(buffer, length);
81
82 return false;
83 }
Matteo Martincighd0613b52019-10-09 16:47:04 +010084
85 Packet ReadPacket(uint32_t timeout) override
86 {
Colm Donelan2ba48d22019-11-29 09:10:59 +000087 // First time we're called return a connection ack packet. After that always timeout.
88 if (m_FirstCall)
89 {
90 m_FirstCall = false;
91 // Return connection acknowledged packet
92 return Packet(65536);
93 }
94 else
95 {
96 std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
97 throw armnn::TimeoutException("Simulate a timeout error\n");
98 }
Matteo Martincighd0613b52019-10-09 16:47:04 +010099 }
Colm Donelan2ba48d22019-11-29 09:10:59 +0000100
101 bool m_FirstCall = true;
Matteo Martincighd0613b52019-10-09 16:47:04 +0100102};
103
104class TestProfilingConnectionTimeoutError : public TestProfilingConnectionBase
105{
106public:
107 TestProfilingConnectionTimeoutError()
108 : m_ReadRequests(0)
109 {}
110
111 Packet ReadPacket(uint32_t timeout) override
112 {
Colm Donelan2ba48d22019-11-29 09:10:59 +0000113 // Return connection acknowledged packet after three timeouts
114 if (m_ReadRequests % 3 == 0)
Matteo Martincighd0613b52019-10-09 16:47:04 +0100115 {
Colm Donelan2ba48d22019-11-29 09:10:59 +0000116 std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
117 ++m_ReadRequests;
Matteo Martincighd0613b52019-10-09 16:47:04 +0100118 throw armnn::TimeoutException("Simulate a timeout error\n");
119 }
120
Matteo Martincigh67ef2a52019-10-10 13:29:02 +0100121 return Packet(65536);
Matteo Martincighd0613b52019-10-09 16:47:04 +0100122 }
123
Colm Donelan2ba48d22019-11-29 09:10:59 +0000124 int ReadCalledCount()
125 {
126 return m_ReadRequests.load();
127 }
128
Matteo Martincighd0613b52019-10-09 16:47:04 +0100129private:
Colm Donelan2ba48d22019-11-29 09:10:59 +0000130 std::atomic<int> m_ReadRequests;
Matteo Martincighd0613b52019-10-09 16:47:04 +0100131};
132
133class TestProfilingConnectionArmnnError : public TestProfilingConnectionBase
134{
135public:
Colm Donelan2ba48d22019-11-29 09:10:59 +0000136 TestProfilingConnectionArmnnError()
137 : m_ReadRequests(0)
138 {}
139
Matteo Martincighd0613b52019-10-09 16:47:04 +0100140 Packet ReadPacket(uint32_t timeout) override
141 {
Derek Lamberti1dd75b32019-12-10 21:23:23 +0000142 boost::ignore_unused(timeout);
Colm Donelan2ba48d22019-11-29 09:10:59 +0000143 ++m_ReadRequests;
Matteo Martincighd0613b52019-10-09 16:47:04 +0100144 throw armnn::Exception("Simulate a non-timeout error");
145 }
Colm Donelan2ba48d22019-11-29 09:10:59 +0000146
147 int ReadCalledCount()
148 {
149 return m_ReadRequests.load();
150 }
151
152private:
153 std::atomic<int> m_ReadRequests;
Matteo Martincighd0613b52019-10-09 16:47:04 +0100154};
155
156class TestFunctorA : public CommandHandlerFunctor
157{
158public:
159 using CommandHandlerFunctor::CommandHandlerFunctor;
160
161 int GetCount() { return m_Count; }
162
163 void operator()(const Packet& packet) override
164 {
Derek Lamberti1dd75b32019-12-10 21:23:23 +0000165 boost::ignore_unused(packet);
Matteo Martincighd0613b52019-10-09 16:47:04 +0100166 m_Count++;
167 }
168
169private:
170 int m_Count = 0;
171};
172
173class TestFunctorB : public TestFunctorA
174{
175 using TestFunctorA::TestFunctorA;
176};
177
178class TestFunctorC : public TestFunctorA
179{
180 using TestFunctorA::TestFunctorA;
181};
182
Matteo Martincighd0613b52019-10-09 16:47:04 +0100183class SwapProfilingConnectionFactoryHelper : public ProfilingService
184{
185public:
186 using MockProfilingConnectionFactoryPtr = std::unique_ptr<MockProfilingConnectionFactory>;
187
188 SwapProfilingConnectionFactoryHelper()
189 : ProfilingService()
190 , m_MockProfilingConnectionFactory(new MockProfilingConnectionFactory())
191 , m_BackupProfilingConnectionFactory(nullptr)
192 {
193 BOOST_CHECK(m_MockProfilingConnectionFactory);
194 SwapProfilingConnectionFactory(ProfilingService::Instance(),
195 m_MockProfilingConnectionFactory.get(),
196 m_BackupProfilingConnectionFactory);
197 BOOST_CHECK(m_BackupProfilingConnectionFactory);
198 }
199 ~SwapProfilingConnectionFactoryHelper()
200 {
201 BOOST_CHECK(m_BackupProfilingConnectionFactory);
202 IProfilingConnectionFactory* temp = nullptr;
203 SwapProfilingConnectionFactory(ProfilingService::Instance(),
204 m_BackupProfilingConnectionFactory,
205 temp);
206 }
207
208 MockProfilingConnection* GetMockProfilingConnection()
209 {
210 IProfilingConnection* profilingConnection = GetProfilingConnection(ProfilingService::Instance());
211 return boost::polymorphic_downcast<MockProfilingConnection*>(profilingConnection);
212 }
213
Matteo Martincigh8efc5002019-10-10 14:30:29 +0100214 void ForceTransitionToState(ProfilingState newState)
215 {
216 TransitionToState(ProfilingService::Instance(), newState);
217 }
218
Colm Donelan2ba48d22019-11-29 09:10:59 +0000219 void WaitForProfilingPacketsSent(MockProfilingConnection* mockProfilingConnection, uint32_t timeout = 1000)
Matteo Martincighe8485382019-10-10 14:08:21 +0100220 {
Colm Donelan2ba48d22019-11-29 09:10:59 +0000221 if (!mockProfilingConnection->HasWrittenData())
222 {
223 WaitForPacketSent(ProfilingService::Instance(), timeout);
224 // It's possible the wait has timed out. Check there is some data.
225 if (!mockProfilingConnection->HasWrittenData())
226 {
227 throw RuntimeException("ProfilingTests::WaitForProfilingPacketsSent timeout waiting for packet.");
228 }
229 }
Matteo Martincighe8485382019-10-10 14:08:21 +0100230 }
231
Matteo Martincighd0613b52019-10-09 16:47:04 +0100232private:
233 MockProfilingConnectionFactoryPtr m_MockProfilingConnectionFactory;
234 IProfilingConnectionFactory* m_BackupProfilingConnectionFactory;
235};
236
237} // namespace profiling
238
239} // namespace armnn