blob: cbc8607137404042cb0c6f1648b173bffb640531 [file] [log] [blame]
Teresa Charlin83b42912022-07-07 14:24:59 +01001//
Mike Kelly5446a4d2023-01-20 15:51:05 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Teresa Charlin83b42912022-07-07 14:24:59 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "IExecutor.hpp"
9#include "NetworkExecutionUtils/NetworkExecutionUtils.hpp"
10#include "ExecuteNetworkProgramOptions.hpp"
11#include "armnn/utility/NumericCast.hpp"
12#include "armnn/utility/Timer.hpp"
13
14#include <armnn/ArmNN.hpp>
15#include <armnn/Threadpool.hpp>
16#include <armnn/Logging.hpp>
17#include <armnn/utility/Timer.hpp>
18#include <armnn/BackendRegistry.hpp>
19#include <armnn/utility/Assert.hpp>
20#include <armnn/utility/NumericCast.hpp>
21
22#include <armnnUtils/Filesystem.hpp>
23#include <HeapProfiling.hpp>
24
25#include <fmt/format.h>
26
27#if defined(ARMNN_SERIALIZER)
28#include "armnnDeserializer/IDeserializer.hpp"
29#endif
30#if defined(ARMNN_TF_LITE_PARSER)
31#include <armnnTfLiteParser/ITfLiteParser.hpp>
32#endif
33#if defined(ARMNN_ONNX_PARSER)
34#include <armnnOnnxParser/IOnnxParser.hpp>
35#endif
36
37class ArmNNExecutor : public IExecutor
38{
39public:
40 ArmNNExecutor(const ExecuteNetworkParams& params, armnn::IRuntime::CreationOptions runtimeOptions);
Colm Donelanf760c932024-03-25 17:54:04 +000041 ~ArmNNExecutor();
42 ArmNNExecutor(const ArmNNExecutor&) = delete; // No copy constructor.
43 ArmNNExecutor & operator=(const ArmNNExecutor&) = delete; // No Copy operator.
Teresa Charlin83b42912022-07-07 14:24:59 +010044
45 std::vector<const void* > Execute() override;
46 void PrintNetworkInfo() override;
47 void CompareAndPrintResult(std::vector<const void*> otherOutput) override;
48
49private:
Colm Donelanf760c932024-03-25 17:54:04 +000050 ArmNNExecutor(ArmNNExecutor&&); // No move constructor.
51 ArmNNExecutor& operator=(ArmNNExecutor&&); // No move operator.
52
Mike Kelly5446a4d2023-01-20 15:51:05 +000053 /**
54 * Returns a pointer to the armnn::IRuntime* this will be shared by all ArmNNExecutors.
55 */
56 armnn::IRuntime* GetRuntime(const armnn::IRuntime::CreationOptions& options)
57 {
58 static armnn::IRuntimePtr instance = armnn::IRuntime::Create(options);
59 // Instantiated on first use.
60 return instance.get();
61 }
62
Teresa Charlin83b42912022-07-07 14:24:59 +010063 struct IParser;
64 struct IOInfo;
65 struct IOStorage;
66
67 using BindingPointInfo = armnn::BindingPointInfo;
68
69 std::unique_ptr<IParser> CreateParser();
70
71 void ExecuteAsync();
72 void ExecuteSync();
73 void SetupInputsAndOutputs();
74
75 IOInfo GetIOInfo(armnn::IOptimizedNetwork* optNet);
76
77 void PrintOutputTensors(const armnn::OutputTensors* outputTensors, unsigned int iteration);
78
79 armnn::IOptimizedNetworkPtr OptimizeNetwork(armnn::INetwork* network);
80
81 struct IOStorage
82 {
83 IOStorage(size_t size)
84 {
85 m_Mem = operator new(size);
86 }
87 ~IOStorage()
88 {
89 operator delete(m_Mem);
90 }
91 IOStorage(IOStorage&& rhs)
92 {
93 this->m_Mem = rhs.m_Mem;
94 rhs.m_Mem = nullptr;
95 }
96
97 IOStorage(const IOStorage& rhs) = delete;
98 IOStorage& operator=(IOStorage& rhs) = delete;
99 IOStorage& operator=(IOStorage&& rhs) = delete;
100
101 void* m_Mem;
102 };
103
104 struct IOInfo
105 {
106 std::vector<std::string> m_InputNames;
107 std::vector<std::string> m_OutputNames;
108 std::map<std::string, armnn::BindingPointInfo> m_InputInfoMap;
109 std::map<std::string, armnn::BindingPointInfo> m_OutputInfoMap;
110 };
111
112 IOInfo m_IOInfo;
113 std::vector<IOStorage> m_InputStorage;
114 std::vector<IOStorage> m_OutputStorage;
115 std::vector<armnn::InputTensors> m_InputTensorsVec;
116 std::vector<armnn::OutputTensors> m_OutputTensorsVec;
117 std::vector<std::vector<unsigned int>> m_ImportedInputIds;
118 std::vector<std::vector<unsigned int>> m_ImportedOutputIds;
Mike Kelly5446a4d2023-01-20 15:51:05 +0000119 armnn::IRuntime* m_Runtime;
Teresa Charlin83b42912022-07-07 14:24:59 +0100120 armnn::NetworkId m_NetworkId;
121 ExecuteNetworkParams m_Params;
122
123 struct IParser
124 {
125 virtual armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) = 0;
126 virtual armnn::BindingPointInfo GetInputBindingPointInfo(size_t id, const std::string& inputName) = 0;
127 virtual armnn::BindingPointInfo GetOutputBindingPointInfo(size_t id, const std::string& outputName) = 0;
128
129 virtual ~IParser(){};
130 };
131
132#if defined(ARMNN_SERIALIZER)
133 class ArmNNDeserializer : public IParser
134 {
135 public:
136 ArmNNDeserializer();
137
138 armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) override;
139 armnn::BindingPointInfo GetInputBindingPointInfo(size_t, const std::string& inputName) override;
140 armnn::BindingPointInfo GetOutputBindingPointInfo(size_t, const std::string& outputName) override;
141
142 private:
143 armnnDeserializer::IDeserializerPtr m_Parser;
144 };
145#endif
146
147#if defined(ARMNN_TF_LITE_PARSER)
148 class TfliteParser : public IParser
149 {
150 public:
151 TfliteParser(const ExecuteNetworkParams& params);
152
153 armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) override;
154 armnn::BindingPointInfo GetInputBindingPointInfo(size_t subgraphId, const std::string& inputName) override;
155 armnn::BindingPointInfo GetOutputBindingPointInfo(size_t subgraphId, const std::string& outputName) override;
156
157 private:
158 armnnTfLiteParser::ITfLiteParserPtr m_Parser{nullptr, [](armnnTfLiteParser::ITfLiteParser*){}};
159 };
160#endif
161
162#if defined(ARMNN_ONNX_PARSER)
163 class OnnxParser : public IParser
164 {
165 public:
166 OnnxParser();
167
168 armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) override;
169 armnn::BindingPointInfo GetInputBindingPointInfo(size_t subgraphId, const std::string& inputName) override;
170 armnn::BindingPointInfo GetOutputBindingPointInfo(size_t subgraphId, const std::string& outputName) override;
171
172 private:
173 armnnOnnxParser::IOnnxParserPtr m_Parser;
174 };
175#endif
176};