blob: b229ae7e3fedba7161def31049c892bc38ca0e09 [file] [log] [blame]
Mike Kelly8c1701a2019-02-11 17:01:27 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "Serializer.hpp"
Saoirse Stewart3166c3e2019-02-18 15:24:53 +00007
8#include "SerializerUtils.hpp"
9
Mike Kelly8c1701a2019-02-11 17:01:27 +000010#include <armnn/ArmNN.hpp>
Saoirse Stewart3166c3e2019-02-18 15:24:53 +000011
Mike Kelly8c1701a2019-02-11 17:01:27 +000012#include <iostream>
Saoirse Stewart3166c3e2019-02-18 15:24:53 +000013
Mike Kelly8c1701a2019-02-11 17:01:27 +000014#include <Schema_generated.h>
Saoirse Stewart3166c3e2019-02-18 15:24:53 +000015
Mike Kelly8c1701a2019-02-11 17:01:27 +000016#include <flatbuffers/util.h>
17
18using namespace armnn;
19namespace fb = flatbuffers;
20namespace serializer = armnn::armnnSerializer;
21
22namespace armnnSerializer
23{
24
Saoirse Stewartcb8a3212019-02-14 15:46:10 +000025uint32_t SerializerVisitor::GetSerializedId(unsigned int guid)
26{
27 std::pair<unsigned int, uint32_t> guidPair(guid, m_layerId);
28
29 if (m_guidMap.empty())
30 {
31 m_guidMap.insert(guidPair);
32 }
33 else if (m_guidMap.find(guid) == m_guidMap.end())
34 {
35 guidPair.second = ++m_layerId;
36 m_guidMap.insert(guidPair);
37 return m_layerId;
38 }
39 return m_layerId;
40}
41
Mike Kelly8c1701a2019-02-11 17:01:27 +000042// Build FlatBuffer for Input Layer
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +000043void SerializerVisitor::VisitInputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name)
Mike Kelly8c1701a2019-02-11 17:01:27 +000044{
45 // Create FlatBuffer BaseLayer
46 auto flatBufferInputBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Input);
47
48 // Create FlatBuffer BindableBaseLayer
49 auto flatBufferInputBindableBaseLayer = serializer::CreateBindableLayerBase(m_flatBufferBuilder,
50 flatBufferInputBaseLayer,
51 id);
Mike Kelly8c1701a2019-02-11 17:01:27 +000052 // Push layer Guid to outputIds.
Saoirse Stewartcb8a3212019-02-14 15:46:10 +000053 m_inputIds.push_back(GetSerializedId(layer->GetGuid()));
Mike Kelly8c1701a2019-02-11 17:01:27 +000054
55 // Create the FlatBuffer InputLayer
56 auto flatBufferInputLayer = serializer::CreateInputLayer(m_flatBufferBuilder, flatBufferInputBindableBaseLayer);
57
58 // Add the AnyLayer to the FlatBufferLayers
59 CreateAnyLayer(flatBufferInputLayer.o, serializer::Layer::Layer_InputLayer);
60}
61
62// Build FlatBuffer for Output Layer
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +000063void SerializerVisitor::VisitOutputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name)
Mike Kelly8c1701a2019-02-11 17:01:27 +000064{
65 // Create FlatBuffer BaseLayer
66 auto flatBufferOutputBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Output);
67
68 // Create FlatBuffer BindableBaseLayer
69 auto flatBufferOutputBindableBaseLayer = serializer::CreateBindableLayerBase(m_flatBufferBuilder,
70 flatBufferOutputBaseLayer,
71 id);
72 // Push layer Guid to outputIds.
Saoirse Stewartcb8a3212019-02-14 15:46:10 +000073 m_outputIds.push_back(GetSerializedId(layer->GetGuid()));
Mike Kelly8c1701a2019-02-11 17:01:27 +000074
75 // Create the FlatBuffer OutputLayer
76 auto flatBufferOutputLayer = serializer::CreateOutputLayer(m_flatBufferBuilder, flatBufferOutputBindableBaseLayer);
77 // Add the AnyLayer to the FlatBufferLayers
78 CreateAnyLayer(flatBufferOutputLayer.o, serializer::Layer::Layer_OutputLayer);
79}
80
81// Build FlatBuffer for Addition Layer
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +000082void SerializerVisitor::VisitAdditionLayer(const IConnectableLayer* layer, const char* name)
Mike Kelly8c1701a2019-02-11 17:01:27 +000083{
84 // Create FlatBuffer BaseLayer
85 auto flatBufferAdditionBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Addition);
86
87 // Create the FlatBuffer AdditionLayer
88 auto flatBufferAdditionLayer = serializer::CreateAdditionLayer(m_flatBufferBuilder, flatBufferAdditionBaseLayer);
89
90 // Add the AnyLayer to the FlatBufferLayers
91 CreateAnyLayer(flatBufferAdditionLayer.o, serializer::Layer::Layer_AdditionLayer);
92}
93
Sadik Armagan5f450272019-02-12 14:31:45 +000094// Build FlatBuffer for Multiplication Layer
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +000095void SerializerVisitor::VisitMultiplicationLayer(const IConnectableLayer* layer, const char* name)
Sadik Armagan5f450272019-02-12 14:31:45 +000096{
97 // Create FlatBuffer BaseLayer
98 auto flatBufferMultiplicationBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Multiplication);
99
100 // Create the FlatBuffer MultiplicationLayer
101 auto flatBufferMultiplicationLayer =
102 serializer::CreateMultiplicationLayer(m_flatBufferBuilder, flatBufferMultiplicationBaseLayer);
103
104 // Add the AnyLayer to the FlatBufferLayers
105 CreateAnyLayer(flatBufferMultiplicationLayer.o, serializer::Layer::Layer_MultiplicationLayer);
106}
107
Saoirse Stewart263829c2019-02-19 15:54:14 +0000108// Build FlatBuffer for Reshape Layer
109void SerializerVisitor::VisitReshapeLayer(const IConnectableLayer* layer,
110 const armnn::ReshapeDescriptor& reshapeDescriptor,
111 const char* name)
112{
113 // Create FlatBuffer BaseLayer
114 auto flatBufferReshapeBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Reshape);
115
116 std::vector<unsigned int> targetShape;
117 for (unsigned int i =0; i < reshapeDescriptor.m_TargetShape.GetNumDimensions(); i++)
118 {
119 targetShape.push_back(reshapeDescriptor.m_TargetShape[i]);
120 }
121
122 auto flatBufferReshapeDesc = serializer::CreateReshapeDescriptor(m_flatBufferBuilder,
123 m_flatBufferBuilder.CreateVector(targetShape));
124
125 // Create the FlatBuffer ReshapeLayer
126 auto flatBufferReshapeLayer = serializer::CreateReshapeLayer(m_flatBufferBuilder, flatBufferReshapeBaseLayer,
127 flatBufferReshapeDesc);
128
129 // Add the AnyLayer to the FlatBufferLayers
130 CreateAnyLayer(flatBufferReshapeLayer.o, serializer::Layer::Layer_ReshapeLayer);
131}
132
Aron Virginas-Tarfc413c02019-02-13 15:41:52 +0000133// Build FlatBuffer for Softmax Layer
134void SerializerVisitor::VisitSoftmaxLayer(const IConnectableLayer* layer,
135 const SoftmaxDescriptor& softmaxDescriptor,
136 const char* name)
137{
138 // Create FlatBuffer BaseLayer
139 auto flatBufferSoftmaxBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Softmax);
140
141 // Create the FlatBuffer SoftmaxDescriptor
142 auto flatBufferSoftmaxDesc =
143 serializer::CreateSoftmaxDescriptor(m_flatBufferBuilder, softmaxDescriptor.m_Beta);
144
145 // Create the FlatBuffer SoftmaxLayer
146 auto flatBufferSoftmaxLayer =
147 serializer::CreateSoftmaxLayer(m_flatBufferBuilder,
148 flatBufferSoftmaxBaseLayer,
149 flatBufferSoftmaxDesc);
150
151 CreateAnyLayer(flatBufferSoftmaxLayer.o, serializer::Layer::Layer_SoftmaxLayer);
152}
153
Saoirse Stewart3166c3e2019-02-18 15:24:53 +0000154void SerializerVisitor::VisitPooling2dLayer(const IConnectableLayer* layer,
155 const Pooling2dDescriptor& pooling2dDescriptor,
156 const char* name)
157{
158 auto fbPooling2dBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Pooling2d);
159 auto fbPooling2dDescriptor = serializer::CreatePooling2dDescriptor(
160 m_flatBufferBuilder,
161 GetFlatBufferPoolingAlgorithm(pooling2dDescriptor.m_PoolType),
162 pooling2dDescriptor.m_PadLeft,
163 pooling2dDescriptor.m_PadRight,
164 pooling2dDescriptor.m_PadTop,
165 pooling2dDescriptor.m_PadBottom,
166 pooling2dDescriptor.m_PoolWidth,
167 pooling2dDescriptor.m_PoolHeight,
168 pooling2dDescriptor.m_StrideX,
169 pooling2dDescriptor.m_StrideY,
170 GetFlatBufferOutputShapeRounding(pooling2dDescriptor.m_OutputShapeRounding),
171 GetFlatBufferPaddingMethod(pooling2dDescriptor.m_PaddingMethod),
172 GetFlatBufferDataLayout(pooling2dDescriptor.m_DataLayout));
173
174 auto fbPooling2dLayer = serializer::CreatePooling2dLayer(m_flatBufferBuilder,
175 fbPooling2dBaseLayer,
176 fbPooling2dDescriptor);
177
178 CreateAnyLayer(fbPooling2dLayer.o, serializer::Layer::Layer_Pooling2dLayer);
179}
180
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +0000181fb::Offset<serializer::LayerBase> SerializerVisitor::CreateLayerBase(const IConnectableLayer* layer,
182 const serializer::LayerType layerType)
Mike Kelly8c1701a2019-02-11 17:01:27 +0000183{
184 std::vector<fb::Offset<serializer::InputSlot>> inputSlots = CreateInputSlots(layer);
185 std::vector<fb::Offset<serializer::OutputSlot>> outputSlots = CreateOutputSlots(layer);
186
187 return serializer::CreateLayerBase(m_flatBufferBuilder,
Saoirse Stewartcb8a3212019-02-14 15:46:10 +0000188 GetSerializedId(layer->GetGuid()),
Mike Kelly8c1701a2019-02-11 17:01:27 +0000189 m_flatBufferBuilder.CreateString(layer->GetName()),
190 layerType,
191 m_flatBufferBuilder.CreateVector(inputSlots),
192 m_flatBufferBuilder.CreateVector(outputSlots));
193}
194
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +0000195void SerializerVisitor::CreateAnyLayer(const flatbuffers::Offset<void>& layer, const serializer::Layer serializerLayer)
Mike Kelly8c1701a2019-02-11 17:01:27 +0000196{
197 auto anyLayer = armnn::armnnSerializer::CreateAnyLayer(m_flatBufferBuilder,
198 serializerLayer,
199 layer);
200 m_serializedLayers.push_back(anyLayer);
201}
202
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +0000203std::vector<fb::Offset<serializer::InputSlot>> SerializerVisitor::CreateInputSlots(const IConnectableLayer* layer)
Mike Kelly8c1701a2019-02-11 17:01:27 +0000204{
205 std::vector<fb::Offset <serializer::InputSlot>> inputSlots;
206
207 // Get the InputSlots
208 for (unsigned int slotIndex = 0; slotIndex<layer->GetNumInputSlots(); ++slotIndex)
209 {
210 const IInputSlot& inputSlot = layer->GetInputSlot(slotIndex);
211
212 // Get the Connection for the InputSlot
213 const IOutputSlot* connection = inputSlot.GetConnection();
214
215 // Create FlatBuffer Connection
Saoirse Stewartcb8a3212019-02-14 15:46:10 +0000216 serializer::Connection conn(GetSerializedId(inputSlot.GetConnection()->GetOwningLayerGuid()),
217 connection->CalculateIndexOnOwner());
Mike Kelly8c1701a2019-02-11 17:01:27 +0000218 // Create FlatBuffer InputSlot
219 inputSlots.push_back(serializer::CreateInputSlot(m_flatBufferBuilder, slotIndex, &conn));
220 }
221 return inputSlots;
222}
223
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +0000224std::vector<fb::Offset<serializer::OutputSlot>> SerializerVisitor::CreateOutputSlots(const IConnectableLayer* layer)
Mike Kelly8c1701a2019-02-11 17:01:27 +0000225{
226 std::vector<fb::Offset<serializer::OutputSlot>> outputSlots;
227
228 // Get the OutputSlots
229 for (unsigned int slotIndex = 0; slotIndex < layer->GetNumOutputSlots(); ++slotIndex)
230 {
231 const IOutputSlot& outputSlot = layer->GetOutputSlot(slotIndex);
232 const TensorInfo& tensorInfo = outputSlot.GetTensorInfo();
233
234 // Get the dimensions
235 std::vector<unsigned int> shape;
236 for(unsigned int dim = 0; dim < tensorInfo.GetShape().GetNumDimensions(); ++dim)
237 {
238 shape.push_back(tensorInfo.GetShape()[dim]);
239 }
240
241 // Create FlatBuffer TensorInfo
242 auto flatBufferTensorInfo = serializer::CreateTensorInfo(m_flatBufferBuilder,
243 m_flatBufferBuilder.CreateVector(shape),
244 GetFlatBufferDataType(tensorInfo.GetDataType()),
245 tensorInfo.GetQuantizationScale(),
246 tensorInfo.GetQuantizationOffset());
247
248 // Create FlatBuffer Outputslot
249 outputSlots.push_back(serializer::CreateOutputSlot(m_flatBufferBuilder,
250 slotIndex,
251 flatBufferTensorInfo));
252 }
253 return outputSlots;
254}
255
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +0000256
257ISerializer* ISerializer::CreateRaw()
258{
259 return new Serializer();
260}
261
262ISerializerPtr ISerializer::Create()
263{
264 return ISerializerPtr(CreateRaw(), &ISerializer::Destroy);
265}
266
267void ISerializer::Destroy(ISerializer* serializer)
268{
269 delete serializer;
270}
271
272void Serializer::Serialize(const INetwork& inNetwork)
273{
274 // Iterate through to network
275 inNetwork.Accept(m_SerializerVisitor);
276 flatbuffers::FlatBufferBuilder& fbBuilder = m_SerializerVisitor.GetFlatBufferBuilder();
277
278 // Create FlatBuffer SerializedGraph
279 auto serializedGraph = serializer::CreateSerializedGraph(
280 fbBuilder,
281 fbBuilder.CreateVector(m_SerializerVisitor.GetSerializedLayers()),
282 fbBuilder.CreateVector(m_SerializerVisitor.GetInputIds()),
283 fbBuilder.CreateVector(m_SerializerVisitor.GetOutputIds()));
284
285 // Serialize the graph
286 fbBuilder.Finish(serializedGraph);
287}
288
289bool Serializer::SaveSerializedToStream(std::ostream& stream)
290{
291 flatbuffers::FlatBufferBuilder& fbBuilder = m_SerializerVisitor.GetFlatBufferBuilder();
292
Nattapat Chaimanowong7b53b692019-02-12 14:38:31 +0000293 auto bytesToWrite = boost::numeric_cast<std::streamsize>(fbBuilder.GetSize());
294 stream.write(reinterpret_cast<const char*>(fbBuilder.GetBufferPointer()), bytesToWrite);
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +0000295 return !stream.bad();
296}
297
Matteo Martincighec333912019-02-13 15:12:39 +0000298} // namespace armnnSerializer