blob: e97e7d7bd5e62feb563aa4fa04aefa8645f07b77 [file] [log] [blame]
Sadik Armagan8f397a12022-06-17 15:38:22 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "ArmnnDriver.hpp"
9#include "ArmnnDriverImpl.hpp"
10#include "ModelToINetworkTransformer.hpp"
11
12#include <armnn/ArmNN.hpp>
13
14#include <BufferTracker.h>
15#include <CpuExecutor.h>
16#include <nnapi/IExecution.h>
17#include <nnapi/IPreparedModel.h>
18#include <nnapi/Result.h>
19#include <nnapi/Types.h>
20
21#include <memory>
22#include <tuple>
23#include <utility>
24#include <vector>
25#include <string>
26
27namespace armnn_driver
28{
29 struct CanonicalExecutionContext
30 {
31 ::android::nn::MeasureTiming measureTimings =
32 ::android::nn::MeasureTiming::NO;
33 android::nn::TimePoint driverStart;
34 android::nn::TimePoint driverEnd;
35 android::nn::TimePoint deviceStart;
36 android::nn::TimePoint deviceEnd;
37 };
38class ArmnnPreparedModel final : public IPreparedModel,
39 public std::enable_shared_from_this<ArmnnPreparedModel>
40{
41public:
42 ArmnnPreparedModel(armnn::NetworkId networkId,
43 armnn::IRuntime* runtime,
44 const Model& model,
45 const std::string& requestInputsAndOutputsDumpDir,
46 const bool gpuProfilingEnabled,
47 Priority priority = Priority::MEDIUM);
48
49 ArmnnPreparedModel(armnn::NetworkId networkId,
50 armnn::IRuntime* runtime,
51 const std::string& requestInputsAndOutputsDumpDir,
52 const bool gpuProfilingEnabled,
53 Priority priority = Priority::MEDIUM,
54 const bool prepareModelFromCache = false);
55
56 virtual ~ArmnnPreparedModel();
57
58 ExecutionResult<std::pair<std::vector<OutputShape>, Timing>> execute(
59 const Request& request,
60 MeasureTiming measureTiming,
61 const OptionalTimePoint& deadline,
62 const OptionalDuration& loopTimeoutDuration,
63 const std::vector<android::nn::TokenValuePair>& hints,
64 const std::vector<android::nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
65
66 GeneralResult<std::pair<SyncFence, ExecuteFencedInfoCallback>> executeFenced(
67 const Request& request,
68 const std::vector<SyncFence>& waitFor,
69 MeasureTiming measureTiming,
70 const OptionalTimePoint& deadline,
71 const OptionalDuration& loopTimeoutDuration,
72 const OptionalDuration& timeoutDurationAfterFence,
73 const std::vector<android::nn::TokenValuePair>& hints,
74 const std::vector<android::nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
75
76 GeneralResult<android::nn::SharedExecution> createReusableExecution(
77 const Request& request,
78 MeasureTiming measureTiming,
79 const OptionalDuration& loopTimeoutDuration,
80 const std::vector<android::nn::TokenValuePair>& hints,
81 const std::vector<android::nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
82
83 GeneralResult<SharedBurst> configureExecutionBurst() const override;
84
85 std::any getUnderlyingResource() const override;
86
87 /// execute the graph prepared from the request
88 ErrorStatus ExecuteGraph(
89 std::shared_ptr<std::vector<android::nn::RunTimePoolInfo>>& pMemPools,
90 armnn::InputTensors& inputTensors,
91 armnn::OutputTensors& outputTensors,
92 CanonicalExecutionContext callback) const;
93
94 Priority GetModelPriority() const;
95
96 /// Executes this model with dummy inputs (e.g. all zeroes).
97 /// \return false on failure, otherwise true
98 bool ExecuteWithDummyInputs(unsigned int numInputs, unsigned int numOutputs) const;
99
100private:
101 void Init();
102 ErrorStatus PrepareMemoryForInputs(
103 armnn::InputTensors& inputs,
104 const Request& request,
105 const std::vector<android::nn::RunTimePoolInfo>& memPools) const;
106
107 ErrorStatus PrepareMemoryForOutputs(
108 armnn::OutputTensors& outputs,
109 std::vector<OutputShape> &outputShapes,
110 const Request& request,
111 const std::vector<android::nn::RunTimePoolInfo>& memPools) const;
112
113 ErrorStatus PrepareMemoryForIO(armnn::InputTensors& inputs,
114 armnn::OutputTensors& outputs,
115 std::vector<android::nn::RunTimePoolInfo>& memPools,
116 const Request& request) const;
117
118 template <typename TensorBindingCollection>
119 void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings) const;
120
121 /// schedule the graph prepared from the request for execution
122 armnn::NetworkId m_NetworkId;
123 armnn::IRuntime* m_Runtime;
124
125 const Model m_Model;
126 const std::string& m_RequestInputsAndOutputsDumpDir;
127 const bool m_GpuProfilingEnabled;
128 Priority m_ModelPriority;
129 const bool m_PrepareFromCache;
130};
131
132}