blob: 6a6a0fafe2edb66919ce11657782e767f1f450db [file] [log] [blame]
Kevin May43a799c2019-02-08 16:31:42 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "DeserializeParser.hpp"
7
8#include <armnn/ArmNN.hpp>
9#include <armnn/Exceptions.hpp>
10
11#include <ParserHelper.hpp>
12#include <Permute.hpp>
13#include <VerificationHelpers.hpp>
14
15#include <boost/filesystem.hpp>
16#include <boost/format.hpp>
17#include <boost/core/ignore_unused.hpp>
18#include <boost/assert.hpp>
19#include <boost/format.hpp>
20#include <boost/log/trivial.hpp>
21
22// The generated code based on the Serialize schema:
23#include <Schema_generated.h>
24
25#include <fstream>
26
27using armnn::ParseException;
28using namespace armnn;
29using namespace armnn::armnnSerializer;
30
31namespace armnnDeserializeParser {
32
33namespace {
34
35const uint32_t VIRTUAL_LAYER_ID = std::numeric_limits<uint32_t>::max();
36
37 void CheckGraph(const DeserializeParser::GraphPtr& graph,
38 unsigned int layersIndex,
39 const CheckLocation& location)
40{
41 if (graph->layers() == nullptr)
42 {
43 throw ParseException(
44 boost::str(
45 boost::format("%1% was called with invalid (null) graph. "
46 "Possible reason is that the graph is not yet loaded and Unpack(ed). "
47 "layers:%2% at %3%") %
48 location.m_Function %
49 layersIndex %
50 location.FileLine()));
51 }
52 else if (layersIndex >= graph->layers()->size())
53 {
54 throw ParseException(
55 boost::str(
56 boost::format("%1% was called with an invalid layers index. "
57 "layers:%2% at %3%") %
58 location.m_Function %
59 layersIndex %
60 location.FileLine()));
61 }
62}
63
64void CheckLayers(const DeserializeParser::GraphPtr& graph,
65 unsigned int layersIndex,
66 unsigned int layerIndex,
67 const CheckLocation& location)
68{
69 if (graph->layers() == nullptr)
70 {
71 throw ParseException(
72 boost::str(
73 boost::format("%1% was called with invalid (null) graph. "
74 "Possible reason is that the graph is not yet loaded and Unpack(ed). "
Nattapat Chaimanowong43e78642019-02-13 15:56:24 +000075 "layers:%2% at %3%") %
Kevin May43a799c2019-02-08 16:31:42 +000076 location.m_Function %
77 layersIndex %
78 location.FileLine()));
79 }
80 else if (layersIndex >= graph->layers()->size())
81 {
82 throw ParseException(
83 boost::str(
84 boost::format("%1% was called with an invalid layers index. "
Nattapat Chaimanowong43e78642019-02-13 15:56:24 +000085 "layers:%2% at %3%") %
Kevin May43a799c2019-02-08 16:31:42 +000086 location.m_Function %
87 layersIndex %
88 location.FileLine()));
89 }
90 else if (layerIndex >= graph->layers()[layersIndex].size()
91 && layerIndex != VIRTUAL_LAYER_ID)
92 {
93 throw ParseException(
94 boost::str(
95 boost::format("%1% was called with an invalid layer index. "
96 "layers:%2% layer:%3% at %4%") %
97 location.m_Function %
98 layersIndex %
99 layerIndex %
100 location.FileLine()));
101 }
102}
103
104void CheckTensorPtr(DeserializeParser::TensorRawPtr rawPtr,
105 const CheckLocation& location)
106{
107 if (rawPtr == nullptr)
108 {
109 throw ParseException(
110 boost::str(
111 boost::format("%1% was called with a null tensor pointer. "
112 "at %2%") %
113 location.m_Function %
114 location.FileLine()));
115
116 }
117}
118
119#define CHECK_TENSOR_PTR(TENSOR_PTR) \
120 CheckTensorPtr(TENSOR_PTR, CHECK_LOCATION())
121
122#define CHECK_LAYERS(GRAPH, LAYERS_INDEX, LAYER_INDEX) \
123 CheckLayers(GRAPH, LAYERS_INDEX, LAYER_INDEX, CHECK_LOCATION())
124
125#define CHECK_GRAPH(GRAPH, LAYERS_INDEX) \
126 CheckGraph(GRAPH, LAYERS_INDEX, CHECK_LOCATION())
127}
128
129DeserializeParser::DeserializeParser()
130: m_Network(nullptr, nullptr),
131//May require LayerType_Max to be included
132m_ParserFunctions(Layer_MAX+1, &DeserializeParser::ParseUnsupportedLayer)
133{
134 // register supported layers
Sadik Armagan5f450272019-02-12 14:31:45 +0000135 m_ParserFunctions[Layer_AdditionLayer] = &DeserializeParser::ParseAdd;
136 m_ParserFunctions[Layer_MultiplicationLayer] = &DeserializeParser::ParseMultiplication;
Kevin May43a799c2019-02-08 16:31:42 +0000137}
138
139DeserializeParser::LayerBaseRawPtr DeserializeParser::GetBaseLayer(const GraphPtr& graphPtr, unsigned int layerIndex)
140{
141 auto layerType = graphPtr->layers()->Get(layerIndex)->layer_type();
142
143 switch(layerType)
144 {
145 case Layer::Layer_AdditionLayer:
146 return graphPtr->layers()->Get(layerIndex)->layer_as_AdditionLayer()->base();
147 case Layer::Layer_InputLayer:
148 return graphPtr->layers()->Get(layerIndex)->layer_as_InputLayer()->base()->base();
Sadik Armagan5f450272019-02-12 14:31:45 +0000149 case Layer::Layer_MultiplicationLayer:
150 return graphPtr->layers()->Get(layerIndex)->layer_as_MultiplicationLayer()->base();
Kevin May43a799c2019-02-08 16:31:42 +0000151 case Layer::Layer_OutputLayer:
152 return graphPtr->layers()->Get(layerIndex)->layer_as_OutputLayer()->base()->base();
153 case Layer::Layer_NONE:
154 default:
155 throw ParseException(boost::str(
156 boost::format("Layer must have a type %1%") %
157 Layer::Layer_NONE));
158 }
159}
160
161int32_t DeserializeParser::GetBindingLayerInfo(const GraphPtr& graphPtr, unsigned int layerIndex)
162{
163 auto layerType = graphPtr->layers()->Get(layerIndex)->layer_type();
164
165 if (layerType == Layer::Layer_InputLayer)
166 {
167 return graphPtr->layers()->Get(layerIndex)->layer_as_InputLayer()->base()->layerBindingId();
168 }
169 else if ( layerType == Layer::Layer_OutputLayer )
170 {
171 return graphPtr->layers()->Get(layerIndex)->layer_as_OutputLayer()->base()->layerBindingId();
172 }
173 return 0;
174}
175
176armnn::TensorInfo ToTensorInfo(DeserializeParser::TensorRawPtr tensorPtr)
177{
178 armnn::DataType type;
179 CHECK_TENSOR_PTR(tensorPtr);
180
181 switch (tensorPtr->dataType())
182 {
183 case DataType_QuantisedAsymm8:
184 type = armnn::DataType::QuantisedAsymm8;
185 break;
186 case DataType_Float32:
187 type = armnn::DataType::Float32;
188 break;
189 case DataType_Float16:
190 type = armnn::DataType::Float16;
191 break;
192 case DataType_Boolean:
193 type = armnn::DataType::Boolean;
194 break;
195 default:
196 {
197 CheckLocation location = CHECK_LOCATION();
198 throw ParseException(
199 boost::str(
200 boost::format("Unsupported data type %1% = %2%. %3%") %
201 tensorPtr->dataType() %
202 EnumNameDataType(tensorPtr->dataType()) %
203 location.AsString()));
204 }
205 }
206 float quantizationScale = tensorPtr->quantizationScale();
207 int32_t quantizationOffset = tensorPtr->quantizationOffset();
208
209 auto dimensions = tensorPtr->dimensions();
210 unsigned int size = dimensions->size();
211 std::vector<unsigned int> outputDims(dimensions->begin(), dimensions->begin() + size);
212
213 // two statements (on purpose) for easier debugging:
214 armnn::TensorInfo result(size,
215 outputDims.data(),
216 type,
217 quantizationScale,
218 quantizationOffset);
219 return result;
220}
221
222DeserializeParser::LayerBaseRawPtrVector DeserializeParser::GetGraphInputs(const GraphPtr& graphPtr)
223{
224
225 CHECK_GRAPH(graphPtr, 0);
226 const auto& numInputs = graphPtr->inputIds()->size();
227
228 LayerBaseRawPtrVector result(numInputs);
229
230 for (unsigned int i=0; i<numInputs; ++i)
231 {
Mike Kelly8c1701a2019-02-11 17:01:27 +0000232 uint32_t inputId = graphPtr->inputIds()->Get(i);
Kevin May43a799c2019-02-08 16:31:42 +0000233 result[i] = GetBaseLayer(graphPtr, static_cast<uint32_t>(inputId));
234 }
235 return result;
236}
237
238DeserializeParser::LayerBaseRawPtrVector DeserializeParser::GetGraphOutputs(const GraphPtr& graphPtr)
239{
240 CHECK_GRAPH(graphPtr, 0);
241 const auto& numOutputs = graphPtr->outputIds()->size();
242
243 LayerBaseRawPtrVector result(numOutputs);
244
245 for (unsigned int i=0; i<numOutputs; ++i)
246 {
Mike Kelly8c1701a2019-02-11 17:01:27 +0000247 uint32_t outputId = graphPtr->outputIds()->Get(i);
Kevin May43a799c2019-02-08 16:31:42 +0000248 result[i] = GetBaseLayer(graphPtr, static_cast<uint32_t>(outputId));
249 }
250 return result;
251}
252
253DeserializeParser::TensorRawPtrVector DeserializeParser::GetInputs(const GraphPtr& graphPtr,
254 unsigned int layerIndex)
255{
256 CHECK_LAYERS(graphPtr, 0, layerIndex);
257 auto layer = GetBaseLayer(graphPtr, layerIndex);
258 const auto& numInputs = layer->inputSlots()->size();
259
260 TensorRawPtrVector result(numInputs);
261
262 for (unsigned int i=0; i<numInputs; ++i)
263 {
264 auto inputId = CHECKED_NON_NEGATIVE(static_cast<int32_t>
265 (layer->inputSlots()->Get(i)->connection()->sourceLayerIndex()));
266 result[i] = GetBaseLayer(graphPtr, inputId)->outputSlots()->Get(0)->tensorInfo();
267 }
268 return result;
269}
270
271DeserializeParser::TensorRawPtrVector DeserializeParser::GetOutputs(const GraphPtr& graphPtr,
272 unsigned int layerIndex)
273{
274 CHECK_LAYERS(graphPtr, 0, layerIndex);
275 auto layer = GetBaseLayer(graphPtr, layerIndex);
276 const auto& numOutputs = layer->outputSlots()->size();
277
278 TensorRawPtrVector result(numOutputs);
279
280 for (unsigned int i=0; i<numOutputs; ++i)
281 {
282 result[i] = layer->outputSlots()->Get(i)->tensorInfo();
283 }
284 return result;
285}
286
287void DeserializeParser::ParseUnsupportedLayer(unsigned int layerIndex)
288{
289 CHECK_LAYERS(m_Graph, 0, layerIndex);
290 const auto layerName = GetBaseLayer(m_Graph, layerIndex)->layerName()->c_str();
291 throw ParseException(
292 boost::str(
293 boost::format("Layer not supported. "
294 "layerIndex: %1% "
295 "layerName: %2% / %3%") %
296 layerIndex %
297 layerName %
298 CHECK_LOCATION().AsString()));
299}
300
301void DeserializeParser::ResetParser()
302{
303 m_Network = armnn::INetworkPtr(nullptr, nullptr);
304 m_Graph = nullptr;
305}
306
307IDeserializeParser* IDeserializeParser::CreateRaw()
308{
309 return new DeserializeParser();
310}
311
312IDeserializeParserPtr IDeserializeParser::Create()
313{
314 return IDeserializeParserPtr(CreateRaw(), &IDeserializeParser::Destroy);
315}
316
317void IDeserializeParser::Destroy(IDeserializeParser* parser)
318{
319 delete parser;
320}
321
322INetworkPtr DeserializeParser::CreateNetworkFromBinaryFile(const char* graphFile)
323{
324 ResetParser();
Nattapat Chaimanowong43e78642019-02-13 15:56:24 +0000325 m_Graph = LoadGraphFromFile(graphFile, m_FileContent);
Kevin May43a799c2019-02-08 16:31:42 +0000326 return CreateNetworkFromGraph();
327}
328
329INetworkPtr DeserializeParser::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent)
330{
331 ResetParser();
332 m_Graph = LoadGraphFromBinary(binaryContent.data(), binaryContent.size());
333 return CreateNetworkFromGraph();
334}
335
Nattapat Chaimanowong43e78642019-02-13 15:56:24 +0000336DeserializeParser::GraphPtr DeserializeParser::LoadGraphFromFile(const char* fileName, std::string& fileContent)
Kevin May43a799c2019-02-08 16:31:42 +0000337{
338 if (fileName == nullptr)
339 {
340 throw InvalidArgumentException(boost::str(boost::format("Invalid (null) file name %1%") %
341 CHECK_LOCATION().AsString()));
342 }
343 boost::system::error_code errorCode;
344 boost::filesystem::path pathToFile(fileName);
345 if (!boost::filesystem::exists(pathToFile, errorCode))
346 {
347 throw FileNotFoundException(boost::str(boost::format("Cannot find the file (%1%) errorCode: %2% %3%") %
348 fileName %
349 errorCode %
350 CHECK_LOCATION().AsString()));
351 }
352 std::ifstream file(fileName, std::ios::binary);
Nattapat Chaimanowong43e78642019-02-13 15:56:24 +0000353 fileContent = std::string((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
354
Kevin May43a799c2019-02-08 16:31:42 +0000355 return LoadGraphFromBinary(reinterpret_cast<const uint8_t*>(fileContent.c_str()), fileContent.size());
356}
357
358DeserializeParser::GraphPtr DeserializeParser::LoadGraphFromBinary(const uint8_t* binaryContent, size_t len)
359{
360 if (binaryContent == nullptr)
361 {
362 throw InvalidArgumentException(boost::str(boost::format("Invalid (null) binary content %1%") %
363 CHECK_LOCATION().AsString()));
364 }
365 flatbuffers::Verifier verifier(binaryContent, len);
366 if (verifier.VerifyBuffer<SerializedGraph>() == false)
367 {
368 throw ParseException(
369 boost::str(boost::format("Buffer doesn't conform to the expected Armnn "
370 "flatbuffers format. size:%1% %2%") %
371 len %
372 CHECK_LOCATION().AsString()));
373 }
374 return GetSerializedGraph(binaryContent);
375}
376
377INetworkPtr DeserializeParser::CreateNetworkFromGraph()
378{
379 m_Network = INetwork::Create();
380 BOOST_ASSERT(m_Graph != nullptr);
381 unsigned int layerIndex = 0;
382 m_GraphConnections.emplace_back(m_Graph->layers()->size());
383 for (AnyLayer const* layer : *m_Graph->layers())
384 {
385 if (layer->layer_type() != Layer_InputLayer &&
386 layer->layer_type() != Layer_OutputLayer)
387 {
388 // lookup and call the parser function
389 auto& parserFunction = m_ParserFunctions[layer->layer_type()];
390 (this->*parserFunction)(layerIndex);
391 }
392 ++layerIndex;
393 }
394
395 SetupInputLayers();
396 SetupOutputLayers();
397
398 // establish the connections from the layer outputs to the inputs of the subsequent layers
399 for (size_t connectionIndex = 0; connectionIndex < m_GraphConnections[0].size(); ++connectionIndex)
400 {
401 if (m_GraphConnections[0][connectionIndex].outputSlot != nullptr)
402 {
403 for (size_t inputSlotIdx = 0;
404 inputSlotIdx < m_GraphConnections[0][connectionIndex].inputSlots.size();
405 ++inputSlotIdx)
406 {
407 m_GraphConnections[0][connectionIndex].outputSlot->Connect(
408 *(m_GraphConnections[0][connectionIndex].inputSlots[inputSlotIdx]));
409 }
410 }
411 }
412
413 return std::move(m_Network);
414}
415
416BindingPointInfo DeserializeParser::GetNetworkInputBindingInfo(unsigned int layerIndex,
417 const std::string& name) const
418{
419 CHECK_LAYERS(m_Graph, 0, layerIndex);
420 auto inputs = GetGraphInputs(m_Graph);
421
422 for (auto const& input : inputs)
423 {
424 if (input->layerName()->c_str() == name)
425 {
426 int bindingId = reinterpret_cast<armnn::LayerBindingId>(GetBindingLayerInfo(m_Graph, input->index()));
427 auto layerBase = GetBaseLayer(m_Graph,input->index())->outputSlots()->Get(layerIndex);
428 return std::make_pair(bindingId, ToTensorInfo(layerBase->tensorInfo()));
429 }
430 }
431 throw ParseException(
432 boost::str(
433 boost::format("No input binding found for layer:%1% / %2%") %
434 name %
435 CHECK_LOCATION().AsString()));
436}
437
438BindingPointInfo DeserializeParser::GetNetworkOutputBindingInfo(unsigned int layerIndex,
439 const std::string& name) const
440{
441 CHECK_LAYERS(m_Graph, 0, layerIndex);
442 auto outputs = GetGraphOutputs(m_Graph);
443
444 for (auto const& output : outputs)
445 {
446 if (output->layerName()->c_str() == name)
447 {
448 int bindingId = reinterpret_cast<armnn::LayerBindingId>(GetBindingLayerInfo(m_Graph, output->index()));
449 auto layer = GetBaseLayer(m_Graph, output->index());
450 auto sourceLayerIndex = layer->inputSlots()->Get(0)->connection()->sourceLayerIndex();
451 auto sourceLayer = GetBaseLayer(m_Graph, sourceLayerIndex);
452 return std::make_pair(bindingId, ToTensorInfo(sourceLayer->outputSlots()->Get(0)->tensorInfo()));
453 }
454 }
455 throw ParseException(
456 boost::str(
457 boost::format("No output binding found for layer:%1% / %2%") %
458 name %
459 CHECK_LOCATION().AsString()));
460}
461
462void DeserializeParser::SetupInputLayers()
463{
464 CHECK_GRAPH(m_Graph, 0);
465 auto inputs = GetGraphInputs(m_Graph);
466 for (auto const& input : inputs)
467 {
468 IConnectableLayer* layer =
Saoirse Stewart3fcef202019-02-14 14:57:37 +0000469 m_Network->AddInputLayer(GetBindingLayerInfo(m_Graph, input->index()), input->layerName()->c_str());
Kevin May43a799c2019-02-08 16:31:42 +0000470
471 auto tensorInfo = ToTensorInfo(input->outputSlots()->Get(0)->tensorInfo());
472 layer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
473
474 RegisterOutputSlots(input->index(), layer);
475 }
476}
477
478void DeserializeParser::SetupOutputLayers()
479{
480 CHECK_GRAPH(m_Graph, 0);
481 auto outputs = GetGraphOutputs(m_Graph);
482 for (auto const& output : outputs)
483 {
484 IConnectableLayer* layer =
Saoirse Stewart3fcef202019-02-14 14:57:37 +0000485 m_Network->AddOutputLayer(GetBindingLayerInfo(m_Graph, output->index()), output->layerName()->c_str());
Kevin May43a799c2019-02-08 16:31:42 +0000486
487 RegisterInputSlots(output->index(), layer);
488 }
489}
490
491void DeserializeParser::RegisterOutputSlots(uint32_t layerIndex,
492 IConnectableLayer* layer)
493{
494 CHECK_LAYERS(m_Graph, 0, layerIndex);
495 BOOST_ASSERT(layer != nullptr);
496 auto parsedLayer = GetBaseLayer(m_Graph, layerIndex);
497 if (parsedLayer->outputSlots()->size() != layer->GetNumOutputSlots())
498 {
499 throw ParseException(
500 boost::str(boost::format("The number of outputslots (%1%) does not match the number expected (%2%)"
501 " for layer index: %3% %4%") %
502 parsedLayer->outputSlots()->size() %
503 layer->GetNumOutputSlots() %
504 layerIndex %
505 CHECK_LOCATION().AsString()));
506 }
507
508 for (unsigned int slotIndex = 0; slotIndex < layer->GetNumOutputSlots(); ++slotIndex)
509 {
510 armnn::IOutputSlot* slot = &(layer->GetOutputSlot(slotIndex));
511 RegisterOutputSlotOfConnection(layerIndex, slot);
512 }
513}
514
515void DeserializeParser::RegisterInputSlots(uint32_t layerIndex,
516 armnn::IConnectableLayer* layer)
517{
518 CHECK_LAYERS(m_Graph, 0, layerIndex);
519 BOOST_ASSERT(layer != nullptr);
520 auto parsedLayer = GetBaseLayer(m_Graph, layerIndex);
521 if (parsedLayer->inputSlots()->size() != layer->GetNumInputSlots())
522 {
523 throw ParseException(
524 boost::str(boost::format("The number of inputslots (%1%) does not match the number expected (%2%)"
525 " for layer index:%3% %4%") %
526 parsedLayer->inputSlots()->size() %
527 layer->GetNumInputSlots() %
528 layerIndex %
529 CHECK_LOCATION().AsString()));
530 }
531
532 for (unsigned int slotIndex = 0; slotIndex < layer->GetNumInputSlots(); ++slotIndex)
533 {
534 armnn::IInputSlot* slot = &(layer->GetInputSlot(slotIndex));
535 uint32_t sourceLayerIndex = parsedLayer->inputSlots()->Get(slotIndex)->connection()->sourceLayerIndex();
536 RegisterInputSlotOfConnection(sourceLayerIndex, slot);
537 }
538}
539
540void DeserializeParser::RegisterInputSlotOfConnection(uint32_t connectionIndex,
541 armnn::IInputSlot* slot)
542{
543 BOOST_ASSERT(m_GraphConnections[0].size() > connectionIndex);
544
545 Slots& slots = m_GraphConnections[0][connectionIndex];
546 slots.inputSlots.push_back(slot);
547}
548
549void DeserializeParser::RegisterOutputSlotOfConnection(uint32_t connectionIndex,
550 armnn::IOutputSlot* slot)
551{
552 BOOST_ASSERT(m_GraphConnections[0].size() > connectionIndex);
553
554 Slots& slots = m_GraphConnections[0][connectionIndex];
555
556 // assuming there is only one producer for that tensor
557 if (slots.outputSlot != nullptr)
558 {
559 throw ParseException(boost::str(
560 boost::format("Another layer has already registered itself as the producer of "
561 "connection:%1% / %2%") %
562 connectionIndex %
563 CHECK_LOCATION().AsString()));
564 }
565
566 slots.outputSlot = slot;
567}
568
569void DeserializeParser::ParseAdd(unsigned int layerIndex)
570{
571 CHECK_LAYERS(m_Graph, 0, layerIndex);
572 auto inputs = GetInputs(m_Graph, layerIndex);
573 CHECK_LOCATION();
574 CHECK_VALID_SIZE(inputs.size(), 2);
575
576 auto outputs = GetOutputs(m_Graph, layerIndex);
577 CHECK_VALID_SIZE(outputs.size(), 1);
578
579 auto layerName = boost::str(boost::format("Addition:%1%") % layerIndex);
580 IConnectableLayer* layer = m_Network->AddAdditionLayer(layerName.c_str());
581
582 armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
583 layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
584
585 RegisterInputSlots(layerIndex, layer);
586 RegisterOutputSlots(layerIndex, layer);
587}
588
Sadik Armagan5f450272019-02-12 14:31:45 +0000589void DeserializeParser::ParseMultiplication(unsigned int layerIndex)
590{
591 CHECK_LAYERS(m_Graph, 0, layerIndex);
592 auto inputs = GetInputs(m_Graph, layerIndex);
593 CHECK_LOCATION();
594 CHECK_VALID_SIZE(inputs.size(), 2);
595
596 auto outputs = GetOutputs(m_Graph, layerIndex);
597 CHECK_VALID_SIZE(outputs.size(), 1);
598
599 auto layerName = boost::str(boost::format("Multiplication:%1%") % layerIndex);
600 IConnectableLayer* layer = m_Network->AddMultiplicationLayer(layerName.c_str());
601
602 armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
603 layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
604
605 RegisterInputSlots(layerIndex, layer);
606 RegisterOutputSlots(layerIndex, layer);
607}
608
Kevin May43a799c2019-02-08 16:31:42 +0000609}