blob: 79f309a32ecb6b38998bd24cf1a9a2654a026606 [file] [log] [blame]
telsoa015307bc12018-03-09 13:51:08 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beck93e48982018-09-05 13:05:09 +01003// SPDX-License-Identifier: MIT
telsoa015307bc12018-03-09 13:51:08 +00004//
5
6#pragma once
7
8#include <queue>
9#include <thread>
10#include <mutex>
11#include <condition_variable>
12
telsoa01ce3e84a2018-08-31 09:31:35 +010013#include "ArmnnDriver.hpp"
Matteo Martincighe48bdff2018-09-03 13:50:50 +010014#include "ArmnnDriverImpl.hpp"
telsoa01ce3e84a2018-08-31 09:31:35 +010015
16#include <CpuExecutor.h>
telsoa015307bc12018-03-09 13:51:08 +000017#include <armnn/ArmNN.hpp>
18
19namespace armnn_driver
20{
Mike Kelly65c42dc2019-07-22 14:06:00 +010021using TimePoint = std::chrono::steady_clock::time_point;
22static const TimePoint g_Min = std::chrono::steady_clock::time_point::min();
telsoa015307bc12018-03-09 13:51:08 +000023
Derek Lamberti4de83c52020-03-17 13:40:18 +000024template<template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
telsoa015307bc12018-03-09 13:51:08 +000025class RequestThread
26{
27public:
28 /// Constructor creates the thread
29 RequestThread();
30
31 /// Destructor terminates the thread
32 ~RequestThread();
33
34 /// Add a message to the thread queue.
35 /// @param[in] model pointer to the prepared model handling the request
36 /// @param[in] memPools pointer to the memory pools vector for the tensors
37 /// @param[in] inputTensors pointer to the input tensors for the request
38 /// @param[in] outputTensors pointer to the output tensors for the request
39 /// @param[in] callback the android notification callback
Mike Kellyb5fdf382019-06-11 16:35:25 +010040 void PostMsg(PreparedModel<HalVersion>* model,
telsoa015307bc12018-03-09 13:51:08 +000041 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools,
42 std::shared_ptr<armnn::InputTensors>& inputTensors,
43 std::shared_ptr<armnn::OutputTensors>& outputTensors,
Derek Lamberti4de83c52020-03-17 13:40:18 +000044 CallbackContext callbackContext);
telsoa015307bc12018-03-09 13:51:08 +000045
46private:
47 RequestThread(const RequestThread&) = delete;
48 RequestThread& operator=(const RequestThread&) = delete;
49
50 /// storage for a prepared model and args for the asyncExecute call
51 struct AsyncExecuteData
52 {
Mike Kellyb5fdf382019-06-11 16:35:25 +010053 AsyncExecuteData(PreparedModel<HalVersion>* model,
telsoa015307bc12018-03-09 13:51:08 +000054 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools,
55 std::shared_ptr<armnn::InputTensors>& inputTensors,
56 std::shared_ptr<armnn::OutputTensors>& outputTensors,
Derek Lamberti4de83c52020-03-17 13:40:18 +000057 CallbackContext callbackContext)
telsoa015307bc12018-03-09 13:51:08 +000058 : m_Model(model)
59 , m_MemPools(memPools)
60 , m_InputTensors(inputTensors)
61 , m_OutputTensors(outputTensors)
Derek Lamberti4de83c52020-03-17 13:40:18 +000062 , m_CallbackContext(callbackContext)
telsoa015307bc12018-03-09 13:51:08 +000063 {
64 }
65
Mike Kellyb5fdf382019-06-11 16:35:25 +010066 PreparedModel<HalVersion>* m_Model;
telsoa015307bc12018-03-09 13:51:08 +000067 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools;
68 std::shared_ptr<armnn::InputTensors> m_InputTensors;
69 std::shared_ptr<armnn::OutputTensors> m_OutputTensors;
Derek Lamberti4de83c52020-03-17 13:40:18 +000070 CallbackContext m_CallbackContext;
telsoa015307bc12018-03-09 13:51:08 +000071 };
telsoa015307bc12018-03-09 13:51:08 +000072 enum class ThreadMsgType
73 {
74 EXIT, // exit the thread
75 REQUEST // user request to process
76 };
77
78 /// storage for the thread message type and data
79 struct ThreadMsg
80 {
81 ThreadMsg(ThreadMsgType msgType,
82 std::shared_ptr<AsyncExecuteData>& msgData)
83 : type(msgType)
84 , data(msgData)
85 {
86 }
87
88 ThreadMsgType type;
89 std::shared_ptr<AsyncExecuteData> data;
90 };
91
92 /// Add a prepared thread message to the thread queue.
93 /// @param[in] threadMsg the message to add to the queue
94 void PostMsg(std::shared_ptr<ThreadMsg>& pThreadMsg);
95
96 /// Entry point for the request thread
97 void Process();
98
99 std::unique_ptr<std::thread> m_Thread;
100 std::queue<std::shared_ptr<ThreadMsg>> m_Queue;
101 std::mutex m_Mutex;
102 std::condition_variable m_Cv;
103};
104
Matthew Bentham9e80cd22019-05-03 22:54:36 +0100105} // namespace armnn_driver