blob: e07075fbb593e2524cef8b61ba40f3593207bc0b [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Teresa Charlin52664732020-06-29 16:27:03 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
7#include <armnn/DescriptorsFwd.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +01008#include <armnn/LstmParams.hpp>
James Conroyee18dc82019-07-17 11:27:46 +01009#include <armnn/QuantizedLstmParams.hpp>
telsoa014fcda012018-03-09 14:13:49 +000010#include <armnn/TensorFwd.hpp>
11#include <armnn/Types.hpp>
12
13#include <armnn/INetwork.hpp>
14
15#include <string>
16#include <vector>
Derek Lamberti84da38b2019-06-13 11:40:08 +010017#include <map>
telsoa014fcda012018-03-09 14:13:49 +000018#include <memory>
19
Derek Lamberti4a9e24b2020-01-03 16:53:38 +000020#include "Graph.hpp"
telsoa014fcda012018-03-09 14:13:49 +000021#include "Layer.hpp"
Francis Murtagh3d2b4b22021-02-15 18:23:17 +000022#include "OptimizedNetworkImpl.hpp"
telsoa014fcda012018-03-09 14:13:49 +000023
24namespace armnn
25{
26class Graph;
27
Francis Murtagh3d2b4b22021-02-15 18:23:17 +000028using NetworkImplPtr = std::unique_ptr<NetworkImpl, void(*)(NetworkImpl* network)>;
29
telsoa01c577f2c2018-08-31 09:22:23 +010030/// Private implementation of INetwork.
Francis Murtagh3d2b4b22021-02-15 18:23:17 +000031class NetworkImpl
telsoa014fcda012018-03-09 14:13:49 +000032{
33public:
Francis Murtagh3d2b4b22021-02-15 18:23:17 +000034 NetworkImpl(NetworkOptions networkOptions = {});
35 ~NetworkImpl();
telsoa014fcda012018-03-09 14:13:49 +000036
37 const Graph& GetGraph() const { return *m_Graph; }
38
Francis Murtagh3d2b4b22021-02-15 18:23:17 +000039 Status PrintGraph();
telsoa014fcda012018-03-09 14:13:49 +000040
Francis Murtagh3d2b4b22021-02-15 18:23:17 +000041 IConnectableLayer* AddInputLayer(LayerBindingId id, const char* name=nullptr);
telsoa014fcda012018-03-09 14:13:49 +000042
Nikhil Rajee391d52019-09-05 17:50:44 +010043 IConnectableLayer* AddArgMinMaxLayer(const ArgMinMaxDescriptor& desc,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +000044 const char* name = nullptr);
Nikhil Rajee391d52019-09-05 17:50:44 +010045
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +000046 IConnectableLayer* AddBatchToSpaceNdLayer(const BatchToSpaceNdDescriptor& batchToSpaceNdDescriptor,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +000047 const char* name = nullptr);
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +000048
mathad01b392e982021-04-07 12:07:30 +010049 IConnectableLayer* AddCastLayer(const char* name = nullptr);
50
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +010051 IConnectableLayer* AddComparisonLayer(const ComparisonDescriptor& comparisonDescriptor,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +000052 const char* name = nullptr);
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +010053
Jim Flynne242f2d2019-05-22 14:24:13 +010054 IConnectableLayer* AddConcatLayer(const ConcatDescriptor& concatDescriptor,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +000055 const char* name = nullptr);
Jim Flynn906f9462019-05-10 13:55:21 +010056
telsoa014fcda012018-03-09 14:13:49 +000057 IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
Aron Virginas-Tarad402702019-02-22 17:03:44 +000058 const ConstTensor& weights,
59 const Optional<ConstTensor>& biases,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +000060 const char* name = nullptr);
Aron Virginas-Tarad402702019-02-22 17:03:44 +000061
Matteo Martincighfc598e12019-05-14 10:36:13 +010062 ARMNN_DEPRECATED_MSG("This AddConvolution2dLayer overload is deprecated")
Aron Virginas-Tarad402702019-02-22 17:03:44 +000063 IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
64 const ConstTensor& weights,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +000065 const char* name = nullptr);
Aron Virginas-Tarad402702019-02-22 17:03:44 +000066
Matteo Martincighfc598e12019-05-14 10:36:13 +010067 ARMNN_DEPRECATED_MSG("This AddConvolution2dLayer overload is deprecated")
Aron Virginas-Tarad402702019-02-22 17:03:44 +000068 IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
69 const ConstTensor& weights,
70 const ConstTensor& biases,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +000071 const char* name = nullptr);
Aron Virginas-Tarad402702019-02-22 17:03:44 +000072
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +010073 IConnectableLayer* AddDepthToSpaceLayer(const DepthToSpaceDescriptor& depthToSpaceDescriptor,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +000074 const char* name = nullptr);
Aron Virginas-Tardd6247f2019-09-19 14:31:17 +010075
Aron Virginas-Tarad402702019-02-22 17:03:44 +000076 IConnectableLayer* AddDepthwiseConvolution2dLayer(
77 const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
78 const ConstTensor& weights,
79 const Optional<ConstTensor>& biases,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +000080 const char* name = nullptr);
Aron Virginas-Tarad402702019-02-22 17:03:44 +000081
Matteo Martincighfc598e12019-05-14 10:36:13 +010082 ARMNN_DEPRECATED_MSG("This AddDepthwiseConvolution2dLayer overload is deprecated")
Aron Virginas-Tarad402702019-02-22 17:03:44 +000083 IConnectableLayer* AddDepthwiseConvolution2dLayer(
84 const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
telsoa014fcda012018-03-09 14:13:49 +000085 const ConstTensor& weights,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +000086 const char* name = nullptr);
telsoa014fcda012018-03-09 14:13:49 +000087
Matteo Martincighfc598e12019-05-14 10:36:13 +010088 ARMNN_DEPRECATED_MSG("This AddDepthwiseConvolution2dLayer overload is deprecated")
Aron Virginas-Tarad402702019-02-22 17:03:44 +000089 IConnectableLayer* AddDepthwiseConvolution2dLayer(
90 const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
telsoa014fcda012018-03-09 14:13:49 +000091 const ConstTensor& weights,
92 const ConstTensor& biases,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +000093 const char* name = nullptr);
telsoa014fcda012018-03-09 14:13:49 +000094
Francis Murtagh3d2b4b22021-02-15 18:23:17 +000095 IConnectableLayer* AddDequantizeLayer(const char* name = nullptr);
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +000096
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +000097 IConnectableLayer* AddDetectionPostProcessLayer(
98 const DetectionPostProcessDescriptor& descriptor,
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +000099 const ConstTensor& anchors,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000100 const char* name = nullptr);
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000101
josh minor4a3c6102020-01-06 16:40:46 -0600102 IConnectableLayer* AddElementwiseUnaryLayer(const ElementwiseUnaryDescriptor& elementwiseUnaryDescriptor,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000103 const char* name = nullptr);
josh minor4a3c6102020-01-06 16:40:46 -0600104
Ryan OSheaec6c6802020-06-05 17:17:06 +0100105 IConnectableLayer* AddFillLayer(const FillDescriptor& fillDescriptor,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000106 const char* name = nullptr);
Ryan OSheaec6c6802020-06-05 17:17:06 +0100107
telsoa014fcda012018-03-09 14:13:49 +0000108 IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000109 const Optional<ConstTensor>& weights,
110 const Optional<ConstTensor>& biases,
111 const char* name = nullptr);
112
113 IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
Aron Virginas-Tarad402702019-02-22 17:03:44 +0000114 const ConstTensor& weights,
115 const Optional<ConstTensor>& biases,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000116 const char* name = nullptr);
telsoa014fcda012018-03-09 14:13:49 +0000117
Matteo Martincighfc598e12019-05-14 10:36:13 +0100118 ARMNN_DEPRECATED_MSG("This AddFullyConnectedLayer overload is deprecated")
telsoa014fcda012018-03-09 14:13:49 +0000119 IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
Aron Virginas-Tarad402702019-02-22 17:03:44 +0000120 const ConstTensor& weights,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000121 const char* name = nullptr);
Aron Virginas-Tarad402702019-02-22 17:03:44 +0000122
Matteo Martincighfc598e12019-05-14 10:36:13 +0100123 ARMNN_DEPRECATED_MSG("This AddFullyConnectedLayer overload is deprecated")
Aron Virginas-Tarad402702019-02-22 17:03:44 +0000124 IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
125 const ConstTensor& weights,
126 const ConstTensor& biases,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000127 const char* name = nullptr);
telsoa014fcda012018-03-09 14:13:49 +0000128
Teresa Charlin52664732020-06-29 16:27:03 +0100129 ARMNN_DEPRECATED_MSG("This AddGatherLayer overload is deprecated")
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000130 IConnectableLayer* AddGatherLayer(const char* name = nullptr);
narpra01b89b05f2019-01-16 09:53:09 +0000131
Teresa Charlin52664732020-06-29 16:27:03 +0100132 IConnectableLayer* AddGatherLayer(const GatherDescriptor& gatherDescriptor,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000133 const char* name = nullptr);
Teresa Charlin52664732020-06-29 16:27:03 +0100134
telsoa014fcda012018-03-09 14:13:49 +0000135 IConnectableLayer* AddPermuteLayer(const PermuteDescriptor& permuteDescriptor,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000136 const char* name = nullptr);
telsoa014fcda012018-03-09 14:13:49 +0000137
138 IConnectableLayer* AddPooling2dLayer(const Pooling2dDescriptor& pooling2dDescriptor,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000139 const char* name = nullptr);
telsoa014fcda012018-03-09 14:13:49 +0000140
141 IConnectableLayer* AddActivationLayer(const ActivationDescriptor& activationDescriptor,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000142 const char* name = nullptr);
telsoa014fcda012018-03-09 14:13:49 +0000143
144 IConnectableLayer* AddNormalizationLayer(const NormalizationDescriptor& normalizationDescriptor,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000145 const char* name = nullptr);
telsoa014fcda012018-03-09 14:13:49 +0000146
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000147 IConnectableLayer* AddSliceLayer(const SliceDescriptor& sliceDescriptor, const char* name = nullptr);
Aron Virginas-Tar636ab402019-09-16 14:27:45 +0100148
telsoa014fcda012018-03-09 14:13:49 +0000149 IConnectableLayer* AddSoftmaxLayer(const SoftmaxDescriptor& softmaxDescriptor,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000150 const char* name = nullptr);
telsoa014fcda012018-03-09 14:13:49 +0000151
152 IConnectableLayer* AddSplitterLayer(const ViewsDescriptor& splitterDescriptor,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000153 const char* name = nullptr);
telsoa014fcda012018-03-09 14:13:49 +0000154
Jim Flynn906f9462019-05-10 13:55:21 +0100155 ARMNN_DEPRECATED_MSG("Use AddConcatLayer instead")
Jim Flynne242f2d2019-05-22 14:24:13 +0100156 IConnectableLayer* AddMergerLayer(const MergerDescriptor& mergerDescriptor,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000157 const char* name = nullptr);
telsoa014fcda012018-03-09 14:13:49 +0000158
josh minor4a3c6102020-01-06 16:40:46 -0600159 ARMNN_DEPRECATED_MSG("Use AddElementwiseUnaryLayer instead")
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000160 IConnectableLayer* AddAbsLayer(const char* name = nullptr);
Kevin May868eb142019-09-04 17:29:31 +0100161
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000162 IConnectableLayer* AddAdditionLayer(const char* name = nullptr);
telsoa014fcda012018-03-09 14:13:49 +0000163
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000164 IConnectableLayer* AddMultiplicationLayer(const char* name = nullptr);
telsoa014fcda012018-03-09 14:13:49 +0000165
166 IConnectableLayer* AddBatchNormalizationLayer(const BatchNormalizationDescriptor& desc,
167 const ConstTensor& mean,
168 const ConstTensor& variance,
169 const ConstTensor& beta,
170 const ConstTensor& gamma,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000171 const char* name = nullptr);
telsoa014fcda012018-03-09 14:13:49 +0000172
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000173 IConnectableLayer* AddRankLayer(const char* name = nullptr);
Finn Williams2605b232020-06-10 15:53:46 +0100174
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +0100175 ARMNN_DEPRECATED_MSG("Use AddResizeLayer instead")
telsoa014fcda012018-03-09 14:13:49 +0000176 IConnectableLayer* AddResizeBilinearLayer(const ResizeBilinearDescriptor& resizeDesc,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000177 const char* name = nullptr);
telsoa014fcda012018-03-09 14:13:49 +0000178
Teresa Charlina9075df2019-06-27 15:41:57 +0100179 IConnectableLayer* AddResizeLayer(const ResizeDescriptor& resizeDescriptor,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000180 const char* name = nullptr);
Teresa Charlina9075df2019-06-27 15:41:57 +0100181
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +0000182 IConnectableLayer* AddReduceLayer(const ReduceDescriptor& reduceDescriptor,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000183 const char* name = nullptr);
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +0000184
Kevin Mayce5045a2019-10-02 14:07:47 +0100185 IConnectableLayer* AddInstanceNormalizationLayer(const InstanceNormalizationDescriptor& desc,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000186 const char* name = nullptr);
Kevin Mayce5045a2019-10-02 14:07:47 +0100187
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100188 IConnectableLayer* AddL2NormalizationLayer(const L2NormalizationDescriptor& desc,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000189 const char* name = nullptr);
telsoa014fcda012018-03-09 14:13:49 +0000190
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100191 IConnectableLayer* AddLogSoftmaxLayer(const LogSoftmaxDescriptor& logSoftmaxDescriptor,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000192 const char* name = nullptr);
Aron Virginas-Tarf982dea2019-10-11 14:07:53 +0100193
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000194 IConnectableLayer* AddConstantLayer(const ConstTensor& input, const char* name = nullptr);
telsoa014fcda012018-03-09 14:13:49 +0000195
196 IConnectableLayer* AddReshapeLayer(const ReshapeDescriptor& reshapeDescriptor,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000197 const char* name = nullptr);
telsoa014fcda012018-03-09 14:13:49 +0000198
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000199 IConnectableLayer* AddSpaceToBatchNdLayer(const SpaceToBatchNdDescriptor& spaceToBatchNdDescriptor,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000200 const char* name = nullptr);
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000201
Aron Virginas-Tar972af152019-06-11 14:14:03 +0100202 IConnectableLayer* AddSpaceToDepthLayer(const SpaceToDepthDescriptor& spaceToDepthDescriptor,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000203 const char* name = nullptr);
Aron Virginas-Tar972af152019-06-11 14:14:03 +0100204
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000205 IConnectableLayer* AddFloorLayer(const char* name = nullptr);
telsoa014fcda012018-03-09 14:13:49 +0000206
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000207 IConnectableLayer* AddOutputLayer(LayerBindingId id, const char* name = nullptr);
telsoa014fcda012018-03-09 14:13:49 +0000208
telsoa01c577f2c2018-08-31 09:22:23 +0100209 IConnectableLayer* AddLstmLayer(const LstmDescriptor& descriptor,
210 const LstmInputParams& params,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000211 const char* name = nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +0100212
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000213 IConnectableLayer* AddDivisionLayer(const char* name = nullptr);
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100214
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000215 IConnectableLayer* AddSubtractionLayer(const char* name = nullptr);
David Beck19526222018-09-12 16:00:08 +0100216
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000217 IConnectableLayer* AddMaximumLayer(const char* name = nullptr);
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000218
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000219 IConnectableLayer* AddMeanLayer(const MeanDescriptor& meanDescriptor, const char* name = nullptr);
narpra0132b90462018-09-13 11:07:48 +0100220
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000221 IConnectableLayer* AddPadLayer(const PadDescriptor& padDescriptor, const char* name = nullptr);
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100222
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000223 IConnectableLayer* AddQuantizeLayer(const char* name = nullptr);
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000224
Conor Kennedy430b5d82018-11-14 15:28:28 +0000225 IConnectableLayer* AddStridedSliceLayer(const StridedSliceDescriptor& stridedSliceDescriptor,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000226 const char* name = nullptr);
Conor Kennedy430b5d82018-11-14 15:28:28 +0000227
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000228 IConnectableLayer* AddMinimumLayer(const char* name = nullptr);
kevmay0190539692018-11-29 08:40:19 +0000229
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100230 ARMNN_DEPRECATED_MSG("Use AddComparisonLayer instead")
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000231 IConnectableLayer* AddGreaterLayer(const char* name = nullptr);
Matteo Martincigh59a950c2018-12-13 12:48:25 +0000232
Aron Virginas-Tar77bfb5e2019-10-16 17:45:38 +0100233 ARMNN_DEPRECATED_MSG("Use AddComparisonLayer instead")
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000234 IConnectableLayer* AddEqualLayer(const char* name = nullptr);
FrancisMurtagh20995952018-12-17 12:11:36 +0000235
josh minor4a3c6102020-01-06 16:40:46 -0600236 ARMNN_DEPRECATED_MSG("Use AddElementwiseUnaryLayer instead")
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000237 IConnectableLayer* AddRsqrtLayer(const char* name = nullptr);
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +0000238
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000239 IConnectableLayer* AddMergeLayer(const char* name = nullptr);
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100240
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000241 IConnectableLayer* AddSwitchLayer(const char* name = nullptr);
Sadik Armaganeff363d2019-04-05 15:25:46 +0100242
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000243 IConnectableLayer* AddPreluLayer(const char* name = nullptr);
Matteo Martincigh0e406ee2019-06-12 15:42:18 +0100244
Aron Virginas-Tar639fb042019-06-20 14:28:19 +0100245 IConnectableLayer* AddTransposeConvolution2dLayer(const TransposeConvolution2dDescriptor& descriptor,
246 const ConstTensor& weights,
247 const Optional<ConstTensor>& biases,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000248 const char* name = nullptr);
Aron Virginas-Tar639fb042019-06-20 14:28:19 +0100249
Mike Kellyc9ea45a2020-02-28 18:11:58 +0000250 IConnectableLayer* AddTransposeLayer(const TransposeDescriptor& transposeDescriptor,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000251 const char* name = nullptr);
Mike Kellyc9ea45a2020-02-28 18:11:58 +0000252
Keith Davis3ae3f972021-05-21 16:33:48 +0100253 IConnectableLayer* AddShapeLayer(const char* name = nullptr);
254
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100255 IConnectableLayer* AddStackLayer(const StackDescriptor& stackDescriptor,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000256 const char* name = nullptr);
Matthew Jackson2b8c1da2019-07-04 14:59:16 +0100257
Derek Lamberti013c3902019-10-21 10:46:16 +0100258 IConnectableLayer* AddStandInLayer(const StandInDescriptor& descriptor,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000259 const char* name = nullptr);
Derek Lamberti013c3902019-10-21 10:46:16 +0100260
James Conroy586a9aa2020-03-20 08:49:33 +0000261 IConnectableLayer* AddQLstmLayer(const QLstmDescriptor& descriptor,
262 const LstmInputParams& params,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000263 const char* name = nullptr);
James Conroy586a9aa2020-03-20 08:49:33 +0000264
James Conroyee18dc82019-07-17 11:27:46 +0100265 IConnectableLayer* AddQuantizedLstmLayer(const QuantizedLstmInputParams& params,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000266 const char* name = nullptr);
James Conroyee18dc82019-07-17 11:27:46 +0100267
James Conroyaba90cd2020-11-06 16:28:18 +0000268 IConnectableLayer* AddLogicalBinaryLayer(const LogicalBinaryDescriptor& logicalBinaryDescriptor,
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000269 const char* name = nullptr);
James Conroyaba90cd2020-11-06 16:28:18 +0000270
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000271 void Accept(ILayerVisitor& visitor) const;
Mike Kelly8c1701a2019-02-11 17:01:27 +0000272
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000273 void ExecuteStrategy(IStrategy& strategy) const;
Finn Williamsb454c5c2021-02-09 15:56:23 +0000274
telsoa014fcda012018-03-09 14:13:49 +0000275private:
276 IConnectableLayer* AddFullyConnectedLayerImpl(const FullyConnectedDescriptor& fullyConnectedDescriptor,
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000277 const Optional<ConstTensor>& weights,
Aron Virginas-Tarad402702019-02-22 17:03:44 +0000278 const Optional<ConstTensor>& biases,
279 const char* name);
telsoa014fcda012018-03-09 14:13:49 +0000280
281 IConnectableLayer* AddConvolution2dLayerImpl(const Convolution2dDescriptor& convolution2dDescriptor,
Aron Virginas-Tarad402702019-02-22 17:03:44 +0000282 const ConstTensor& weights,
283 const Optional<ConstTensor>& biases,
284 const char* name);
telsoa014fcda012018-03-09 14:13:49 +0000285
286 IConnectableLayer* AddDepthwiseConvolution2dLayerImpl(
287 const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
288 const ConstTensor& weights,
Aron Virginas-Tarad402702019-02-22 17:03:44 +0000289 const Optional<ConstTensor>& biases,
telsoa014fcda012018-03-09 14:13:49 +0000290 const char* name);
291
Finn Williamsf24effa2020-07-03 10:12:03 +0100292 bool GetShapeInferenceMethod();
293 NetworkOptions m_NetworkOptions;
294
telsoa014fcda012018-03-09 14:13:49 +0000295 std::unique_ptr<Graph> m_Graph;
Sadik Armagan045f6be2020-09-10 13:37:32 +0100296 ModelOptions m_ModelOptions;
telsoa014fcda012018-03-09 14:13:49 +0000297};
298
Derek Lamberti84da38b2019-06-13 11:40:08 +0100299struct OptimizationResult
300{
301 bool m_Warning;
302 bool m_Error;
303
Derek Lamberti4a9e24b2020-01-03 16:53:38 +0000304 OptimizationResult(bool warning, bool error)
305 : m_Warning(warning)
306 , m_Error(error)
Derek Lamberti84da38b2019-06-13 11:40:08 +0100307 {}
Derek Lamberti4a9e24b2020-01-03 16:53:38 +0000308
309 OptimizationResult()
310 : OptimizationResult(false, false)
311 {}
312
313 bool IsOk() const { return !m_Warning && !m_Error; }
314 bool IsWarningOnly() const { return m_Warning && !m_Error; }
315 bool IsError() const { return m_Error; }
316
Derek Lamberti84da38b2019-06-13 11:40:08 +0100317};
318
319using BackendsMap = std::map<BackendId, std::unique_ptr<class IBackendInternal>>;
320
321BackendsMap CreateSupportedBackends(TensorHandleFactoryRegistry& handleFactoryRegistry,
322 struct BackendSettings& backendSettings);
323
324OptimizationResult SelectTensorHandleStrategy(Graph& optGraph,
325 BackendsMap& backends,
326 TensorHandleFactoryRegistry& registry,
Narumol Prangnawarata2493a02020-08-19 14:39:07 +0100327 bool importEnabled,
Derek Lamberti84da38b2019-06-13 11:40:08 +0100328 Optional<std::vector<std::string>&> errMessages);
329
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000330OptimizationResult AssignBackends(OptimizedNetworkImpl* optNetObjPtr,
Derek Lamberti4a9e24b2020-01-03 16:53:38 +0000331 BackendSettings& backendSettings,
332 Graph::Iterator& firstLayer,
333 Graph::Iterator& lastLayer,
334 Optional<std::vector<std::string>&> errMessages);
335
telsoa014fcda012018-03-09 14:13:49 +0000336} // namespace armnn