blob: 877faa667efd0fad591f785eba46ace690f2d5ee [file] [log] [blame]
Sadik Armagan8f397a12022-06-17 15:38:22 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <android-base/logging.h>
9#include <nnapi/IBuffer.h>
10#include <nnapi/IDevice.h>
11#include <nnapi/IPreparedModel.h>
12#include <nnapi/OperandTypes.h>
13#include <nnapi/Result.h>
14#include <nnapi/Types.h>
15#include <nnapi/Validation.h>
16
17#include "ArmnnDevice.hpp"
18#include "ArmnnDriverImpl.hpp"
19#include "Converter.hpp"
20
21#include "ArmnnDriverImpl.hpp"
22#include "ModelToINetworkTransformer.hpp"
23
24#include <log/log.h>
25namespace armnn_driver
26{
27
28//using namespace android::nn;
29
30class ArmnnDriver : public ArmnnDevice, public IDevice
31{
32public:
33
34 ArmnnDriver(DriverOptions options)
35 : ArmnnDevice(std::move(options))
36 {
37 VLOG(DRIVER) << "ArmnnDriver::ArmnnDriver()";
38 }
39 ~ArmnnDriver()
40 {
41 VLOG(DRIVER) << "ArmnnDriver::~ArmnnDriver()";
42 // Unload the networks
43 for (auto& netId : ArmnnDriverImpl::GetLoadedNetworks())
44 {
45 m_Runtime->UnloadNetwork(netId);
46 }
47 ArmnnDriverImpl::ClearNetworks();
48 }
49
50public:
51
52 const std::string& getName() const override
53 {
54 VLOG(DRIVER) << "ArmnnDriver::getName()";
55 static const std::string name = "arm-armnn-sl";
56 return name;
57 }
58
59 const std::string& getVersionString() const override
60 {
61 VLOG(DRIVER) << "ArmnnDriver::getVersionString()";
62 static const std::string versionString = "ArmNN";
63 return versionString;
64 }
65
66 Version getFeatureLevel() const override
67 {
68 VLOG(DRIVER) << "ArmnnDriver::getFeatureLevel()";
69 return kVersionFeatureLevel5;
70 }
71
72 DeviceType getType() const override
73 {
74 VLOG(DRIVER) << "ArmnnDriver::getType()";
75 return DeviceType::CPU;
76 }
77
78 const std::vector<Extension>& getSupportedExtensions() const override
79 {
80 VLOG(DRIVER) << "ArmnnDriver::getSupportedExtensions()";
81 static const std::vector<Extension> extensions = {};
82 return extensions;
83 }
84
85 const Capabilities& getCapabilities() const override
86 {
87 VLOG(DRIVER) << "ArmnnDriver::GetCapabilities()";
88 return ArmnnDriverImpl::GetCapabilities(m_Runtime);
89 }
90
91 std::pair<uint32_t, uint32_t> getNumberOfCacheFilesNeeded() const override
92 {
93 VLOG(DRIVER) << "ArmnnDriver::getNumberOfCacheFilesNeeded()";
94 unsigned int numberOfCachedModelFiles = 0;
95 for (auto& backend : m_Options.GetBackends())
96 {
97 numberOfCachedModelFiles += GetNumberOfCacheFiles(backend);
98 VLOG(DRIVER) << "ArmnnDriver::getNumberOfCacheFilesNeeded() = " << std::to_string(numberOfCachedModelFiles);
99 }
100 return std::make_pair(numberOfCachedModelFiles, 1ul);
101 }
102
103 GeneralResult<void> wait() const override
104 {
105 VLOG(DRIVER) << "ArmnnDriver::wait()";
106 return {};
107 }
108
109 GeneralResult<std::vector<bool>> getSupportedOperations(const Model& model) const override
110 {
111 VLOG(DRIVER) << "ArmnnDriver::getSupportedOperations()";
112
113 std::stringstream ss;
114 ss << "ArmnnDriverImpl::getSupportedOperations()";
115 std::string fileName;
116 std::string timestamp;
117 if (!m_Options.GetRequestInputsAndOutputsDumpDir().empty())
118 {
119 ss << " : "
120 << m_Options.GetRequestInputsAndOutputsDumpDir()
121 << "/"
122 // << GetFileTimestamp()
123 << "_getSupportedOperations.txt";
124 }
125 VLOG(DRIVER) << ss.str().c_str();
126
127 if (!m_Options.GetRequestInputsAndOutputsDumpDir().empty())
128 {
129 //dump the marker file
130 std::ofstream fileStream;
131 fileStream.open(fileName, std::ofstream::out | std::ofstream::trunc);
132 if (fileStream.good())
133 {
134 fileStream << timestamp << std::endl;
135 fileStream << timestamp << std::endl;
136 }
137 fileStream.close();
138 }
139
140 std::vector<bool> result;
141 if (!m_Runtime)
142 {
143 return NN_ERROR(ErrorStatus::DEVICE_UNAVAILABLE) << "Device Unavailable!";
144 }
145
146 // Run general model validation, if this doesn't pass we shouldn't analyse the model anyway.
147 if (const auto result = validate(model); !result.ok())
148 {
149 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Model!";
150 }
151
152 // Attempt to convert the model to an ArmNN input network (INetwork).
153 ModelToINetworkTransformer modelConverter(m_Options.GetBackends(),
154 model,
155 m_Options.GetForcedUnsupportedOperations());
156
157 if (modelConverter.GetConversionResult() != ConversionResult::Success
158 && modelConverter.GetConversionResult() != ConversionResult::UnsupportedFeature)
159 {
160 return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << "Conversion Error!";
161 }
162
163 // Check each operation if it was converted successfully and copy the flags
164 // into the result (vector<bool>) that we need to return to Android.
165 result.reserve(model.main.operations.size());
166 for (uint32_t operationIdx = 0; operationIdx < model.main.operations.size(); ++operationIdx)
167 {
168 bool operationSupported = modelConverter.IsOperationSupported(operationIdx);
169 result.push_back(operationSupported);
170 }
171
172 return result;
173 }
174
175 GeneralResult<SharedPreparedModel> prepareModel(const Model& model,
176 ExecutionPreference preference,
177 Priority priority,
178 OptionalTimePoint deadline,
179 const std::vector<SharedHandle>& modelCache,
180 const std::vector<SharedHandle>& dataCache,
181 const CacheToken& token,
182 const std::vector<android::nn::TokenValuePair>& hints,
183 const std::vector<android::nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override
184 {
185 VLOG(DRIVER) << "ArmnnDriver::prepareModel()";
186
187 // Validate arguments.
188 if (const auto result = validate(model); !result.ok()) {
189 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Model: " << result.error();
190 }
191 if (const auto result = validate(preference); !result.ok()) {
192 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT)
193 << "Invalid ExecutionPreference: " << result.error();
194 }
195 if (const auto result = validate(priority); !result.ok()) {
196 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Priority: " << result.error();
197 }
198
199 // Check if deadline has passed.
200 if (hasDeadlinePassed(deadline)) {
201 return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT);
202 }
203
204 return ArmnnDriverImpl::PrepareArmnnModel(m_Runtime,
205 m_ClTunedParameters,
206 m_Options,
207 model,
208 modelCache,
209 dataCache,
210 token,
211 model.relaxComputationFloat32toFloat16 && m_Options.GetFp16Enabled(),
212 priority);
213 }
214
215 GeneralResult<SharedPreparedModel> prepareModelFromCache(OptionalTimePoint deadline,
216 const std::vector<SharedHandle>& modelCache,
217 const std::vector<SharedHandle>& dataCache,
218 const CacheToken& token) const override
219 {
220 VLOG(DRIVER) << "ArmnnDriver::prepareModelFromCache()";
221
222 // Check if deadline has passed.
223 if (hasDeadlinePassed(deadline)) {
224 return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT);
225 }
226
227 return ArmnnDriverImpl::PrepareArmnnModelFromCache(
228 m_Runtime,
229 m_ClTunedParameters,
230 m_Options,
231 modelCache,
232 dataCache,
233 token,
234 m_Options.GetFp16Enabled());
235 }
236
237 GeneralResult<SharedBuffer> allocate(const BufferDesc&,
238 const std::vector<SharedPreparedModel>&,
239 const std::vector<BufferRole>&,
240 const std::vector<BufferRole>&) const override
241 {
242 VLOG(DRIVER) << "ArmnnDriver::allocate()";
243 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "ArmnnDriver::allocate -- does not support allocate.";
244 }
245};
246
247} // namespace armnn_driver