blob: 15dc888e21b31ab6cf1220822a24346d09602640 [file] [log] [blame]
narpra016f37f832018-12-21 18:30:00 +00001//
Sadik Armagana097d2a2021-11-24 15:47:28 +00002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
narpra016f37f832018-12-21 18:30:00 +00003// SPDX-License-Identifier: MIT
4//
5
6#include "GraphUtils.hpp"
7
Jan Eilersbb446e52020-04-02 13:56:54 +01008#include <armnn/utility/PolymorphicDowncast.hpp>
9
narpra016f37f832018-12-21 18:30:00 +000010bool GraphHasNamedLayer(const armnn::Graph& graph, const std::string& name)
11{
12 for (auto&& layer : graph)
13 {
14 if (layer->GetName() == name)
15 {
16 return true;
17 }
18 }
19 return false;
20}
21
22armnn::Layer* GetFirstLayerWithName(armnn::Graph& graph, const std::string& name)
23{
24 for (auto&& layer : graph)
25 {
26 if (layer->GetNameStr() == name)
27 {
28 return layer;
29 }
30 }
31 return nullptr;
32}
33
34bool CheckNumberOfInputSlot(armnn::Layer* layer, unsigned int num)
35{
36 return layer->GetNumInputSlots() == num;
37}
38
39bool CheckNumberOfOutputSlot(armnn::Layer* layer, unsigned int num)
40{
41 return layer->GetNumOutputSlots() == num;
42}
43
44bool IsConnected(armnn::Layer* srcLayer, armnn::Layer* destLayer,
45 unsigned int srcSlot, unsigned int destSlot,
46 const armnn::TensorInfo& expectedTensorInfo)
47{
48 const armnn::IOutputSlot& outputSlot = srcLayer->GetOutputSlot(srcSlot);
49 const armnn::TensorInfo& tensorInfo = outputSlot.GetTensorInfo();
50 if (expectedTensorInfo != tensorInfo)
51 {
52 return false;
53 }
54 const unsigned int numConnections = outputSlot.GetNumConnections();
55 for (unsigned int c = 0; c < numConnections; ++c)
56 {
Jan Eilersbb446e52020-04-02 13:56:54 +010057 auto inputSlot = armnn::PolymorphicDowncast<const armnn::InputSlot*>(outputSlot.GetConnection(c));
narpra016f37f832018-12-21 18:30:00 +000058 if (inputSlot->GetOwningLayer().GetNameStr() == destLayer->GetNameStr() &&
59 inputSlot->GetSlotIndex() == destSlot)
60 {
61 return true;
62 }
63 }
64 return false;
65}
Narumol Prangnawaratb8d771a2020-08-14 11:51:12 +010066
67/// Checks that first comes before second in the order.
68bool CheckOrder(const armnn::Graph& graph, const armnn::Layer* first, const armnn::Layer* second)
69{
70 graph.Print();
71
72 const auto& order = graph.TopologicalSort();
73
74 auto firstPos = std::find(order.begin(), order.end(), first);
75 auto secondPos = std::find(firstPos, order.end(), second);
76
77 return (secondPos != order.end());
78}