blob: eb7bccaa1d34ec9de49917d5c89669913b52a5c1 [file] [log] [blame]
Kevin May43a799c2019-02-08 16:31:42 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "DeserializeParser.hpp"
7
8#include <armnn/ArmNN.hpp>
9#include <armnn/Exceptions.hpp>
10
11#include <ParserHelper.hpp>
12#include <Permute.hpp>
13#include <VerificationHelpers.hpp>
14
15#include <boost/filesystem.hpp>
16#include <boost/format.hpp>
17#include <boost/core/ignore_unused.hpp>
18#include <boost/assert.hpp>
19#include <boost/format.hpp>
20#include <boost/log/trivial.hpp>
21
22// The generated code based on the Serialize schema:
23#include <Schema_generated.h>
24
25#include <fstream>
26
27using armnn::ParseException;
28using namespace armnn;
29using namespace armnn::armnnSerializer;
30
Aron Virginas-Tarfc413c02019-02-13 15:41:52 +000031namespace armnnDeserializeParser
32{
Kevin May43a799c2019-02-08 16:31:42 +000033
Aron Virginas-Tarfc413c02019-02-13 15:41:52 +000034namespace
35{
Kevin May43a799c2019-02-08 16:31:42 +000036
37const uint32_t VIRTUAL_LAYER_ID = std::numeric_limits<uint32_t>::max();
38
39 void CheckGraph(const DeserializeParser::GraphPtr& graph,
40 unsigned int layersIndex,
41 const CheckLocation& location)
42{
43 if (graph->layers() == nullptr)
44 {
45 throw ParseException(
46 boost::str(
47 boost::format("%1% was called with invalid (null) graph. "
48 "Possible reason is that the graph is not yet loaded and Unpack(ed). "
49 "layers:%2% at %3%") %
50 location.m_Function %
51 layersIndex %
52 location.FileLine()));
53 }
54 else if (layersIndex >= graph->layers()->size())
55 {
56 throw ParseException(
57 boost::str(
58 boost::format("%1% was called with an invalid layers index. "
59 "layers:%2% at %3%") %
60 location.m_Function %
61 layersIndex %
62 location.FileLine()));
63 }
64}
65
66void CheckLayers(const DeserializeParser::GraphPtr& graph,
67 unsigned int layersIndex,
68 unsigned int layerIndex,
69 const CheckLocation& location)
70{
71 if (graph->layers() == nullptr)
72 {
73 throw ParseException(
74 boost::str(
75 boost::format("%1% was called with invalid (null) graph. "
76 "Possible reason is that the graph is not yet loaded and Unpack(ed). "
Nattapat Chaimanowong43e78642019-02-13 15:56:24 +000077 "layers:%2% at %3%") %
Kevin May43a799c2019-02-08 16:31:42 +000078 location.m_Function %
79 layersIndex %
80 location.FileLine()));
81 }
82 else if (layersIndex >= graph->layers()->size())
83 {
84 throw ParseException(
85 boost::str(
86 boost::format("%1% was called with an invalid layers index. "
Nattapat Chaimanowong43e78642019-02-13 15:56:24 +000087 "layers:%2% at %3%") %
Kevin May43a799c2019-02-08 16:31:42 +000088 location.m_Function %
89 layersIndex %
90 location.FileLine()));
91 }
92 else if (layerIndex >= graph->layers()[layersIndex].size()
93 && layerIndex != VIRTUAL_LAYER_ID)
94 {
95 throw ParseException(
96 boost::str(
97 boost::format("%1% was called with an invalid layer index. "
98 "layers:%2% layer:%3% at %4%") %
99 location.m_Function %
100 layersIndex %
101 layerIndex %
102 location.FileLine()));
103 }
104}
105
106void CheckTensorPtr(DeserializeParser::TensorRawPtr rawPtr,
107 const CheckLocation& location)
108{
109 if (rawPtr == nullptr)
110 {
111 throw ParseException(
112 boost::str(
113 boost::format("%1% was called with a null tensor pointer. "
114 "at %2%") %
115 location.m_Function %
116 location.FileLine()));
117
118 }
119}
120
121#define CHECK_TENSOR_PTR(TENSOR_PTR) \
122 CheckTensorPtr(TENSOR_PTR, CHECK_LOCATION())
123
124#define CHECK_LAYERS(GRAPH, LAYERS_INDEX, LAYER_INDEX) \
125 CheckLayers(GRAPH, LAYERS_INDEX, LAYER_INDEX, CHECK_LOCATION())
126
127#define CHECK_GRAPH(GRAPH, LAYERS_INDEX) \
128 CheckGraph(GRAPH, LAYERS_INDEX, CHECK_LOCATION())
129}
130
131DeserializeParser::DeserializeParser()
132: m_Network(nullptr, nullptr),
133//May require LayerType_Max to be included
134m_ParserFunctions(Layer_MAX+1, &DeserializeParser::ParseUnsupportedLayer)
135{
136 // register supported layers
Aron Virginas-Tarfc413c02019-02-13 15:41:52 +0000137 m_ParserFunctions[Layer_AdditionLayer] = &DeserializeParser::ParseAdd;
138 m_ParserFunctions[Layer_MultiplicationLayer] = &DeserializeParser::ParseMultiplication;
139 m_ParserFunctions[Layer_SoftmaxLayer] = &DeserializeParser::ParseSoftmax;
Kevin May43a799c2019-02-08 16:31:42 +0000140}
141
142DeserializeParser::LayerBaseRawPtr DeserializeParser::GetBaseLayer(const GraphPtr& graphPtr, unsigned int layerIndex)
143{
144 auto layerType = graphPtr->layers()->Get(layerIndex)->layer_type();
145
146 switch(layerType)
147 {
148 case Layer::Layer_AdditionLayer:
149 return graphPtr->layers()->Get(layerIndex)->layer_as_AdditionLayer()->base();
150 case Layer::Layer_InputLayer:
151 return graphPtr->layers()->Get(layerIndex)->layer_as_InputLayer()->base()->base();
Sadik Armagan5f450272019-02-12 14:31:45 +0000152 case Layer::Layer_MultiplicationLayer:
153 return graphPtr->layers()->Get(layerIndex)->layer_as_MultiplicationLayer()->base();
Kevin May43a799c2019-02-08 16:31:42 +0000154 case Layer::Layer_OutputLayer:
155 return graphPtr->layers()->Get(layerIndex)->layer_as_OutputLayer()->base()->base();
Aron Virginas-Tarfc413c02019-02-13 15:41:52 +0000156 case Layer::Layer_SoftmaxLayer:
157 return graphPtr->layers()->Get(layerIndex)->layer_as_SoftmaxLayer()->base();
Kevin May43a799c2019-02-08 16:31:42 +0000158 case Layer::Layer_NONE:
159 default:
160 throw ParseException(boost::str(
161 boost::format("Layer must have a type %1%") %
162 Layer::Layer_NONE));
163 }
164}
165
166int32_t DeserializeParser::GetBindingLayerInfo(const GraphPtr& graphPtr, unsigned int layerIndex)
167{
168 auto layerType = graphPtr->layers()->Get(layerIndex)->layer_type();
169
170 if (layerType == Layer::Layer_InputLayer)
171 {
172 return graphPtr->layers()->Get(layerIndex)->layer_as_InputLayer()->base()->layerBindingId();
173 }
174 else if ( layerType == Layer::Layer_OutputLayer )
175 {
176 return graphPtr->layers()->Get(layerIndex)->layer_as_OutputLayer()->base()->layerBindingId();
177 }
178 return 0;
179}
180
181armnn::TensorInfo ToTensorInfo(DeserializeParser::TensorRawPtr tensorPtr)
182{
183 armnn::DataType type;
184 CHECK_TENSOR_PTR(tensorPtr);
185
186 switch (tensorPtr->dataType())
187 {
188 case DataType_QuantisedAsymm8:
189 type = armnn::DataType::QuantisedAsymm8;
190 break;
191 case DataType_Float32:
192 type = armnn::DataType::Float32;
193 break;
194 case DataType_Float16:
195 type = armnn::DataType::Float16;
196 break;
197 case DataType_Boolean:
198 type = armnn::DataType::Boolean;
199 break;
200 default:
201 {
202 CheckLocation location = CHECK_LOCATION();
203 throw ParseException(
204 boost::str(
205 boost::format("Unsupported data type %1% = %2%. %3%") %
206 tensorPtr->dataType() %
207 EnumNameDataType(tensorPtr->dataType()) %
208 location.AsString()));
209 }
210 }
211 float quantizationScale = tensorPtr->quantizationScale();
212 int32_t quantizationOffset = tensorPtr->quantizationOffset();
213
214 auto dimensions = tensorPtr->dimensions();
215 unsigned int size = dimensions->size();
216 std::vector<unsigned int> outputDims(dimensions->begin(), dimensions->begin() + size);
217
218 // two statements (on purpose) for easier debugging:
219 armnn::TensorInfo result(size,
220 outputDims.data(),
221 type,
222 quantizationScale,
223 quantizationOffset);
224 return result;
225}
226
227DeserializeParser::LayerBaseRawPtrVector DeserializeParser::GetGraphInputs(const GraphPtr& graphPtr)
228{
229
230 CHECK_GRAPH(graphPtr, 0);
231 const auto& numInputs = graphPtr->inputIds()->size();
232
233 LayerBaseRawPtrVector result(numInputs);
234
235 for (unsigned int i=0; i<numInputs; ++i)
236 {
Mike Kelly8c1701a2019-02-11 17:01:27 +0000237 uint32_t inputId = graphPtr->inputIds()->Get(i);
Kevin May43a799c2019-02-08 16:31:42 +0000238 result[i] = GetBaseLayer(graphPtr, static_cast<uint32_t>(inputId));
239 }
240 return result;
241}
242
243DeserializeParser::LayerBaseRawPtrVector DeserializeParser::GetGraphOutputs(const GraphPtr& graphPtr)
244{
245 CHECK_GRAPH(graphPtr, 0);
246 const auto& numOutputs = graphPtr->outputIds()->size();
247
248 LayerBaseRawPtrVector result(numOutputs);
249
250 for (unsigned int i=0; i<numOutputs; ++i)
251 {
Mike Kelly8c1701a2019-02-11 17:01:27 +0000252 uint32_t outputId = graphPtr->outputIds()->Get(i);
Kevin May43a799c2019-02-08 16:31:42 +0000253 result[i] = GetBaseLayer(graphPtr, static_cast<uint32_t>(outputId));
254 }
255 return result;
256}
257
258DeserializeParser::TensorRawPtrVector DeserializeParser::GetInputs(const GraphPtr& graphPtr,
259 unsigned int layerIndex)
260{
261 CHECK_LAYERS(graphPtr, 0, layerIndex);
262 auto layer = GetBaseLayer(graphPtr, layerIndex);
263 const auto& numInputs = layer->inputSlots()->size();
264
265 TensorRawPtrVector result(numInputs);
266
267 for (unsigned int i=0; i<numInputs; ++i)
268 {
269 auto inputId = CHECKED_NON_NEGATIVE(static_cast<int32_t>
270 (layer->inputSlots()->Get(i)->connection()->sourceLayerIndex()));
271 result[i] = GetBaseLayer(graphPtr, inputId)->outputSlots()->Get(0)->tensorInfo();
272 }
273 return result;
274}
275
276DeserializeParser::TensorRawPtrVector DeserializeParser::GetOutputs(const GraphPtr& graphPtr,
277 unsigned int layerIndex)
278{
279 CHECK_LAYERS(graphPtr, 0, layerIndex);
280 auto layer = GetBaseLayer(graphPtr, layerIndex);
281 const auto& numOutputs = layer->outputSlots()->size();
282
283 TensorRawPtrVector result(numOutputs);
284
285 for (unsigned int i=0; i<numOutputs; ++i)
286 {
287 result[i] = layer->outputSlots()->Get(i)->tensorInfo();
288 }
289 return result;
290}
291
292void DeserializeParser::ParseUnsupportedLayer(unsigned int layerIndex)
293{
294 CHECK_LAYERS(m_Graph, 0, layerIndex);
295 const auto layerName = GetBaseLayer(m_Graph, layerIndex)->layerName()->c_str();
296 throw ParseException(
297 boost::str(
298 boost::format("Layer not supported. "
299 "layerIndex: %1% "
300 "layerName: %2% / %3%") %
301 layerIndex %
302 layerName %
303 CHECK_LOCATION().AsString()));
304}
305
306void DeserializeParser::ResetParser()
307{
308 m_Network = armnn::INetworkPtr(nullptr, nullptr);
309 m_Graph = nullptr;
310}
311
312IDeserializeParser* IDeserializeParser::CreateRaw()
313{
314 return new DeserializeParser();
315}
316
317IDeserializeParserPtr IDeserializeParser::Create()
318{
319 return IDeserializeParserPtr(CreateRaw(), &IDeserializeParser::Destroy);
320}
321
322void IDeserializeParser::Destroy(IDeserializeParser* parser)
323{
324 delete parser;
325}
326
327INetworkPtr DeserializeParser::CreateNetworkFromBinaryFile(const char* graphFile)
328{
329 ResetParser();
Nattapat Chaimanowong43e78642019-02-13 15:56:24 +0000330 m_Graph = LoadGraphFromFile(graphFile, m_FileContent);
Kevin May43a799c2019-02-08 16:31:42 +0000331 return CreateNetworkFromGraph();
332}
333
334INetworkPtr DeserializeParser::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent)
335{
336 ResetParser();
337 m_Graph = LoadGraphFromBinary(binaryContent.data(), binaryContent.size());
338 return CreateNetworkFromGraph();
339}
340
Nattapat Chaimanowong43e78642019-02-13 15:56:24 +0000341DeserializeParser::GraphPtr DeserializeParser::LoadGraphFromFile(const char* fileName, std::string& fileContent)
Kevin May43a799c2019-02-08 16:31:42 +0000342{
343 if (fileName == nullptr)
344 {
345 throw InvalidArgumentException(boost::str(boost::format("Invalid (null) file name %1%") %
346 CHECK_LOCATION().AsString()));
347 }
348 boost::system::error_code errorCode;
349 boost::filesystem::path pathToFile(fileName);
350 if (!boost::filesystem::exists(pathToFile, errorCode))
351 {
352 throw FileNotFoundException(boost::str(boost::format("Cannot find the file (%1%) errorCode: %2% %3%") %
353 fileName %
354 errorCode %
355 CHECK_LOCATION().AsString()));
356 }
357 std::ifstream file(fileName, std::ios::binary);
Nattapat Chaimanowong43e78642019-02-13 15:56:24 +0000358 fileContent = std::string((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
359
Kevin May43a799c2019-02-08 16:31:42 +0000360 return LoadGraphFromBinary(reinterpret_cast<const uint8_t*>(fileContent.c_str()), fileContent.size());
361}
362
363DeserializeParser::GraphPtr DeserializeParser::LoadGraphFromBinary(const uint8_t* binaryContent, size_t len)
364{
365 if (binaryContent == nullptr)
366 {
367 throw InvalidArgumentException(boost::str(boost::format("Invalid (null) binary content %1%") %
368 CHECK_LOCATION().AsString()));
369 }
370 flatbuffers::Verifier verifier(binaryContent, len);
371 if (verifier.VerifyBuffer<SerializedGraph>() == false)
372 {
373 throw ParseException(
374 boost::str(boost::format("Buffer doesn't conform to the expected Armnn "
375 "flatbuffers format. size:%1% %2%") %
376 len %
377 CHECK_LOCATION().AsString()));
378 }
379 return GetSerializedGraph(binaryContent);
380}
381
382INetworkPtr DeserializeParser::CreateNetworkFromGraph()
383{
384 m_Network = INetwork::Create();
385 BOOST_ASSERT(m_Graph != nullptr);
386 unsigned int layerIndex = 0;
387 m_GraphConnections.emplace_back(m_Graph->layers()->size());
388 for (AnyLayer const* layer : *m_Graph->layers())
389 {
390 if (layer->layer_type() != Layer_InputLayer &&
391 layer->layer_type() != Layer_OutputLayer)
392 {
393 // lookup and call the parser function
394 auto& parserFunction = m_ParserFunctions[layer->layer_type()];
395 (this->*parserFunction)(layerIndex);
396 }
397 ++layerIndex;
398 }
399
400 SetupInputLayers();
401 SetupOutputLayers();
402
403 // establish the connections from the layer outputs to the inputs of the subsequent layers
404 for (size_t connectionIndex = 0; connectionIndex < m_GraphConnections[0].size(); ++connectionIndex)
405 {
406 if (m_GraphConnections[0][connectionIndex].outputSlot != nullptr)
407 {
408 for (size_t inputSlotIdx = 0;
409 inputSlotIdx < m_GraphConnections[0][connectionIndex].inputSlots.size();
410 ++inputSlotIdx)
411 {
412 m_GraphConnections[0][connectionIndex].outputSlot->Connect(
413 *(m_GraphConnections[0][connectionIndex].inputSlots[inputSlotIdx]));
414 }
415 }
416 }
417
418 return std::move(m_Network);
419}
420
421BindingPointInfo DeserializeParser::GetNetworkInputBindingInfo(unsigned int layerIndex,
422 const std::string& name) const
423{
424 CHECK_LAYERS(m_Graph, 0, layerIndex);
425 auto inputs = GetGraphInputs(m_Graph);
426
427 for (auto const& input : inputs)
428 {
429 if (input->layerName()->c_str() == name)
430 {
431 int bindingId = reinterpret_cast<armnn::LayerBindingId>(GetBindingLayerInfo(m_Graph, input->index()));
432 auto layerBase = GetBaseLayer(m_Graph,input->index())->outputSlots()->Get(layerIndex);
433 return std::make_pair(bindingId, ToTensorInfo(layerBase->tensorInfo()));
434 }
435 }
436 throw ParseException(
437 boost::str(
438 boost::format("No input binding found for layer:%1% / %2%") %
439 name %
440 CHECK_LOCATION().AsString()));
441}
442
443BindingPointInfo DeserializeParser::GetNetworkOutputBindingInfo(unsigned int layerIndex,
444 const std::string& name) const
445{
446 CHECK_LAYERS(m_Graph, 0, layerIndex);
447 auto outputs = GetGraphOutputs(m_Graph);
448
449 for (auto const& output : outputs)
450 {
451 if (output->layerName()->c_str() == name)
452 {
453 int bindingId = reinterpret_cast<armnn::LayerBindingId>(GetBindingLayerInfo(m_Graph, output->index()));
454 auto layer = GetBaseLayer(m_Graph, output->index());
455 auto sourceLayerIndex = layer->inputSlots()->Get(0)->connection()->sourceLayerIndex();
456 auto sourceLayer = GetBaseLayer(m_Graph, sourceLayerIndex);
457 return std::make_pair(bindingId, ToTensorInfo(sourceLayer->outputSlots()->Get(0)->tensorInfo()));
458 }
459 }
460 throw ParseException(
461 boost::str(
462 boost::format("No output binding found for layer:%1% / %2%") %
463 name %
464 CHECK_LOCATION().AsString()));
465}
466
467void DeserializeParser::SetupInputLayers()
468{
469 CHECK_GRAPH(m_Graph, 0);
470 auto inputs = GetGraphInputs(m_Graph);
471 for (auto const& input : inputs)
472 {
473 IConnectableLayer* layer =
Saoirse Stewart3fcef202019-02-14 14:57:37 +0000474 m_Network->AddInputLayer(GetBindingLayerInfo(m_Graph, input->index()), input->layerName()->c_str());
Kevin May43a799c2019-02-08 16:31:42 +0000475
476 auto tensorInfo = ToTensorInfo(input->outputSlots()->Get(0)->tensorInfo());
477 layer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
478
479 RegisterOutputSlots(input->index(), layer);
480 }
481}
482
483void DeserializeParser::SetupOutputLayers()
484{
485 CHECK_GRAPH(m_Graph, 0);
486 auto outputs = GetGraphOutputs(m_Graph);
487 for (auto const& output : outputs)
488 {
489 IConnectableLayer* layer =
Saoirse Stewart3fcef202019-02-14 14:57:37 +0000490 m_Network->AddOutputLayer(GetBindingLayerInfo(m_Graph, output->index()), output->layerName()->c_str());
Kevin May43a799c2019-02-08 16:31:42 +0000491
492 RegisterInputSlots(output->index(), layer);
493 }
494}
495
496void DeserializeParser::RegisterOutputSlots(uint32_t layerIndex,
497 IConnectableLayer* layer)
498{
499 CHECK_LAYERS(m_Graph, 0, layerIndex);
500 BOOST_ASSERT(layer != nullptr);
501 auto parsedLayer = GetBaseLayer(m_Graph, layerIndex);
502 if (parsedLayer->outputSlots()->size() != layer->GetNumOutputSlots())
503 {
504 throw ParseException(
505 boost::str(boost::format("The number of outputslots (%1%) does not match the number expected (%2%)"
506 " for layer index: %3% %4%") %
507 parsedLayer->outputSlots()->size() %
508 layer->GetNumOutputSlots() %
509 layerIndex %
510 CHECK_LOCATION().AsString()));
511 }
512
513 for (unsigned int slotIndex = 0; slotIndex < layer->GetNumOutputSlots(); ++slotIndex)
514 {
515 armnn::IOutputSlot* slot = &(layer->GetOutputSlot(slotIndex));
516 RegisterOutputSlotOfConnection(layerIndex, slot);
517 }
518}
519
520void DeserializeParser::RegisterInputSlots(uint32_t layerIndex,
521 armnn::IConnectableLayer* layer)
522{
523 CHECK_LAYERS(m_Graph, 0, layerIndex);
524 BOOST_ASSERT(layer != nullptr);
525 auto parsedLayer = GetBaseLayer(m_Graph, layerIndex);
526 if (parsedLayer->inputSlots()->size() != layer->GetNumInputSlots())
527 {
528 throw ParseException(
529 boost::str(boost::format("The number of inputslots (%1%) does not match the number expected (%2%)"
530 " for layer index:%3% %4%") %
531 parsedLayer->inputSlots()->size() %
532 layer->GetNumInputSlots() %
533 layerIndex %
534 CHECK_LOCATION().AsString()));
535 }
536
537 for (unsigned int slotIndex = 0; slotIndex < layer->GetNumInputSlots(); ++slotIndex)
538 {
539 armnn::IInputSlot* slot = &(layer->GetInputSlot(slotIndex));
540 uint32_t sourceLayerIndex = parsedLayer->inputSlots()->Get(slotIndex)->connection()->sourceLayerIndex();
541 RegisterInputSlotOfConnection(sourceLayerIndex, slot);
542 }
543}
544
545void DeserializeParser::RegisterInputSlotOfConnection(uint32_t connectionIndex,
546 armnn::IInputSlot* slot)
547{
548 BOOST_ASSERT(m_GraphConnections[0].size() > connectionIndex);
549
550 Slots& slots = m_GraphConnections[0][connectionIndex];
551 slots.inputSlots.push_back(slot);
552}
553
554void DeserializeParser::RegisterOutputSlotOfConnection(uint32_t connectionIndex,
555 armnn::IOutputSlot* slot)
556{
557 BOOST_ASSERT(m_GraphConnections[0].size() > connectionIndex);
558
559 Slots& slots = m_GraphConnections[0][connectionIndex];
560
561 // assuming there is only one producer for that tensor
562 if (slots.outputSlot != nullptr)
563 {
564 throw ParseException(boost::str(
565 boost::format("Another layer has already registered itself as the producer of "
566 "connection:%1% / %2%") %
567 connectionIndex %
568 CHECK_LOCATION().AsString()));
569 }
570
571 slots.outputSlot = slot;
572}
573
574void DeserializeParser::ParseAdd(unsigned int layerIndex)
575{
576 CHECK_LAYERS(m_Graph, 0, layerIndex);
577 auto inputs = GetInputs(m_Graph, layerIndex);
578 CHECK_LOCATION();
579 CHECK_VALID_SIZE(inputs.size(), 2);
580
581 auto outputs = GetOutputs(m_Graph, layerIndex);
582 CHECK_VALID_SIZE(outputs.size(), 1);
583
584 auto layerName = boost::str(boost::format("Addition:%1%") % layerIndex);
585 IConnectableLayer* layer = m_Network->AddAdditionLayer(layerName.c_str());
586
587 armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
588 layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
589
590 RegisterInputSlots(layerIndex, layer);
591 RegisterOutputSlots(layerIndex, layer);
592}
593
Sadik Armagan5f450272019-02-12 14:31:45 +0000594void DeserializeParser::ParseMultiplication(unsigned int layerIndex)
595{
596 CHECK_LAYERS(m_Graph, 0, layerIndex);
597 auto inputs = GetInputs(m_Graph, layerIndex);
598 CHECK_LOCATION();
599 CHECK_VALID_SIZE(inputs.size(), 2);
600
601 auto outputs = GetOutputs(m_Graph, layerIndex);
602 CHECK_VALID_SIZE(outputs.size(), 1);
603
604 auto layerName = boost::str(boost::format("Multiplication:%1%") % layerIndex);
605 IConnectableLayer* layer = m_Network->AddMultiplicationLayer(layerName.c_str());
606
607 armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
608 layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
609
610 RegisterInputSlots(layerIndex, layer);
611 RegisterOutputSlots(layerIndex, layer);
612}
613
Aron Virginas-Tarfc413c02019-02-13 15:41:52 +0000614void DeserializeParser::ParseSoftmax(unsigned int layerIndex)
615{
616 CHECK_LAYERS(m_Graph, 0, layerIndex);
617
618 DeserializeParser::TensorRawPtrVector inputs = GetInputs(m_Graph, layerIndex);
619 CHECK_VALID_SIZE(inputs.size(), 1);
620
621 DeserializeParser::TensorRawPtrVector outputs = GetOutputs(m_Graph, layerIndex);
622 CHECK_VALID_SIZE(outputs.size(), 1);
623
624 armnn::SoftmaxDescriptor descriptor;
625 descriptor.m_Beta = m_Graph->layers()->Get(layerIndex)->layer_as_SoftmaxLayer()->descriptor()->beta();
626
627 const std::string layerName = boost::str(boost::format("Softmax:%1%") % layerIndex);
628 IConnectableLayer* layer = m_Network->AddSoftmaxLayer(descriptor, layerName.c_str());
629
630 armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
631 layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
632
633 RegisterInputSlots(layerIndex, layer);
634 RegisterOutputSlots(layerIndex, layer);
Kevin May43a799c2019-02-08 16:31:42 +0000635}
Aron Virginas-Tarfc413c02019-02-13 15:41:52 +0000636
637} // namespace armnnDeserializeParser