blob: ba4b36934c03e4842c66abaf8442b314ea1b6922 [file] [log] [blame]
Mike Kelly8c1701a2019-02-11 17:01:27 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "Serializer.hpp"
7#include <armnn/ArmNN.hpp>
8#include <iostream>
9#include <Schema_generated.h>
10#include <flatbuffers/util.h>
11
12using namespace armnn;
13namespace fb = flatbuffers;
14namespace serializer = armnn::armnnSerializer;
15
16namespace armnnSerializer
17{
18
19serializer::DataType GetFlatBufferDataType(DataType dataType)
20{
21 switch (dataType)
22 {
23 case DataType::Float32:
24 return serializer::DataType::DataType_Float32;
25 case DataType::Float16:
26 return serializer::DataType::DataType_Float16;
27 case DataType::Signed32:
28 return serializer::DataType::DataType_Signed32;
29 case DataType::QuantisedAsymm8:
30 return serializer::DataType::DataType_QuantisedAsymm8;
31 case DataType::Boolean:
32 return serializer::DataType::DataType_Boolean;
33 default:
34 return serializer::DataType::DataType_Float16;
35 }
36}
37
Saoirse Stewartcb8a3212019-02-14 15:46:10 +000038uint32_t SerializerVisitor::GetSerializedId(unsigned int guid)
39{
40 std::pair<unsigned int, uint32_t> guidPair(guid, m_layerId);
41
42 if (m_guidMap.empty())
43 {
44 m_guidMap.insert(guidPair);
45 }
46 else if (m_guidMap.find(guid) == m_guidMap.end())
47 {
48 guidPair.second = ++m_layerId;
49 m_guidMap.insert(guidPair);
50 return m_layerId;
51 }
52 return m_layerId;
53}
54
Mike Kelly8c1701a2019-02-11 17:01:27 +000055// Build FlatBuffer for Input Layer
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +000056void SerializerVisitor::VisitInputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name)
Mike Kelly8c1701a2019-02-11 17:01:27 +000057{
58 // Create FlatBuffer BaseLayer
59 auto flatBufferInputBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Input);
60
61 // Create FlatBuffer BindableBaseLayer
62 auto flatBufferInputBindableBaseLayer = serializer::CreateBindableLayerBase(m_flatBufferBuilder,
63 flatBufferInputBaseLayer,
64 id);
65
66 // Push layer Guid to outputIds.
Saoirse Stewartcb8a3212019-02-14 15:46:10 +000067 m_inputIds.push_back(GetSerializedId(layer->GetGuid()));
Mike Kelly8c1701a2019-02-11 17:01:27 +000068
69 // Create the FlatBuffer InputLayer
70 auto flatBufferInputLayer = serializer::CreateInputLayer(m_flatBufferBuilder, flatBufferInputBindableBaseLayer);
71
72 // Add the AnyLayer to the FlatBufferLayers
73 CreateAnyLayer(flatBufferInputLayer.o, serializer::Layer::Layer_InputLayer);
74}
75
76// Build FlatBuffer for Output Layer
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +000077void SerializerVisitor::VisitOutputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name)
Mike Kelly8c1701a2019-02-11 17:01:27 +000078{
79 // Create FlatBuffer BaseLayer
80 auto flatBufferOutputBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Output);
81
82 // Create FlatBuffer BindableBaseLayer
83 auto flatBufferOutputBindableBaseLayer = serializer::CreateBindableLayerBase(m_flatBufferBuilder,
84 flatBufferOutputBaseLayer,
85 id);
86 // Push layer Guid to outputIds.
Saoirse Stewartcb8a3212019-02-14 15:46:10 +000087 m_outputIds.push_back(GetSerializedId(layer->GetGuid()));
Mike Kelly8c1701a2019-02-11 17:01:27 +000088
89 // Create the FlatBuffer OutputLayer
90 auto flatBufferOutputLayer = serializer::CreateOutputLayer(m_flatBufferBuilder, flatBufferOutputBindableBaseLayer);
91 // Add the AnyLayer to the FlatBufferLayers
92 CreateAnyLayer(flatBufferOutputLayer.o, serializer::Layer::Layer_OutputLayer);
93}
94
95// Build FlatBuffer for Addition Layer
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +000096void SerializerVisitor::VisitAdditionLayer(const IConnectableLayer* layer, const char* name)
Mike Kelly8c1701a2019-02-11 17:01:27 +000097{
98 // Create FlatBuffer BaseLayer
99 auto flatBufferAdditionBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Addition);
100
101 // Create the FlatBuffer AdditionLayer
102 auto flatBufferAdditionLayer = serializer::CreateAdditionLayer(m_flatBufferBuilder, flatBufferAdditionBaseLayer);
103
104 // Add the AnyLayer to the FlatBufferLayers
105 CreateAnyLayer(flatBufferAdditionLayer.o, serializer::Layer::Layer_AdditionLayer);
106}
107
Sadik Armagan5f450272019-02-12 14:31:45 +0000108// Build FlatBuffer for Multiplication Layer
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +0000109void SerializerVisitor::VisitMultiplicationLayer(const IConnectableLayer* layer, const char* name)
Sadik Armagan5f450272019-02-12 14:31:45 +0000110{
111 // Create FlatBuffer BaseLayer
112 auto flatBufferMultiplicationBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Multiplication);
113
114 // Create the FlatBuffer MultiplicationLayer
115 auto flatBufferMultiplicationLayer =
116 serializer::CreateMultiplicationLayer(m_flatBufferBuilder, flatBufferMultiplicationBaseLayer);
117
118 // Add the AnyLayer to the FlatBufferLayers
119 CreateAnyLayer(flatBufferMultiplicationLayer.o, serializer::Layer::Layer_MultiplicationLayer);
120}
121
Aron Virginas-Tarfc413c02019-02-13 15:41:52 +0000122// Build FlatBuffer for Softmax Layer
123void SerializerVisitor::VisitSoftmaxLayer(const IConnectableLayer* layer,
124 const SoftmaxDescriptor& softmaxDescriptor,
125 const char* name)
126{
127 // Create FlatBuffer BaseLayer
128 auto flatBufferSoftmaxBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Softmax);
129
130 // Create the FlatBuffer SoftmaxDescriptor
131 auto flatBufferSoftmaxDesc =
132 serializer::CreateSoftmaxDescriptor(m_flatBufferBuilder, softmaxDescriptor.m_Beta);
133
134 // Create the FlatBuffer SoftmaxLayer
135 auto flatBufferSoftmaxLayer =
136 serializer::CreateSoftmaxLayer(m_flatBufferBuilder,
137 flatBufferSoftmaxBaseLayer,
138 flatBufferSoftmaxDesc);
139
140 CreateAnyLayer(flatBufferSoftmaxLayer.o, serializer::Layer::Layer_SoftmaxLayer);
141}
142
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +0000143fb::Offset<serializer::LayerBase> SerializerVisitor::CreateLayerBase(const IConnectableLayer* layer,
144 const serializer::LayerType layerType)
Mike Kelly8c1701a2019-02-11 17:01:27 +0000145{
146 std::vector<fb::Offset<serializer::InputSlot>> inputSlots = CreateInputSlots(layer);
147 std::vector<fb::Offset<serializer::OutputSlot>> outputSlots = CreateOutputSlots(layer);
148
149 return serializer::CreateLayerBase(m_flatBufferBuilder,
Saoirse Stewartcb8a3212019-02-14 15:46:10 +0000150 GetSerializedId(layer->GetGuid()),
Mike Kelly8c1701a2019-02-11 17:01:27 +0000151 m_flatBufferBuilder.CreateString(layer->GetName()),
152 layerType,
153 m_flatBufferBuilder.CreateVector(inputSlots),
154 m_flatBufferBuilder.CreateVector(outputSlots));
155}
156
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +0000157void SerializerVisitor::CreateAnyLayer(const flatbuffers::Offset<void>& layer, const serializer::Layer serializerLayer)
Mike Kelly8c1701a2019-02-11 17:01:27 +0000158{
159 auto anyLayer = armnn::armnnSerializer::CreateAnyLayer(m_flatBufferBuilder,
160 serializerLayer,
161 layer);
162 m_serializedLayers.push_back(anyLayer);
163}
164
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +0000165std::vector<fb::Offset<serializer::InputSlot>> SerializerVisitor::CreateInputSlots(const IConnectableLayer* layer)
Mike Kelly8c1701a2019-02-11 17:01:27 +0000166{
167 std::vector<fb::Offset <serializer::InputSlot>> inputSlots;
168
169 // Get the InputSlots
170 for (unsigned int slotIndex = 0; slotIndex<layer->GetNumInputSlots(); ++slotIndex)
171 {
172 const IInputSlot& inputSlot = layer->GetInputSlot(slotIndex);
173
174 // Get the Connection for the InputSlot
175 const IOutputSlot* connection = inputSlot.GetConnection();
176
177 // Create FlatBuffer Connection
Saoirse Stewartcb8a3212019-02-14 15:46:10 +0000178 serializer::Connection conn(GetSerializedId(inputSlot.GetConnection()->GetOwningLayerGuid()),
179 connection->CalculateIndexOnOwner());
Mike Kelly8c1701a2019-02-11 17:01:27 +0000180 // Create FlatBuffer InputSlot
181 inputSlots.push_back(serializer::CreateInputSlot(m_flatBufferBuilder, slotIndex, &conn));
182 }
183 return inputSlots;
184}
185
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +0000186std::vector<fb::Offset<serializer::OutputSlot>> SerializerVisitor::CreateOutputSlots(const IConnectableLayer* layer)
Mike Kelly8c1701a2019-02-11 17:01:27 +0000187{
188 std::vector<fb::Offset<serializer::OutputSlot>> outputSlots;
189
190 // Get the OutputSlots
191 for (unsigned int slotIndex = 0; slotIndex < layer->GetNumOutputSlots(); ++slotIndex)
192 {
193 const IOutputSlot& outputSlot = layer->GetOutputSlot(slotIndex);
194 const TensorInfo& tensorInfo = outputSlot.GetTensorInfo();
195
196 // Get the dimensions
197 std::vector<unsigned int> shape;
198 for(unsigned int dim = 0; dim < tensorInfo.GetShape().GetNumDimensions(); ++dim)
199 {
200 shape.push_back(tensorInfo.GetShape()[dim]);
201 }
202
203 // Create FlatBuffer TensorInfo
204 auto flatBufferTensorInfo = serializer::CreateTensorInfo(m_flatBufferBuilder,
205 m_flatBufferBuilder.CreateVector(shape),
206 GetFlatBufferDataType(tensorInfo.GetDataType()),
207 tensorInfo.GetQuantizationScale(),
208 tensorInfo.GetQuantizationOffset());
209
210 // Create FlatBuffer Outputslot
211 outputSlots.push_back(serializer::CreateOutputSlot(m_flatBufferBuilder,
212 slotIndex,
213 flatBufferTensorInfo));
214 }
215 return outputSlots;
216}
217
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +0000218
219ISerializer* ISerializer::CreateRaw()
220{
221 return new Serializer();
222}
223
224ISerializerPtr ISerializer::Create()
225{
226 return ISerializerPtr(CreateRaw(), &ISerializer::Destroy);
227}
228
229void ISerializer::Destroy(ISerializer* serializer)
230{
231 delete serializer;
232}
233
234void Serializer::Serialize(const INetwork& inNetwork)
235{
236 // Iterate through to network
237 inNetwork.Accept(m_SerializerVisitor);
238 flatbuffers::FlatBufferBuilder& fbBuilder = m_SerializerVisitor.GetFlatBufferBuilder();
239
240 // Create FlatBuffer SerializedGraph
241 auto serializedGraph = serializer::CreateSerializedGraph(
242 fbBuilder,
243 fbBuilder.CreateVector(m_SerializerVisitor.GetSerializedLayers()),
244 fbBuilder.CreateVector(m_SerializerVisitor.GetInputIds()),
245 fbBuilder.CreateVector(m_SerializerVisitor.GetOutputIds()));
246
247 // Serialize the graph
248 fbBuilder.Finish(serializedGraph);
249}
250
251bool Serializer::SaveSerializedToStream(std::ostream& stream)
252{
253 flatbuffers::FlatBufferBuilder& fbBuilder = m_SerializerVisitor.GetFlatBufferBuilder();
254
Nattapat Chaimanowong7b53b692019-02-12 14:38:31 +0000255 auto bytesToWrite = boost::numeric_cast<std::streamsize>(fbBuilder.GetSize());
256 stream.write(reinterpret_cast<const char*>(fbBuilder.GetBufferPointer()), bytesToWrite);
Nattapat Chaimanowongac9cadc2019-02-13 15:52:41 +0000257 return !stream.bad();
258}
259
Matteo Martincighec333912019-02-13 15:12:39 +0000260} // namespace armnnSerializer