blob: 836bf469ccaa26bcfa6d9ddac49cea60155ca31e [file] [log] [blame]
Sadik Armagan8f397a12022-06-17 15:38:22 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "DriverOptions.hpp"
9
10#include <armnn/ArmNN.hpp>
11
12#include <nnapi/IPreparedModel.h>
13#include <nnapi/Result.h>
14#include <nnapi/TypeUtils.h>
15#include <nnapi/Types.h>
16#include <nnapi/Validation.h>
17
18using namespace android::nn;
19
20namespace armnn_driver
21{
22
23class ArmnnDriverImpl
24{
25public:
26 static GeneralResult<SharedPreparedModel> PrepareArmnnModel(
27 const armnn::IRuntimePtr& runtime,
28 const armnn::IGpuAccTunedParametersPtr& clTunedParameters,
29 const DriverOptions& options,
30 const Model& model,
31 const std::vector<SharedHandle>& modelCacheHandle,
32 const std::vector<SharedHandle>& dataCacheHandle,
33 const CacheToken& token,
34 bool float32ToFloat16 = false,
35 Priority priority = Priority::MEDIUM);
36
37 static GeneralResult<SharedPreparedModel> PrepareArmnnModelFromCache(
38 const armnn::IRuntimePtr& runtime,
39 const armnn::IGpuAccTunedParametersPtr& clTunedParameters,
40 const DriverOptions& options,
41 const std::vector<SharedHandle>& modelCacheHandle,
42 const std::vector<SharedHandle>& dataCacheHandle,
43 const CacheToken& token,
44 bool float32ToFloat16 = false);
45
46 static const Capabilities& GetCapabilities(const armnn::IRuntimePtr& runtime);
47
48 static std::vector<armnn::NetworkId>& GetLoadedNetworks();
49
50 static void ClearNetworks();
51
52private:
53 static bool ValidateSharedHandle(const SharedHandle& sharedHandle);
54 static bool ValidateDataCacheHandle(const std::vector<SharedHandle>& dataCacheHandle, const size_t dataSize);
55
56 static std::vector<armnn::NetworkId> m_NetworkIDs;
57};
58
59} // namespace armnn_driver