blob: f566131ce162230fd8841cf172ca9682706f6e76 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
Teresa Charlince7f51f2024-03-05 15:33:10 +00002// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5#pragma once
6
Mike Kelly377fb212023-01-10 15:55:28 +00007#include <armnn/Descriptors.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +01008#include "armnn/INetwork.hpp"
9#include "armnnTfLiteParser/ITfLiteParser.hpp"
Nattapat Chaimanowongb66504b2018-10-17 15:19:14 +010010#include "armnn/Types.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +010011
12#include <schema_generated.h>
13#include <functional>
Aron Virginas-Tarc975f922019-10-23 17:38:17 +010014#include <unordered_map>
telsoa01c577f2c2018-08-31 09:22:23 +010015#include <vector>
16
Matthew Sloyan4d217c02021-10-07 11:48:58 +010017#include <tensorflow/lite/version.h>
18
19#if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 3)
20#define ARMNN_POST_TFLITE_2_3
21#endif
22
telsoa01c577f2c2018-08-31 09:22:23 +010023namespace armnnTfLiteParser
24{
25
Kevin May7d96b162021-02-03 17:38:41 +000026class TfLiteParserImpl
telsoa01c577f2c2018-08-31 09:22:23 +010027{
28public:
29 // Shorthands for TfLite types
30 using ModelPtr = std::unique_ptr<tflite::ModelT>;
Derek Lambertiff05cc52019-04-26 13:05:17 +010031 using SubgraphPtr = std::unique_ptr<tflite::SubGraphT>;
telsoa01c577f2c2018-08-31 09:22:23 +010032 using OperatorPtr = std::unique_ptr<tflite::OperatorT>;
33 using OperatorCodePtr = std::unique_ptr<tflite::OperatorCodeT>;
34 using TensorPtr = std::unique_ptr<tflite::TensorT>;
35 using TensorRawPtr = const tflite::TensorT *;
36 using TensorRawPtrVector = std::vector<TensorRawPtr>;
37 using TensorIdRawPtr = std::pair<size_t, TensorRawPtr>;
38 using TensorIdRawPtrVector = std::vector<TensorIdRawPtr>;
39 using BufferPtr = std::unique_ptr<tflite::BufferT>;
40 using BufferRawPtr = const tflite::BufferT *;
41
42public:
43 /// Create the network from a flatbuffers binary file on disk
Kevin May7d96b162021-02-03 17:38:41 +000044 armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile);
telsoa01c577f2c2018-08-31 09:22:23 +010045
46 /// Create the network from a flatbuffers binary
Kevin May7d96b162021-02-03 17:38:41 +000047 armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t> & binaryContent);
telsoa01c577f2c2018-08-31 09:22:23 +010048
49
50 /// Retrieve binding info (layer id and tensor info) for the network input identified by
51 /// the given layer name and subgraph id
Kevin May7d96b162021-02-03 17:38:41 +000052 BindingPointInfo GetNetworkInputBindingInfo(size_t subgraphId,
53 const std::string& name) const;
telsoa01c577f2c2018-08-31 09:22:23 +010054
55 /// Retrieve binding info (layer id and tensor info) for the network output identified by
56 /// the given layer name and subgraph id
Kevin May7d96b162021-02-03 17:38:41 +000057 BindingPointInfo GetNetworkOutputBindingInfo(size_t subgraphId,
58 const std::string& name) const;
telsoa01c577f2c2018-08-31 09:22:23 +010059
60 /// Return the number of subgraphs in the parsed model
Kevin May7d96b162021-02-03 17:38:41 +000061 size_t GetSubgraphCount() const;
telsoa01c577f2c2018-08-31 09:22:23 +010062
63 /// Return the input tensor names for a given subgraph
Kevin May7d96b162021-02-03 17:38:41 +000064 std::vector<std::string> GetSubgraphInputTensorNames(size_t subgraphId) const;
telsoa01c577f2c2018-08-31 09:22:23 +010065
66 /// Return the output tensor names for a given subgraph
Kevin May7d96b162021-02-03 17:38:41 +000067 std::vector<std::string> GetSubgraphOutputTensorNames(size_t subgraphId) const;
telsoa01c577f2c2018-08-31 09:22:23 +010068
Kevin May7d96b162021-02-03 17:38:41 +000069 TfLiteParserImpl(const armnn::Optional<ITfLiteParser::TfLiteParserOptions>& options = armnn::EmptyOptional());
70 ~TfLiteParserImpl() = default;
telsoa01c577f2c2018-08-31 09:22:23 +010071
72public:
73 // testable helpers
Finn Williamsb49ed182021-06-29 15:50:08 +010074 armnn::INetworkPtr CreateNetworkFromBinaryAsDynamic(const std::vector<uint8_t>& binaryContent);
75
76 armnn::INetworkPtr LoadModel(std::unique_ptr<tflite::ModelT> model);
77
Teresa Charlin3ab85482021-06-08 16:59:29 +010078 static ModelPtr LoadModelFromFile(const char* fileName);
79 static ModelPtr LoadModelFromBinary(const uint8_t* binaryContent, size_t len);
80 static TensorRawPtrVector GetInputs(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
81 static TensorRawPtrVector GetOutputs(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
82 static TensorIdRawPtrVector GetSubgraphInputs(const ModelPtr& model, size_t subgraphIndex);
83 static TensorIdRawPtrVector GetSubgraphOutputs(const ModelPtr& model, size_t subgraphIndex);
telsoa01c577f2c2018-08-31 09:22:23 +010084 static std::vector<int32_t>& GetInputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
85 static std::vector<int32_t>& GetOutputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
86
87 static BufferRawPtr GetBuffer(const ModelPtr& model, size_t bufferIndex);
Teresa Charlin3ab85482021-06-08 16:59:29 +010088 static armnn::TensorInfo OutputShapeOfSqueeze(std::vector<uint32_t> squeezeDims,
89 const armnn::TensorInfo& inputTensorInfo);
90 static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo& inputTensorInfo,
91 const std::vector<int32_t>& targetDimsIn);
telsoa01c577f2c2018-08-31 09:22:23 +010092
Matthew Sloyanac001ee2021-02-03 10:43:04 +000093 /// Retrieve version in X.Y.Z form
94 static const std::string GetVersion();
95
telsoa01c577f2c2018-08-31 09:22:23 +010096private:
Finn Williamsd4fa5452021-03-01 12:31:41 +000097
telsoa01c577f2c2018-08-31 09:22:23 +010098 // No copying allowed until it is wanted and properly implemented
Kevin May7d96b162021-02-03 17:38:41 +000099 TfLiteParserImpl(const TfLiteParserImpl &) = delete;
100 TfLiteParserImpl & operator=(const TfLiteParserImpl &) = delete;
telsoa01c577f2c2018-08-31 09:22:23 +0100101
102 /// Create the network from an already loaded flatbuffers model
103 armnn::INetworkPtr CreateNetworkFromModel();
104
105 // signature for the parser functions
Kevin May7d96b162021-02-03 17:38:41 +0000106 using OperatorParsingFunction = void(TfLiteParserImpl::*)(size_t subgraphIndex, size_t operatorIndex);
telsoa01c577f2c2018-08-31 09:22:23 +0100107
Aron Virginas-Tarc975f922019-10-23 17:38:17 +0100108 void ParseCustomOperator(size_t subgraphIndex, size_t operatorIndex);
telsoa01c577f2c2018-08-31 09:22:23 +0100109 void ParseUnsupportedOperator(size_t subgraphIndex, size_t operatorIndex);
Aron Virginas-Tarc975f922019-10-23 17:38:17 +0100110
Matthew Sloyaned7fce42021-04-15 20:46:24 +0100111 void ParseAbs(size_t subgraphIndex, size_t operatorIndex);
Finn Williamsc42c3842019-01-22 14:18:11 +0000112 void ParseActivation(size_t subgraphIndex, size_t operatorIndex, armnn::ActivationFunction activationType);
Nina Drozd200e3802019-04-15 09:47:39 +0100113 void ParseAdd(size_t subgraphIndex, size_t operatorIndex);
Matthew Sloyan28f177c2021-04-09 14:38:52 +0100114 void ParseArgMinMax(size_t subgraphIndex, size_t operatorIndex, armnn::ArgMinMaxFunction argMinMaxFunction);
115 void ParseArgMin(size_t subgraphIndex, size_t operatorIndex);
116 void ParseArgMax(size_t subgraphIndex, size_t operatorIndex);
telsoa01c577f2c2018-08-31 09:22:23 +0100117 void ParseAveragePool2D(size_t subgraphIndex, size_t operatorIndex);
Samuel Yapfd3ba5a2022-08-24 17:04:34 +0100118 void ParseBatchMatMul(size_t subgraphIndex, size_t operatorIndex);
Bruno Goncalvesdb947e22019-02-08 18:52:21 -0200119 void ParseBatchToSpaceND(size_t subgraphIndex, size_t operatorIndex);
Idriss Chaouch564c13d2023-09-01 17:58:38 +0100120 void ParseBroadcastTo(size_t subgraphIndex, size_t operatorIndex);
mathad01b392e982021-04-07 12:07:30 +0100121 void ParseCast(size_t subgraphIndex, size_t operatorIndex);
Teresa Charlin93f0ad02023-03-23 15:28:02 +0000122 void ParseCeil(size_t subgraphIndex, size_t operatorIndex);
Bruno Goncalves2d0eb862021-07-11 14:10:15 -0300123 void ParseComparison(size_t subgraphIndex, size_t operatorIndex, armnn::ComparisonOperation comparisonOperation);
Sadik Armagan479045b2018-10-01 11:51:37 +0100124 void ParseConcatenation(size_t subgraphIndex, size_t operatorIndex);
telsoa01c577f2c2018-08-31 09:22:23 +0100125 void ParseConv2D(size_t subgraphIndex, size_t operatorIndex);
Matthew Sloyan4d217c02021-10-07 11:48:58 +0100126 // Conv3D support was added in TF 2.5, so for backwards compatibility a hash define is needed.
Cathal Corbett80b4ef02022-05-25 11:21:11 +0100127 #if defined(ARMNN_POST_TFLITE_2_4)
Matthew Sloyaneb5f8102021-10-05 17:31:42 +0100128 void ParseConv3D(size_t subgraphIndex, size_t operatorIndex);
Matthew Sloyan4d217c02021-10-07 11:48:58 +0100129 #endif
Sadik Armagan26868492021-01-22 14:25:31 +0000130 void ParseDepthToSpace(size_t subgraphIndex, size_t operatorIndex);
telsoa01c577f2c2018-08-31 09:22:23 +0100131 void ParseDepthwiseConv2D(size_t subgraphIndex, size_t operatorIndex);
Finn Williamsed66d142019-12-06 09:55:55 +0000132 void ParseDequantize(size_t subgraphIndex, size_t operatorIndex);
keidav011b3e2ea2019-02-21 10:07:37 +0000133 void ParseDetectionPostProcess(size_t subgraphIndex, size_t operatorIndex);
Matthew Sloyan28f177c2021-04-09 14:38:52 +0100134 void ParseDiv(size_t subgraphIndex, size_t operatorIndex);
Matthew Sloyaned7fce42021-04-15 20:46:24 +0100135 void ParseElementwiseUnary(size_t subgraphIndex, size_t operatorIndex, armnn::UnaryOperation unaryOperation);
Matthew Sloyan7515d072020-12-16 12:50:01 +0000136 void ParseElu(size_t subgraphIndex, size_t operatorIndex);
Bruno Goncalves2d0eb862021-07-11 14:10:15 -0300137 void ParseEqual(size_t subgraphIndex, size_t operatorIndex);
Derek Lambertif0176992020-04-28 13:37:49 +0100138 void ParseExp(size_t subgraphIndex, size_t operatorIndex);
Teresa Charlin3ab85482021-06-08 16:59:29 +0100139 void ParseExpandDims(size_t subgraphIndex, size_t operatorIndex);
Teresa Charlincdbd40b2022-02-25 13:21:55 +0000140 void ParseFloorDiv(size_t subgraphIndex, size_t operatorIndex);
Sadik Armagan8853c1f2018-10-22 09:04:18 +0100141 void ParseFullyConnected(size_t subgraphIndex, size_t operatorIndex);
Sadik Armagan26868492021-01-22 14:25:31 +0000142 void ParseGather(size_t subgraphIndex, size_t operatorIndex);
Teresa Charlin91a53ea2022-04-25 15:47:29 +0100143 void ParseGatherNd(size_t subgraphIndex, size_t operatorIndex);
Teresa Charlin077cddb2023-09-15 15:19:21 +0100144 void ParseGelu(size_t subgraphIndex, size_t operatorIndex);
Bruno Goncalves2d0eb862021-07-11 14:10:15 -0300145 void ParseGreater(size_t subgraphIndex, size_t operatorIndex);
146 void ParseGreaterOrEqual(size_t subgraphIndex, size_t operatorIndex);
Jan Eilers2f746b32020-07-28 14:00:06 +0100147 void ParseHardSwish(size_t subgraphIndex, size_t operatorIndex);
Sadik Armagan12239e72020-05-27 11:06:17 +0100148 void ParseLeakyRelu(size_t subgraphIndex, size_t operatorIndex);
Bruno Goncalves2d0eb862021-07-11 14:10:15 -0300149 void ParseLess(size_t subgraphIndex, size_t operatorIndex);
150 void ParseLessOrEqual(size_t subgraphIndex, size_t operatorIndex);
Teresa Charlin28aa6692022-07-12 11:18:44 +0100151 void ParseLog(size_t subgraphIndex, size_t operatorIndex);
Mike Kelly31dce2b2021-09-01 21:22:37 +0100152 void ParseLocalResponseNormalization(size_t subgraphIndex, size_t operatorIndex);
Matthew Sloyaned7fce42021-04-15 20:46:24 +0100153 void ParseLogicalNot(size_t subgraphIndex, size_t operatorIndex);
Finn Williamsc42c3842019-01-22 14:18:11 +0000154 void ParseLogistic(size_t subgraphIndex, size_t operatorIndex);
Teresa Charlinfd33a692022-06-29 15:35:57 +0100155 void ParseLogSoftmax(size_t subgraphIndex, size_t operatorIndex);
Matthew Jackson28c94572019-07-18 10:47:03 +0100156 void ParseL2Normalization(size_t subgraphIndex, size_t operatorIndex);
Nattapat Chaimanowongb66504b2018-10-17 15:19:14 +0100157 void ParseMaxPool2D(size_t subgraphIndex, size_t operatorIndex);
Bruno Goncalvesb8d805e2019-02-12 22:57:13 -0200158 void ParseMaximum(size_t subgraphIndex, size_t operatorIndex);
Nina Drozd200e3802019-04-15 09:47:39 +0100159 void ParseMean(size_t subgraphIndex, size_t operatorIndex);
Bruno Goncalves8f6d7a72019-02-12 22:58:18 -0200160 void ParseMinimum(size_t subgraphIndex, size_t operatorIndex);
Matthew Sloyanaf3a4ef2021-10-22 15:48:12 +0100161 void ParseMirrorPad(size_t subgraphIndex, size_t operatorIndex);
Nina Drozd200e3802019-04-15 09:47:39 +0100162 void ParseMul(size_t subgraphIndex, size_t operatorIndex);
Darshan Patel83fcf982020-05-26 22:22:42 +0530163 void ParseNeg(size_t subgraphIndex, size_t operatorIndex);
Bruno Goncalves2d0eb862021-07-11 14:10:15 -0300164 void ParseNotEqual(size_t subgraphIndex, size_t operatorIndex);
Matthew Jacksonbcca1f42019-07-16 11:39:21 +0100165 void ParsePack(size_t subgraphIndex, size_t operatorIndex);
Nina Drozd200e3802019-04-15 09:47:39 +0100166 void ParsePad(size_t subgraphIndex, size_t operatorIndex);
167 void ParsePool(size_t subgraphIndex, size_t operatorIndex, armnn::PoolingAlgorithm algorithm);
John Mcloughlin0ec00872023-05-15 17:03:49 +0100168 void ParsePower(size_t subgraphIndex, size_t operatorIndex);
Narumol Prangnawaratbfaee6b2021-05-24 18:50:24 +0100169 void ParsePrelu(size_t subgraphIndex, size_t operatorIndex);
Sadik Armagan66dedc72019-12-10 16:32:07 +0000170 void ParseQuantize(size_t subgraphIndex, size_t operatorIndex);
Sadik Armagana2747482021-02-09 10:28:54 +0000171 void ParseReduce(size_t subgraphIndex, size_t operatorIndex, armnn::ReduceOperation reduceOperation);
172 void ParseReduceMax(size_t subgraphIndex, size_t operatorIndex);
173 void ParseReduceMin(size_t subgraphIndex, size_t operatorIndex);
Teresa Charlin4e3e8312021-08-05 12:34:37 +0100174 void ParseReduceProd(size_t subgraphIndex, size_t operatorIndex);
Sadik Armagan58f39192018-09-17 14:14:39 +0100175 void ParseRelu(size_t subgraphIndex, size_t operatorIndex);
176 void ParseRelu6(size_t subgraphIndex, size_t operatorIndex);
Sadikb94967b2018-09-19 15:30:00 +0100177 void ParseReshape(size_t subgraphIndex, size_t operatorIndex);
Sadik Armagana3b31f02019-12-05 09:08:53 +0000178 void ParseResize(size_t subgraphIndex, size_t operatorIndex, armnn::ResizeMethod resizeMethod);
Bruno Goncalves3f58ddb2019-02-07 18:40:11 -0200179 void ParseResizeBilinear(size_t subgraphIndex, size_t operatorIndex);
Sadik Armagana3b31f02019-12-05 09:08:53 +0000180 void ParseResizeNearestNeighbor(size_t subgraphIndex, size_t operatorIndex);
Tianle Chenge5a30ff2023-07-03 11:24:12 +0100181 void ParseReverseV2(size_t subgraphIndex, size_t operatorIndex);
Matthew Sloyaned7fce42021-04-15 20:46:24 +0100182 void ParseRsqrt(size_t subgraphIndex, size_t operatorIndex);
Teresa Charlince7f51f2024-03-05 15:33:10 +0000183 void ParseScatterNd(size_t subgraphIndex, size_t operatorIndex);
Keith Davis0176fd82021-06-01 17:36:32 +0100184 void ParseShape(size_t subgraphIndex, size_t operatorIndex);
Teresa Charlin28aa6692022-07-12 11:18:44 +0100185 void ParseSin(size_t subgraphIndex, size_t operatorIndex);
josh minorba424d22019-11-13 10:55:17 -0600186 void ParseSlice(size_t subgraphIndex, size_t operatorIndex);
Sadik Armagan479045b2018-10-01 11:51:37 +0100187 void ParseSoftmax(size_t subgraphIndex, size_t operatorIndex);
Teresa Charlinf0fce5b2022-05-04 17:24:43 +0100188 void ParseSqrt(size_t subgraphIndex, size_t operatorIndex);
Bruno Goncalvesbaded142019-02-08 19:02:48 -0200189 void ParseSpaceToBatchND(size_t subgraphIndex, size_t operatorIndex);
Teresa Charlin2a764ad2023-02-24 18:17:31 +0000190 void ParseSpaceToDepth(size_t subgraphIndex, size_t operatorIndex);
Nina Drozd200e3802019-04-15 09:47:39 +0100191 void ParseSplit(size_t subgraphIndex, size_t operatorIndex);
Derek Lambertif0176992020-04-28 13:37:49 +0100192 void ParseSplitV(size_t subgraphIndex, size_t operatorIndex);
Sadik Armagan479045b2018-10-01 11:51:37 +0100193 void ParseSqueeze(size_t subgraphIndex, size_t operatorIndex);
Teresa Charlin6963b332023-07-11 11:35:41 +0100194 void ParseSquare(size_t subgraphIndex, size_t operatorIndex);
John Mcloughlin0ec00872023-05-15 17:03:49 +0100195 void ParseSquaredDifference(size_t subgraphIndex, size_t operatorIndex);
Bruno Goncalves451d95b2019-02-12 22:59:22 -0200196 void ParseStridedSlice(size_t subgraphIndex, size_t operatorIndex);
Bruno Goncalvesbbeae262019-02-07 18:37:39 -0200197 void ParseSub(size_t subgraphIndex, size_t operatorIndex);
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +0000198 void ParseSum(size_t subgraphIndex, size_t operatorIndex);
Nina Drozd99851762019-04-09 09:37:38 +0100199 void ParseTanH(size_t subgraphIndex, size_t operatorIndex);
Teresa Charlin777008b2023-07-26 10:07:55 +0100200 void ParseTile(size_t subgraphIndex, size_t operatorIndex);
Keith Davis4cd29a02019-09-09 14:49:20 +0100201 void ParseTranspose(size_t subgraphIndex, size_t operatorIndex);
Matthew Jackson74bf7da2019-08-16 16:51:42 +0100202 void ParseTransposeConv(size_t subgraphIndex, size_t operatorIndex);
Mike Kelly5880b912022-01-28 16:18:54 +0000203 void ParseUnidirectionalSequenceLSTM(size_t subgraphIndex, size_t operatorIndex);
Nina Drozd200e3802019-04-15 09:47:39 +0100204 void ParseUnpack(size_t subgraphIndex, size_t operatorIndex);
Nattapat Chaimanowongb66504b2018-10-17 15:19:14 +0100205
telsoa01c577f2c2018-08-31 09:22:23 +0100206 void RegisterProducerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IOutputSlot* slot);
207 void RegisterConsumerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IInputSlot* slot);
208 void RegisterInputSlots(size_t subgraphIndex,
209 size_t operatorIndex,
210 armnn::IConnectableLayer* layer,
Finn Williamsd4fa5452021-03-01 12:31:41 +0000211 const std::vector<unsigned int>& tensorIndexes,
212 unsigned int startingSlotIndex = 0);
telsoa01c577f2c2018-08-31 09:22:23 +0100213 void RegisterOutputSlots(size_t subgraphIndex,
214 size_t operatorIndex,
215 armnn::IConnectableLayer* layer,
216 const std::vector<unsigned int>& tensorIndexes);
217
Mike Kelly377fb212023-01-10 15:55:28 +0000218 void SetupInputLayerTensorInfos(size_t subgraphIndex);
219 void SetupConstantLayerTensorInfos(size_t subgraphIndex);
220
telsoa01c577f2c2018-08-31 09:22:23 +0100221 void SetupInputLayers(size_t subgraphIndex);
222 void SetupOutputLayers(size_t subgraphIndex);
Bruno Goncalves3d7efe92018-12-27 14:21:43 -0200223 void SetupConstantLayers(size_t subgraphIndex);
telsoa01c577f2c2018-08-31 09:22:23 +0100224
225 void ResetParser();
226
Bruno Goncalves9c761a62018-12-27 14:20:35 -0200227 void AddBroadcastReshapeLayer(size_t subgraphIndex,
228 size_t operatorIndex,
229 armnn::IConnectableLayer* layer);
230
Mike Kelly04d82292023-01-19 18:29:40 +0000231 /// Attach an reshape layer to the one passed as a parameter
232 armnn::IConnectableLayer* AddReshapeLayer(armnn::IConnectableLayer* layer,
233 unsigned int outputSlot,
234 std::string reshapeLayerName,
235 armnn::TensorInfo outputShape);
236
telsoa01c577f2c2018-08-31 09:22:23 +0100237 /// Attach an activation layer to the one passed as a parameter
Sadik Armagan58f39192018-09-17 14:14:39 +0100238 armnn::IConnectableLayer* AddFusedActivationLayer(armnn::IConnectableLayer* layer,
239 unsigned int outputSlot,
240 tflite::ActivationFunctionType activationType);
telsoa01c577f2c2018-08-31 09:22:23 +0100241
Teresa Charlincdbd40b2022-02-25 13:21:55 +0000242 /// Attach a floor layer to the one passed as a parameter
243 armnn::IConnectableLayer* AddFusedFloorLayer(armnn::IConnectableLayer* layer, unsigned int outputSlot);
244
telsoa01c577f2c2018-08-31 09:22:23 +0100245 // SupportedDataStorage's purpose is to hold data till we pass over to the network.
246 // We don't care about the content, and we want a single datatype to simplify the code.
247 struct SupportedDataStorage
248 {
Matteo Martincigh747ef822018-12-18 09:26:39 +0000249 public:
250 // Convenience constructors
251 SupportedDataStorage(std::unique_ptr<float[]>&& data);
252 SupportedDataStorage(std::unique_ptr<uint8_t[]>&& data);
Keith Davisd305e1a2020-01-22 11:57:54 +0000253 SupportedDataStorage(std::unique_ptr<int8_t[]>&& data);
Matteo Martincigh747ef822018-12-18 09:26:39 +0000254 SupportedDataStorage(std::unique_ptr<int32_t[]>&& data);
telsoa01c577f2c2018-08-31 09:22:23 +0100255
Matteo Martincigh747ef822018-12-18 09:26:39 +0000256 private:
257 // Pointers to the data buffers
258 std::unique_ptr<float[]> m_FloatData;
259 std::unique_ptr<uint8_t[]> m_Uint8Data;
Keith Davisd305e1a2020-01-22 11:57:54 +0000260 std::unique_ptr<int8_t[]> m_Int8Data;
Matteo Martincigh747ef822018-12-18 09:26:39 +0000261 std::unique_ptr<int32_t[]> m_Int32Data;
telsoa01c577f2c2018-08-31 09:22:23 +0100262 };
263
Mike Kelly5880b912022-01-28 16:18:54 +0000264 bool ShouldConstantTensorBeCreated(unsigned int tensorIndex);
Mike Kelly0506ef02023-01-03 16:29:44 +0000265
Finn Williamsd4fa5452021-03-01 12:31:41 +0000266 bool IsConstTensor(TensorRawPtr tensorPtr);
Mike Kelly0506ef02023-01-03 16:29:44 +0000267
268 bool ShouldConstantTensorBeConverted(TfLiteParserImpl::TensorRawPtr tensorPtr,
269 armnn::DataType inputDataType,
270 armnn::DataType filterDataType);
271
Finn Williamsd4fa5452021-03-01 12:31:41 +0000272 armnn::ConstTensor CreateConstTensorNonPermuted(TensorRawPtr tensorPtr,
273 armnn::TensorInfo& tensorInfo);
Mike Kelly5880b912022-01-28 16:18:54 +0000274
Finn Williamsd4fa5452021-03-01 12:31:41 +0000275 std::pair<armnn::ConstTensor, SupportedDataStorage>
276 CreateConstTensorPermuted(TensorRawPtr tensorPtr,
277 armnn::TensorInfo& tensorInfo,
278 armnn::Optional<armnn::PermutationVector&> permutationVector);
Mike Kelly0506ef02023-01-03 16:29:44 +0000279
Mike Kelly5880b912022-01-28 16:18:54 +0000280 std::pair<armnn::ConstTensor, std::unique_ptr<float[]>>
281 CreateConstTensorNonPermuted(TensorRawPtr tensorPtr,
282 armnn::TensorInfo& tensorInfo,
283 armnn::DataType inputDataType);
Matteo Martincigh747ef822018-12-18 09:26:39 +0000284
285 template<typename T>
Kevin May7d96b162021-02-03 17:38:41 +0000286 std::pair<armnn::ConstTensor, TfLiteParserImpl::SupportedDataStorage>
287 CreateConstTensorAndStoreData(TfLiteParserImpl::BufferRawPtr bufferPtr,
288 TfLiteParserImpl::TensorRawPtr tensorPtr,
Matteo Martincigh747ef822018-12-18 09:26:39 +0000289 armnn::TensorInfo& tensorInfo,
290 armnn::Optional<armnn::PermutationVector&> permutationVector);
Mike Kelly0506ef02023-01-03 16:29:44 +0000291
Mike Kelly5880b912022-01-28 16:18:54 +0000292 std::pair<armnn::ConstTensor*, std::unique_ptr<float[]>>
293 CreateConstTensorPtr(TensorRawPtr tensorPtr,
294 armnn::TensorInfo& inputTensorInfo);
Matteo Martincigh747ef822018-12-18 09:26:39 +0000295
Mike Kelly377fb212023-01-10 15:55:28 +0000296 armnn::TensorInfo InputTensorInfo(size_t subgraphIndex,
297 size_t operatorIndex,
298 int input);
299
300 armnn::TensorInfo OutputTensorInfoFromInputs(size_t subgraphIndex,
301 size_t operatorIndex,
302 armnn::IConnectableLayer* layer,
303 int output,
304 std::vector<int> inputs);
305
306 armnn::TensorInfo OutputTensorInfoFromShapes(size_t subgraphIndex,
307 size_t operatorIndex,
308 armnn::IConnectableLayer* layer,
309 int output = 0,
310 std::vector<armnn::TensorShape> inputShapes = {});
311
312 /// Settings for configuring the TfLiteParser
Aron Virginas-Tarc975f922019-10-23 17:38:17 +0100313 armnn::Optional<ITfLiteParser::TfLiteParserOptions> m_Options;
314
telsoa01c577f2c2018-08-31 09:22:23 +0100315 /// The network we're building. Gets cleared after it is passed to the user
316 armnn::INetworkPtr m_Network;
telsoa01c577f2c2018-08-31 09:22:23 +0100317 ModelPtr m_Model;
318
Aron Virginas-Tarc975f922019-10-23 17:38:17 +0100319 std::vector<OperatorParsingFunction> m_ParserFunctions;
320 std::unordered_map<std::string, OperatorParsingFunction> m_CustomParserFunctions;
321
telsoa01c577f2c2018-08-31 09:22:23 +0100322 /// A mapping of an output slot to each of the input slots it should be connected to
323 /// The outputSlot is from the layer that creates this tensor as one of its ouputs
324 /// The inputSlots are from the layers that use this tensor as one of their inputs
325 struct TensorSlots
326 {
327 armnn::IOutputSlot* outputSlot;
328 std::vector<armnn::IInputSlot*> inputSlots;
329
330 TensorSlots() : outputSlot(nullptr) { }
331 };
332 typedef std::vector<TensorSlots> TensorConnections;
333 /// Connections for tensors in each subgraph
334 /// The first index is the subgraph ID, the second index is the tensor ID
335 std::vector<TensorConnections> m_SubgraphConnections;
Narumol Prangnawarat4628d052019-02-25 17:26:05 +0000336
Mike Kelly5880b912022-01-28 16:18:54 +0000337 /// This is used in case that the model does not specify the output.
Narumol Prangnawarat4628d052019-02-25 17:26:05 +0000338 /// The shape can be calculated from the options.
Mike Kelly377fb212023-01-10 15:55:28 +0000339 std::vector<std::vector<unsigned int>> m_OverriddenOutputShapes;
Mike Kelly5880b912022-01-28 16:18:54 +0000340
341 std::vector<unsigned int> m_ConstantsToDequantize;
342 std::vector<unsigned int> m_ConstantsToBeCreated;
Mike Kelly377fb212023-01-10 15:55:28 +0000343 std::map<size_t, armnn::TensorInfo> m_TensorInfos;
telsoa01c577f2c2018-08-31 09:22:23 +0100344};
345
346}