blob: b286c1ee4ca8b602d5856a22b2e372a0ec96eb2f [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5#pragma once
6
7#include "armnn/Types.hpp"
8#include "armnn/NetworkFwd.hpp"
9#include "armnn/Tensor.hpp"
10#include "armnn/INetwork.hpp"
Aron Virginas-Tarc975f922019-10-23 17:38:17 +010011#include "armnn/Optional.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +010012
13#include <memory>
14#include <map>
15#include <vector>
16
17namespace armnnTfLiteParser
18{
19
Jim Flynnb4d7eae2019-05-01 14:44:27 +010020using BindingPointInfo = armnn::BindingPointInfo;
telsoa01c577f2c2018-08-31 09:22:23 +010021
Kevin May7d96b162021-02-03 17:38:41 +000022class TfLiteParserImpl;
telsoa01c577f2c2018-08-31 09:22:23 +010023class ITfLiteParser;
24using ITfLiteParserPtr = std::unique_ptr<ITfLiteParser, void(*)(ITfLiteParser* parser)>;
25
26class ITfLiteParser
27{
28public:
Aron Virginas-Tarc975f922019-10-23 17:38:17 +010029 struct TfLiteParserOptions
30 {
31 TfLiteParserOptions()
Sadik Armagand109a4d2020-07-28 10:42:13 +010032 : m_StandInLayerForUnsupported(false),
33 m_InferAndValidate(false) {}
Aron Virginas-Tarc975f922019-10-23 17:38:17 +010034
35 bool m_StandInLayerForUnsupported;
Sadik Armagand109a4d2020-07-28 10:42:13 +010036 bool m_InferAndValidate;
Aron Virginas-Tarc975f922019-10-23 17:38:17 +010037 };
38
39 static ITfLiteParser* CreateRaw(const armnn::Optional<TfLiteParserOptions>& options = armnn::EmptyOptional());
40 static ITfLiteParserPtr Create(const armnn::Optional<TfLiteParserOptions>& options = armnn::EmptyOptional());
telsoa01c577f2c2018-08-31 09:22:23 +010041 static void Destroy(ITfLiteParser* parser);
42
43 /// Create the network from a flatbuffers binary file on disk
Kevin May7d96b162021-02-03 17:38:41 +000044 armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile);
telsoa01c577f2c2018-08-31 09:22:23 +010045
46 /// Create the network from a flatbuffers binary
Kevin May7d96b162021-02-03 17:38:41 +000047 armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t> & binaryContent);
telsoa01c577f2c2018-08-31 09:22:23 +010048
49 /// Retrieve binding info (layer id and tensor info) for the network input identified by
50 /// the given layer name and subgraph id
Kevin May7d96b162021-02-03 17:38:41 +000051 BindingPointInfo GetNetworkInputBindingInfo(size_t subgraphId,
52 const std::string& name) const;
telsoa01c577f2c2018-08-31 09:22:23 +010053
54 /// Retrieve binding info (layer id and tensor info) for the network output identified by
55 /// the given layer name and subgraph id
Kevin May7d96b162021-02-03 17:38:41 +000056 BindingPointInfo GetNetworkOutputBindingInfo(size_t subgraphId,
57 const std::string& name) const;
telsoa01c577f2c2018-08-31 09:22:23 +010058
59 /// Return the number of subgraphs in the parsed model
Kevin May7d96b162021-02-03 17:38:41 +000060 size_t GetSubgraphCount() const;
telsoa01c577f2c2018-08-31 09:22:23 +010061
62 /// Return the input tensor names for a given subgraph
Kevin May7d96b162021-02-03 17:38:41 +000063 std::vector<std::string> GetSubgraphInputTensorNames(size_t subgraphId) const;
telsoa01c577f2c2018-08-31 09:22:23 +010064
65 /// Return the output tensor names for a given subgraph
Kevin May7d96b162021-02-03 17:38:41 +000066 std::vector<std::string> GetSubgraphOutputTensorNames(size_t subgraphId) const;
telsoa01c577f2c2018-08-31 09:22:23 +010067
Kevin May7d96b162021-02-03 17:38:41 +000068private:
69 ITfLiteParser(const armnn::Optional<TfLiteParserOptions>& options = armnn::EmptyOptional());
70 ~ITfLiteParser();
71
72 std::unique_ptr<TfLiteParserImpl> pTfLiteParserImpl;
telsoa01c577f2c2018-08-31 09:22:23 +010073};
74
75}