blob: 5b87927af2e8f33facaab2eb4220b1fc97aec014 [file] [log] [blame]
Keith Davise813d672021-04-22 10:10:34 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <AsyncExecutionCallback.hpp>
7
8namespace armnn
9{
10
11namespace experimental
12{
13
14void AsyncExecutionCallback::Notify(armnn::Status status, InferenceTimingPair timeTaken)
15{
16 {
Jim Flynn870b96c2022-03-25 21:24:56 +000017#if !defined(ARMNN_DISABLE_THREADS)
Keith Davise813d672021-04-22 10:10:34 +010018 std::lock_guard<std::mutex> hold(m_Mutex);
Jim Flynn870b96c2022-03-25 21:24:56 +000019#endif
Keith Davise813d672021-04-22 10:10:34 +010020 // store results and mark as notified
21 m_Status = status;
22 m_StartTime = timeTaken.first;
23 m_EndTime = timeTaken.second;
Finn Williamsf364d532021-06-09 17:07:33 +010024 m_NotificationQueue.push(m_InferenceId);
Keith Davise813d672021-04-22 10:10:34 +010025 }
Jim Flynn870b96c2022-03-25 21:24:56 +000026#if !defined(ARMNN_DISABLE_THREADS)
Keith Davise813d672021-04-22 10:10:34 +010027 m_Condition.notify_all();
Jim Flynn870b96c2022-03-25 21:24:56 +000028#endif
Keith Davise813d672021-04-22 10:10:34 +010029}
30
Keith Davise813d672021-04-22 10:10:34 +010031armnn::Status AsyncExecutionCallback::GetStatus() const
32{
Keith Davise813d672021-04-22 10:10:34 +010033 return m_Status;
34}
35
36HighResolutionClock AsyncExecutionCallback::GetStartTime() const
37{
Keith Davise813d672021-04-22 10:10:34 +010038 return m_StartTime;
39}
40
41HighResolutionClock AsyncExecutionCallback::GetEndTime() const
42{
Keith Davise813d672021-04-22 10:10:34 +010043 return m_EndTime;
44}
45
Finn Williamsf364d532021-06-09 17:07:33 +010046std::shared_ptr<AsyncExecutionCallback> AsyncCallbackManager::GetNewCallback()
47{
Jim Flynn870b96c2022-03-25 21:24:56 +000048 auto cb = std::make_unique<AsyncExecutionCallback>(m_NotificationQueue
49#if !defined(ARMNN_DISABLE_THREADS)
50 , m_Mutex
51 , m_Condition
52#endif
53 );
Finn Williamsf364d532021-06-09 17:07:33 +010054 InferenceId id = cb->GetInferenceId();
55 m_Callbacks.insert({id, std::move(cb)});
56
57 return m_Callbacks.at(id);
58}
59
60std::shared_ptr<AsyncExecutionCallback> AsyncCallbackManager::GetNotifiedCallback()
61{
Jim Flynn870b96c2022-03-25 21:24:56 +000062#if !defined(ARMNN_DISABLE_THREADS)
Finn Williamsf364d532021-06-09 17:07:33 +010063 std::unique_lock<std::mutex> lock(m_Mutex);
64
65 m_Condition.wait(lock, [this] { return !m_NotificationQueue.empty(); });
Jim Flynn870b96c2022-03-25 21:24:56 +000066#endif
Finn Williamsf364d532021-06-09 17:07:33 +010067 InferenceId id = m_NotificationQueue.front();
68 m_NotificationQueue.pop();
69
70 std::shared_ptr<AsyncExecutionCallback> callback = m_Callbacks.at(id);
71 m_Callbacks.erase(id);
72 return callback;
73}
74
Keith Davise813d672021-04-22 10:10:34 +010075} // namespace experimental
76
Jim Flynn870b96c2022-03-25 21:24:56 +000077} // namespace armnn