blob: 3c4b32b768f89c8eba2c3b277cdcea202a46e082 [file] [log] [blame]
telsoa015307bc12018-03-09 13:51:08 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beck93e48982018-09-05 13:05:09 +01003// SPDX-License-Identifier: MIT
telsoa015307bc12018-03-09 13:51:08 +00004//
5
6#pragma once
7
surmeh01deb3bdb2018-07-05 12:06:04 +01008#include "ArmnnDriver.hpp"
Matteo Martincighe48bdff2018-09-03 13:50:50 +01009#include "ArmnnDriverImpl.hpp"
arovir01b0717b52018-09-05 17:03:25 +010010#include "RequestThread.hpp"
surmeh01deb3bdb2018-07-05 12:06:04 +010011
telsoa01ce3e84a2018-08-31 09:31:35 +010012#include <NeuralNetworks.h>
13#include <armnn/ArmNN.hpp>
14
telsoa015307bc12018-03-09 13:51:08 +000015#include <string>
16#include <vector>
17
18namespace armnn_driver
19{
20
Matteo Martincighe48bdff2018-09-03 13:50:50 +010021template <typename HalVersion>
telsoa015307bc12018-03-09 13:51:08 +000022class ArmnnPreparedModel : public IPreparedModel
23{
24public:
Matteo Martincighe48bdff2018-09-03 13:50:50 +010025 using HalModel = typename HalVersion::Model;
26
telsoa015307bc12018-03-09 13:51:08 +000027 ArmnnPreparedModel(armnn::NetworkId networkId,
28 armnn::IRuntime* runtime,
Matteo Martincighe48bdff2018-09-03 13:50:50 +010029 const HalModel& model,
telsoa01ce3e84a2018-08-31 09:31:35 +010030 const std::string& requestInputsAndOutputsDumpDir,
31 const bool gpuProfilingEnabled);
telsoa015307bc12018-03-09 13:51:08 +000032
33 virtual ~ArmnnPreparedModel();
34
35 virtual Return<ErrorStatus> execute(const Request& request,
36 const ::android::sp<IExecutionCallback>& callback) override;
37
38 /// execute the graph prepared from the request
39 void ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
40 std::shared_ptr<armnn::InputTensors>& pInputTensors,
41 std::shared_ptr<armnn::OutputTensors>& pOutputTensors,
42 const ::android::sp<IExecutionCallback>& callback);
43
44 /// Executes this model with dummy inputs (e.g. all zeroes).
45 void ExecuteWithDummyInputs();
46
47private:
telsoa015307bc12018-03-09 13:51:08 +000048 template <typename TensorBindingCollection>
49 void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
50
Matteo Martincighe48bdff2018-09-03 13:50:50 +010051 armnn::NetworkId m_NetworkId;
52 armnn::IRuntime* m_Runtime;
53 HalModel m_Model;
telsoa015307bc12018-03-09 13:51:08 +000054 // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
55 // It is specific to this class, so it is declared as static here
Matteo Martincighe48bdff2018-09-03 13:50:50 +010056 static RequestThread<HalVersion> m_RequestThread;
57 uint32_t m_RequestCount;
58 const std::string& m_RequestInputsAndOutputsDumpDir;
59 const bool m_GpuProfilingEnabled;
telsoa015307bc12018-03-09 13:51:08 +000060};
61
62}