blob: 36b9246ee5f3f4718ee7ad1066dd58f9ff0ee402 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5#pragma once
6
7#include "armnn/Types.hpp"
8#include "armnn/NetworkFwd.hpp"
9#include "armnn/Tensor.hpp"
10#include "armnn/INetwork.hpp"
11
12#include <memory>
13#include <map>
14#include <vector>
15
16namespace armnnTfLiteParser
17{
18
Jim Flynnb4d7eae2019-05-01 14:44:27 +010019using BindingPointInfo = armnn::BindingPointInfo;
telsoa01c577f2c2018-08-31 09:22:23 +010020
21class ITfLiteParser;
22using ITfLiteParserPtr = std::unique_ptr<ITfLiteParser, void(*)(ITfLiteParser* parser)>;
23
24class ITfLiteParser
25{
26public:
27 static ITfLiteParser* CreateRaw();
28 static ITfLiteParserPtr Create();
29 static void Destroy(ITfLiteParser* parser);
30
31 /// Create the network from a flatbuffers binary file on disk
32 virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile) = 0;
33
34 /// Create the network from a flatbuffers binary
35 virtual armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t> & binaryContent) = 0;
36
37 /// Retrieve binding info (layer id and tensor info) for the network input identified by
38 /// the given layer name and subgraph id
39 virtual BindingPointInfo GetNetworkInputBindingInfo(size_t subgraphId,
40 const std::string& name) const = 0;
41
42 /// Retrieve binding info (layer id and tensor info) for the network output identified by
43 /// the given layer name and subgraph id
44 virtual BindingPointInfo GetNetworkOutputBindingInfo(size_t subgraphId,
45 const std::string& name) const = 0;
46
47 /// Return the number of subgraphs in the parsed model
48 virtual size_t GetSubgraphCount() const = 0;
49
50 /// Return the input tensor names for a given subgraph
51 virtual std::vector<std::string> GetSubgraphInputTensorNames(size_t subgraphId) const = 0;
52
53 /// Return the output tensor names for a given subgraph
54 virtual std::vector<std::string> GetSubgraphOutputTensorNames(size_t subgraphId) const = 0;
55
56protected:
57 virtual ~ITfLiteParser() {};
58};
59
60}