blob: ea6e87a0a7f935b351843343ce1eb7fb84aa2366 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5#pragma once
6
7#include "armnn/Types.hpp"
8#include "armnn/NetworkFwd.hpp"
9#include "armnn/Tensor.hpp"
10#include "armnn/INetwork.hpp"
Aron Virginas-Tarc975f922019-10-23 17:38:17 +010011#include "armnn/Optional.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +010012
13#include <memory>
14#include <map>
15#include <vector>
16
17namespace armnnTfLiteParser
18{
19
Jim Flynnb4d7eae2019-05-01 14:44:27 +010020using BindingPointInfo = armnn::BindingPointInfo;
telsoa01c577f2c2018-08-31 09:22:23 +010021
Kevin May7d96b162021-02-03 17:38:41 +000022class TfLiteParserImpl;
telsoa01c577f2c2018-08-31 09:22:23 +010023class ITfLiteParser;
24using ITfLiteParserPtr = std::unique_ptr<ITfLiteParser, void(*)(ITfLiteParser* parser)>;
25
26class ITfLiteParser
27{
28public:
Aron Virginas-Tarc975f922019-10-23 17:38:17 +010029 struct TfLiteParserOptions
30 {
31 TfLiteParserOptions()
Mike Kelly80512b02022-05-16 23:10:42 +010032 : m_AllowExpandedDims(false),
33 m_StandInLayerForUnsupported(false),
Sadik Armagand109a4d2020-07-28 10:42:13 +010034 m_InferAndValidate(false) {}
Aron Virginas-Tarc975f922019-10-23 17:38:17 +010035
Mike Kelly80512b02022-05-16 23:10:42 +010036 bool m_AllowExpandedDims;
Aron Virginas-Tarc975f922019-10-23 17:38:17 +010037 bool m_StandInLayerForUnsupported;
Sadik Armagand109a4d2020-07-28 10:42:13 +010038 bool m_InferAndValidate;
Aron Virginas-Tarc975f922019-10-23 17:38:17 +010039 };
40
41 static ITfLiteParser* CreateRaw(const armnn::Optional<TfLiteParserOptions>& options = armnn::EmptyOptional());
42 static ITfLiteParserPtr Create(const armnn::Optional<TfLiteParserOptions>& options = armnn::EmptyOptional());
telsoa01c577f2c2018-08-31 09:22:23 +010043 static void Destroy(ITfLiteParser* parser);
44
45 /// Create the network from a flatbuffers binary file on disk
Kevin May7d96b162021-02-03 17:38:41 +000046 armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile);
telsoa01c577f2c2018-08-31 09:22:23 +010047
48 /// Create the network from a flatbuffers binary
Kevin May7d96b162021-02-03 17:38:41 +000049 armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t> & binaryContent);
telsoa01c577f2c2018-08-31 09:22:23 +010050
51 /// Retrieve binding info (layer id and tensor info) for the network input identified by
52 /// the given layer name and subgraph id
Kevin May7d96b162021-02-03 17:38:41 +000053 BindingPointInfo GetNetworkInputBindingInfo(size_t subgraphId,
54 const std::string& name) const;
telsoa01c577f2c2018-08-31 09:22:23 +010055
56 /// Retrieve binding info (layer id and tensor info) for the network output identified by
57 /// the given layer name and subgraph id
Kevin May7d96b162021-02-03 17:38:41 +000058 BindingPointInfo GetNetworkOutputBindingInfo(size_t subgraphId,
59 const std::string& name) const;
telsoa01c577f2c2018-08-31 09:22:23 +010060
61 /// Return the number of subgraphs in the parsed model
Kevin May7d96b162021-02-03 17:38:41 +000062 size_t GetSubgraphCount() const;
telsoa01c577f2c2018-08-31 09:22:23 +010063
64 /// Return the input tensor names for a given subgraph
Kevin May7d96b162021-02-03 17:38:41 +000065 std::vector<std::string> GetSubgraphInputTensorNames(size_t subgraphId) const;
telsoa01c577f2c2018-08-31 09:22:23 +010066
67 /// Return the output tensor names for a given subgraph
Kevin May7d96b162021-02-03 17:38:41 +000068 std::vector<std::string> GetSubgraphOutputTensorNames(size_t subgraphId) const;
telsoa01c577f2c2018-08-31 09:22:23 +010069
Kevin May7d96b162021-02-03 17:38:41 +000070private:
71 ITfLiteParser(const armnn::Optional<TfLiteParserOptions>& options = armnn::EmptyOptional());
72 ~ITfLiteParser();
73
74 std::unique_ptr<TfLiteParserImpl> pTfLiteParserImpl;
telsoa01c577f2c2018-08-31 09:22:23 +010075};
76
77}