blob: 91e4cb39bfa2e2e677853ca489d9ae44e825a911 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
surmeh01bceff2f2018-03-29 16:29:27 +01002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
surmeh01bceff2f2018-03-29 16:29:27 +01004//
5#pragma once
6
7#include "armnn/Types.hpp"
8#include "armnn/Tensor.hpp"
9#include "armnn/INetwork.hpp"
10
11#include <map>
12#include <memory>
13#include <unordered_map>
14#include <vector>
15
16namespace armnnTfParser
17{
18
Jim Flynnb4d7eae2019-05-01 14:44:27 +010019using BindingPointInfo = armnn::BindingPointInfo;
surmeh01bceff2f2018-03-29 16:29:27 +010020
21class ITfParser;
22using ITfParserPtr = std::unique_ptr<ITfParser, void(*)(ITfParser* parser)>;
23
telsoa01c577f2c2018-08-31 09:22:23 +010024/// Parses a directed acyclic graph from a tensorflow protobuf file.
surmeh01bceff2f2018-03-29 16:29:27 +010025class ITfParser
26{
27public:
28 static ITfParser* CreateRaw();
29 static ITfParserPtr Create();
30 static void Destroy(ITfParser* parser);
31
telsoa01c577f2c2018-08-31 09:22:23 +010032 /// Create the network from a protobuf text file on the disk.
Kevin May7d96b162021-02-03 17:38:41 +000033 armnn::INetworkPtr CreateNetworkFromTextFile(
surmeh01bceff2f2018-03-29 16:29:27 +010034 const char* graphFile,
35 const std::map<std::string, armnn::TensorShape>& inputShapes,
Kevin May7d96b162021-02-03 17:38:41 +000036 const std::vector<std::string>& requestedOutputs);
surmeh01bceff2f2018-03-29 16:29:27 +010037
telsoa01c577f2c2018-08-31 09:22:23 +010038 /// Create the network from a protobuf binary file on the disk.
Kevin May7d96b162021-02-03 17:38:41 +000039 armnn::INetworkPtr CreateNetworkFromBinaryFile(
surmeh01bceff2f2018-03-29 16:29:27 +010040 const char* graphFile,
41 const std::map<std::string, armnn::TensorShape>& inputShapes,
Kevin May7d96b162021-02-03 17:38:41 +000042 const std::vector<std::string>& requestedOutputs);
surmeh01bceff2f2018-03-29 16:29:27 +010043
telsoa01c577f2c2018-08-31 09:22:23 +010044 /// Create the network directly from protobuf text in a string. Useful for debugging/testing.
Kevin May7d96b162021-02-03 17:38:41 +000045 armnn::INetworkPtr CreateNetworkFromString(
surmeh01bceff2f2018-03-29 16:29:27 +010046 const char* protoText,
47 const std::map<std::string, armnn::TensorShape>& inputShapes,
Kevin May7d96b162021-02-03 17:38:41 +000048 const std::vector<std::string>& requestedOutputs);
surmeh01bceff2f2018-03-29 16:29:27 +010049
telsoa01c577f2c2018-08-31 09:22:23 +010050 /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name.
Kevin May7d96b162021-02-03 17:38:41 +000051 BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const;
surmeh01bceff2f2018-03-29 16:29:27 +010052
telsoa01c577f2c2018-08-31 09:22:23 +010053 /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name.
Kevin May7d96b162021-02-03 17:38:41 +000054 BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const;
surmeh01bceff2f2018-03-29 16:29:27 +010055
Kevin May7d96b162021-02-03 17:38:41 +000056private:
57 template <typename T>
58 friend class ParsedConstTfOperation;
59 friend class ParsedMatMulTfOperation;
60 friend class ParsedMulTfOperation;
61 friend class ParsedTfOperation;
62 friend class SingleLayerParsedTfOperation;
63 friend class DeferredSingleLayerParsedTfOperation;
64 friend class ParsedIdentityTfOperation;
65
66 template <template<typename> class OperatorType, typename T>
67 friend struct MakeTfOperation;
68
69
70 ITfParser();
71 ~ITfParser();
72
73 struct TfParserImpl;
74 std::unique_ptr<TfParserImpl> pTfParserImpl;
surmeh01bceff2f2018-03-29 16:29:27 +010075};
76
77}