blob: c8abc5e5e7dc3729ad1df56e79c4e34fc87e744f [file] [log] [blame]
Narumol Prangnawaratcad4e912020-06-02 12:07:43 +01001//
2// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <queue>
9#include <thread>
10#include <mutex>
11#include <condition_variable>
12
13#include "ArmnnDriver.hpp"
14#include "ArmnnDriverImpl.hpp"
15
16#include <CpuExecutor.h>
17#include <armnn/ArmNN.hpp>
18
19namespace armnn_driver
20{
21using TimePoint = std::chrono::steady_clock::time_point;
22
23template<template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
24class RequestThread_1_3
25{
26public:
27 /// Constructor creates the thread
28 RequestThread_1_3();
29
30 /// Destructor terminates the thread
31 ~RequestThread_1_3();
32
33 /// Add a message to the thread queue.
34 /// @param[in] model pointer to the prepared model handling the request
35 /// @param[in] memPools pointer to the memory pools vector for the tensors
36 /// @param[in] inputTensors pointer to the input tensors for the request
37 /// @param[in] outputTensors pointer to the output tensors for the request
38 /// @param[in] callback the android notification callback
39 void PostMsg(PreparedModel<HalVersion>* model,
40 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools,
41 std::shared_ptr<armnn::InputTensors>& inputTensors,
42 std::shared_ptr<armnn::OutputTensors>& outputTensors,
43 CallbackContext callbackContext);
44
45private:
46 RequestThread_1_3(const RequestThread_1_3&) = delete;
47 RequestThread_1_3& operator=(const RequestThread_1_3&) = delete;
48
49 /// storage for a prepared model and args for the asyncExecute call
50 struct AsyncExecuteData
51 {
52 AsyncExecuteData(PreparedModel<HalVersion>* model,
53 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools,
54 std::shared_ptr<armnn::InputTensors>& inputTensors,
55 std::shared_ptr<armnn::OutputTensors>& outputTensors,
56 CallbackContext callbackContext)
57 : m_Model(model)
58 , m_MemPools(memPools)
59 , m_InputTensors(inputTensors)
60 , m_OutputTensors(outputTensors)
61 , m_CallbackContext(callbackContext)
62 {
63 }
64
65 PreparedModel<HalVersion>* m_Model;
66 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools;
67 std::shared_ptr<armnn::InputTensors> m_InputTensors;
68 std::shared_ptr<armnn::OutputTensors> m_OutputTensors;
69 CallbackContext m_CallbackContext;
70 };
71 enum class ThreadMsgType
72 {
73 EXIT, // exit the thread
74 REQUEST // user request to process
75 };
76
77 /// storage for the thread message type and data
78 struct ThreadMsg
79 {
80 ThreadMsg(ThreadMsgType msgType,
81 std::shared_ptr<AsyncExecuteData>& msgData)
82 : type(msgType)
83 , data(msgData)
84 {
85 }
86
87 ThreadMsgType type;
88 std::shared_ptr<AsyncExecuteData> data;
89 };
90
91 /// Add a prepared thread message to the thread queue.
92 /// @param[in] threadMsg the message to add to the queue
93 void PostMsg(std::shared_ptr<ThreadMsg>& pThreadMsg, V1_3::Priority priority = V1_3::Priority::MEDIUM);
94
95 /// Entry point for the request thread
96 void Process();
97
98 std::unique_ptr<std::thread> m_Thread;
99 std::queue<std::shared_ptr<ThreadMsg>> m_HighPriorityQueue;
100 std::queue<std::shared_ptr<ThreadMsg>> m_MediumPriorityQueue;
101 std::queue<std::shared_ptr<ThreadMsg>> m_LowPriorityQueue;
102 std::mutex m_Mutex;
103 std::condition_variable m_Cv;
104};
105
106} // namespace armnn_driver