blob: 6cb06604d2b4eba327096a782903b121669d3b62 [file] [log] [blame]
Sadik Armagan8f397a12022-06-17 15:38:22 +01001//
Kevin May05509cb2023-01-09 16:06:45 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Sadik Armagan8f397a12022-06-17 15:38:22 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <android-base/logging.h>
9#include <nnapi/IBuffer.h>
10#include <nnapi/IDevice.h>
11#include <nnapi/IPreparedModel.h>
12#include <nnapi/OperandTypes.h>
13#include <nnapi/Result.h>
14#include <nnapi/Types.h>
15#include <nnapi/Validation.h>
16
17#include "ArmnnDevice.hpp"
18#include "ArmnnDriverImpl.hpp"
19#include "Converter.hpp"
20
21#include "ArmnnDriverImpl.hpp"
22#include "ModelToINetworkTransformer.hpp"
23
Kevin May80a9d882022-09-16 10:33:29 +010024#include <armnn/Version.hpp>
Sadik Armagan8f397a12022-06-17 15:38:22 +010025#include <log/log.h>
26namespace armnn_driver
27{
28
29//using namespace android::nn;
30
Kevin May05509cb2023-01-09 16:06:45 +000031class ArmnnDriver : public IDevice
Sadik Armagan8f397a12022-06-17 15:38:22 +010032{
Kevin May05509cb2023-01-09 16:06:45 +000033private:
34 std::unique_ptr<ArmnnDevice> m_Device;
Sadik Armagan8f397a12022-06-17 15:38:22 +010035public:
Sadik Armagan8f397a12022-06-17 15:38:22 +010036 ArmnnDriver(DriverOptions options)
Sadik Armagan8f397a12022-06-17 15:38:22 +010037 {
Kevin May05509cb2023-01-09 16:06:45 +000038 try
39 {
40 VLOG(DRIVER) << "ArmnnDriver::ArmnnDriver()";
41 m_Device = std::unique_ptr<ArmnnDevice>(new ArmnnDevice(std::move(options)));
42 }
43 catch (armnn::InvalidArgumentException& ex)
44 {
45 VLOG(DRIVER) << "ArmnnDevice failed to initialise: " << ex.what();
46 }
47 catch (...)
48 {
49 VLOG(DRIVER) << "ArmnnDevice failed to initialise with an unknown error";
50 }
Sadik Armagan8f397a12022-06-17 15:38:22 +010051 }
52
53public:
54
55 const std::string& getName() const override
56 {
57 VLOG(DRIVER) << "ArmnnDriver::getName()";
58 static const std::string name = "arm-armnn-sl";
59 return name;
60 }
61
62 const std::string& getVersionString() const override
63 {
64 VLOG(DRIVER) << "ArmnnDriver::getVersionString()";
Kevin May80a9d882022-09-16 10:33:29 +010065 static const std::string versionString = ARMNN_VERSION;
Sadik Armagan8f397a12022-06-17 15:38:22 +010066 return versionString;
67 }
68
69 Version getFeatureLevel() const override
70 {
71 VLOG(DRIVER) << "ArmnnDriver::getFeatureLevel()";
Kevin May9636a9b2022-09-21 15:41:41 +010072 return kVersionFeatureLevel6;
Sadik Armagan8f397a12022-06-17 15:38:22 +010073 }
74
75 DeviceType getType() const override
76 {
77 VLOG(DRIVER) << "ArmnnDriver::getType()";
78 return DeviceType::CPU;
79 }
80
81 const std::vector<Extension>& getSupportedExtensions() const override
82 {
83 VLOG(DRIVER) << "ArmnnDriver::getSupportedExtensions()";
84 static const std::vector<Extension> extensions = {};
85 return extensions;
86 }
87
88 const Capabilities& getCapabilities() const override
89 {
90 VLOG(DRIVER) << "ArmnnDriver::GetCapabilities()";
Kevin May05509cb2023-01-09 16:06:45 +000091 return ArmnnDriverImpl::GetCapabilities(m_Device->m_Runtime);
Sadik Armagan8f397a12022-06-17 15:38:22 +010092 }
93
94 std::pair<uint32_t, uint32_t> getNumberOfCacheFilesNeeded() const override
95 {
96 VLOG(DRIVER) << "ArmnnDriver::getNumberOfCacheFilesNeeded()";
97 unsigned int numberOfCachedModelFiles = 0;
Kevin May05509cb2023-01-09 16:06:45 +000098 for (auto& backend : m_Device->m_Options.GetBackends())
Sadik Armagan8f397a12022-06-17 15:38:22 +010099 {
100 numberOfCachedModelFiles += GetNumberOfCacheFiles(backend);
Kevin May05509cb2023-01-09 16:06:45 +0000101 VLOG(DRIVER) << "ArmnnDriver::getNumberOfCacheFilesNeeded() = "
102 << std::to_string(numberOfCachedModelFiles);
Sadik Armagan8f397a12022-06-17 15:38:22 +0100103 }
104 return std::make_pair(numberOfCachedModelFiles, 1ul);
105 }
106
107 GeneralResult<void> wait() const override
108 {
109 VLOG(DRIVER) << "ArmnnDriver::wait()";
110 return {};
111 }
112
113 GeneralResult<std::vector<bool>> getSupportedOperations(const Model& model) const override
114 {
115 VLOG(DRIVER) << "ArmnnDriver::getSupportedOperations()";
Kevin May05509cb2023-01-09 16:06:45 +0000116 if (m_Device.get() == nullptr)
117 {
118 return NN_ERROR(ErrorStatus::DEVICE_UNAVAILABLE) << "Device Unavailable!";
119 }
Sadik Armagan8f397a12022-06-17 15:38:22 +0100120
121 std::stringstream ss;
122 ss << "ArmnnDriverImpl::getSupportedOperations()";
123 std::string fileName;
124 std::string timestamp;
Kevin May05509cb2023-01-09 16:06:45 +0000125 if (!m_Device->m_Options.GetRequestInputsAndOutputsDumpDir().empty())
Sadik Armagan8f397a12022-06-17 15:38:22 +0100126 {
127 ss << " : "
Kevin May05509cb2023-01-09 16:06:45 +0000128 << m_Device->m_Options.GetRequestInputsAndOutputsDumpDir()
Sadik Armagan8f397a12022-06-17 15:38:22 +0100129 << "/"
130 // << GetFileTimestamp()
131 << "_getSupportedOperations.txt";
132 }
133 VLOG(DRIVER) << ss.str().c_str();
134
Kevin May05509cb2023-01-09 16:06:45 +0000135 if (!m_Device->m_Options.GetRequestInputsAndOutputsDumpDir().empty())
Sadik Armagan8f397a12022-06-17 15:38:22 +0100136 {
137 //dump the marker file
138 std::ofstream fileStream;
139 fileStream.open(fileName, std::ofstream::out | std::ofstream::trunc);
140 if (fileStream.good())
141 {
142 fileStream << timestamp << std::endl;
143 fileStream << timestamp << std::endl;
144 }
145 fileStream.close();
146 }
147
148 std::vector<bool> result;
Kevin May05509cb2023-01-09 16:06:45 +0000149 if (!m_Device->m_Runtime)
Sadik Armagan8f397a12022-06-17 15:38:22 +0100150 {
151 return NN_ERROR(ErrorStatus::DEVICE_UNAVAILABLE) << "Device Unavailable!";
152 }
153
154 // Run general model validation, if this doesn't pass we shouldn't analyse the model anyway.
155 if (const auto result = validate(model); !result.ok())
156 {
157 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Model!";
158 }
159
160 // Attempt to convert the model to an ArmNN input network (INetwork).
Kevin May05509cb2023-01-09 16:06:45 +0000161 ModelToINetworkTransformer modelConverter(m_Device->m_Options.GetBackends(),
Sadik Armagan8f397a12022-06-17 15:38:22 +0100162 model,
Kevin May05509cb2023-01-09 16:06:45 +0000163 m_Device->m_Options.GetForcedUnsupportedOperations());
Sadik Armagan8f397a12022-06-17 15:38:22 +0100164
165 if (modelConverter.GetConversionResult() != ConversionResult::Success
166 && modelConverter.GetConversionResult() != ConversionResult::UnsupportedFeature)
167 {
168 return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << "Conversion Error!";
169 }
170
171 // Check each operation if it was converted successfully and copy the flags
172 // into the result (vector<bool>) that we need to return to Android.
173 result.reserve(model.main.operations.size());
174 for (uint32_t operationIdx = 0; operationIdx < model.main.operations.size(); ++operationIdx)
175 {
176 bool operationSupported = modelConverter.IsOperationSupported(operationIdx);
177 result.push_back(operationSupported);
178 }
179
180 return result;
181 }
182
183 GeneralResult<SharedPreparedModel> prepareModel(const Model& model,
184 ExecutionPreference preference,
185 Priority priority,
186 OptionalTimePoint deadline,
187 const std::vector<SharedHandle>& modelCache,
188 const std::vector<SharedHandle>& dataCache,
189 const CacheToken& token,
190 const std::vector<android::nn::TokenValuePair>& hints,
191 const std::vector<android::nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override
192 {
193 VLOG(DRIVER) << "ArmnnDriver::prepareModel()";
194
Kevin May05509cb2023-01-09 16:06:45 +0000195 if (m_Device.get() == nullptr)
196 {
197 return NN_ERROR(ErrorStatus::DEVICE_UNAVAILABLE) << "Device Unavailable!";
198 }
Sadik Armagan8f397a12022-06-17 15:38:22 +0100199 // Validate arguments.
200 if (const auto result = validate(model); !result.ok()) {
201 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Model: " << result.error();
202 }
203 if (const auto result = validate(preference); !result.ok()) {
204 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT)
205 << "Invalid ExecutionPreference: " << result.error();
206 }
207 if (const auto result = validate(priority); !result.ok()) {
208 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Priority: " << result.error();
209 }
210
211 // Check if deadline has passed.
212 if (hasDeadlinePassed(deadline)) {
213 return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT);
214 }
215
Kevin May05509cb2023-01-09 16:06:45 +0000216 return ArmnnDriverImpl::PrepareArmnnModel(m_Device->m_Runtime,
217 m_Device->m_ClTunedParameters,
218 m_Device->m_Options,
219 model,
220 modelCache,
221 dataCache,
222 token,
223 model.relaxComputationFloat32toFloat16 && m_Device->m_Options.GetFp16Enabled(),
224 priority);
Sadik Armagan8f397a12022-06-17 15:38:22 +0100225 }
226
227 GeneralResult<SharedPreparedModel> prepareModelFromCache(OptionalTimePoint deadline,
228 const std::vector<SharedHandle>& modelCache,
229 const std::vector<SharedHandle>& dataCache,
230 const CacheToken& token) const override
231 {
232 VLOG(DRIVER) << "ArmnnDriver::prepareModelFromCache()";
Kevin May05509cb2023-01-09 16:06:45 +0000233 if (m_Device.get() == nullptr)
234 {
235 return NN_ERROR(ErrorStatus::DEVICE_UNAVAILABLE) << "Device Unavailable!";
236 }
Sadik Armagan8f397a12022-06-17 15:38:22 +0100237 // Check if deadline has passed.
238 if (hasDeadlinePassed(deadline)) {
239 return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT);
240 }
241
242 return ArmnnDriverImpl::PrepareArmnnModelFromCache(
Kevin May05509cb2023-01-09 16:06:45 +0000243 m_Device->m_Runtime,
244 m_Device->m_ClTunedParameters,
245 m_Device->m_Options,
Sadik Armagan8f397a12022-06-17 15:38:22 +0100246 modelCache,
247 dataCache,
248 token,
Kevin May05509cb2023-01-09 16:06:45 +0000249 m_Device->m_Options.GetFp16Enabled());
Sadik Armagan8f397a12022-06-17 15:38:22 +0100250 }
251
252 GeneralResult<SharedBuffer> allocate(const BufferDesc&,
253 const std::vector<SharedPreparedModel>&,
254 const std::vector<BufferRole>&,
255 const std::vector<BufferRole>&) const override
256 {
257 VLOG(DRIVER) << "ArmnnDriver::allocate()";
258 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "ArmnnDriver::allocate -- does not support allocate.";
259 }
260};
261
262} // namespace armnn_driver