blob: 8a509e880e19227af94a76b19f13c57e0c86561d [file] [log] [blame]
Mike Kelly8c1701a2019-02-11 17:01:27 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include <armnn/ILayerVisitor.hpp>
8#include <armnn/LayerVisitorBase.hpp>
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +00009
10#include <armnnSerializer/ISerializer.hpp>
11
Mike Kelly8c1701a2019-02-11 17:01:27 +000012#include <iostream>
Saoirse Stewartcb8a3212019-02-14 15:46:10 +000013#include <unordered_map>
14
Mike Kelly8c1701a2019-02-11 17:01:27 +000015#include <Schema_generated.h>
16
17namespace armnnSerializer
18{
19
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +000020class SerializerVisitor : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy>
Mike Kelly8c1701a2019-02-11 17:01:27 +000021{
22public:
Saoirse Stewartcb8a3212019-02-14 15:46:10 +000023 SerializerVisitor() : m_layerId(0) {};
Matteo Martincighec333912019-02-13 15:12:39 +000024 ~SerializerVisitor() {}
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +000025
26 flatbuffers::FlatBufferBuilder& GetFlatBufferBuilder()
27 {
28 return m_flatBufferBuilder;
29 }
30
31 std::vector<unsigned int>& GetInputIds()
32 {
33 return m_inputIds;
34 }
35
36 std::vector<unsigned int>& GetOutputIds()
37 {
38 return m_outputIds;
39 }
40
41 std::vector<flatbuffers::Offset<armnn::armnnSerializer::AnyLayer>>& GetSerializedLayers()
42 {
43 return m_serializedLayers;
44 }
Mike Kelly8c1701a2019-02-11 17:01:27 +000045
46 void VisitAdditionLayer(const armnn::IConnectableLayer* layer,
47 const char* name = nullptr) override;
48
49 void VisitInputLayer(const armnn::IConnectableLayer* layer,
50 armnn::LayerBindingId id,
51 const char* name = nullptr) override;
52
53 void VisitOutputLayer(const armnn::IConnectableLayer* layer,
54 armnn::LayerBindingId id,
55 const char* name = nullptr) override;
56
Sadik Armagan5f450272019-02-12 14:31:45 +000057 void VisitMultiplicationLayer(const armnn::IConnectableLayer* layer,
58 const char* name = nullptr) override;
59
Mike Kelly8c1701a2019-02-11 17:01:27 +000060private:
61
62 /// Creates the Input Slots and Output Slots and LayerBase for the layer.
63 flatbuffers::Offset<armnn::armnnSerializer::LayerBase> CreateLayerBase(
64 const armnn::IConnectableLayer* layer,
65 const armnn::armnnSerializer::LayerType layerType);
66
67 /// Creates the serializer AnyLayer for the layer and adds it to m_serializedLayers.
68 void CreateAnyLayer(const flatbuffers::Offset<void>& layer, const armnn::armnnSerializer::Layer serializerLayer);
69
Saoirse Stewartcb8a3212019-02-14 15:46:10 +000070 ///Function which maps Guid to an index
71 uint32_t GetSerializedId(unsigned int guid);
72
Mike Kelly8c1701a2019-02-11 17:01:27 +000073 /// Creates the serializer InputSlots for the layer.
74 std::vector<flatbuffers::Offset<armnn::armnnSerializer::InputSlot>> CreateInputSlots(
75 const armnn::IConnectableLayer* layer);
76
77 /// Creates the serializer OutputSlots for the layer.
78 std::vector<flatbuffers::Offset<armnn::armnnSerializer::OutputSlot>> CreateOutputSlots(
79 const armnn::IConnectableLayer* layer);
80
81 /// FlatBufferBuilder to create our layers' FlatBuffers.
82 flatbuffers::FlatBufferBuilder m_flatBufferBuilder;
83
84 /// AnyLayers required by the SerializedGraph.
85 std::vector<flatbuffers::Offset<armnn::armnnSerializer::AnyLayer>> m_serializedLayers;
86
87 /// Guids of all Input Layers required by the SerializedGraph.
88 std::vector<unsigned int> m_inputIds;
89
90 /// Guids of all Output Layers required by the SerializedGraph.
91 std::vector<unsigned int> m_outputIds;
Saoirse Stewartcb8a3212019-02-14 15:46:10 +000092
93 /// Mapped Guids of all Layers to match our index.
94 std::unordered_map<unsigned int, uint32_t > m_guidMap;
95
96 /// layer within our FlatBuffer index.
97 uint32_t m_layerId;
Mike Kelly8c1701a2019-02-11 17:01:27 +000098};
99
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +0000100class Serializer : public ISerializer
101{
102public:
Matteo Martincighec333912019-02-13 15:12:39 +0000103 Serializer() {}
104 ~Serializer() {}
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +0000105
106 /// Serializes the network to ArmNN SerializedGraph.
107 /// @param [in] inNetwork The network to be serialized.
108 void Serialize(const armnn::INetwork& inNetwork) override;
109
110 /// Serializes the SerializedGraph to the stream.
111 /// @param [stream] the stream to save to
112 /// @return true if graph is Serialized to the Stream, false otherwise
113 bool SaveSerializedToStream(std::ostream& stream) override;
114
115private:
116
117 /// Visitor to contruct serialized network
118 SerializerVisitor m_SerializerVisitor;
119};
120
Mike Kelly8c1701a2019-02-11 17:01:27 +0000121} //namespace armnnSerializer