blob: a7f004c184b84b0705560d54a08c938f80dd6fdf [file] [log] [blame]
telsoa015307bc12018-03-09 13:51:08 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beck93e48982018-09-05 13:05:09 +01003// SPDX-License-Identifier: MIT
telsoa015307bc12018-03-09 13:51:08 +00004//
5
6#pragma once
7
8#include "RequestThread.hpp"
9
surmeh01deb3bdb2018-07-05 12:06:04 +010010#include "ArmnnDriver.hpp"
Matteo Martincighe48bdff2018-09-03 13:50:50 +010011#include "ArmnnDriverImpl.hpp"
surmeh01deb3bdb2018-07-05 12:06:04 +010012
telsoa01ce3e84a2018-08-31 09:31:35 +010013#include <NeuralNetworks.h>
14#include <armnn/ArmNN.hpp>
15
telsoa015307bc12018-03-09 13:51:08 +000016#include <string>
17#include <vector>
18
19namespace armnn_driver
20{
21
Matteo Martincighe48bdff2018-09-03 13:50:50 +010022template <typename HalVersion>
telsoa015307bc12018-03-09 13:51:08 +000023class ArmnnPreparedModel : public IPreparedModel
24{
25public:
Matteo Martincighe48bdff2018-09-03 13:50:50 +010026 using HalModel = typename HalVersion::Model;
27
telsoa015307bc12018-03-09 13:51:08 +000028 ArmnnPreparedModel(armnn::NetworkId networkId,
29 armnn::IRuntime* runtime,
Matteo Martincighe48bdff2018-09-03 13:50:50 +010030 const HalModel& model,
telsoa01ce3e84a2018-08-31 09:31:35 +010031 const std::string& requestInputsAndOutputsDumpDir,
32 const bool gpuProfilingEnabled);
telsoa015307bc12018-03-09 13:51:08 +000033
34 virtual ~ArmnnPreparedModel();
35
36 virtual Return<ErrorStatus> execute(const Request& request,
37 const ::android::sp<IExecutionCallback>& callback) override;
38
39 /// execute the graph prepared from the request
40 void ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
41 std::shared_ptr<armnn::InputTensors>& pInputTensors,
42 std::shared_ptr<armnn::OutputTensors>& pOutputTensors,
43 const ::android::sp<IExecutionCallback>& callback);
44
45 /// Executes this model with dummy inputs (e.g. all zeroes).
46 void ExecuteWithDummyInputs();
47
48private:
telsoa015307bc12018-03-09 13:51:08 +000049 template <typename TensorBindingCollection>
50 void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
51
Matteo Martincighe48bdff2018-09-03 13:50:50 +010052 armnn::NetworkId m_NetworkId;
53 armnn::IRuntime* m_Runtime;
54 HalModel m_Model;
telsoa015307bc12018-03-09 13:51:08 +000055 // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
56 // It is specific to this class, so it is declared as static here
Matteo Martincighe48bdff2018-09-03 13:50:50 +010057 static RequestThread<HalVersion> m_RequestThread;
58 uint32_t m_RequestCount;
59 const std::string& m_RequestInputsAndOutputsDumpDir;
60 const bool m_GpuProfilingEnabled;
telsoa015307bc12018-03-09 13:51:08 +000061};
62
63}