blob: 21c98723be93ca8896d57ce9723e4faa8b1ecf00 [file] [log] [blame]
Matteo Martincighd0613b52019-10-09 16:47:04 +01001//
2// Copyright © 2019 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "SendCounterPacketTests.hpp"
9
10#include <CommandHandlerFunctor.hpp>
11#include <IProfilingConnection.hpp>
Matteo Martincighd0613b52019-10-09 16:47:04 +010012#include <Logging.hpp>
13#include <ProfilingService.hpp>
14
15#include <boost/test/unit_test.hpp>
16
17#include <chrono>
Matteo Martincighd0613b52019-10-09 16:47:04 +010018#include <thread>
19
20namespace armnn
21{
22
23namespace profiling
24{
25
26struct LogLevelSwapper
27{
28public:
29 LogLevelSwapper(armnn::LogSeverity severity)
30 {
31 // Set the new log level
32 armnn::ConfigureLogging(true, true, severity);
33 }
34 ~LogLevelSwapper()
35 {
36 // The default log level for unit tests is "Fatal"
37 armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal);
38 }
39};
40
Matteo Martincighd0613b52019-10-09 16:47:04 +010041struct StreamRedirector
42{
43public:
44 StreamRedirector(std::ostream& stream, std::streambuf* newStreamBuffer)
45 : m_Stream(stream)
46 , m_BackupBuffer(m_Stream.rdbuf(newStreamBuffer))
47 {}
48 ~StreamRedirector() { m_Stream.rdbuf(m_BackupBuffer); }
49
50private:
51 std::ostream& m_Stream;
52 std::streambuf* m_BackupBuffer;
53};
54
55class TestProfilingConnectionBase : public IProfilingConnection
56{
57public:
58 TestProfilingConnectionBase() = default;
59 ~TestProfilingConnectionBase() = default;
60
61 bool IsOpen() const override { return true; }
62
63 void Close() override {}
64
65 bool WritePacket(const unsigned char* buffer, uint32_t length) override { return false; }
66
67 Packet ReadPacket(uint32_t timeout) override
68 {
69 std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
70
71 // Return connection acknowledged packet
Matteo Martincigh67ef2a52019-10-10 13:29:02 +010072 return Packet(65536);
Matteo Martincighd0613b52019-10-09 16:47:04 +010073 }
74};
75
76class TestProfilingConnectionTimeoutError : public TestProfilingConnectionBase
77{
78public:
79 TestProfilingConnectionTimeoutError()
80 : m_ReadRequests(0)
81 {}
82
83 Packet ReadPacket(uint32_t timeout) override
84 {
85 std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
86
87 if (m_ReadRequests < 3)
88 {
89 m_ReadRequests++;
90 throw armnn::TimeoutException("Simulate a timeout error\n");
91 }
92
93 // Return connection acknowledged packet after three timeouts
Matteo Martincigh67ef2a52019-10-10 13:29:02 +010094 return Packet(65536);
Matteo Martincighd0613b52019-10-09 16:47:04 +010095 }
96
97private:
98 int m_ReadRequests;
99};
100
101class TestProfilingConnectionArmnnError : public TestProfilingConnectionBase
102{
103public:
104 Packet ReadPacket(uint32_t timeout) override
105 {
106 std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
107
108 throw armnn::Exception("Simulate a non-timeout error");
109 }
110};
111
112class TestFunctorA : public CommandHandlerFunctor
113{
114public:
115 using CommandHandlerFunctor::CommandHandlerFunctor;
116
117 int GetCount() { return m_Count; }
118
119 void operator()(const Packet& packet) override
120 {
121 m_Count++;
122 }
123
124private:
125 int m_Count = 0;
126};
127
128class TestFunctorB : public TestFunctorA
129{
130 using TestFunctorA::TestFunctorA;
131};
132
133class TestFunctorC : public TestFunctorA
134{
135 using TestFunctorA::TestFunctorA;
136};
137
Matteo Martincighd0613b52019-10-09 16:47:04 +0100138class SwapProfilingConnectionFactoryHelper : public ProfilingService
139{
140public:
141 using MockProfilingConnectionFactoryPtr = std::unique_ptr<MockProfilingConnectionFactory>;
142
143 SwapProfilingConnectionFactoryHelper()
144 : ProfilingService()
145 , m_MockProfilingConnectionFactory(new MockProfilingConnectionFactory())
146 , m_BackupProfilingConnectionFactory(nullptr)
147 {
148 BOOST_CHECK(m_MockProfilingConnectionFactory);
149 SwapProfilingConnectionFactory(ProfilingService::Instance(),
150 m_MockProfilingConnectionFactory.get(),
151 m_BackupProfilingConnectionFactory);
152 BOOST_CHECK(m_BackupProfilingConnectionFactory);
153 }
154 ~SwapProfilingConnectionFactoryHelper()
155 {
156 BOOST_CHECK(m_BackupProfilingConnectionFactory);
157 IProfilingConnectionFactory* temp = nullptr;
158 SwapProfilingConnectionFactory(ProfilingService::Instance(),
159 m_BackupProfilingConnectionFactory,
160 temp);
161 }
162
163 MockProfilingConnection* GetMockProfilingConnection()
164 {
165 IProfilingConnection* profilingConnection = GetProfilingConnection(ProfilingService::Instance());
166 return boost::polymorphic_downcast<MockProfilingConnection*>(profilingConnection);
167 }
168
Matteo Martincigh8efc5002019-10-10 14:30:29 +0100169 void ForceTransitionToState(ProfilingState newState)
170 {
171 TransitionToState(ProfilingService::Instance(), newState);
172 }
173
Matteo Martincighe8485382019-10-10 14:08:21 +0100174 void WaitForProfilingPacketsSent()
175 {
176 return WaitForPacketSent(ProfilingService::Instance());
177 }
178
Matteo Martincighd0613b52019-10-09 16:47:04 +0100179private:
180 MockProfilingConnectionFactoryPtr m_MockProfilingConnectionFactory;
181 IProfilingConnectionFactory* m_BackupProfilingConnectionFactory;
182};
183
184} // namespace profiling
185
186} // namespace armnn