blob: d48f80737daf65061c6595be0579e693f2cda45b [file] [log] [blame]
Keith Davise813d672021-04-22 10:10:34 +01001//
Ryan OSheab5540542022-07-06 09:52:52 +01002// Copyright © 2021-2023 Arm Ltd and Contributors. All rights reserved.
Keith Davise813d672021-04-22 10:10:34 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/IAsyncExecutionCallback.hpp>
Finn Williamsf364d532021-06-09 17:07:33 +01009#include <armnn/IWorkingMemHandle.hpp>
Keith Davise813d672021-04-22 10:10:34 +010010#include <armnn/Types.hpp>
Keith Davise813d672021-04-22 10:10:34 +010011
Finn Williamsf364d532021-06-09 17:07:33 +010012#include <condition_variable>
Keith Davise813d672021-04-22 10:10:34 +010013#include <mutex>
14#include <thread>
Finn Williamsf364d532021-06-09 17:07:33 +010015#include <queue>
16#include <unordered_map>
Keith Davise813d672021-04-22 10:10:34 +010017
18namespace armnn
19{
20
21namespace experimental
22{
23
Finn Williamsf364d532021-06-09 17:07:33 +010024using InferenceId = uint64_t;
Keith Davise813d672021-04-22 10:10:34 +010025class AsyncExecutionCallback final : public IAsyncExecutionCallback
26{
Finn Williamsf364d532021-06-09 17:07:33 +010027private:
28 static InferenceId nextID;
29
Keith Davise813d672021-04-22 10:10:34 +010030public:
Jim Flynn870b96c2022-03-25 21:24:56 +000031 AsyncExecutionCallback(std::queue<InferenceId>& notificationQueue
32#if !defined(ARMNN_DISABLE_THREADS)
33 , std::mutex& mutex
34 , std::condition_variable& condition
35#endif
36 )
Finn Williamsf364d532021-06-09 17:07:33 +010037 : m_NotificationQueue(notificationQueue)
Jim Flynn870b96c2022-03-25 21:24:56 +000038#if !defined(ARMNN_DISABLE_THREADS)
Finn Williamsf364d532021-06-09 17:07:33 +010039 , m_Mutex(mutex)
40 , m_Condition(condition)
Jim Flynn870b96c2022-03-25 21:24:56 +000041#endif
Finn Williamsf364d532021-06-09 17:07:33 +010042 , m_InferenceId(++nextID)
Keith Davise813d672021-04-22 10:10:34 +010043 {}
Finn Williamsf364d532021-06-09 17:07:33 +010044
Keith Davise813d672021-04-22 10:10:34 +010045 ~AsyncExecutionCallback()
46 {}
47
48 void Notify(armnn::Status status, InferenceTimingPair timeTaken);
Finn Williamsf364d532021-06-09 17:07:33 +010049
50 InferenceId GetInferenceId()
51 {
52 return m_InferenceId;
53 }
Keith Davise813d672021-04-22 10:10:34 +010054
55 armnn::Status GetStatus() const;
56 HighResolutionClock GetStartTime() const;
57 HighResolutionClock GetEndTime() const;
58
59private:
Finn Williamsf364d532021-06-09 17:07:33 +010060 std::queue<InferenceId>& m_NotificationQueue;
Jim Flynn870b96c2022-03-25 21:24:56 +000061#if !defined(ARMNN_DISABLE_THREADS)
Finn Williamsf364d532021-06-09 17:07:33 +010062 std::mutex& m_Mutex;
63 std::condition_variable& m_Condition;
Jim Flynn870b96c2022-03-25 21:24:56 +000064#endif
Keith Davise813d672021-04-22 10:10:34 +010065
66 HighResolutionClock m_StartTime;
67 HighResolutionClock m_EndTime;
Finn Williamsf364d532021-06-09 17:07:33 +010068 armnn::Status m_Status = Status::Failure;
69 InferenceId m_InferenceId;
70};
Finn Williamsf364d532021-06-09 17:07:33 +010071
72// Manager to create and monitor AsyncExecutionCallbacks
73// GetNewCallback will create a callback for use in Threadpool::Schedule
74// GetNotifiedCallback will return the first callback to be notified (finished execution)
75class AsyncCallbackManager
76{
77public:
78 std::shared_ptr<AsyncExecutionCallback> GetNewCallback();
79 std::shared_ptr<AsyncExecutionCallback> GetNotifiedCallback();
80
81private:
Jim Flynn870b96c2022-03-25 21:24:56 +000082#if !defined(ARMNN_DISABLE_THREADS)
Finn Williamsf364d532021-06-09 17:07:33 +010083 std::mutex m_Mutex;
84 std::condition_variable m_Condition;
Jim Flynn870b96c2022-03-25 21:24:56 +000085#endif
Finn Williamsf364d532021-06-09 17:07:33 +010086 std::unordered_map<InferenceId, std::shared_ptr<AsyncExecutionCallback>> m_Callbacks;
87 std::queue<InferenceId> m_NotificationQueue;
Keith Davise813d672021-04-22 10:10:34 +010088};
89
90} // namespace experimental
91
Jim Flynn870b96c2022-03-25 21:24:56 +000092} // namespace armnn