blob: ef6ea424029e1729f64b2023ef1f22b7afbf2d56 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Jerry Ge9e94af82022-10-27 09:57:00 -07002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef SUBGRAPH_TRAVERSER_H
17#define SUBGRAPH_TRAVERSER_H
18
Eric Kunzee5e26762020-10-13 16:11:07 -070019#include "graph_node.h"
Jerry Ge9c9c8da2023-07-19 23:08:16 +000020#include "graph_status.h"
Kevin Chengacb550f2021-06-29 15:32:19 -070021#include "model_common.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070022#include "ops/op_factory.h"
Kevin Chengacb550f2021-06-29 15:32:19 -070023#include "tensor.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070024#include "tosa_serialization_handler.h"
Kevin Chengcc61be32021-10-14 17:09:57 -070025#include <unordered_set>
Eric Kunzee5e26762020-10-13 16:11:07 -070026
27namespace TosaReference
28{
29
30class SubgraphTraverser
31{
32public:
Jerry Ge9e94af82022-10-27 09:57:00 -070033 SubgraphTraverser(TosaSerializationBasicBlock* block, TosaSerializationHandler* tsh, SubgraphTraverser* parent_sgt);
Eric Kunzee5e26762020-10-13 16:11:07 -070034 ~SubgraphTraverser();
35
36 int initializeGraph();
37 int isFullyEvaluated() const;
38 int evaluateNextNode();
39 int evaluateAll();
40
Kevin Chengacb550f2021-06-29 15:32:19 -070041 GraphStatus getGraphStatus() const
42 {
43 return graph_status;
44 }
45 void setGraphStatus(GraphStatus status)
46 {
47 graph_status = status;
48 }
49
Eric Kunzee5e26762020-10-13 16:11:07 -070050 int linkTensorsAndNodes();
51 int validateGraph();
Jerry Gee5cabbf2023-07-17 21:33:17 +000052 int allocateInputTensors();
53 int allocateTensor(std::string name);
Eric Kunzee5e26762020-10-13 16:11:07 -070054
55 int dumpGraph(FILE* out) const;
56 int dumpNextNodeList(FILE* out) const;
57 int clearAllNodeMarkings();
58
Kevin Cheng5d00c692021-10-15 20:06:00 +000059 std::string getBlockName() const
60 {
61 return block->GetName();
62 }
Jerry Ge9e94af82022-10-27 09:57:00 -070063 std::string getRegionName() const
64 {
65 return block->GetRegionName();
66 }
Jerry Ge264f7fa2023-04-21 22:49:57 +000067 TosaSerializationHandler* getTsh() const
68 {
69 return tsh;
70 }
Eric Kunzee5e26762020-10-13 16:11:07 -070071 int getNumInputTensors() const;
72 Tensor* getInputTensor(const unsigned int idx) const;
73 Tensor* getInputTensorByName(const std::string name) const;
74 int getNumOutputTensors() const;
75 Tensor* getOutputTensor(const unsigned int idx) const;
76 Tensor* getOutputTensorByName(const std::string name) const;
77 int addToNextNodeList(GraphNode*);
78
79private:
80 int addTensor(Tensor* ct);
81 int addNode(GraphNode* cn);
82
83 Tensor* findTensorByName(const std::string& name) const;
84
85 GraphNode* getNextNode();
86
Kevin Chengacb550f2021-06-29 15:32:19 -070087 GraphStatus graph_status;
88
Jerry Ge9e94af82022-10-27 09:57:00 -070089 // pointer to the parent subgraph traversal if exists
90 // e.g., Control Flow Ops will have nested blocks (subgraph traversals)
91 SubgraphTraverser* parent_sgt;
92
Eric Kunzee5e26762020-10-13 16:11:07 -070093 // pointer to serialization library and corresponding basic block
94 TosaSerializationBasicBlock* block;
95 TosaSerializationHandler* tsh;
96
97 // The definitive list of all tensors
98 std::vector<Tensor*> tensors;
99
100 // The subset of tensors that are also input tensors
101 std::vector<Tensor*> inputTensors;
102
103 // The subset of tensors that are also output tensors
104 std::vector<Tensor*> outputTensors;
105
106 // The definitive list of all nodes in the graph
107 std::vector<GraphNode*> nodes;
108
109 // The subset of node that have all of their input tensors ready, but
110 // have not yet been evaluated to produce their output tensors.
111 // With control flow, a node may appear on this list more than once during its
112 // lifetime, although the list itself should only contain unique nodes.
113 std::list<GraphNode*> nextNodeList;
114
Kevin Chengcc61be32021-10-14 17:09:57 -0700115 // tensor name set which contains all the name used by operator
116 std::unordered_set<std::string> used_tensor_name_set;
117
Eric Kunzee5e26762020-10-13 16:11:07 -0700118 // Maximum number of times to evalute a node before
119 // warning.
120 const int MAX_EVAL_COUNT = 10000;
121};
122}; // namespace TosaReference
123
124#endif