blob: c33c61a65b77359d095b248d09c4506e6199e3f1 [file] [log] [blame]
Sadik Armagan8f397a12022-06-17 15:38:22 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <android-base/logging.h>
9#include <nnapi/IBuffer.h>
10#include <nnapi/IDevice.h>
11#include <nnapi/IPreparedModel.h>
12#include <nnapi/OperandTypes.h>
13#include <nnapi/Result.h>
14#include <nnapi/Types.h>
15#include <nnapi/Validation.h>
16
17#include "ArmnnDevice.hpp"
18#include "ArmnnDriverImpl.hpp"
19#include "Converter.hpp"
20
21#include "ArmnnDriverImpl.hpp"
22#include "ModelToINetworkTransformer.hpp"
23
24#include <log/log.h>
25namespace armnn_driver
26{
27
28//using namespace android::nn;
29
30class ArmnnDriver : public ArmnnDevice, public IDevice
31{
32public:
33
34 ArmnnDriver(DriverOptions options)
35 : ArmnnDevice(std::move(options))
36 {
37 VLOG(DRIVER) << "ArmnnDriver::ArmnnDriver()";
38 }
39 ~ArmnnDriver()
40 {
41 VLOG(DRIVER) << "ArmnnDriver::~ArmnnDriver()";
Sadik Armagan8f397a12022-06-17 15:38:22 +010042 }
43
44public:
45
46 const std::string& getName() const override
47 {
48 VLOG(DRIVER) << "ArmnnDriver::getName()";
49 static const std::string name = "arm-armnn-sl";
50 return name;
51 }
52
53 const std::string& getVersionString() const override
54 {
55 VLOG(DRIVER) << "ArmnnDriver::getVersionString()";
56 static const std::string versionString = "ArmNN";
57 return versionString;
58 }
59
60 Version getFeatureLevel() const override
61 {
62 VLOG(DRIVER) << "ArmnnDriver::getFeatureLevel()";
63 return kVersionFeatureLevel5;
64 }
65
66 DeviceType getType() const override
67 {
68 VLOG(DRIVER) << "ArmnnDriver::getType()";
69 return DeviceType::CPU;
70 }
71
72 const std::vector<Extension>& getSupportedExtensions() const override
73 {
74 VLOG(DRIVER) << "ArmnnDriver::getSupportedExtensions()";
75 static const std::vector<Extension> extensions = {};
76 return extensions;
77 }
78
79 const Capabilities& getCapabilities() const override
80 {
81 VLOG(DRIVER) << "ArmnnDriver::GetCapabilities()";
82 return ArmnnDriverImpl::GetCapabilities(m_Runtime);
83 }
84
85 std::pair<uint32_t, uint32_t> getNumberOfCacheFilesNeeded() const override
86 {
87 VLOG(DRIVER) << "ArmnnDriver::getNumberOfCacheFilesNeeded()";
88 unsigned int numberOfCachedModelFiles = 0;
89 for (auto& backend : m_Options.GetBackends())
90 {
91 numberOfCachedModelFiles += GetNumberOfCacheFiles(backend);
92 VLOG(DRIVER) << "ArmnnDriver::getNumberOfCacheFilesNeeded() = " << std::to_string(numberOfCachedModelFiles);
93 }
94 return std::make_pair(numberOfCachedModelFiles, 1ul);
95 }
96
97 GeneralResult<void> wait() const override
98 {
99 VLOG(DRIVER) << "ArmnnDriver::wait()";
100 return {};
101 }
102
103 GeneralResult<std::vector<bool>> getSupportedOperations(const Model& model) const override
104 {
105 VLOG(DRIVER) << "ArmnnDriver::getSupportedOperations()";
106
107 std::stringstream ss;
108 ss << "ArmnnDriverImpl::getSupportedOperations()";
109 std::string fileName;
110 std::string timestamp;
111 if (!m_Options.GetRequestInputsAndOutputsDumpDir().empty())
112 {
113 ss << " : "
114 << m_Options.GetRequestInputsAndOutputsDumpDir()
115 << "/"
116 // << GetFileTimestamp()
117 << "_getSupportedOperations.txt";
118 }
119 VLOG(DRIVER) << ss.str().c_str();
120
121 if (!m_Options.GetRequestInputsAndOutputsDumpDir().empty())
122 {
123 //dump the marker file
124 std::ofstream fileStream;
125 fileStream.open(fileName, std::ofstream::out | std::ofstream::trunc);
126 if (fileStream.good())
127 {
128 fileStream << timestamp << std::endl;
129 fileStream << timestamp << std::endl;
130 }
131 fileStream.close();
132 }
133
134 std::vector<bool> result;
135 if (!m_Runtime)
136 {
137 return NN_ERROR(ErrorStatus::DEVICE_UNAVAILABLE) << "Device Unavailable!";
138 }
139
140 // Run general model validation, if this doesn't pass we shouldn't analyse the model anyway.
141 if (const auto result = validate(model); !result.ok())
142 {
143 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Model!";
144 }
145
146 // Attempt to convert the model to an ArmNN input network (INetwork).
147 ModelToINetworkTransformer modelConverter(m_Options.GetBackends(),
148 model,
149 m_Options.GetForcedUnsupportedOperations());
150
151 if (modelConverter.GetConversionResult() != ConversionResult::Success
152 && modelConverter.GetConversionResult() != ConversionResult::UnsupportedFeature)
153 {
154 return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << "Conversion Error!";
155 }
156
157 // Check each operation if it was converted successfully and copy the flags
158 // into the result (vector<bool>) that we need to return to Android.
159 result.reserve(model.main.operations.size());
160 for (uint32_t operationIdx = 0; operationIdx < model.main.operations.size(); ++operationIdx)
161 {
162 bool operationSupported = modelConverter.IsOperationSupported(operationIdx);
163 result.push_back(operationSupported);
164 }
165
166 return result;
167 }
168
169 GeneralResult<SharedPreparedModel> prepareModel(const Model& model,
170 ExecutionPreference preference,
171 Priority priority,
172 OptionalTimePoint deadline,
173 const std::vector<SharedHandle>& modelCache,
174 const std::vector<SharedHandle>& dataCache,
175 const CacheToken& token,
176 const std::vector<android::nn::TokenValuePair>& hints,
177 const std::vector<android::nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override
178 {
179 VLOG(DRIVER) << "ArmnnDriver::prepareModel()";
180
181 // Validate arguments.
182 if (const auto result = validate(model); !result.ok()) {
183 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Model: " << result.error();
184 }
185 if (const auto result = validate(preference); !result.ok()) {
186 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT)
187 << "Invalid ExecutionPreference: " << result.error();
188 }
189 if (const auto result = validate(priority); !result.ok()) {
190 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Priority: " << result.error();
191 }
192
193 // Check if deadline has passed.
194 if (hasDeadlinePassed(deadline)) {
195 return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT);
196 }
197
198 return ArmnnDriverImpl::PrepareArmnnModel(m_Runtime,
199 m_ClTunedParameters,
200 m_Options,
201 model,
202 modelCache,
203 dataCache,
204 token,
205 model.relaxComputationFloat32toFloat16 && m_Options.GetFp16Enabled(),
206 priority);
207 }
208
209 GeneralResult<SharedPreparedModel> prepareModelFromCache(OptionalTimePoint deadline,
210 const std::vector<SharedHandle>& modelCache,
211 const std::vector<SharedHandle>& dataCache,
212 const CacheToken& token) const override
213 {
214 VLOG(DRIVER) << "ArmnnDriver::prepareModelFromCache()";
215
216 // Check if deadline has passed.
217 if (hasDeadlinePassed(deadline)) {
218 return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT);
219 }
220
221 return ArmnnDriverImpl::PrepareArmnnModelFromCache(
222 m_Runtime,
223 m_ClTunedParameters,
224 m_Options,
225 modelCache,
226 dataCache,
227 token,
228 m_Options.GetFp16Enabled());
229 }
230
231 GeneralResult<SharedBuffer> allocate(const BufferDesc&,
232 const std::vector<SharedPreparedModel>&,
233 const std::vector<BufferRole>&,
234 const std::vector<BufferRole>&) const override
235 {
236 VLOG(DRIVER) << "ArmnnDriver::allocate()";
237 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "ArmnnDriver::allocate -- does not support allocate.";
238 }
239};
240
241} // namespace armnn_driver