blob: c4adc9e120354cafed583d7437eb706eda4f2d30 [file] [log] [blame]
Teresa Charlin83b42912022-07-07 14:24:59 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "IExecutor.hpp"
9#include "NetworkExecutionUtils/NetworkExecutionUtils.hpp"
10#include "ExecuteNetworkProgramOptions.hpp"
11#include "armnn/utility/NumericCast.hpp"
12#include "armnn/utility/Timer.hpp"
13
14#include <armnn/ArmNN.hpp>
15#include <armnn/Threadpool.hpp>
16#include <armnn/Logging.hpp>
17#include <armnn/utility/Timer.hpp>
18#include <armnn/BackendRegistry.hpp>
19#include <armnn/utility/Assert.hpp>
20#include <armnn/utility/NumericCast.hpp>
21
22#include <armnnUtils/Filesystem.hpp>
23#include <HeapProfiling.hpp>
24
25#include <fmt/format.h>
26
27#if defined(ARMNN_SERIALIZER)
28#include "armnnDeserializer/IDeserializer.hpp"
29#endif
30#if defined(ARMNN_TF_LITE_PARSER)
31#include <armnnTfLiteParser/ITfLiteParser.hpp>
32#endif
33#if defined(ARMNN_ONNX_PARSER)
34#include <armnnOnnxParser/IOnnxParser.hpp>
35#endif
36
37class ArmNNExecutor : public IExecutor
38{
39public:
40 ArmNNExecutor(const ExecuteNetworkParams& params, armnn::IRuntime::CreationOptions runtimeOptions);
41
42 std::vector<const void* > Execute() override;
43 void PrintNetworkInfo() override;
44 void CompareAndPrintResult(std::vector<const void*> otherOutput) override;
45
46private:
47
48 struct IParser;
49 struct IOInfo;
50 struct IOStorage;
51
52 using BindingPointInfo = armnn::BindingPointInfo;
53
54 std::unique_ptr<IParser> CreateParser();
55
56 void ExecuteAsync();
57 void ExecuteSync();
58 void SetupInputsAndOutputs();
59
60 IOInfo GetIOInfo(armnn::IOptimizedNetwork* optNet);
61
62 void PrintOutputTensors(const armnn::OutputTensors* outputTensors, unsigned int iteration);
63
64 armnn::IOptimizedNetworkPtr OptimizeNetwork(armnn::INetwork* network);
65
66 struct IOStorage
67 {
68 IOStorage(size_t size)
69 {
70 m_Mem = operator new(size);
71 }
72 ~IOStorage()
73 {
74 operator delete(m_Mem);
75 }
76 IOStorage(IOStorage&& rhs)
77 {
78 this->m_Mem = rhs.m_Mem;
79 rhs.m_Mem = nullptr;
80 }
81
82 IOStorage(const IOStorage& rhs) = delete;
83 IOStorage& operator=(IOStorage& rhs) = delete;
84 IOStorage& operator=(IOStorage&& rhs) = delete;
85
86 void* m_Mem;
87 };
88
89 struct IOInfo
90 {
91 std::vector<std::string> m_InputNames;
92 std::vector<std::string> m_OutputNames;
93 std::map<std::string, armnn::BindingPointInfo> m_InputInfoMap;
94 std::map<std::string, armnn::BindingPointInfo> m_OutputInfoMap;
95 };
96
97 IOInfo m_IOInfo;
98 std::vector<IOStorage> m_InputStorage;
99 std::vector<IOStorage> m_OutputStorage;
100 std::vector<armnn::InputTensors> m_InputTensorsVec;
101 std::vector<armnn::OutputTensors> m_OutputTensorsVec;
102 std::vector<std::vector<unsigned int>> m_ImportedInputIds;
103 std::vector<std::vector<unsigned int>> m_ImportedOutputIds;
104 std::shared_ptr<armnn::IRuntime> m_Runtime;
105 armnn::NetworkId m_NetworkId;
106 ExecuteNetworkParams m_Params;
107
108 struct IParser
109 {
110 virtual armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) = 0;
111 virtual armnn::BindingPointInfo GetInputBindingPointInfo(size_t id, const std::string& inputName) = 0;
112 virtual armnn::BindingPointInfo GetOutputBindingPointInfo(size_t id, const std::string& outputName) = 0;
113
114 virtual ~IParser(){};
115 };
116
117#if defined(ARMNN_SERIALIZER)
118 class ArmNNDeserializer : public IParser
119 {
120 public:
121 ArmNNDeserializer();
122
123 armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) override;
124 armnn::BindingPointInfo GetInputBindingPointInfo(size_t, const std::string& inputName) override;
125 armnn::BindingPointInfo GetOutputBindingPointInfo(size_t, const std::string& outputName) override;
126
127 private:
128 armnnDeserializer::IDeserializerPtr m_Parser;
129 };
130#endif
131
132#if defined(ARMNN_TF_LITE_PARSER)
133 class TfliteParser : public IParser
134 {
135 public:
136 TfliteParser(const ExecuteNetworkParams& params);
137
138 armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) override;
139 armnn::BindingPointInfo GetInputBindingPointInfo(size_t subgraphId, const std::string& inputName) override;
140 armnn::BindingPointInfo GetOutputBindingPointInfo(size_t subgraphId, const std::string& outputName) override;
141
142 private:
143 armnnTfLiteParser::ITfLiteParserPtr m_Parser{nullptr, [](armnnTfLiteParser::ITfLiteParser*){}};
144 };
145#endif
146
147#if defined(ARMNN_ONNX_PARSER)
148 class OnnxParser : public IParser
149 {
150 public:
151 OnnxParser();
152
153 armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) override;
154 armnn::BindingPointInfo GetInputBindingPointInfo(size_t subgraphId, const std::string& inputName) override;
155 armnn::BindingPointInfo GetOutputBindingPointInfo(size_t subgraphId, const std::string& outputName) override;
156
157 private:
158 armnnOnnxParser::IOnnxParserPtr m_Parser;
159 };
160#endif
161};