blob: 86b5c315e2c68836ba4e4952fde45ee95e3913db [file] [log] [blame]
Matteo Martincighd0613b52019-10-09 16:47:04 +01001//
2// Copyright © 2019 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "SendCounterPacketTests.hpp"
9
10#include <CommandHandlerFunctor.hpp>
11#include <IProfilingConnection.hpp>
Matteo Martincighd0613b52019-10-09 16:47:04 +010012#include <Logging.hpp>
13#include <ProfilingService.hpp>
14
Narumol Prangnawarat85ad78c2019-11-18 15:34:23 +000015#include <boost/polymorphic_cast.hpp>
Matteo Martincighd0613b52019-10-09 16:47:04 +010016#include <boost/test/unit_test.hpp>
17
18#include <chrono>
Matteo Martincighd0613b52019-10-09 16:47:04 +010019#include <thread>
20
21namespace armnn
22{
23
24namespace profiling
25{
26
27struct LogLevelSwapper
28{
29public:
30 LogLevelSwapper(armnn::LogSeverity severity)
31 {
32 // Set the new log level
33 armnn::ConfigureLogging(true, true, severity);
34 }
35 ~LogLevelSwapper()
36 {
37 // The default log level for unit tests is "Fatal"
38 armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal);
39 }
40};
41
Matteo Martincighd0613b52019-10-09 16:47:04 +010042struct StreamRedirector
43{
44public:
45 StreamRedirector(std::ostream& stream, std::streambuf* newStreamBuffer)
46 : m_Stream(stream)
47 , m_BackupBuffer(m_Stream.rdbuf(newStreamBuffer))
48 {}
49 ~StreamRedirector() { m_Stream.rdbuf(m_BackupBuffer); }
50
51private:
52 std::ostream& m_Stream;
53 std::streambuf* m_BackupBuffer;
54};
55
56class TestProfilingConnectionBase : public IProfilingConnection
57{
58public:
59 TestProfilingConnectionBase() = default;
60 ~TestProfilingConnectionBase() = default;
61
62 bool IsOpen() const override { return true; }
63
64 void Close() override {}
65
66 bool WritePacket(const unsigned char* buffer, uint32_t length) override { return false; }
67
68 Packet ReadPacket(uint32_t timeout) override
69 {
70 std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
71
72 // Return connection acknowledged packet
Matteo Martincigh67ef2a52019-10-10 13:29:02 +010073 return Packet(65536);
Matteo Martincighd0613b52019-10-09 16:47:04 +010074 }
75};
76
77class TestProfilingConnectionTimeoutError : public TestProfilingConnectionBase
78{
79public:
80 TestProfilingConnectionTimeoutError()
81 : m_ReadRequests(0)
82 {}
83
84 Packet ReadPacket(uint32_t timeout) override
85 {
86 std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
87
88 if (m_ReadRequests < 3)
89 {
90 m_ReadRequests++;
91 throw armnn::TimeoutException("Simulate a timeout error\n");
92 }
93
94 // Return connection acknowledged packet after three timeouts
Matteo Martincigh67ef2a52019-10-10 13:29:02 +010095 return Packet(65536);
Matteo Martincighd0613b52019-10-09 16:47:04 +010096 }
97
98private:
99 int m_ReadRequests;
100};
101
102class TestProfilingConnectionArmnnError : public TestProfilingConnectionBase
103{
104public:
105 Packet ReadPacket(uint32_t timeout) override
106 {
107 std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
108
109 throw armnn::Exception("Simulate a non-timeout error");
110 }
111};
112
113class TestFunctorA : public CommandHandlerFunctor
114{
115public:
116 using CommandHandlerFunctor::CommandHandlerFunctor;
117
118 int GetCount() { return m_Count; }
119
120 void operator()(const Packet& packet) override
121 {
122 m_Count++;
123 }
124
125private:
126 int m_Count = 0;
127};
128
129class TestFunctorB : public TestFunctorA
130{
131 using TestFunctorA::TestFunctorA;
132};
133
134class TestFunctorC : public TestFunctorA
135{
136 using TestFunctorA::TestFunctorA;
137};
138
Matteo Martincighd0613b52019-10-09 16:47:04 +0100139class SwapProfilingConnectionFactoryHelper : public ProfilingService
140{
141public:
142 using MockProfilingConnectionFactoryPtr = std::unique_ptr<MockProfilingConnectionFactory>;
143
144 SwapProfilingConnectionFactoryHelper()
145 : ProfilingService()
146 , m_MockProfilingConnectionFactory(new MockProfilingConnectionFactory())
147 , m_BackupProfilingConnectionFactory(nullptr)
148 {
149 BOOST_CHECK(m_MockProfilingConnectionFactory);
150 SwapProfilingConnectionFactory(ProfilingService::Instance(),
151 m_MockProfilingConnectionFactory.get(),
152 m_BackupProfilingConnectionFactory);
153 BOOST_CHECK(m_BackupProfilingConnectionFactory);
154 }
155 ~SwapProfilingConnectionFactoryHelper()
156 {
157 BOOST_CHECK(m_BackupProfilingConnectionFactory);
158 IProfilingConnectionFactory* temp = nullptr;
159 SwapProfilingConnectionFactory(ProfilingService::Instance(),
160 m_BackupProfilingConnectionFactory,
161 temp);
162 }
163
164 MockProfilingConnection* GetMockProfilingConnection()
165 {
166 IProfilingConnection* profilingConnection = GetProfilingConnection(ProfilingService::Instance());
167 return boost::polymorphic_downcast<MockProfilingConnection*>(profilingConnection);
168 }
169
Matteo Martincigh8efc5002019-10-10 14:30:29 +0100170 void ForceTransitionToState(ProfilingState newState)
171 {
172 TransitionToState(ProfilingService::Instance(), newState);
173 }
174
Matteo Martincighe8485382019-10-10 14:08:21 +0100175 void WaitForProfilingPacketsSent()
176 {
177 return WaitForPacketSent(ProfilingService::Instance());
178 }
179
Matteo Martincighd0613b52019-10-09 16:47:04 +0100180private:
181 MockProfilingConnectionFactoryPtr m_MockProfilingConnectionFactory;
182 IProfilingConnectionFactory* m_BackupProfilingConnectionFactory;
183};
184
185} // namespace profiling
186
187} // namespace armnn