blob: dc1b535aed079cb4dc9789d148f8ed0a6299988b [file] [log] [blame]
telsoa015307bc12018-03-09 13:51:08 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beck93e48982018-09-05 13:05:09 +01003// SPDX-License-Identifier: MIT
telsoa015307bc12018-03-09 13:51:08 +00004//
5
6#pragma once
7
8#include <queue>
9#include <thread>
10#include <mutex>
11#include <condition_variable>
12
telsoa01ce3e84a2018-08-31 09:31:35 +010013#include "ArmnnDriver.hpp"
Matteo Martincighe48bdff2018-09-03 13:50:50 +010014#include "ArmnnDriverImpl.hpp"
telsoa01ce3e84a2018-08-31 09:31:35 +010015
16#include <CpuExecutor.h>
telsoa015307bc12018-03-09 13:51:08 +000017#include <armnn/ArmNN.hpp>
18
19namespace armnn_driver
20{
21
Mike Kellyb5fdf382019-06-11 16:35:25 +010022template<template <typename HalVersion> class PreparedModel, typename HalVersion>
telsoa015307bc12018-03-09 13:51:08 +000023class RequestThread
24{
25public:
26 /// Constructor creates the thread
27 RequestThread();
28
29 /// Destructor terminates the thread
30 ~RequestThread();
31
32 /// Add a message to the thread queue.
33 /// @param[in] model pointer to the prepared model handling the request
34 /// @param[in] memPools pointer to the memory pools vector for the tensors
35 /// @param[in] inputTensors pointer to the input tensors for the request
36 /// @param[in] outputTensors pointer to the output tensors for the request
37 /// @param[in] callback the android notification callback
Mike Kellyb5fdf382019-06-11 16:35:25 +010038 void PostMsg(PreparedModel<HalVersion>* model,
telsoa015307bc12018-03-09 13:51:08 +000039 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools,
40 std::shared_ptr<armnn::InputTensors>& inputTensors,
41 std::shared_ptr<armnn::OutputTensors>& outputTensors,
Matthew Bentham9e80cd22019-05-03 22:54:36 +010042 const ::android::sp<V1_0::IExecutionCallback>& callback);
telsoa015307bc12018-03-09 13:51:08 +000043
44private:
45 RequestThread(const RequestThread&) = delete;
46 RequestThread& operator=(const RequestThread&) = delete;
47
48 /// storage for a prepared model and args for the asyncExecute call
49 struct AsyncExecuteData
50 {
Mike Kellyb5fdf382019-06-11 16:35:25 +010051 AsyncExecuteData(PreparedModel<HalVersion>* model,
telsoa015307bc12018-03-09 13:51:08 +000052 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools,
53 std::shared_ptr<armnn::InputTensors>& inputTensors,
54 std::shared_ptr<armnn::OutputTensors>& outputTensors,
Matthew Bentham9e80cd22019-05-03 22:54:36 +010055 const ::android::sp<V1_0::IExecutionCallback>& cb)
telsoa015307bc12018-03-09 13:51:08 +000056 : m_Model(model)
57 , m_MemPools(memPools)
58 , m_InputTensors(inputTensors)
59 , m_OutputTensors(outputTensors)
60 , m_callback(cb)
61 {
62 }
63
Mike Kellyb5fdf382019-06-11 16:35:25 +010064 PreparedModel<HalVersion>* m_Model;
telsoa015307bc12018-03-09 13:51:08 +000065 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools;
66 std::shared_ptr<armnn::InputTensors> m_InputTensors;
67 std::shared_ptr<armnn::OutputTensors> m_OutputTensors;
Matthew Bentham9e80cd22019-05-03 22:54:36 +010068 const ::android::sp<V1_0::IExecutionCallback> m_callback;
telsoa015307bc12018-03-09 13:51:08 +000069 };
70
71 enum class ThreadMsgType
72 {
73 EXIT, // exit the thread
74 REQUEST // user request to process
75 };
76
77 /// storage for the thread message type and data
78 struct ThreadMsg
79 {
80 ThreadMsg(ThreadMsgType msgType,
81 std::shared_ptr<AsyncExecuteData>& msgData)
82 : type(msgType)
83 , data(msgData)
84 {
85 }
86
87 ThreadMsgType type;
88 std::shared_ptr<AsyncExecuteData> data;
89 };
90
91 /// Add a prepared thread message to the thread queue.
92 /// @param[in] threadMsg the message to add to the queue
93 void PostMsg(std::shared_ptr<ThreadMsg>& pThreadMsg);
94
95 /// Entry point for the request thread
96 void Process();
97
98 std::unique_ptr<std::thread> m_Thread;
99 std::queue<std::shared_ptr<ThreadMsg>> m_Queue;
100 std::mutex m_Mutex;
101 std::condition_variable m_Cv;
102};
103
Matthew Bentham9e80cd22019-05-03 22:54:36 +0100104} // namespace armnn_driver