blob: 253d104c44e5c6c203388d3b8f440bd69f3dac40 [file] [log] [blame]
telsoa015307bc12018-03-09 13:51:08 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beck93e48982018-09-05 13:05:09 +01003// SPDX-License-Identifier: MIT
telsoa015307bc12018-03-09 13:51:08 +00004//
5
6#pragma once
7
8#include <queue>
9#include <thread>
10#include <mutex>
11#include <condition_variable>
12
telsoa01ce3e84a2018-08-31 09:31:35 +010013#include "ArmnnDriver.hpp"
Matteo Martincighe48bdff2018-09-03 13:50:50 +010014#include "ArmnnDriverImpl.hpp"
telsoa01ce3e84a2018-08-31 09:31:35 +010015
16#include <CpuExecutor.h>
telsoa015307bc12018-03-09 13:51:08 +000017#include <armnn/ArmNN.hpp>
18
19namespace armnn_driver
20{
Mike Kelly65c42dc2019-07-22 14:06:00 +010021using TimePoint = std::chrono::steady_clock::time_point;
22static const TimePoint g_Min = std::chrono::steady_clock::time_point::min();
telsoa015307bc12018-03-09 13:51:08 +000023
Mike Kelly65c42dc2019-07-22 14:06:00 +010024template<template <typename HalVersion> class PreparedModel, typename HalVersion, typename Callback>
telsoa015307bc12018-03-09 13:51:08 +000025class RequestThread
26{
27public:
28 /// Constructor creates the thread
29 RequestThread();
30
31 /// Destructor terminates the thread
32 ~RequestThread();
33
34 /// Add a message to the thread queue.
35 /// @param[in] model pointer to the prepared model handling the request
36 /// @param[in] memPools pointer to the memory pools vector for the tensors
37 /// @param[in] inputTensors pointer to the input tensors for the request
38 /// @param[in] outputTensors pointer to the output tensors for the request
39 /// @param[in] callback the android notification callback
Mike Kellyb5fdf382019-06-11 16:35:25 +010040 void PostMsg(PreparedModel<HalVersion>* model,
telsoa015307bc12018-03-09 13:51:08 +000041 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools,
42 std::shared_ptr<armnn::InputTensors>& inputTensors,
43 std::shared_ptr<armnn::OutputTensors>& outputTensors,
Mike Kelly65c42dc2019-07-22 14:06:00 +010044 Callback callback);
telsoa015307bc12018-03-09 13:51:08 +000045
46private:
47 RequestThread(const RequestThread&) = delete;
48 RequestThread& operator=(const RequestThread&) = delete;
49
50 /// storage for a prepared model and args for the asyncExecute call
51 struct AsyncExecuteData
52 {
Mike Kellyb5fdf382019-06-11 16:35:25 +010053 AsyncExecuteData(PreparedModel<HalVersion>* model,
telsoa015307bc12018-03-09 13:51:08 +000054 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools,
55 std::shared_ptr<armnn::InputTensors>& inputTensors,
56 std::shared_ptr<armnn::OutputTensors>& outputTensors,
Mike Kelly65c42dc2019-07-22 14:06:00 +010057 Callback callback)
telsoa015307bc12018-03-09 13:51:08 +000058 : m_Model(model)
59 , m_MemPools(memPools)
60 , m_InputTensors(inputTensors)
61 , m_OutputTensors(outputTensors)
Mike Kelly65c42dc2019-07-22 14:06:00 +010062 , m_Callback(callback)
telsoa015307bc12018-03-09 13:51:08 +000063 {
64 }
65
Mike Kellyb5fdf382019-06-11 16:35:25 +010066 PreparedModel<HalVersion>* m_Model;
telsoa015307bc12018-03-09 13:51:08 +000067 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools;
68 std::shared_ptr<armnn::InputTensors> m_InputTensors;
69 std::shared_ptr<armnn::OutputTensors> m_OutputTensors;
Mike Kelly65c42dc2019-07-22 14:06:00 +010070 Callback m_Callback;
telsoa015307bc12018-03-09 13:51:08 +000071 };
72
73 enum class ThreadMsgType
74 {
75 EXIT, // exit the thread
76 REQUEST // user request to process
77 };
78
79 /// storage for the thread message type and data
80 struct ThreadMsg
81 {
82 ThreadMsg(ThreadMsgType msgType,
83 std::shared_ptr<AsyncExecuteData>& msgData)
84 : type(msgType)
85 , data(msgData)
86 {
87 }
88
89 ThreadMsgType type;
90 std::shared_ptr<AsyncExecuteData> data;
91 };
92
93 /// Add a prepared thread message to the thread queue.
94 /// @param[in] threadMsg the message to add to the queue
95 void PostMsg(std::shared_ptr<ThreadMsg>& pThreadMsg);
96
97 /// Entry point for the request thread
98 void Process();
99
100 std::unique_ptr<std::thread> m_Thread;
101 std::queue<std::shared_ptr<ThreadMsg>> m_Queue;
102 std::mutex m_Mutex;
103 std::condition_variable m_Cv;
104};
105
Matthew Bentham9e80cd22019-05-03 22:54:36 +0100106} // namespace armnn_driver