blob: 3e6cf63efe15aadcd5f2e7f7bd222ed6490855bf [file] [log] [blame]
Matteo Martincighd0613b52019-10-09 16:47:04 +01001//
2// Copyright © 2019 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "SendCounterPacketTests.hpp"
9
10#include <CommandHandlerFunctor.hpp>
11#include <IProfilingConnection.hpp>
12#include <IProfilingConnectionFactory.hpp>
13#include <Logging.hpp>
14#include <ProfilingService.hpp>
15
16#include <boost/test/unit_test.hpp>
17
18#include <chrono>
19#include <iostream>
20#include <thread>
21
22namespace armnn
23{
24
25namespace profiling
26{
27
28struct LogLevelSwapper
29{
30public:
31 LogLevelSwapper(armnn::LogSeverity severity)
32 {
33 // Set the new log level
34 armnn::ConfigureLogging(true, true, severity);
35 }
36 ~LogLevelSwapper()
37 {
38 // The default log level for unit tests is "Fatal"
39 armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal);
40 }
41};
42
43struct CoutRedirect
44{
45public:
46 CoutRedirect(std::streambuf* newStreamBuffer)
47 : m_Old(std::cout.rdbuf(newStreamBuffer)) {}
48 ~CoutRedirect() { std::cout.rdbuf(m_Old); }
49
50private:
51 std::streambuf* m_Old;
52};
53
54struct StreamRedirector
55{
56public:
57 StreamRedirector(std::ostream& stream, std::streambuf* newStreamBuffer)
58 : m_Stream(stream)
59 , m_BackupBuffer(m_Stream.rdbuf(newStreamBuffer))
60 {}
61 ~StreamRedirector() { m_Stream.rdbuf(m_BackupBuffer); }
62
63private:
64 std::ostream& m_Stream;
65 std::streambuf* m_BackupBuffer;
66};
67
68class TestProfilingConnectionBase : public IProfilingConnection
69{
70public:
71 TestProfilingConnectionBase() = default;
72 ~TestProfilingConnectionBase() = default;
73
74 bool IsOpen() const override { return true; }
75
76 void Close() override {}
77
78 bool WritePacket(const unsigned char* buffer, uint32_t length) override { return false; }
79
80 Packet ReadPacket(uint32_t timeout) override
81 {
82 std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
83
84 // Return connection acknowledged packet
85 std::unique_ptr<char[]> packetData;
86 return Packet(65536, 0, packetData);
87 }
88};
89
90class TestProfilingConnectionTimeoutError : public TestProfilingConnectionBase
91{
92public:
93 TestProfilingConnectionTimeoutError()
94 : m_ReadRequests(0)
95 {}
96
97 Packet ReadPacket(uint32_t timeout) override
98 {
99 std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
100
101 if (m_ReadRequests < 3)
102 {
103 m_ReadRequests++;
104 throw armnn::TimeoutException("Simulate a timeout error\n");
105 }
106
107 // Return connection acknowledged packet after three timeouts
108 std::unique_ptr<char[]> packetData;
109 return Packet(65536, 0, packetData);
110 }
111
112private:
113 int m_ReadRequests;
114};
115
116class TestProfilingConnectionArmnnError : public TestProfilingConnectionBase
117{
118public:
119 Packet ReadPacket(uint32_t timeout) override
120 {
121 std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
122
123 throw armnn::Exception("Simulate a non-timeout error");
124 }
125};
126
127class TestFunctorA : public CommandHandlerFunctor
128{
129public:
130 using CommandHandlerFunctor::CommandHandlerFunctor;
131
132 int GetCount() { return m_Count; }
133
134 void operator()(const Packet& packet) override
135 {
136 m_Count++;
137 }
138
139private:
140 int m_Count = 0;
141};
142
143class TestFunctorB : public TestFunctorA
144{
145 using TestFunctorA::TestFunctorA;
146};
147
148class TestFunctorC : public TestFunctorA
149{
150 using TestFunctorA::TestFunctorA;
151};
152
153class MockProfilingConnectionFactory : public IProfilingConnectionFactory
154{
155public:
156 IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override
157 {
158 return std::make_unique<MockProfilingConnection>();
159 }
160};
161
162class SwapProfilingConnectionFactoryHelper : public ProfilingService
163{
164public:
165 using MockProfilingConnectionFactoryPtr = std::unique_ptr<MockProfilingConnectionFactory>;
166
167 SwapProfilingConnectionFactoryHelper()
168 : ProfilingService()
169 , m_MockProfilingConnectionFactory(new MockProfilingConnectionFactory())
170 , m_BackupProfilingConnectionFactory(nullptr)
171 {
172 BOOST_CHECK(m_MockProfilingConnectionFactory);
173 SwapProfilingConnectionFactory(ProfilingService::Instance(),
174 m_MockProfilingConnectionFactory.get(),
175 m_BackupProfilingConnectionFactory);
176 BOOST_CHECK(m_BackupProfilingConnectionFactory);
177 }
178 ~SwapProfilingConnectionFactoryHelper()
179 {
180 BOOST_CHECK(m_BackupProfilingConnectionFactory);
181 IProfilingConnectionFactory* temp = nullptr;
182 SwapProfilingConnectionFactory(ProfilingService::Instance(),
183 m_BackupProfilingConnectionFactory,
184 temp);
185 }
186
187 MockProfilingConnection* GetMockProfilingConnection()
188 {
189 IProfilingConnection* profilingConnection = GetProfilingConnection(ProfilingService::Instance());
190 return boost::polymorphic_downcast<MockProfilingConnection*>(profilingConnection);
191 }
192
193private:
194 MockProfilingConnectionFactoryPtr m_MockProfilingConnectionFactory;
195 IProfilingConnectionFactory* m_BackupProfilingConnectionFactory;
196};
197
198} // namespace profiling
199
200} // namespace armnn