blob: 7940ee4c4c05b7afa25ed38940fcc1ebb8075ea8 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
2// Copyright (c) 2020, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef SUBGRAPH_TRAVERSER_H
17#define SUBGRAPH_TRAVERSER_H
18
Matthew Sloyanba5fad32022-09-26 13:31:43 +010019#include "graph_status.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070020#include "graph_node.h"
Kevin Chengacb550f2021-06-29 15:32:19 -070021#include "model_common.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070022#include "ops/op_factory.h"
Kevin Chengacb550f2021-06-29 15:32:19 -070023#include "tensor.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070024#include "tosa_serialization_handler.h"
Kevin Chengcc61be32021-10-14 17:09:57 -070025#include <unordered_set>
Eric Kunzee5e26762020-10-13 16:11:07 -070026
27namespace TosaReference
28{
29
30class SubgraphTraverser
31{
32public:
33 SubgraphTraverser(TosaSerializationBasicBlock* block, TosaSerializationHandler* tsh);
34 ~SubgraphTraverser();
35
36 int initializeGraph();
37 int isFullyEvaluated() const;
38 int evaluateNextNode();
39 int evaluateAll();
40
Kevin Chengacb550f2021-06-29 15:32:19 -070041 GraphStatus getGraphStatus() const
42 {
43 return graph_status;
44 }
45 void setGraphStatus(GraphStatus status)
46 {
47 graph_status = status;
48 }
49
Eric Kunzee5e26762020-10-13 16:11:07 -070050 int linkTensorsAndNodes();
51 int validateGraph();
Kevin Chengcc61be32021-10-14 17:09:57 -070052 int allocateTensor();
Eric Kunzee5e26762020-10-13 16:11:07 -070053
54 int dumpGraph(FILE* out) const;
55 int dumpNextNodeList(FILE* out) const;
56 int clearAllNodeMarkings();
57
Kevin Cheng5d00c692021-10-15 20:06:00 +000058 std::string getBlockName() const
59 {
60 return block->GetName();
61 }
Eric Kunzee5e26762020-10-13 16:11:07 -070062 int getNumInputTensors() const;
63 Tensor* getInputTensor(const unsigned int idx) const;
64 Tensor* getInputTensorByName(const std::string name) const;
65 int getNumOutputTensors() const;
66 Tensor* getOutputTensor(const unsigned int idx) const;
67 Tensor* getOutputTensorByName(const std::string name) const;
68 int addToNextNodeList(GraphNode*);
69
70private:
71 int addTensor(Tensor* ct);
72 int addNode(GraphNode* cn);
73
74 Tensor* findTensorByName(const std::string& name) const;
75
76 GraphNode* getNextNode();
77
Kevin Chengacb550f2021-06-29 15:32:19 -070078 GraphStatus graph_status;
79
Eric Kunzee5e26762020-10-13 16:11:07 -070080 // pointer to serialization library and corresponding basic block
81 TosaSerializationBasicBlock* block;
82 TosaSerializationHandler* tsh;
83
84 // The definitive list of all tensors
85 std::vector<Tensor*> tensors;
86
87 // The subset of tensors that are also input tensors
88 std::vector<Tensor*> inputTensors;
89
90 // The subset of tensors that are also output tensors
91 std::vector<Tensor*> outputTensors;
92
93 // The definitive list of all nodes in the graph
94 std::vector<GraphNode*> nodes;
95
96 // The subset of node that have all of their input tensors ready, but
97 // have not yet been evaluated to produce their output tensors.
98 // With control flow, a node may appear on this list more than once during its
99 // lifetime, although the list itself should only contain unique nodes.
100 std::list<GraphNode*> nextNodeList;
101
Kevin Chengcc61be32021-10-14 17:09:57 -0700102 // tensor name set which contains all the name used by operator
103 std::unordered_set<std::string> used_tensor_name_set;
104
Eric Kunzee5e26762020-10-13 16:11:07 -0700105 // Maximum number of times to evalute a node before
106 // warning.
107 const int MAX_EVAL_COUNT = 10000;
108};
109}; // namespace TosaReference
110
111#endif