blob: 484a5318f7d935459b1198d1ac8c69e59ead31c8 [file] [log] [blame]
Sadik Armagan8f397a12022-06-17 15:38:22 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <android-base/logging.h>
9#include <nnapi/IBuffer.h>
10#include <nnapi/IDevice.h>
11#include <nnapi/IPreparedModel.h>
12#include <nnapi/OperandTypes.h>
13#include <nnapi/Result.h>
14#include <nnapi/Types.h>
15#include <nnapi/Validation.h>
16
17#include "ArmnnDevice.hpp"
18#include "ArmnnDriverImpl.hpp"
19#include "Converter.hpp"
20
21#include "ArmnnDriverImpl.hpp"
22#include "ModelToINetworkTransformer.hpp"
23
Kevin May80a9d882022-09-16 10:33:29 +010024#include <armnn/Version.hpp>
Sadik Armagan8f397a12022-06-17 15:38:22 +010025#include <log/log.h>
26namespace armnn_driver
27{
28
29//using namespace android::nn;
30
31class ArmnnDriver : public ArmnnDevice, public IDevice
32{
33public:
34
35 ArmnnDriver(DriverOptions options)
36 : ArmnnDevice(std::move(options))
37 {
38 VLOG(DRIVER) << "ArmnnDriver::ArmnnDriver()";
39 }
40 ~ArmnnDriver()
41 {
42 VLOG(DRIVER) << "ArmnnDriver::~ArmnnDriver()";
Sadik Armagan8f397a12022-06-17 15:38:22 +010043 }
44
45public:
46
47 const std::string& getName() const override
48 {
49 VLOG(DRIVER) << "ArmnnDriver::getName()";
50 static const std::string name = "arm-armnn-sl";
51 return name;
52 }
53
54 const std::string& getVersionString() const override
55 {
56 VLOG(DRIVER) << "ArmnnDriver::getVersionString()";
Kevin May80a9d882022-09-16 10:33:29 +010057 static const std::string versionString = ARMNN_VERSION;
Sadik Armagan8f397a12022-06-17 15:38:22 +010058 return versionString;
59 }
60
61 Version getFeatureLevel() const override
62 {
63 VLOG(DRIVER) << "ArmnnDriver::getFeatureLevel()";
64 return kVersionFeatureLevel5;
65 }
66
67 DeviceType getType() const override
68 {
69 VLOG(DRIVER) << "ArmnnDriver::getType()";
70 return DeviceType::CPU;
71 }
72
73 const std::vector<Extension>& getSupportedExtensions() const override
74 {
75 VLOG(DRIVER) << "ArmnnDriver::getSupportedExtensions()";
76 static const std::vector<Extension> extensions = {};
77 return extensions;
78 }
79
80 const Capabilities& getCapabilities() const override
81 {
82 VLOG(DRIVER) << "ArmnnDriver::GetCapabilities()";
83 return ArmnnDriverImpl::GetCapabilities(m_Runtime);
84 }
85
86 std::pair<uint32_t, uint32_t> getNumberOfCacheFilesNeeded() const override
87 {
88 VLOG(DRIVER) << "ArmnnDriver::getNumberOfCacheFilesNeeded()";
89 unsigned int numberOfCachedModelFiles = 0;
90 for (auto& backend : m_Options.GetBackends())
91 {
92 numberOfCachedModelFiles += GetNumberOfCacheFiles(backend);
93 VLOG(DRIVER) << "ArmnnDriver::getNumberOfCacheFilesNeeded() = " << std::to_string(numberOfCachedModelFiles);
94 }
95 return std::make_pair(numberOfCachedModelFiles, 1ul);
96 }
97
98 GeneralResult<void> wait() const override
99 {
100 VLOG(DRIVER) << "ArmnnDriver::wait()";
101 return {};
102 }
103
104 GeneralResult<std::vector<bool>> getSupportedOperations(const Model& model) const override
105 {
106 VLOG(DRIVER) << "ArmnnDriver::getSupportedOperations()";
107
108 std::stringstream ss;
109 ss << "ArmnnDriverImpl::getSupportedOperations()";
110 std::string fileName;
111 std::string timestamp;
112 if (!m_Options.GetRequestInputsAndOutputsDumpDir().empty())
113 {
114 ss << " : "
115 << m_Options.GetRequestInputsAndOutputsDumpDir()
116 << "/"
117 // << GetFileTimestamp()
118 << "_getSupportedOperations.txt";
119 }
120 VLOG(DRIVER) << ss.str().c_str();
121
122 if (!m_Options.GetRequestInputsAndOutputsDumpDir().empty())
123 {
124 //dump the marker file
125 std::ofstream fileStream;
126 fileStream.open(fileName, std::ofstream::out | std::ofstream::trunc);
127 if (fileStream.good())
128 {
129 fileStream << timestamp << std::endl;
130 fileStream << timestamp << std::endl;
131 }
132 fileStream.close();
133 }
134
135 std::vector<bool> result;
136 if (!m_Runtime)
137 {
138 return NN_ERROR(ErrorStatus::DEVICE_UNAVAILABLE) << "Device Unavailable!";
139 }
140
141 // Run general model validation, if this doesn't pass we shouldn't analyse the model anyway.
142 if (const auto result = validate(model); !result.ok())
143 {
144 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Model!";
145 }
146
147 // Attempt to convert the model to an ArmNN input network (INetwork).
148 ModelToINetworkTransformer modelConverter(m_Options.GetBackends(),
149 model,
150 m_Options.GetForcedUnsupportedOperations());
151
152 if (modelConverter.GetConversionResult() != ConversionResult::Success
153 && modelConverter.GetConversionResult() != ConversionResult::UnsupportedFeature)
154 {
155 return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << "Conversion Error!";
156 }
157
158 // Check each operation if it was converted successfully and copy the flags
159 // into the result (vector<bool>) that we need to return to Android.
160 result.reserve(model.main.operations.size());
161 for (uint32_t operationIdx = 0; operationIdx < model.main.operations.size(); ++operationIdx)
162 {
163 bool operationSupported = modelConverter.IsOperationSupported(operationIdx);
164 result.push_back(operationSupported);
165 }
166
167 return result;
168 }
169
170 GeneralResult<SharedPreparedModel> prepareModel(const Model& model,
171 ExecutionPreference preference,
172 Priority priority,
173 OptionalTimePoint deadline,
174 const std::vector<SharedHandle>& modelCache,
175 const std::vector<SharedHandle>& dataCache,
176 const CacheToken& token,
177 const std::vector<android::nn::TokenValuePair>& hints,
178 const std::vector<android::nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override
179 {
180 VLOG(DRIVER) << "ArmnnDriver::prepareModel()";
181
182 // Validate arguments.
183 if (const auto result = validate(model); !result.ok()) {
184 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Model: " << result.error();
185 }
186 if (const auto result = validate(preference); !result.ok()) {
187 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT)
188 << "Invalid ExecutionPreference: " << result.error();
189 }
190 if (const auto result = validate(priority); !result.ok()) {
191 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Priority: " << result.error();
192 }
193
194 // Check if deadline has passed.
195 if (hasDeadlinePassed(deadline)) {
196 return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT);
197 }
198
199 return ArmnnDriverImpl::PrepareArmnnModel(m_Runtime,
200 m_ClTunedParameters,
201 m_Options,
202 model,
203 modelCache,
204 dataCache,
205 token,
206 model.relaxComputationFloat32toFloat16 && m_Options.GetFp16Enabled(),
207 priority);
208 }
209
210 GeneralResult<SharedPreparedModel> prepareModelFromCache(OptionalTimePoint deadline,
211 const std::vector<SharedHandle>& modelCache,
212 const std::vector<SharedHandle>& dataCache,
213 const CacheToken& token) const override
214 {
215 VLOG(DRIVER) << "ArmnnDriver::prepareModelFromCache()";
216
217 // Check if deadline has passed.
218 if (hasDeadlinePassed(deadline)) {
219 return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT);
220 }
221
222 return ArmnnDriverImpl::PrepareArmnnModelFromCache(
223 m_Runtime,
224 m_ClTunedParameters,
225 m_Options,
226 modelCache,
227 dataCache,
228 token,
229 m_Options.GetFp16Enabled());
230 }
231
232 GeneralResult<SharedBuffer> allocate(const BufferDesc&,
233 const std::vector<SharedPreparedModel>&,
234 const std::vector<BufferRole>&,
235 const std::vector<BufferRole>&) const override
236 {
237 VLOG(DRIVER) << "ArmnnDriver::allocate()";
238 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "ArmnnDriver::allocate -- does not support allocate.";
239 }
240};
241
242} // namespace armnn_driver