blob: de1eae76355af6d18a5c10724cc79413ad6c1047 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5#pragma once
6
7#include "armnn/Types.hpp"
8#include "armnn/NetworkFwd.hpp"
9#include "armnn/Tensor.hpp"
10#include "armnn/INetwork.hpp"
Aron Virginas-Tarc975f922019-10-23 17:38:17 +010011#include "armnn/Optional.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +010012
13#include <memory>
14#include <map>
15#include <vector>
16
17namespace armnnTfLiteParser
18{
19
Jim Flynnb4d7eae2019-05-01 14:44:27 +010020using BindingPointInfo = armnn::BindingPointInfo;
telsoa01c577f2c2018-08-31 09:22:23 +010021
22class ITfLiteParser;
23using ITfLiteParserPtr = std::unique_ptr<ITfLiteParser, void(*)(ITfLiteParser* parser)>;
24
25class ITfLiteParser
26{
27public:
Aron Virginas-Tarc975f922019-10-23 17:38:17 +010028 struct TfLiteParserOptions
29 {
30 TfLiteParserOptions()
31 : m_StandInLayerForUnsupported(false) {}
32
33 bool m_StandInLayerForUnsupported;
34 };
35
36 static ITfLiteParser* CreateRaw(const armnn::Optional<TfLiteParserOptions>& options = armnn::EmptyOptional());
37 static ITfLiteParserPtr Create(const armnn::Optional<TfLiteParserOptions>& options = armnn::EmptyOptional());
telsoa01c577f2c2018-08-31 09:22:23 +010038 static void Destroy(ITfLiteParser* parser);
39
40 /// Create the network from a flatbuffers binary file on disk
41 virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile) = 0;
42
43 /// Create the network from a flatbuffers binary
44 virtual armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t> & binaryContent) = 0;
45
46 /// Retrieve binding info (layer id and tensor info) for the network input identified by
47 /// the given layer name and subgraph id
48 virtual BindingPointInfo GetNetworkInputBindingInfo(size_t subgraphId,
49 const std::string& name) const = 0;
50
51 /// Retrieve binding info (layer id and tensor info) for the network output identified by
52 /// the given layer name and subgraph id
53 virtual BindingPointInfo GetNetworkOutputBindingInfo(size_t subgraphId,
54 const std::string& name) const = 0;
55
56 /// Return the number of subgraphs in the parsed model
57 virtual size_t GetSubgraphCount() const = 0;
58
59 /// Return the input tensor names for a given subgraph
60 virtual std::vector<std::string> GetSubgraphInputTensorNames(size_t subgraphId) const = 0;
61
62 /// Return the output tensor names for a given subgraph
63 virtual std::vector<std::string> GetSubgraphOutputTensorNames(size_t subgraphId) const = 0;
64
65protected:
66 virtual ~ITfLiteParser() {};
67};
68
69}