blob: 33be972fa6b2078da3fec2810f9d80929d7f3f1b [file] [log] [blame]
telsoa015307bc12018-03-09 13:51:08 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beck93e48982018-09-05 13:05:09 +01003// SPDX-License-Identifier: MIT
telsoa015307bc12018-03-09 13:51:08 +00004//
5
6#pragma once
7
surmeh01deb3bdb2018-07-05 12:06:04 +01008#include "ArmnnDriver.hpp"
Matteo Martincighe48bdff2018-09-03 13:50:50 +01009#include "ArmnnDriverImpl.hpp"
arovir01b0717b52018-09-05 17:03:25 +010010#include "RequestThread.hpp"
surmeh01deb3bdb2018-07-05 12:06:04 +010011
telsoa01ce3e84a2018-08-31 09:31:35 +010012#include <NeuralNetworks.h>
13#include <armnn/ArmNN.hpp>
14
telsoa015307bc12018-03-09 13:51:08 +000015#include <string>
16#include <vector>
17
18namespace armnn_driver
19{
Mike Kelly65c42dc2019-07-22 14:06:00 +010020using armnnExecuteCallback_1_0 = std::function<void(V1_0::ErrorStatus status, std::string callingFunction)>;
21
22struct ArmnnCallback_1_0
23{
24 armnnExecuteCallback_1_0 callback;
25};
telsoa015307bc12018-03-09 13:51:08 +000026
Matteo Martincighe48bdff2018-09-03 13:50:50 +010027template <typename HalVersion>
Matthew Bentham912b3622019-05-03 15:49:14 +010028class ArmnnPreparedModel : public V1_0::IPreparedModel
telsoa015307bc12018-03-09 13:51:08 +000029{
30public:
Matteo Martincighe48bdff2018-09-03 13:50:50 +010031 using HalModel = typename HalVersion::Model;
32
telsoa015307bc12018-03-09 13:51:08 +000033 ArmnnPreparedModel(armnn::NetworkId networkId,
34 armnn::IRuntime* runtime,
Matteo Martincighe48bdff2018-09-03 13:50:50 +010035 const HalModel& model,
telsoa01ce3e84a2018-08-31 09:31:35 +010036 const std::string& requestInputsAndOutputsDumpDir,
37 const bool gpuProfilingEnabled);
telsoa015307bc12018-03-09 13:51:08 +000038
39 virtual ~ArmnnPreparedModel();
40
41 virtual Return<ErrorStatus> execute(const Request& request,
Matthew Bentham912b3622019-05-03 15:49:14 +010042 const ::android::sp<V1_0::IExecutionCallback>& callback) override;
telsoa015307bc12018-03-09 13:51:08 +000043
44 /// execute the graph prepared from the request
45 void ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
46 std::shared_ptr<armnn::InputTensors>& pInputTensors,
47 std::shared_ptr<armnn::OutputTensors>& pOutputTensors,
Mike Kelly65c42dc2019-07-22 14:06:00 +010048 ArmnnCallback_1_0 callback);
telsoa015307bc12018-03-09 13:51:08 +000049
50 /// Executes this model with dummy inputs (e.g. all zeroes).
Matthew Bentham16196e22019-04-01 17:17:58 +010051 /// \return false on failure, otherwise true
52 bool ExecuteWithDummyInputs();
telsoa015307bc12018-03-09 13:51:08 +000053
54private:
telsoa015307bc12018-03-09 13:51:08 +000055 template <typename TensorBindingCollection>
56 void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
57
Mike Kelly65c42dc2019-07-22 14:06:00 +010058 armnn::NetworkId m_NetworkId;
59 armnn::IRuntime* m_Runtime;
60 HalModel m_Model;
telsoa015307bc12018-03-09 13:51:08 +000061 // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
62 // It is specific to this class, so it is declared as static here
Mike Kelly65c42dc2019-07-22 14:06:00 +010063 static RequestThread<ArmnnPreparedModel, HalVersion, ArmnnCallback_1_0> m_RequestThread;
64 uint32_t m_RequestCount;
65 const std::string& m_RequestInputsAndOutputsDumpDir;
66 const bool m_GpuProfilingEnabled;
telsoa015307bc12018-03-09 13:51:08 +000067};
68
69}