blob: b799dbe380ab1abda154620bcbae1a0f27b5fef0 [file] [log] [blame]
Mike Kelly8c1701a2019-02-11 17:01:27 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include <armnn/ILayerVisitor.hpp>
8#include <armnn/LayerVisitorBase.hpp>
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +00009
10#include <armnnSerializer/ISerializer.hpp>
11
Mike Kelly8c1701a2019-02-11 17:01:27 +000012#include <iostream>
13#include <Schema_generated.h>
14
15namespace armnnSerializer
16{
17
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +000018class SerializerVisitor : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy>
Mike Kelly8c1701a2019-02-11 17:01:27 +000019{
20public:
Matteo Martincighec333912019-02-13 15:12:39 +000021 SerializerVisitor() {}
22 ~SerializerVisitor() {}
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +000023
24 flatbuffers::FlatBufferBuilder& GetFlatBufferBuilder()
25 {
26 return m_flatBufferBuilder;
27 }
28
29 std::vector<unsigned int>& GetInputIds()
30 {
31 return m_inputIds;
32 }
33
34 std::vector<unsigned int>& GetOutputIds()
35 {
36 return m_outputIds;
37 }
38
39 std::vector<flatbuffers::Offset<armnn::armnnSerializer::AnyLayer>>& GetSerializedLayers()
40 {
41 return m_serializedLayers;
42 }
Mike Kelly8c1701a2019-02-11 17:01:27 +000043
44 void VisitAdditionLayer(const armnn::IConnectableLayer* layer,
45 const char* name = nullptr) override;
46
47 void VisitInputLayer(const armnn::IConnectableLayer* layer,
48 armnn::LayerBindingId id,
49 const char* name = nullptr) override;
50
51 void VisitOutputLayer(const armnn::IConnectableLayer* layer,
52 armnn::LayerBindingId id,
53 const char* name = nullptr) override;
54
Sadik Armagan5f450272019-02-12 14:31:45 +000055 void VisitMultiplicationLayer(const armnn::IConnectableLayer* layer,
56 const char* name = nullptr) override;
57
Mike Kelly8c1701a2019-02-11 17:01:27 +000058private:
59
60 /// Creates the Input Slots and Output Slots and LayerBase for the layer.
61 flatbuffers::Offset<armnn::armnnSerializer::LayerBase> CreateLayerBase(
62 const armnn::IConnectableLayer* layer,
63 const armnn::armnnSerializer::LayerType layerType);
64
65 /// Creates the serializer AnyLayer for the layer and adds it to m_serializedLayers.
66 void CreateAnyLayer(const flatbuffers::Offset<void>& layer, const armnn::armnnSerializer::Layer serializerLayer);
67
68 /// Creates the serializer InputSlots for the layer.
69 std::vector<flatbuffers::Offset<armnn::armnnSerializer::InputSlot>> CreateInputSlots(
70 const armnn::IConnectableLayer* layer);
71
72 /// Creates the serializer OutputSlots for the layer.
73 std::vector<flatbuffers::Offset<armnn::armnnSerializer::OutputSlot>> CreateOutputSlots(
74 const armnn::IConnectableLayer* layer);
75
76 /// FlatBufferBuilder to create our layers' FlatBuffers.
77 flatbuffers::FlatBufferBuilder m_flatBufferBuilder;
78
79 /// AnyLayers required by the SerializedGraph.
80 std::vector<flatbuffers::Offset<armnn::armnnSerializer::AnyLayer>> m_serializedLayers;
81
82 /// Guids of all Input Layers required by the SerializedGraph.
83 std::vector<unsigned int> m_inputIds;
84
85 /// Guids of all Output Layers required by the SerializedGraph.
86 std::vector<unsigned int> m_outputIds;
87};
88
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +000089class Serializer : public ISerializer
90{
91public:
Matteo Martincighec333912019-02-13 15:12:39 +000092 Serializer() {}
93 ~Serializer() {}
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +000094
95 /// Serializes the network to ArmNN SerializedGraph.
96 /// @param [in] inNetwork The network to be serialized.
97 void Serialize(const armnn::INetwork& inNetwork) override;
98
99 /// Serializes the SerializedGraph to the stream.
100 /// @param [stream] the stream to save to
101 /// @return true if graph is Serialized to the Stream, false otherwise
102 bool SaveSerializedToStream(std::ostream& stream) override;
103
104private:
105
106 /// Visitor to contruct serialized network
107 SerializerVisitor m_SerializerVisitor;
108};
109
Mike Kelly8c1701a2019-02-11 17:01:27 +0000110} //namespace armnnSerializer