blob: a68b719a668c840d6363726b3209fac222e5eb80 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5#pragma once
6
7#include "armnn/Types.hpp"
8#include "armnn/NetworkFwd.hpp"
9#include "armnn/Tensor.hpp"
10#include "armnn/INetwork.hpp"
Aron Virginas-Tarc975f922019-10-23 17:38:17 +010011#include "armnn/Optional.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +010012
13#include <memory>
14#include <map>
15#include <vector>
16
17namespace armnnTfLiteParser
18{
19
Jim Flynnb4d7eae2019-05-01 14:44:27 +010020using BindingPointInfo = armnn::BindingPointInfo;
telsoa01c577f2c2018-08-31 09:22:23 +010021
22class ITfLiteParser;
23using ITfLiteParserPtr = std::unique_ptr<ITfLiteParser, void(*)(ITfLiteParser* parser)>;
24
25class ITfLiteParser
26{
27public:
Aron Virginas-Tarc975f922019-10-23 17:38:17 +010028 struct TfLiteParserOptions
29 {
30 TfLiteParserOptions()
Sadik Armagand109a4d2020-07-28 10:42:13 +010031 : m_StandInLayerForUnsupported(false),
32 m_InferAndValidate(false) {}
Aron Virginas-Tarc975f922019-10-23 17:38:17 +010033
34 bool m_StandInLayerForUnsupported;
Sadik Armagand109a4d2020-07-28 10:42:13 +010035 bool m_InferAndValidate;
Aron Virginas-Tarc975f922019-10-23 17:38:17 +010036 };
37
38 static ITfLiteParser* CreateRaw(const armnn::Optional<TfLiteParserOptions>& options = armnn::EmptyOptional());
39 static ITfLiteParserPtr Create(const armnn::Optional<TfLiteParserOptions>& options = armnn::EmptyOptional());
telsoa01c577f2c2018-08-31 09:22:23 +010040 static void Destroy(ITfLiteParser* parser);
41
42 /// Create the network from a flatbuffers binary file on disk
43 virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile) = 0;
44
45 /// Create the network from a flatbuffers binary
46 virtual armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t> & binaryContent) = 0;
47
48 /// Retrieve binding info (layer id and tensor info) for the network input identified by
49 /// the given layer name and subgraph id
50 virtual BindingPointInfo GetNetworkInputBindingInfo(size_t subgraphId,
51 const std::string& name) const = 0;
52
53 /// Retrieve binding info (layer id and tensor info) for the network output identified by
54 /// the given layer name and subgraph id
55 virtual BindingPointInfo GetNetworkOutputBindingInfo(size_t subgraphId,
56 const std::string& name) const = 0;
57
58 /// Return the number of subgraphs in the parsed model
59 virtual size_t GetSubgraphCount() const = 0;
60
61 /// Return the input tensor names for a given subgraph
62 virtual std::vector<std::string> GetSubgraphInputTensorNames(size_t subgraphId) const = 0;
63
64 /// Return the output tensor names for a given subgraph
65 virtual std::vector<std::string> GetSubgraphOutputTensorNames(size_t subgraphId) const = 0;
66
67protected:
68 virtual ~ITfLiteParser() {};
69};
70
71}