blob: ef4be6923308e77097bdc6c7396b7ab1d56a5731 [file] [log] [blame]
Mike Kelly8c1701a2019-02-11 17:01:27 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "Serializer.hpp"
7#include <armnn/ArmNN.hpp>
8#include <iostream>
9#include <Schema_generated.h>
10#include <flatbuffers/util.h>
11
12using namespace armnn;
13namespace fb = flatbuffers;
14namespace serializer = armnn::armnnSerializer;
15
16namespace armnnSerializer
17{
18
19serializer::DataType GetFlatBufferDataType(DataType dataType)
20{
21 switch (dataType)
22 {
23 case DataType::Float32:
24 return serializer::DataType::DataType_Float32;
25 case DataType::Float16:
26 return serializer::DataType::DataType_Float16;
27 case DataType::Signed32:
28 return serializer::DataType::DataType_Signed32;
29 case DataType::QuantisedAsymm8:
30 return serializer::DataType::DataType_QuantisedAsymm8;
31 case DataType::Boolean:
32 return serializer::DataType::DataType_Boolean;
33 default:
34 return serializer::DataType::DataType_Float16;
35 }
36}
37
38// Build FlatBuffer for Input Layer
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +000039void SerializerVisitor::VisitInputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name)
Mike Kelly8c1701a2019-02-11 17:01:27 +000040{
41 // Create FlatBuffer BaseLayer
42 auto flatBufferInputBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Input);
43
44 // Create FlatBuffer BindableBaseLayer
45 auto flatBufferInputBindableBaseLayer = serializer::CreateBindableLayerBase(m_flatBufferBuilder,
46 flatBufferInputBaseLayer,
47 id);
48
49 // Push layer Guid to outputIds.
50 m_inputIds.push_back(layer->GetGuid());
51
52 // Create the FlatBuffer InputLayer
53 auto flatBufferInputLayer = serializer::CreateInputLayer(m_flatBufferBuilder, flatBufferInputBindableBaseLayer);
54
55 // Add the AnyLayer to the FlatBufferLayers
56 CreateAnyLayer(flatBufferInputLayer.o, serializer::Layer::Layer_InputLayer);
57}
58
59// Build FlatBuffer for Output Layer
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +000060void SerializerVisitor::VisitOutputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name)
Mike Kelly8c1701a2019-02-11 17:01:27 +000061{
62 // Create FlatBuffer BaseLayer
63 auto flatBufferOutputBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Output);
64
65 // Create FlatBuffer BindableBaseLayer
66 auto flatBufferOutputBindableBaseLayer = serializer::CreateBindableLayerBase(m_flatBufferBuilder,
67 flatBufferOutputBaseLayer,
68 id);
69 // Push layer Guid to outputIds.
70 m_outputIds.push_back(layer->GetGuid());
71
72 // Create the FlatBuffer OutputLayer
73 auto flatBufferOutputLayer = serializer::CreateOutputLayer(m_flatBufferBuilder, flatBufferOutputBindableBaseLayer);
74 // Add the AnyLayer to the FlatBufferLayers
75 CreateAnyLayer(flatBufferOutputLayer.o, serializer::Layer::Layer_OutputLayer);
76}
77
78// Build FlatBuffer for Addition Layer
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +000079void SerializerVisitor::VisitAdditionLayer(const IConnectableLayer* layer, const char* name)
Mike Kelly8c1701a2019-02-11 17:01:27 +000080{
81 // Create FlatBuffer BaseLayer
82 auto flatBufferAdditionBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Addition);
83
84 // Create the FlatBuffer AdditionLayer
85 auto flatBufferAdditionLayer = serializer::CreateAdditionLayer(m_flatBufferBuilder, flatBufferAdditionBaseLayer);
86
87 // Add the AnyLayer to the FlatBufferLayers
88 CreateAnyLayer(flatBufferAdditionLayer.o, serializer::Layer::Layer_AdditionLayer);
89}
90
Sadik Armagan5f450272019-02-12 14:31:45 +000091// Build FlatBuffer for Multiplication Layer
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +000092void SerializerVisitor::VisitMultiplicationLayer(const IConnectableLayer* layer, const char* name)
Sadik Armagan5f450272019-02-12 14:31:45 +000093{
94 // Create FlatBuffer BaseLayer
95 auto flatBufferMultiplicationBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Multiplication);
96
97 // Create the FlatBuffer MultiplicationLayer
98 auto flatBufferMultiplicationLayer =
99 serializer::CreateMultiplicationLayer(m_flatBufferBuilder, flatBufferMultiplicationBaseLayer);
100
101 // Add the AnyLayer to the FlatBufferLayers
102 CreateAnyLayer(flatBufferMultiplicationLayer.o, serializer::Layer::Layer_MultiplicationLayer);
103}
104
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +0000105fb::Offset<serializer::LayerBase> SerializerVisitor::CreateLayerBase(const IConnectableLayer* layer,
106 const serializer::LayerType layerType)
Mike Kelly8c1701a2019-02-11 17:01:27 +0000107{
108 std::vector<fb::Offset<serializer::InputSlot>> inputSlots = CreateInputSlots(layer);
109 std::vector<fb::Offset<serializer::OutputSlot>> outputSlots = CreateOutputSlots(layer);
110
111 return serializer::CreateLayerBase(m_flatBufferBuilder,
112 layer->GetGuid(),
113 m_flatBufferBuilder.CreateString(layer->GetName()),
114 layerType,
115 m_flatBufferBuilder.CreateVector(inputSlots),
116 m_flatBufferBuilder.CreateVector(outputSlots));
117}
118
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +0000119void SerializerVisitor::CreateAnyLayer(const flatbuffers::Offset<void>& layer, const serializer::Layer serializerLayer)
Mike Kelly8c1701a2019-02-11 17:01:27 +0000120{
121 auto anyLayer = armnn::armnnSerializer::CreateAnyLayer(m_flatBufferBuilder,
122 serializerLayer,
123 layer);
124 m_serializedLayers.push_back(anyLayer);
125}
126
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +0000127std::vector<fb::Offset<serializer::InputSlot>> SerializerVisitor::CreateInputSlots(const IConnectableLayer* layer)
Mike Kelly8c1701a2019-02-11 17:01:27 +0000128{
129 std::vector<fb::Offset <serializer::InputSlot>> inputSlots;
130
131 // Get the InputSlots
132 for (unsigned int slotIndex = 0; slotIndex<layer->GetNumInputSlots(); ++slotIndex)
133 {
134 const IInputSlot& inputSlot = layer->GetInputSlot(slotIndex);
135
136 // Get the Connection for the InputSlot
137 const IOutputSlot* connection = inputSlot.GetConnection();
138
139 // Create FlatBuffer Connection
140 serializer::Connection conn(connection->GetOwningLayerGuid(), connection->CalculateIndexOnOwner());
141 // Create FlatBuffer InputSlot
142 inputSlots.push_back(serializer::CreateInputSlot(m_flatBufferBuilder, slotIndex, &conn));
143 }
144 return inputSlots;
145}
146
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +0000147std::vector<fb::Offset<serializer::OutputSlot>> SerializerVisitor::CreateOutputSlots(const IConnectableLayer* layer)
Mike Kelly8c1701a2019-02-11 17:01:27 +0000148{
149 std::vector<fb::Offset<serializer::OutputSlot>> outputSlots;
150
151 // Get the OutputSlots
152 for (unsigned int slotIndex = 0; slotIndex < layer->GetNumOutputSlots(); ++slotIndex)
153 {
154 const IOutputSlot& outputSlot = layer->GetOutputSlot(slotIndex);
155 const TensorInfo& tensorInfo = outputSlot.GetTensorInfo();
156
157 // Get the dimensions
158 std::vector<unsigned int> shape;
159 for(unsigned int dim = 0; dim < tensorInfo.GetShape().GetNumDimensions(); ++dim)
160 {
161 shape.push_back(tensorInfo.GetShape()[dim]);
162 }
163
164 // Create FlatBuffer TensorInfo
165 auto flatBufferTensorInfo = serializer::CreateTensorInfo(m_flatBufferBuilder,
166 m_flatBufferBuilder.CreateVector(shape),
167 GetFlatBufferDataType(tensorInfo.GetDataType()),
168 tensorInfo.GetQuantizationScale(),
169 tensorInfo.GetQuantizationOffset());
170
171 // Create FlatBuffer Outputslot
172 outputSlots.push_back(serializer::CreateOutputSlot(m_flatBufferBuilder,
173 slotIndex,
174 flatBufferTensorInfo));
175 }
176 return outputSlots;
177}
178
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +0000179
180ISerializer* ISerializer::CreateRaw()
181{
182 return new Serializer();
183}
184
185ISerializerPtr ISerializer::Create()
186{
187 return ISerializerPtr(CreateRaw(), &ISerializer::Destroy);
188}
189
190void ISerializer::Destroy(ISerializer* serializer)
191{
192 delete serializer;
193}
194
195void Serializer::Serialize(const INetwork& inNetwork)
196{
197 // Iterate through to network
198 inNetwork.Accept(m_SerializerVisitor);
199 flatbuffers::FlatBufferBuilder& fbBuilder = m_SerializerVisitor.GetFlatBufferBuilder();
200
201 // Create FlatBuffer SerializedGraph
202 auto serializedGraph = serializer::CreateSerializedGraph(
203 fbBuilder,
204 fbBuilder.CreateVector(m_SerializerVisitor.GetSerializedLayers()),
205 fbBuilder.CreateVector(m_SerializerVisitor.GetInputIds()),
206 fbBuilder.CreateVector(m_SerializerVisitor.GetOutputIds()));
207
208 // Serialize the graph
209 fbBuilder.Finish(serializedGraph);
210}
211
212bool Serializer::SaveSerializedToStream(std::ostream& stream)
213{
214 flatbuffers::FlatBufferBuilder& fbBuilder = m_SerializerVisitor.GetFlatBufferBuilder();
215
216 stream.write(reinterpret_cast<const char*>(fbBuilder.GetBufferPointer()), fbBuilder.GetSize());
217 return !stream.bad();
218}
219
220} //nameespace armnnSerializer