telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2017 Arm Ltd. All rights reserved. |
David Beck | ecb56cd | 2018-09-05 12:52:57 +0100 | [diff] [blame] | 3 | // SPDX-License-Identifier: MIT |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 4 | // |
| 5 | #pragma once |
| 6 | |
| 7 | #include "armnn/INetwork.hpp" |
| 8 | #include "armnnTfLiteParser/ITfLiteParser.hpp" |
Nattapat Chaimanowong | b66504b | 2018-10-17 15:19:14 +0100 | [diff] [blame] | 9 | #include "armnn/Types.hpp" |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 10 | |
| 11 | #include <schema_generated.h> |
| 12 | #include <functional> |
Aron Virginas-Tar | c975f92 | 2019-10-23 17:38:17 +0100 | [diff] [blame] | 13 | #include <unordered_map> |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 14 | #include <vector> |
| 15 | |
| 16 | namespace armnnTfLiteParser |
| 17 | { |
| 18 | |
Kevin May | 7d96b16 | 2021-02-03 17:38:41 +0000 | [diff] [blame] | 19 | class TfLiteParserImpl |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 20 | { |
| 21 | public: |
| 22 | // Shorthands for TfLite types |
| 23 | using ModelPtr = std::unique_ptr<tflite::ModelT>; |
Derek Lamberti | ff05cc5 | 2019-04-26 13:05:17 +0100 | [diff] [blame] | 24 | using SubgraphPtr = std::unique_ptr<tflite::SubGraphT>; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 25 | using OperatorPtr = std::unique_ptr<tflite::OperatorT>; |
| 26 | using OperatorCodePtr = std::unique_ptr<tflite::OperatorCodeT>; |
| 27 | using TensorPtr = std::unique_ptr<tflite::TensorT>; |
| 28 | using TensorRawPtr = const tflite::TensorT *; |
| 29 | using TensorRawPtrVector = std::vector<TensorRawPtr>; |
| 30 | using TensorIdRawPtr = std::pair<size_t, TensorRawPtr>; |
| 31 | using TensorIdRawPtrVector = std::vector<TensorIdRawPtr>; |
| 32 | using BufferPtr = std::unique_ptr<tflite::BufferT>; |
| 33 | using BufferRawPtr = const tflite::BufferT *; |
| 34 | |
| 35 | public: |
| 36 | /// Create the network from a flatbuffers binary file on disk |
Kevin May | 7d96b16 | 2021-02-03 17:38:41 +0000 | [diff] [blame] | 37 | armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 38 | |
| 39 | /// Create the network from a flatbuffers binary |
Kevin May | 7d96b16 | 2021-02-03 17:38:41 +0000 | [diff] [blame] | 40 | armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t> & binaryContent); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 41 | |
| 42 | |
| 43 | /// Retrieve binding info (layer id and tensor info) for the network input identified by |
| 44 | /// the given layer name and subgraph id |
Kevin May | 7d96b16 | 2021-02-03 17:38:41 +0000 | [diff] [blame] | 45 | BindingPointInfo GetNetworkInputBindingInfo(size_t subgraphId, |
| 46 | const std::string& name) const; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 47 | |
| 48 | /// Retrieve binding info (layer id and tensor info) for the network output identified by |
| 49 | /// the given layer name and subgraph id |
Kevin May | 7d96b16 | 2021-02-03 17:38:41 +0000 | [diff] [blame] | 50 | BindingPointInfo GetNetworkOutputBindingInfo(size_t subgraphId, |
| 51 | const std::string& name) const; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 52 | |
| 53 | /// Return the number of subgraphs in the parsed model |
Kevin May | 7d96b16 | 2021-02-03 17:38:41 +0000 | [diff] [blame] | 54 | size_t GetSubgraphCount() const; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 55 | |
| 56 | /// Return the input tensor names for a given subgraph |
Kevin May | 7d96b16 | 2021-02-03 17:38:41 +0000 | [diff] [blame] | 57 | std::vector<std::string> GetSubgraphInputTensorNames(size_t subgraphId) const; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 58 | |
| 59 | /// Return the output tensor names for a given subgraph |
Kevin May | 7d96b16 | 2021-02-03 17:38:41 +0000 | [diff] [blame] | 60 | std::vector<std::string> GetSubgraphOutputTensorNames(size_t subgraphId) const; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 61 | |
Kevin May | 7d96b16 | 2021-02-03 17:38:41 +0000 | [diff] [blame] | 62 | TfLiteParserImpl(const armnn::Optional<ITfLiteParser::TfLiteParserOptions>& options = armnn::EmptyOptional()); |
| 63 | ~TfLiteParserImpl() = default; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 64 | |
| 65 | public: |
| 66 | // testable helpers |
Finn Williams | b49ed18 | 2021-06-29 15:50:08 +0100 | [diff] [blame] | 67 | armnn::INetworkPtr CreateNetworkFromBinaryAsDynamic(const std::vector<uint8_t>& binaryContent); |
| 68 | |
| 69 | armnn::INetworkPtr LoadModel(std::unique_ptr<tflite::ModelT> model); |
| 70 | |
Teresa Charlin | 3ab8548 | 2021-06-08 16:59:29 +0100 | [diff] [blame] | 71 | static ModelPtr LoadModelFromFile(const char* fileName); |
| 72 | static ModelPtr LoadModelFromBinary(const uint8_t* binaryContent, size_t len); |
| 73 | static TensorRawPtrVector GetInputs(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex); |
| 74 | static TensorRawPtrVector GetOutputs(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex); |
| 75 | static TensorIdRawPtrVector GetSubgraphInputs(const ModelPtr& model, size_t subgraphIndex); |
| 76 | static TensorIdRawPtrVector GetSubgraphOutputs(const ModelPtr& model, size_t subgraphIndex); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 77 | static std::vector<int32_t>& GetInputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex); |
| 78 | static std::vector<int32_t>& GetOutputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex); |
| 79 | |
| 80 | static BufferRawPtr GetBuffer(const ModelPtr& model, size_t bufferIndex); |
Teresa Charlin | 3ab8548 | 2021-06-08 16:59:29 +0100 | [diff] [blame] | 81 | static armnn::TensorInfo OutputShapeOfSqueeze(std::vector<uint32_t> squeezeDims, |
| 82 | const armnn::TensorInfo& inputTensorInfo); |
| 83 | static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo& inputTensorInfo, |
| 84 | const std::vector<int32_t>& targetDimsIn); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 85 | |
Matthew Sloyan | ac001ee | 2021-02-03 10:43:04 +0000 | [diff] [blame] | 86 | /// Retrieve version in X.Y.Z form |
| 87 | static const std::string GetVersion(); |
| 88 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 89 | private: |
Finn Williams | d4fa545 | 2021-03-01 12:31:41 +0000 | [diff] [blame] | 90 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 91 | // No copying allowed until it is wanted and properly implemented |
Kevin May | 7d96b16 | 2021-02-03 17:38:41 +0000 | [diff] [blame] | 92 | TfLiteParserImpl(const TfLiteParserImpl &) = delete; |
| 93 | TfLiteParserImpl & operator=(const TfLiteParserImpl &) = delete; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 94 | |
| 95 | /// Create the network from an already loaded flatbuffers model |
| 96 | armnn::INetworkPtr CreateNetworkFromModel(); |
| 97 | |
| 98 | // signature for the parser functions |
Kevin May | 7d96b16 | 2021-02-03 17:38:41 +0000 | [diff] [blame] | 99 | using OperatorParsingFunction = void(TfLiteParserImpl::*)(size_t subgraphIndex, size_t operatorIndex); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 100 | |
Aron Virginas-Tar | c975f92 | 2019-10-23 17:38:17 +0100 | [diff] [blame] | 101 | void ParseCustomOperator(size_t subgraphIndex, size_t operatorIndex); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 102 | void ParseUnsupportedOperator(size_t subgraphIndex, size_t operatorIndex); |
Aron Virginas-Tar | c975f92 | 2019-10-23 17:38:17 +0100 | [diff] [blame] | 103 | |
Matthew Sloyan | ed7fce4 | 2021-04-15 20:46:24 +0100 | [diff] [blame] | 104 | void ParseAbs(size_t subgraphIndex, size_t operatorIndex); |
Finn Williams | c42c384 | 2019-01-22 14:18:11 +0000 | [diff] [blame] | 105 | void ParseActivation(size_t subgraphIndex, size_t operatorIndex, armnn::ActivationFunction activationType); |
Nina Drozd | 200e380 | 2019-04-15 09:47:39 +0100 | [diff] [blame] | 106 | void ParseAdd(size_t subgraphIndex, size_t operatorIndex); |
Matthew Sloyan | 28f177c | 2021-04-09 14:38:52 +0100 | [diff] [blame] | 107 | void ParseArgMinMax(size_t subgraphIndex, size_t operatorIndex, armnn::ArgMinMaxFunction argMinMaxFunction); |
| 108 | void ParseArgMin(size_t subgraphIndex, size_t operatorIndex); |
| 109 | void ParseArgMax(size_t subgraphIndex, size_t operatorIndex); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 110 | void ParseAveragePool2D(size_t subgraphIndex, size_t operatorIndex); |
Bruno Goncalves | db947e2 | 2019-02-08 18:52:21 -0200 | [diff] [blame] | 111 | void ParseBatchToSpaceND(size_t subgraphIndex, size_t operatorIndex); |
mathad01 | b392e98 | 2021-04-07 12:07:30 +0100 | [diff] [blame] | 112 | void ParseCast(size_t subgraphIndex, size_t operatorIndex); |
Bruno Goncalves | 2d0eb86 | 2021-07-11 14:10:15 -0300 | [diff] [blame] | 113 | void ParseComparison(size_t subgraphIndex, size_t operatorIndex, armnn::ComparisonOperation comparisonOperation); |
Sadik Armagan | 479045b | 2018-10-01 11:51:37 +0100 | [diff] [blame] | 114 | void ParseConcatenation(size_t subgraphIndex, size_t operatorIndex); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 115 | void ParseConv2D(size_t subgraphIndex, size_t operatorIndex); |
Sadik Armagan | 2686849 | 2021-01-22 14:25:31 +0000 | [diff] [blame] | 116 | void ParseDepthToSpace(size_t subgraphIndex, size_t operatorIndex); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 117 | void ParseDepthwiseConv2D(size_t subgraphIndex, size_t operatorIndex); |
Finn Williams | ed66d14 | 2019-12-06 09:55:55 +0000 | [diff] [blame] | 118 | void ParseDequantize(size_t subgraphIndex, size_t operatorIndex); |
keidav01 | 1b3e2ea | 2019-02-21 10:07:37 +0000 | [diff] [blame] | 119 | void ParseDetectionPostProcess(size_t subgraphIndex, size_t operatorIndex); |
Matthew Sloyan | 28f177c | 2021-04-09 14:38:52 +0100 | [diff] [blame] | 120 | void ParseDiv(size_t subgraphIndex, size_t operatorIndex); |
Matthew Sloyan | ed7fce4 | 2021-04-15 20:46:24 +0100 | [diff] [blame] | 121 | void ParseElementwiseUnary(size_t subgraphIndex, size_t operatorIndex, armnn::UnaryOperation unaryOperation); |
Matthew Sloyan | 7515d07 | 2020-12-16 12:50:01 +0000 | [diff] [blame] | 122 | void ParseElu(size_t subgraphIndex, size_t operatorIndex); |
Bruno Goncalves | 2d0eb86 | 2021-07-11 14:10:15 -0300 | [diff] [blame] | 123 | void ParseEqual(size_t subgraphIndex, size_t operatorIndex); |
Derek Lamberti | f017699 | 2020-04-28 13:37:49 +0100 | [diff] [blame] | 124 | void ParseExp(size_t subgraphIndex, size_t operatorIndex); |
Teresa Charlin | 3ab8548 | 2021-06-08 16:59:29 +0100 | [diff] [blame] | 125 | void ParseExpandDims(size_t subgraphIndex, size_t operatorIndex); |
Sadik Armagan | 8853c1f | 2018-10-22 09:04:18 +0100 | [diff] [blame] | 126 | void ParseFullyConnected(size_t subgraphIndex, size_t operatorIndex); |
Sadik Armagan | 2686849 | 2021-01-22 14:25:31 +0000 | [diff] [blame] | 127 | void ParseGather(size_t subgraphIndex, size_t operatorIndex); |
Bruno Goncalves | 2d0eb86 | 2021-07-11 14:10:15 -0300 | [diff] [blame] | 128 | void ParseGreater(size_t subgraphIndex, size_t operatorIndex); |
| 129 | void ParseGreaterOrEqual(size_t subgraphIndex, size_t operatorIndex); |
Jan Eilers | 2f746b3 | 2020-07-28 14:00:06 +0100 | [diff] [blame] | 130 | void ParseHardSwish(size_t subgraphIndex, size_t operatorIndex); |
Sadik Armagan | 12239e7 | 2020-05-27 11:06:17 +0100 | [diff] [blame] | 131 | void ParseLeakyRelu(size_t subgraphIndex, size_t operatorIndex); |
Bruno Goncalves | 2d0eb86 | 2021-07-11 14:10:15 -0300 | [diff] [blame] | 132 | void ParseLess(size_t subgraphIndex, size_t operatorIndex); |
| 133 | void ParseLessOrEqual(size_t subgraphIndex, size_t operatorIndex); |
Matthew Sloyan | ed7fce4 | 2021-04-15 20:46:24 +0100 | [diff] [blame] | 134 | void ParseLogicalNot(size_t subgraphIndex, size_t operatorIndex); |
Finn Williams | c42c384 | 2019-01-22 14:18:11 +0000 | [diff] [blame] | 135 | void ParseLogistic(size_t subgraphIndex, size_t operatorIndex); |
Matthew Jackson | 28c9457 | 2019-07-18 10:47:03 +0100 | [diff] [blame] | 136 | void ParseL2Normalization(size_t subgraphIndex, size_t operatorIndex); |
Nattapat Chaimanowong | b66504b | 2018-10-17 15:19:14 +0100 | [diff] [blame] | 137 | void ParseMaxPool2D(size_t subgraphIndex, size_t operatorIndex); |
Bruno Goncalves | b8d805e | 2019-02-12 22:57:13 -0200 | [diff] [blame] | 138 | void ParseMaximum(size_t subgraphIndex, size_t operatorIndex); |
Nina Drozd | 200e380 | 2019-04-15 09:47:39 +0100 | [diff] [blame] | 139 | void ParseMean(size_t subgraphIndex, size_t operatorIndex); |
Bruno Goncalves | 8f6d7a7 | 2019-02-12 22:58:18 -0200 | [diff] [blame] | 140 | void ParseMinimum(size_t subgraphIndex, size_t operatorIndex); |
Nina Drozd | 200e380 | 2019-04-15 09:47:39 +0100 | [diff] [blame] | 141 | void ParseMul(size_t subgraphIndex, size_t operatorIndex); |
Darshan Patel | 83fcf98 | 2020-05-26 22:22:42 +0530 | [diff] [blame] | 142 | void ParseNeg(size_t subgraphIndex, size_t operatorIndex); |
Bruno Goncalves | 2d0eb86 | 2021-07-11 14:10:15 -0300 | [diff] [blame] | 143 | void ParseNotEqual(size_t subgraphIndex, size_t operatorIndex); |
Matthew Jackson | bcca1f4 | 2019-07-16 11:39:21 +0100 | [diff] [blame] | 144 | void ParsePack(size_t subgraphIndex, size_t operatorIndex); |
Nina Drozd | 200e380 | 2019-04-15 09:47:39 +0100 | [diff] [blame] | 145 | void ParsePad(size_t subgraphIndex, size_t operatorIndex); |
| 146 | void ParsePool(size_t subgraphIndex, size_t operatorIndex, armnn::PoolingAlgorithm algorithm); |
Narumol Prangnawarat | bfaee6b | 2021-05-24 18:50:24 +0100 | [diff] [blame] | 147 | void ParsePrelu(size_t subgraphIndex, size_t operatorIndex); |
Sadik Armagan | 66dedc7 | 2019-12-10 16:32:07 +0000 | [diff] [blame] | 148 | void ParseQuantize(size_t subgraphIndex, size_t operatorIndex); |
Sadik Armagan | a274748 | 2021-02-09 10:28:54 +0000 | [diff] [blame] | 149 | void ParseReduce(size_t subgraphIndex, size_t operatorIndex, armnn::ReduceOperation reduceOperation); |
| 150 | void ParseReduceMax(size_t subgraphIndex, size_t operatorIndex); |
| 151 | void ParseReduceMin(size_t subgraphIndex, size_t operatorIndex); |
Sadik Armagan | 58f3919 | 2018-09-17 14:14:39 +0100 | [diff] [blame] | 152 | void ParseRelu(size_t subgraphIndex, size_t operatorIndex); |
| 153 | void ParseRelu6(size_t subgraphIndex, size_t operatorIndex); |
Sadik | b94967b | 2018-09-19 15:30:00 +0100 | [diff] [blame] | 154 | void ParseReshape(size_t subgraphIndex, size_t operatorIndex); |
Sadik Armagan | a3b31f0 | 2019-12-05 09:08:53 +0000 | [diff] [blame] | 155 | void ParseResize(size_t subgraphIndex, size_t operatorIndex, armnn::ResizeMethod resizeMethod); |
Bruno Goncalves | 3f58ddb | 2019-02-07 18:40:11 -0200 | [diff] [blame] | 156 | void ParseResizeBilinear(size_t subgraphIndex, size_t operatorIndex); |
Sadik Armagan | a3b31f0 | 2019-12-05 09:08:53 +0000 | [diff] [blame] | 157 | void ParseResizeNearestNeighbor(size_t subgraphIndex, size_t operatorIndex); |
Matthew Sloyan | ed7fce4 | 2021-04-15 20:46:24 +0100 | [diff] [blame] | 158 | void ParseRsqrt(size_t subgraphIndex, size_t operatorIndex); |
Keith Davis | 0176fd8 | 2021-06-01 17:36:32 +0100 | [diff] [blame] | 159 | void ParseShape(size_t subgraphIndex, size_t operatorIndex); |
josh minor | ba424d2 | 2019-11-13 10:55:17 -0600 | [diff] [blame] | 160 | void ParseSlice(size_t subgraphIndex, size_t operatorIndex); |
Sadik Armagan | 479045b | 2018-10-01 11:51:37 +0100 | [diff] [blame] | 161 | void ParseSoftmax(size_t subgraphIndex, size_t operatorIndex); |
Bruno Goncalves | baded14 | 2019-02-08 19:02:48 -0200 | [diff] [blame] | 162 | void ParseSpaceToBatchND(size_t subgraphIndex, size_t operatorIndex); |
Nina Drozd | 200e380 | 2019-04-15 09:47:39 +0100 | [diff] [blame] | 163 | void ParseSplit(size_t subgraphIndex, size_t operatorIndex); |
Derek Lamberti | f017699 | 2020-04-28 13:37:49 +0100 | [diff] [blame] | 164 | void ParseSplitV(size_t subgraphIndex, size_t operatorIndex); |
Sadik Armagan | 479045b | 2018-10-01 11:51:37 +0100 | [diff] [blame] | 165 | void ParseSqueeze(size_t subgraphIndex, size_t operatorIndex); |
Bruno Goncalves | 451d95b | 2019-02-12 22:59:22 -0200 | [diff] [blame] | 166 | void ParseStridedSlice(size_t subgraphIndex, size_t operatorIndex); |
Bruno Goncalves | bbeae26 | 2019-02-07 18:37:39 -0200 | [diff] [blame] | 167 | void ParseSub(size_t subgraphIndex, size_t operatorIndex); |
Sadik Armagan | 0c3ea5b | 2021-02-03 09:29:30 +0000 | [diff] [blame] | 168 | void ParseSum(size_t subgraphIndex, size_t operatorIndex); |
Nina Drozd | 9985176 | 2019-04-09 09:37:38 +0100 | [diff] [blame] | 169 | void ParseTanH(size_t subgraphIndex, size_t operatorIndex); |
Keith Davis | 4cd29a0 | 2019-09-09 14:49:20 +0100 | [diff] [blame] | 170 | void ParseTranspose(size_t subgraphIndex, size_t operatorIndex); |
Matthew Jackson | 74bf7da | 2019-08-16 16:51:42 +0100 | [diff] [blame] | 171 | void ParseTransposeConv(size_t subgraphIndex, size_t operatorIndex); |
Nina Drozd | 200e380 | 2019-04-15 09:47:39 +0100 | [diff] [blame] | 172 | void ParseUnpack(size_t subgraphIndex, size_t operatorIndex); |
Nattapat Chaimanowong | b66504b | 2018-10-17 15:19:14 +0100 | [diff] [blame] | 173 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 174 | void RegisterProducerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IOutputSlot* slot); |
| 175 | void RegisterConsumerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IInputSlot* slot); |
| 176 | void RegisterInputSlots(size_t subgraphIndex, |
| 177 | size_t operatorIndex, |
| 178 | armnn::IConnectableLayer* layer, |
Finn Williams | d4fa545 | 2021-03-01 12:31:41 +0000 | [diff] [blame] | 179 | const std::vector<unsigned int>& tensorIndexes, |
| 180 | unsigned int startingSlotIndex = 0); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 181 | void RegisterOutputSlots(size_t subgraphIndex, |
| 182 | size_t operatorIndex, |
| 183 | armnn::IConnectableLayer* layer, |
| 184 | const std::vector<unsigned int>& tensorIndexes); |
| 185 | |
| 186 | void SetupInputLayers(size_t subgraphIndex); |
| 187 | void SetupOutputLayers(size_t subgraphIndex); |
Bruno Goncalves | 3d7efe9 | 2018-12-27 14:21:43 -0200 | [diff] [blame] | 188 | void SetupConstantLayers(size_t subgraphIndex); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 189 | |
| 190 | void ResetParser(); |
| 191 | |
Bruno Goncalves | 9c761a6 | 2018-12-27 14:20:35 -0200 | [diff] [blame] | 192 | void AddBroadcastReshapeLayer(size_t subgraphIndex, |
| 193 | size_t operatorIndex, |
| 194 | armnn::IConnectableLayer* layer); |
| 195 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 196 | /// Attach an activation layer to the one passed as a parameter |
Sadik Armagan | 58f3919 | 2018-09-17 14:14:39 +0100 | [diff] [blame] | 197 | armnn::IConnectableLayer* AddFusedActivationLayer(armnn::IConnectableLayer* layer, |
| 198 | unsigned int outputSlot, |
| 199 | tflite::ActivationFunctionType activationType); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 200 | |
| 201 | // SupportedDataStorage's purpose is to hold data till we pass over to the network. |
| 202 | // We don't care about the content, and we want a single datatype to simplify the code. |
| 203 | struct SupportedDataStorage |
| 204 | { |
Matteo Martincigh | 747ef82 | 2018-12-18 09:26:39 +0000 | [diff] [blame] | 205 | public: |
| 206 | // Convenience constructors |
| 207 | SupportedDataStorage(std::unique_ptr<float[]>&& data); |
| 208 | SupportedDataStorage(std::unique_ptr<uint8_t[]>&& data); |
Keith Davis | d305e1a | 2020-01-22 11:57:54 +0000 | [diff] [blame] | 209 | SupportedDataStorage(std::unique_ptr<int8_t[]>&& data); |
Matteo Martincigh | 747ef82 | 2018-12-18 09:26:39 +0000 | [diff] [blame] | 210 | SupportedDataStorage(std::unique_ptr<int32_t[]>&& data); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 211 | |
Matteo Martincigh | 747ef82 | 2018-12-18 09:26:39 +0000 | [diff] [blame] | 212 | private: |
| 213 | // Pointers to the data buffers |
| 214 | std::unique_ptr<float[]> m_FloatData; |
| 215 | std::unique_ptr<uint8_t[]> m_Uint8Data; |
Keith Davis | d305e1a | 2020-01-22 11:57:54 +0000 | [diff] [blame] | 216 | std::unique_ptr<int8_t[]> m_Int8Data; |
Matteo Martincigh | 747ef82 | 2018-12-18 09:26:39 +0000 | [diff] [blame] | 217 | std::unique_ptr<int32_t[]> m_Int32Data; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 218 | }; |
| 219 | |
Finn Williams | d4fa545 | 2021-03-01 12:31:41 +0000 | [diff] [blame] | 220 | bool IsConstTensor(TensorRawPtr tensorPtr); |
| 221 | armnn::ConstTensor CreateConstTensorNonPermuted(TensorRawPtr tensorPtr, |
| 222 | armnn::TensorInfo& tensorInfo); |
| 223 | std::pair<armnn::ConstTensor, SupportedDataStorage> |
| 224 | CreateConstTensorPermuted(TensorRawPtr tensorPtr, |
| 225 | armnn::TensorInfo& tensorInfo, |
| 226 | armnn::Optional<armnn::PermutationVector&> permutationVector); |
Matteo Martincigh | 747ef82 | 2018-12-18 09:26:39 +0000 | [diff] [blame] | 227 | |
| 228 | template<typename T> |
Kevin May | 7d96b16 | 2021-02-03 17:38:41 +0000 | [diff] [blame] | 229 | std::pair<armnn::ConstTensor, TfLiteParserImpl::SupportedDataStorage> |
| 230 | CreateConstTensorAndStoreData(TfLiteParserImpl::BufferRawPtr bufferPtr, |
| 231 | TfLiteParserImpl::TensorRawPtr tensorPtr, |
Matteo Martincigh | 747ef82 | 2018-12-18 09:26:39 +0000 | [diff] [blame] | 232 | armnn::TensorInfo& tensorInfo, |
| 233 | armnn::Optional<armnn::PermutationVector&> permutationVector); |
| 234 | |
Aron Virginas-Tar | c975f92 | 2019-10-23 17:38:17 +0100 | [diff] [blame] | 235 | // Settings for configuring the TfLiteParser |
| 236 | armnn::Optional<ITfLiteParser::TfLiteParserOptions> m_Options; |
| 237 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 238 | /// The network we're building. Gets cleared after it is passed to the user |
| 239 | armnn::INetworkPtr m_Network; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 240 | ModelPtr m_Model; |
| 241 | |
Aron Virginas-Tar | c975f92 | 2019-10-23 17:38:17 +0100 | [diff] [blame] | 242 | std::vector<OperatorParsingFunction> m_ParserFunctions; |
| 243 | std::unordered_map<std::string, OperatorParsingFunction> m_CustomParserFunctions; |
| 244 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 245 | /// A mapping of an output slot to each of the input slots it should be connected to |
| 246 | /// The outputSlot is from the layer that creates this tensor as one of its ouputs |
| 247 | /// The inputSlots are from the layers that use this tensor as one of their inputs |
| 248 | struct TensorSlots |
| 249 | { |
| 250 | armnn::IOutputSlot* outputSlot; |
| 251 | std::vector<armnn::IInputSlot*> inputSlots; |
| 252 | |
| 253 | TensorSlots() : outputSlot(nullptr) { } |
| 254 | }; |
| 255 | typedef std::vector<TensorSlots> TensorConnections; |
| 256 | /// Connections for tensors in each subgraph |
| 257 | /// The first index is the subgraph ID, the second index is the tensor ID |
| 258 | std::vector<TensorConnections> m_SubgraphConnections; |
Narumol Prangnawarat | 4628d05 | 2019-02-25 17:26:05 +0000 | [diff] [blame] | 259 | |
| 260 | /// This is used in case that the model does not speciry the output. |
| 261 | /// The shape can be calculated from the options. |
| 262 | std::vector<std::vector<unsigned int>> m_OverridenOutputShapes; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 263 | }; |
| 264 | |
| 265 | } |