blob: 7a65d52bb03f84c80008ec8dd8794a7b0c71e21a [file] [log] [blame]
telsoa015307bc12018-03-09 13:51:08 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beck93e48982018-09-05 13:05:09 +01003// SPDX-License-Identifier: MIT
telsoa015307bc12018-03-09 13:51:08 +00004//
5
6#pragma once
7
surmeh01deb3bdb2018-07-05 12:06:04 +01008#include "ArmnnDriver.hpp"
Matteo Martincighe48bdff2018-09-03 13:50:50 +01009#include "ArmnnDriverImpl.hpp"
arovir01b0717b52018-09-05 17:03:25 +010010#include "RequestThread.hpp"
surmeh01deb3bdb2018-07-05 12:06:04 +010011
telsoa01ce3e84a2018-08-31 09:31:35 +010012#include <NeuralNetworks.h>
13#include <armnn/ArmNN.hpp>
14
telsoa015307bc12018-03-09 13:51:08 +000015#include <string>
16#include <vector>
17
18namespace armnn_driver
19{
20
Matteo Martincighe48bdff2018-09-03 13:50:50 +010021template <typename HalVersion>
Matthew Bentham912b3622019-05-03 15:49:14 +010022class ArmnnPreparedModel : public V1_0::IPreparedModel
telsoa015307bc12018-03-09 13:51:08 +000023{
24public:
Matteo Martincighe48bdff2018-09-03 13:50:50 +010025 using HalModel = typename HalVersion::Model;
26
telsoa015307bc12018-03-09 13:51:08 +000027 ArmnnPreparedModel(armnn::NetworkId networkId,
28 armnn::IRuntime* runtime,
Matteo Martincighe48bdff2018-09-03 13:50:50 +010029 const HalModel& model,
telsoa01ce3e84a2018-08-31 09:31:35 +010030 const std::string& requestInputsAndOutputsDumpDir,
31 const bool gpuProfilingEnabled);
telsoa015307bc12018-03-09 13:51:08 +000032
33 virtual ~ArmnnPreparedModel();
34
35 virtual Return<ErrorStatus> execute(const Request& request,
Matthew Bentham912b3622019-05-03 15:49:14 +010036 const ::android::sp<V1_0::IExecutionCallback>& callback) override;
telsoa015307bc12018-03-09 13:51:08 +000037
38 /// execute the graph prepared from the request
39 void ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
40 std::shared_ptr<armnn::InputTensors>& pInputTensors,
41 std::shared_ptr<armnn::OutputTensors>& pOutputTensors,
Matthew Bentham912b3622019-05-03 15:49:14 +010042 const ::android::sp<V1_0::IExecutionCallback>& callback);
telsoa015307bc12018-03-09 13:51:08 +000043
44 /// Executes this model with dummy inputs (e.g. all zeroes).
Matthew Bentham16196e22019-04-01 17:17:58 +010045 /// \return false on failure, otherwise true
46 bool ExecuteWithDummyInputs();
telsoa015307bc12018-03-09 13:51:08 +000047
48private:
telsoa015307bc12018-03-09 13:51:08 +000049 template <typename TensorBindingCollection>
50 void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
51
Matteo Martincighe48bdff2018-09-03 13:50:50 +010052 armnn::NetworkId m_NetworkId;
53 armnn::IRuntime* m_Runtime;
54 HalModel m_Model;
telsoa015307bc12018-03-09 13:51:08 +000055 // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
56 // It is specific to this class, so it is declared as static here
Matteo Martincighe48bdff2018-09-03 13:50:50 +010057 static RequestThread<HalVersion> m_RequestThread;
58 uint32_t m_RequestCount;
59 const std::string& m_RequestInputsAndOutputsDumpDir;
60 const bool m_GpuProfilingEnabled;
telsoa015307bc12018-03-09 13:51:08 +000061};
62
63}