blob: acb672ad1fd887b44bee9552ff06ce221626deb1 [file] [log] [blame]
Mike Kelly8c1701a2019-02-11 17:01:27 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "Serializer.hpp"
7#include <armnn/ArmNN.hpp>
8#include <iostream>
9#include <Schema_generated.h>
10#include <flatbuffers/util.h>
11
12using namespace armnn;
13namespace fb = flatbuffers;
14namespace serializer = armnn::armnnSerializer;
15
16namespace armnnSerializer
17{
18
19serializer::DataType GetFlatBufferDataType(DataType dataType)
20{
21 switch (dataType)
22 {
23 case DataType::Float32:
24 return serializer::DataType::DataType_Float32;
25 case DataType::Float16:
26 return serializer::DataType::DataType_Float16;
27 case DataType::Signed32:
28 return serializer::DataType::DataType_Signed32;
29 case DataType::QuantisedAsymm8:
30 return serializer::DataType::DataType_QuantisedAsymm8;
31 case DataType::Boolean:
32 return serializer::DataType::DataType_Boolean;
33 default:
34 return serializer::DataType::DataType_Float16;
35 }
36}
37
38// Build FlatBuffer for Input Layer
39void Serializer::VisitInputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name)
40{
41 // Create FlatBuffer BaseLayer
42 auto flatBufferInputBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Input);
43
44 // Create FlatBuffer BindableBaseLayer
45 auto flatBufferInputBindableBaseLayer = serializer::CreateBindableLayerBase(m_flatBufferBuilder,
46 flatBufferInputBaseLayer,
47 id);
48
49 // Push layer Guid to outputIds.
50 m_inputIds.push_back(layer->GetGuid());
51
52 // Create the FlatBuffer InputLayer
53 auto flatBufferInputLayer = serializer::CreateInputLayer(m_flatBufferBuilder, flatBufferInputBindableBaseLayer);
54
55 // Add the AnyLayer to the FlatBufferLayers
56 CreateAnyLayer(flatBufferInputLayer.o, serializer::Layer::Layer_InputLayer);
57}
58
59// Build FlatBuffer for Output Layer
60void Serializer::VisitOutputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name)
61{
62 // Create FlatBuffer BaseLayer
63 auto flatBufferOutputBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Output);
64
65 // Create FlatBuffer BindableBaseLayer
66 auto flatBufferOutputBindableBaseLayer = serializer::CreateBindableLayerBase(m_flatBufferBuilder,
67 flatBufferOutputBaseLayer,
68 id);
69 // Push layer Guid to outputIds.
70 m_outputIds.push_back(layer->GetGuid());
71
72 // Create the FlatBuffer OutputLayer
73 auto flatBufferOutputLayer = serializer::CreateOutputLayer(m_flatBufferBuilder, flatBufferOutputBindableBaseLayer);
74 // Add the AnyLayer to the FlatBufferLayers
75 CreateAnyLayer(flatBufferOutputLayer.o, serializer::Layer::Layer_OutputLayer);
76}
77
78// Build FlatBuffer for Addition Layer
79void Serializer::VisitAdditionLayer(const IConnectableLayer* layer, const char* name)
80{
81 // Create FlatBuffer BaseLayer
82 auto flatBufferAdditionBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Addition);
83
84 // Create the FlatBuffer AdditionLayer
85 auto flatBufferAdditionLayer = serializer::CreateAdditionLayer(m_flatBufferBuilder, flatBufferAdditionBaseLayer);
86
87 // Add the AnyLayer to the FlatBufferLayers
88 CreateAnyLayer(flatBufferAdditionLayer.o, serializer::Layer::Layer_AdditionLayer);
89}
90
Sadik Armagan5f450272019-02-12 14:31:45 +000091// Build FlatBuffer for Multiplication Layer
92void Serializer::VisitMultiplicationLayer(const IConnectableLayer* layer, const char* name)
93{
94 // Create FlatBuffer BaseLayer
95 auto flatBufferMultiplicationBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Multiplication);
96
97 // Create the FlatBuffer MultiplicationLayer
98 auto flatBufferMultiplicationLayer =
99 serializer::CreateMultiplicationLayer(m_flatBufferBuilder, flatBufferMultiplicationBaseLayer);
100
101 // Add the AnyLayer to the FlatBufferLayers
102 CreateAnyLayer(flatBufferMultiplicationLayer.o, serializer::Layer::Layer_MultiplicationLayer);
103}
104
Mike Kelly8c1701a2019-02-11 17:01:27 +0000105void Serializer::Serialize(const INetwork& inNetwork)
106{
107 // Iterate through to network
108 inNetwork.Accept(*this);
109
110 // Create FlatBuffer SerializedGraph
111 auto serializedGraph = serializer::CreateSerializedGraph(m_flatBufferBuilder,
112 m_flatBufferBuilder.CreateVector(m_serializedLayers),
113 m_flatBufferBuilder.CreateVector(m_inputIds),
114 m_flatBufferBuilder.CreateVector(m_outputIds));
115
116 // Serialize the graph
117 m_flatBufferBuilder.Finish(serializedGraph);
118}
119
120bool Serializer::SaveSerializedToStream(std::ostream& stream)
121{
122 stream.write(reinterpret_cast<const char*>(m_flatBufferBuilder.GetBufferPointer()), m_flatBufferBuilder.GetSize());
123 return !stream.bad();
124}
125
126fb::Offset<serializer::LayerBase> Serializer::CreateLayerBase(const IConnectableLayer* layer,
127 const serializer::LayerType layerType)
128{
129 std::vector<fb::Offset<serializer::InputSlot>> inputSlots = CreateInputSlots(layer);
130 std::vector<fb::Offset<serializer::OutputSlot>> outputSlots = CreateOutputSlots(layer);
131
132 return serializer::CreateLayerBase(m_flatBufferBuilder,
133 layer->GetGuid(),
134 m_flatBufferBuilder.CreateString(layer->GetName()),
135 layerType,
136 m_flatBufferBuilder.CreateVector(inputSlots),
137 m_flatBufferBuilder.CreateVector(outputSlots));
138}
139
140void Serializer::CreateAnyLayer(const flatbuffers::Offset<void>& layer, const serializer::Layer serializerLayer)
141{
142 auto anyLayer = armnn::armnnSerializer::CreateAnyLayer(m_flatBufferBuilder,
143 serializerLayer,
144 layer);
145 m_serializedLayers.push_back(anyLayer);
146}
147
148std::vector<fb::Offset<serializer::InputSlot>> Serializer::CreateInputSlots(const IConnectableLayer* layer)
149{
150 std::vector<fb::Offset <serializer::InputSlot>> inputSlots;
151
152 // Get the InputSlots
153 for (unsigned int slotIndex = 0; slotIndex<layer->GetNumInputSlots(); ++slotIndex)
154 {
155 const IInputSlot& inputSlot = layer->GetInputSlot(slotIndex);
156
157 // Get the Connection for the InputSlot
158 const IOutputSlot* connection = inputSlot.GetConnection();
159
160 // Create FlatBuffer Connection
161 serializer::Connection conn(connection->GetOwningLayerGuid(), connection->CalculateIndexOnOwner());
162 // Create FlatBuffer InputSlot
163 inputSlots.push_back(serializer::CreateInputSlot(m_flatBufferBuilder, slotIndex, &conn));
164 }
165 return inputSlots;
166}
167
168std::vector<fb::Offset<serializer::OutputSlot>> Serializer::CreateOutputSlots(const IConnectableLayer* layer)
169{
170 std::vector<fb::Offset<serializer::OutputSlot>> outputSlots;
171
172 // Get the OutputSlots
173 for (unsigned int slotIndex = 0; slotIndex < layer->GetNumOutputSlots(); ++slotIndex)
174 {
175 const IOutputSlot& outputSlot = layer->GetOutputSlot(slotIndex);
176 const TensorInfo& tensorInfo = outputSlot.GetTensorInfo();
177
178 // Get the dimensions
179 std::vector<unsigned int> shape;
180 for(unsigned int dim = 0; dim < tensorInfo.GetShape().GetNumDimensions(); ++dim)
181 {
182 shape.push_back(tensorInfo.GetShape()[dim]);
183 }
184
185 // Create FlatBuffer TensorInfo
186 auto flatBufferTensorInfo = serializer::CreateTensorInfo(m_flatBufferBuilder,
187 m_flatBufferBuilder.CreateVector(shape),
188 GetFlatBufferDataType(tensorInfo.GetDataType()),
189 tensorInfo.GetQuantizationScale(),
190 tensorInfo.GetQuantizationOffset());
191
192 // Create FlatBuffer Outputslot
193 outputSlots.push_back(serializer::CreateOutputSlot(m_flatBufferBuilder,
194 slotIndex,
195 flatBufferTensorInfo));
196 }
197 return outputSlots;
198}
199
200} //namespace armnnSerializer