blob: 361d6f428d6da2d7aafaeddfb41b824eef3b192c [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
6#pragma once
7
8#include <string>
9#include <vector>
10#include <iostream>
11
12#include "caffe/proto/caffe.pb.h"
13
14#include "CaffeParser.hpp"
15
16
17
18namespace armnnCaffeParser
19{
20
21class NetParameterInfo;
22class LayerParameterInfo;
23
24
25class RecordByRecordCaffeParser : public CaffeParserBase
26{
27public:
28
29 /// Create the network from a protobuf binary file on disk
30 virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(
31 const char* graphFile,
32 const std::map<std::string, armnn::TensorShape>& inputShapes,
33 const std::vector<std::string>& requestedOutputs) override;
34
35 RecordByRecordCaffeParser();
36
37private:
38 void ProcessLayers(const NetParameterInfo& netParameterInfo,
39 std::vector<LayerParameterInfo>& layerInfo,
40 const std::vector<std::string>& m_RequestedOutputs,
41 std::vector<const LayerParameterInfo*>& sortedNodes);
42 armnn::INetworkPtr LoadLayers(std::ifstream& ifs,
43 std::vector<const LayerParameterInfo *>& sortedNodes,
44 const NetParameterInfo& netParameterInfo);
45 std::vector<const LayerParameterInfo*> GetInputs(
46 const LayerParameterInfo& layerParam);
47
48 std::map<std::string, const LayerParameterInfo*> m_CaffeLayersByTopName;
49 std::vector<std::string> m_RequestedOutputs;
50};
51
52} // namespace armnnCaffeParser
53