blob: 620648a0c3c21d248a671b005301fb803d3e6ce2 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5#pragma once
6
7#include "armnn/INetwork.hpp"
8#include "armnnTfLiteParser/ITfLiteParser.hpp"
9
10#include <schema_generated.h>
11#include <functional>
12#include <vector>
13
14namespace armnnTfLiteParser
15{
16
17class TfLiteParser : public ITfLiteParser
18{
19public:
20 // Shorthands for TfLite types
21 using ModelPtr = std::unique_ptr<tflite::ModelT>;
22 using SubGraphPtr = std::unique_ptr<tflite::SubGraphT>;
23 using OperatorPtr = std::unique_ptr<tflite::OperatorT>;
24 using OperatorCodePtr = std::unique_ptr<tflite::OperatorCodeT>;
25 using TensorPtr = std::unique_ptr<tflite::TensorT>;
26 using TensorRawPtr = const tflite::TensorT *;
27 using TensorRawPtrVector = std::vector<TensorRawPtr>;
28 using TensorIdRawPtr = std::pair<size_t, TensorRawPtr>;
29 using TensorIdRawPtrVector = std::vector<TensorIdRawPtr>;
30 using BufferPtr = std::unique_ptr<tflite::BufferT>;
31 using BufferRawPtr = const tflite::BufferT *;
32
33public:
34 /// Create the network from a flatbuffers binary file on disk
35 virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile) override;
36
37 /// Create the network from a flatbuffers binary
38 virtual armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t> & binaryContent) override;
39
40
41 /// Retrieve binding info (layer id and tensor info) for the network input identified by
42 /// the given layer name and subgraph id
43 virtual BindingPointInfo GetNetworkInputBindingInfo(size_t subgraphId,
44 const std::string& name) const override;
45
46 /// Retrieve binding info (layer id and tensor info) for the network output identified by
47 /// the given layer name and subgraph id
48 virtual BindingPointInfo GetNetworkOutputBindingInfo(size_t subgraphId,
49 const std::string& name) const override;
50
51 /// Return the number of subgraphs in the parsed model
52 virtual size_t GetSubgraphCount() const override;
53
54 /// Return the input tensor names for a given subgraph
55 virtual std::vector<std::string> GetSubgraphInputTensorNames(size_t subgraphId) const override;
56
57 /// Return the output tensor names for a given subgraph
58 virtual std::vector<std::string> GetSubgraphOutputTensorNames(size_t subgraphId) const override;
59
60 TfLiteParser();
61 virtual ~TfLiteParser() {}
62
63public:
64 // testable helpers
65 static ModelPtr LoadModelFromFile(const char * fileName);
66 static ModelPtr LoadModelFromBinary(const uint8_t * binaryContent, size_t len);
67 static TensorRawPtrVector GetInputs(const ModelPtr & model, size_t subgraphIndex, size_t operatorIndex);
68 static TensorRawPtrVector GetOutputs(const ModelPtr & model, size_t subgraphIndex, size_t operatorIndex);
69 static TensorIdRawPtrVector GetSubgraphInputs(const ModelPtr & model, size_t subgraphIndex);
70 static TensorIdRawPtrVector GetSubgraphOutputs(const ModelPtr & model, size_t subgraphIndex);
71 static std::vector<int32_t>& GetInputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
72 static std::vector<int32_t>& GetOutputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
73
74 static BufferRawPtr GetBuffer(const ModelPtr& model, size_t bufferIndex);
75 static armnn::TensorInfo OutputShapeOfSqueeze(const std::vector<uint32_t> & squeezeDims,
76 const armnn::TensorInfo & inputTensorInfo);
Sadikb94967b2018-09-19 15:30:00 +010077 static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo & inputTensorInfo,
78 const std::vector<int32_t> & targetDimsIn);
telsoa01c577f2c2018-08-31 09:22:23 +010079
80private:
81 // No copying allowed until it is wanted and properly implemented
82 TfLiteParser(const TfLiteParser &) = delete;
83 TfLiteParser & operator=(const TfLiteParser &) = delete;
84
85 /// Create the network from an already loaded flatbuffers model
86 armnn::INetworkPtr CreateNetworkFromModel();
87
88 // signature for the parser functions
89 using OperatorParsingFunction = void(TfLiteParser::*)(size_t subgraphIndex, size_t operatorIndex);
90
91 void ParseUnsupportedOperator(size_t subgraphIndex, size_t operatorIndex);
92 void ParseAveragePool2D(size_t subgraphIndex, size_t operatorIndex);
Sadik Armagan479045b2018-10-01 11:51:37 +010093 void ParseConcatenation(size_t subgraphIndex, size_t operatorIndex);
telsoa01c577f2c2018-08-31 09:22:23 +010094 void ParseConv2D(size_t subgraphIndex, size_t operatorIndex);
95 void ParseDepthwiseConv2D(size_t subgraphIndex, size_t operatorIndex);
Sadik Armagan58f39192018-09-17 14:14:39 +010096 void ParseRelu(size_t subgraphIndex, size_t operatorIndex);
97 void ParseRelu6(size_t subgraphIndex, size_t operatorIndex);
Sadikb94967b2018-09-19 15:30:00 +010098 void ParseReshape(size_t subgraphIndex, size_t operatorIndex);
Sadik Armagan479045b2018-10-01 11:51:37 +010099 void ParseSoftmax(size_t subgraphIndex, size_t operatorIndex);
100 void ParseSqueeze(size_t subgraphIndex, size_t operatorIndex);
telsoa01c577f2c2018-08-31 09:22:23 +0100101
102 void RegisterProducerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IOutputSlot* slot);
103 void RegisterConsumerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IInputSlot* slot);
104 void RegisterInputSlots(size_t subgraphIndex,
105 size_t operatorIndex,
106 armnn::IConnectableLayer* layer,
107 const std::vector<unsigned int>& tensorIndexes);
108 void RegisterOutputSlots(size_t subgraphIndex,
109 size_t operatorIndex,
110 armnn::IConnectableLayer* layer,
111 const std::vector<unsigned int>& tensorIndexes);
112
113 void SetupInputLayers(size_t subgraphIndex);
114 void SetupOutputLayers(size_t subgraphIndex);
115
116 void ResetParser();
117
118 /// Attach an activation layer to the one passed as a parameter
Sadik Armagan58f39192018-09-17 14:14:39 +0100119 armnn::IConnectableLayer* AddFusedActivationLayer(armnn::IConnectableLayer* layer,
120 unsigned int outputSlot,
121 tflite::ActivationFunctionType activationType);
telsoa01c577f2c2018-08-31 09:22:23 +0100122
123 // SupportedDataStorage's purpose is to hold data till we pass over to the network.
124 // We don't care about the content, and we want a single datatype to simplify the code.
125 struct SupportedDataStorage
126 {
127 std::unique_ptr<float[]> m_FloatData;
128 std::unique_ptr<uint8_t[]> m_Uint8Data;
129 std::unique_ptr<int32_t[]> m_Int32Data;
130
131 SupportedDataStorage(std::unique_ptr<float[]> && data);
132 SupportedDataStorage(std::unique_ptr<uint8_t[]> && data);
133 SupportedDataStorage(std::unique_ptr<int32_t[]> && data);
134 };
135
136 std::pair<armnn::ConstTensor, SupportedDataStorage> CreateConstTensor(TensorRawPtr tensorPtr,
137 armnn::TensorInfo & tensorInfo,
138 bool convertFromTfToArmnnFormat);
139
140 /// The network we're building. Gets cleared after it is passed to the user
141 armnn::INetworkPtr m_Network;
142 std::vector<OperatorParsingFunction> m_ParserFunctions;
143 ModelPtr m_Model;
144
145 /// A mapping of an output slot to each of the input slots it should be connected to
146 /// The outputSlot is from the layer that creates this tensor as one of its ouputs
147 /// The inputSlots are from the layers that use this tensor as one of their inputs
148 struct TensorSlots
149 {
150 armnn::IOutputSlot* outputSlot;
151 std::vector<armnn::IInputSlot*> inputSlots;
152
153 TensorSlots() : outputSlot(nullptr) { }
154 };
155 typedef std::vector<TensorSlots> TensorConnections;
156 /// Connections for tensors in each subgraph
157 /// The first index is the subgraph ID, the second index is the tensor ID
158 std::vector<TensorConnections> m_SubgraphConnections;
159};
160
161}