blob: 57baf0e28c7372f2ea6f2c3386372fc3cb2585cc [file] [log] [blame]
Mike Kelly8c1701a2019-02-11 17:01:27 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "Serializer.hpp"
7#include <armnn/ArmNN.hpp>
8#include <iostream>
9#include <Schema_generated.h>
10#include <flatbuffers/util.h>
11
12using namespace armnn;
13namespace fb = flatbuffers;
14namespace serializer = armnn::armnnSerializer;
15
16namespace armnnSerializer
17{
18
19serializer::DataType GetFlatBufferDataType(DataType dataType)
20{
21 switch (dataType)
22 {
23 case DataType::Float32:
24 return serializer::DataType::DataType_Float32;
25 case DataType::Float16:
26 return serializer::DataType::DataType_Float16;
27 case DataType::Signed32:
28 return serializer::DataType::DataType_Signed32;
29 case DataType::QuantisedAsymm8:
30 return serializer::DataType::DataType_QuantisedAsymm8;
31 case DataType::Boolean:
32 return serializer::DataType::DataType_Boolean;
33 default:
34 return serializer::DataType::DataType_Float16;
35 }
36}
37
38// Build FlatBuffer for Input Layer
39void Serializer::VisitInputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name)
40{
41 // Create FlatBuffer BaseLayer
42 auto flatBufferInputBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Input);
43
44 // Create FlatBuffer BindableBaseLayer
45 auto flatBufferInputBindableBaseLayer = serializer::CreateBindableLayerBase(m_flatBufferBuilder,
46 flatBufferInputBaseLayer,
47 id);
48
49 // Push layer Guid to outputIds.
50 m_inputIds.push_back(layer->GetGuid());
51
52 // Create the FlatBuffer InputLayer
53 auto flatBufferInputLayer = serializer::CreateInputLayer(m_flatBufferBuilder, flatBufferInputBindableBaseLayer);
54
55 // Add the AnyLayer to the FlatBufferLayers
56 CreateAnyLayer(flatBufferInputLayer.o, serializer::Layer::Layer_InputLayer);
57}
58
59// Build FlatBuffer for Output Layer
60void Serializer::VisitOutputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name)
61{
62 // Create FlatBuffer BaseLayer
63 auto flatBufferOutputBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Output);
64
65 // Create FlatBuffer BindableBaseLayer
66 auto flatBufferOutputBindableBaseLayer = serializer::CreateBindableLayerBase(m_flatBufferBuilder,
67 flatBufferOutputBaseLayer,
68 id);
69 // Push layer Guid to outputIds.
70 m_outputIds.push_back(layer->GetGuid());
71
72 // Create the FlatBuffer OutputLayer
73 auto flatBufferOutputLayer = serializer::CreateOutputLayer(m_flatBufferBuilder, flatBufferOutputBindableBaseLayer);
74 // Add the AnyLayer to the FlatBufferLayers
75 CreateAnyLayer(flatBufferOutputLayer.o, serializer::Layer::Layer_OutputLayer);
76}
77
78// Build FlatBuffer for Addition Layer
79void Serializer::VisitAdditionLayer(const IConnectableLayer* layer, const char* name)
80{
81 // Create FlatBuffer BaseLayer
82 auto flatBufferAdditionBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Addition);
83
84 // Create the FlatBuffer AdditionLayer
85 auto flatBufferAdditionLayer = serializer::CreateAdditionLayer(m_flatBufferBuilder, flatBufferAdditionBaseLayer);
86
87 // Add the AnyLayer to the FlatBufferLayers
88 CreateAnyLayer(flatBufferAdditionLayer.o, serializer::Layer::Layer_AdditionLayer);
89}
90
91void Serializer::Serialize(const INetwork& inNetwork)
92{
93 // Iterate through to network
94 inNetwork.Accept(*this);
95
96 // Create FlatBuffer SerializedGraph
97 auto serializedGraph = serializer::CreateSerializedGraph(m_flatBufferBuilder,
98 m_flatBufferBuilder.CreateVector(m_serializedLayers),
99 m_flatBufferBuilder.CreateVector(m_inputIds),
100 m_flatBufferBuilder.CreateVector(m_outputIds));
101
102 // Serialize the graph
103 m_flatBufferBuilder.Finish(serializedGraph);
104}
105
106bool Serializer::SaveSerializedToStream(std::ostream& stream)
107{
108 stream.write(reinterpret_cast<const char*>(m_flatBufferBuilder.GetBufferPointer()), m_flatBufferBuilder.GetSize());
109 return !stream.bad();
110}
111
112fb::Offset<serializer::LayerBase> Serializer::CreateLayerBase(const IConnectableLayer* layer,
113 const serializer::LayerType layerType)
114{
115 std::vector<fb::Offset<serializer::InputSlot>> inputSlots = CreateInputSlots(layer);
116 std::vector<fb::Offset<serializer::OutputSlot>> outputSlots = CreateOutputSlots(layer);
117
118 return serializer::CreateLayerBase(m_flatBufferBuilder,
119 layer->GetGuid(),
120 m_flatBufferBuilder.CreateString(layer->GetName()),
121 layerType,
122 m_flatBufferBuilder.CreateVector(inputSlots),
123 m_flatBufferBuilder.CreateVector(outputSlots));
124}
125
126void Serializer::CreateAnyLayer(const flatbuffers::Offset<void>& layer, const serializer::Layer serializerLayer)
127{
128 auto anyLayer = armnn::armnnSerializer::CreateAnyLayer(m_flatBufferBuilder,
129 serializerLayer,
130 layer);
131 m_serializedLayers.push_back(anyLayer);
132}
133
134std::vector<fb::Offset<serializer::InputSlot>> Serializer::CreateInputSlots(const IConnectableLayer* layer)
135{
136 std::vector<fb::Offset <serializer::InputSlot>> inputSlots;
137
138 // Get the InputSlots
139 for (unsigned int slotIndex = 0; slotIndex<layer->GetNumInputSlots(); ++slotIndex)
140 {
141 const IInputSlot& inputSlot = layer->GetInputSlot(slotIndex);
142
143 // Get the Connection for the InputSlot
144 const IOutputSlot* connection = inputSlot.GetConnection();
145
146 // Create FlatBuffer Connection
147 serializer::Connection conn(connection->GetOwningLayerGuid(), connection->CalculateIndexOnOwner());
148 // Create FlatBuffer InputSlot
149 inputSlots.push_back(serializer::CreateInputSlot(m_flatBufferBuilder, slotIndex, &conn));
150 }
151 return inputSlots;
152}
153
154std::vector<fb::Offset<serializer::OutputSlot>> Serializer::CreateOutputSlots(const IConnectableLayer* layer)
155{
156 std::vector<fb::Offset<serializer::OutputSlot>> outputSlots;
157
158 // Get the OutputSlots
159 for (unsigned int slotIndex = 0; slotIndex < layer->GetNumOutputSlots(); ++slotIndex)
160 {
161 const IOutputSlot& outputSlot = layer->GetOutputSlot(slotIndex);
162 const TensorInfo& tensorInfo = outputSlot.GetTensorInfo();
163
164 // Get the dimensions
165 std::vector<unsigned int> shape;
166 for(unsigned int dim = 0; dim < tensorInfo.GetShape().GetNumDimensions(); ++dim)
167 {
168 shape.push_back(tensorInfo.GetShape()[dim]);
169 }
170
171 // Create FlatBuffer TensorInfo
172 auto flatBufferTensorInfo = serializer::CreateTensorInfo(m_flatBufferBuilder,
173 m_flatBufferBuilder.CreateVector(shape),
174 GetFlatBufferDataType(tensorInfo.GetDataType()),
175 tensorInfo.GetQuantizationScale(),
176 tensorInfo.GetQuantizationOffset());
177
178 // Create FlatBuffer Outputslot
179 outputSlots.push_back(serializer::CreateOutputSlot(m_flatBufferBuilder,
180 slotIndex,
181 flatBufferTensorInfo));
182 }
183 return outputSlots;
184}
185
186} //namespace armnnSerializer