blob: b0b29deffd5b74cf8c7186545334bb5e5f374d48 [file] [log] [blame]
Teresa Charlin83b42912022-07-07 14:24:59 +01001//
Mike Kelly5446a4d2023-01-20 15:51:05 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Teresa Charlin83b42912022-07-07 14:24:59 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "IExecutor.hpp"
9#include "NetworkExecutionUtils/NetworkExecutionUtils.hpp"
10#include "ExecuteNetworkProgramOptions.hpp"
11#include "armnn/utility/NumericCast.hpp"
12#include "armnn/utility/Timer.hpp"
13
14#include <armnn/ArmNN.hpp>
15#include <armnn/Threadpool.hpp>
16#include <armnn/Logging.hpp>
17#include <armnn/utility/Timer.hpp>
18#include <armnn/BackendRegistry.hpp>
19#include <armnn/utility/Assert.hpp>
20#include <armnn/utility/NumericCast.hpp>
21
22#include <armnnUtils/Filesystem.hpp>
23#include <HeapProfiling.hpp>
24
25#include <fmt/format.h>
26
27#if defined(ARMNN_SERIALIZER)
28#include "armnnDeserializer/IDeserializer.hpp"
29#endif
30#if defined(ARMNN_TF_LITE_PARSER)
31#include <armnnTfLiteParser/ITfLiteParser.hpp>
32#endif
33#if defined(ARMNN_ONNX_PARSER)
34#include <armnnOnnxParser/IOnnxParser.hpp>
35#endif
36
37class ArmNNExecutor : public IExecutor
38{
39public:
40 ArmNNExecutor(const ExecuteNetworkParams& params, armnn::IRuntime::CreationOptions runtimeOptions);
41
42 std::vector<const void* > Execute() override;
43 void PrintNetworkInfo() override;
44 void CompareAndPrintResult(std::vector<const void*> otherOutput) override;
45
46private:
47
Mike Kelly5446a4d2023-01-20 15:51:05 +000048 /**
49 * Returns a pointer to the armnn::IRuntime* this will be shared by all ArmNNExecutors.
50 */
51 armnn::IRuntime* GetRuntime(const armnn::IRuntime::CreationOptions& options)
52 {
53 static armnn::IRuntimePtr instance = armnn::IRuntime::Create(options);
54 // Instantiated on first use.
55 return instance.get();
56 }
57
Teresa Charlin83b42912022-07-07 14:24:59 +010058 struct IParser;
59 struct IOInfo;
60 struct IOStorage;
61
62 using BindingPointInfo = armnn::BindingPointInfo;
63
64 std::unique_ptr<IParser> CreateParser();
65
66 void ExecuteAsync();
67 void ExecuteSync();
68 void SetupInputsAndOutputs();
69
70 IOInfo GetIOInfo(armnn::IOptimizedNetwork* optNet);
71
72 void PrintOutputTensors(const armnn::OutputTensors* outputTensors, unsigned int iteration);
73
74 armnn::IOptimizedNetworkPtr OptimizeNetwork(armnn::INetwork* network);
75
76 struct IOStorage
77 {
78 IOStorage(size_t size)
79 {
80 m_Mem = operator new(size);
81 }
82 ~IOStorage()
83 {
84 operator delete(m_Mem);
85 }
86 IOStorage(IOStorage&& rhs)
87 {
88 this->m_Mem = rhs.m_Mem;
89 rhs.m_Mem = nullptr;
90 }
91
92 IOStorage(const IOStorage& rhs) = delete;
93 IOStorage& operator=(IOStorage& rhs) = delete;
94 IOStorage& operator=(IOStorage&& rhs) = delete;
95
96 void* m_Mem;
97 };
98
99 struct IOInfo
100 {
101 std::vector<std::string> m_InputNames;
102 std::vector<std::string> m_OutputNames;
103 std::map<std::string, armnn::BindingPointInfo> m_InputInfoMap;
104 std::map<std::string, armnn::BindingPointInfo> m_OutputInfoMap;
105 };
106
107 IOInfo m_IOInfo;
108 std::vector<IOStorage> m_InputStorage;
109 std::vector<IOStorage> m_OutputStorage;
110 std::vector<armnn::InputTensors> m_InputTensorsVec;
111 std::vector<armnn::OutputTensors> m_OutputTensorsVec;
112 std::vector<std::vector<unsigned int>> m_ImportedInputIds;
113 std::vector<std::vector<unsigned int>> m_ImportedOutputIds;
Mike Kelly5446a4d2023-01-20 15:51:05 +0000114 armnn::IRuntime* m_Runtime;
Teresa Charlin83b42912022-07-07 14:24:59 +0100115 armnn::NetworkId m_NetworkId;
116 ExecuteNetworkParams m_Params;
117
118 struct IParser
119 {
120 virtual armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) = 0;
121 virtual armnn::BindingPointInfo GetInputBindingPointInfo(size_t id, const std::string& inputName) = 0;
122 virtual armnn::BindingPointInfo GetOutputBindingPointInfo(size_t id, const std::string& outputName) = 0;
123
124 virtual ~IParser(){};
125 };
126
127#if defined(ARMNN_SERIALIZER)
128 class ArmNNDeserializer : public IParser
129 {
130 public:
131 ArmNNDeserializer();
132
133 armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) override;
134 armnn::BindingPointInfo GetInputBindingPointInfo(size_t, const std::string& inputName) override;
135 armnn::BindingPointInfo GetOutputBindingPointInfo(size_t, const std::string& outputName) override;
136
137 private:
138 armnnDeserializer::IDeserializerPtr m_Parser;
139 };
140#endif
141
142#if defined(ARMNN_TF_LITE_PARSER)
143 class TfliteParser : public IParser
144 {
145 public:
146 TfliteParser(const ExecuteNetworkParams& params);
147
148 armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) override;
149 armnn::BindingPointInfo GetInputBindingPointInfo(size_t subgraphId, const std::string& inputName) override;
150 armnn::BindingPointInfo GetOutputBindingPointInfo(size_t subgraphId, const std::string& outputName) override;
151
152 private:
153 armnnTfLiteParser::ITfLiteParserPtr m_Parser{nullptr, [](armnnTfLiteParser::ITfLiteParser*){}};
154 };
155#endif
156
157#if defined(ARMNN_ONNX_PARSER)
158 class OnnxParser : public IParser
159 {
160 public:
161 OnnxParser();
162
163 armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) override;
164 armnn::BindingPointInfo GetInputBindingPointInfo(size_t subgraphId, const std::string& inputName) override;
165 armnn::BindingPointInfo GetOutputBindingPointInfo(size_t subgraphId, const std::string& outputName) override;
166
167 private:
168 armnnOnnxParser::IOnnxParserPtr m_Parser;
169 };
170#endif
171};