blob: d6b0e8d7f27823b506fd4bd60282c4a1e22a8a71 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Jerry Ge9e94af82022-10-27 09:57:00 -07002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef SUBGRAPH_TRAVERSER_H
17#define SUBGRAPH_TRAVERSER_H
18
Eric Kunzee5e26762020-10-13 16:11:07 -070019#include "graph_node.h"
Jerry Ge9c9c8da2023-07-19 23:08:16 +000020#include "graph_status.h"
Kevin Chengacb550f2021-06-29 15:32:19 -070021#include "model_common.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070022#include "ops/op_factory.h"
Kevin Chengacb550f2021-06-29 15:32:19 -070023#include "tensor.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070024#include "tosa_serialization_handler.h"
Kevin Chengcc61be32021-10-14 17:09:57 -070025#include <unordered_set>
Eric Kunzee5e26762020-10-13 16:11:07 -070026
27namespace TosaReference
28{
29
30class SubgraphTraverser
31{
32public:
Jerry Ge9e94af82022-10-27 09:57:00 -070033 SubgraphTraverser(TosaSerializationBasicBlock* block, TosaSerializationHandler* tsh, SubgraphTraverser* parent_sgt);
Eric Kunzee5e26762020-10-13 16:11:07 -070034 ~SubgraphTraverser();
35
36 int initializeGraph();
37 int isFullyEvaluated() const;
38 int evaluateNextNode();
39 int evaluateAll();
40
Kevin Chengacb550f2021-06-29 15:32:19 -070041 GraphStatus getGraphStatus() const
42 {
43 return graph_status;
44 }
45 void setGraphStatus(GraphStatus status)
46 {
47 graph_status = status;
48 }
49
Eric Kunzee5e26762020-10-13 16:11:07 -070050 int linkTensorsAndNodes();
51 int validateGraph();
Jerry Gee5cabbf2023-07-17 21:33:17 +000052 int allocateInputTensors();
53 int allocateTensor(std::string name);
Eric Kunzee5e26762020-10-13 16:11:07 -070054
55 int dumpGraph(FILE* out) const;
56 int dumpNextNodeList(FILE* out) const;
57 int clearAllNodeMarkings();
58
Kevin Cheng5d00c692021-10-15 20:06:00 +000059 std::string getBlockName() const
60 {
61 return block->GetName();
62 }
Jerry Ge9e94af82022-10-27 09:57:00 -070063 std::string getRegionName() const
64 {
65 return block->GetRegionName();
66 }
Jerry Ge264f7fa2023-04-21 22:49:57 +000067 TosaSerializationHandler* getTsh() const
68 {
69 return tsh;
70 }
Eric Kunzee5e26762020-10-13 16:11:07 -070071 int getNumInputTensors() const;
72 Tensor* getInputTensor(const unsigned int idx) const;
73 Tensor* getInputTensorByName(const std::string name) const;
74 int getNumOutputTensors() const;
75 Tensor* getOutputTensor(const unsigned int idx) const;
76 Tensor* getOutputTensorByName(const std::string name) const;
Tai Lycf84bc92023-09-07 20:49:09 +000077 int getNumVariableTensors() const;
78 Tensor* getVariableTensor(const unsigned int idx) const;
79 Tensor* getVariableTensorByName(const std::string name) const;
80 int registerVariableTensor(Tensor* tensor);
Eric Kunzee5e26762020-10-13 16:11:07 -070081 int addToNextNodeList(GraphNode*);
82
83private:
Tai Lycf84bc92023-09-07 20:49:09 +000084 int addTensor(const TosaSerializationTensor* ts);
Eric Kunzee5e26762020-10-13 16:11:07 -070085 int addNode(GraphNode* cn);
86
87 Tensor* findTensorByName(const std::string& name) const;
88
89 GraphNode* getNextNode();
90
Kevin Chengacb550f2021-06-29 15:32:19 -070091 GraphStatus graph_status;
92
Jerry Ge9e94af82022-10-27 09:57:00 -070093 // pointer to the parent subgraph traversal if exists
94 // e.g., Control Flow Ops will have nested blocks (subgraph traversals)
95 SubgraphTraverser* parent_sgt;
96
Eric Kunzee5e26762020-10-13 16:11:07 -070097 // pointer to serialization library and corresponding basic block
98 TosaSerializationBasicBlock* block;
99 TosaSerializationHandler* tsh;
100
101 // The definitive list of all tensors
102 std::vector<Tensor*> tensors;
103
104 // The subset of tensors that are also input tensors
105 std::vector<Tensor*> inputTensors;
106
107 // The subset of tensors that are also output tensors
108 std::vector<Tensor*> outputTensors;
109
Tai Lycf84bc92023-09-07 20:49:09 +0000110 // The subset of tensors that are also variable tensors
111 std::vector<Tensor*> variableTensors;
112
Eric Kunzee5e26762020-10-13 16:11:07 -0700113 // The definitive list of all nodes in the graph
114 std::vector<GraphNode*> nodes;
115
116 // The subset of node that have all of their input tensors ready, but
117 // have not yet been evaluated to produce their output tensors.
118 // With control flow, a node may appear on this list more than once during its
119 // lifetime, although the list itself should only contain unique nodes.
120 std::list<GraphNode*> nextNodeList;
121
Kevin Chengcc61be32021-10-14 17:09:57 -0700122 // tensor name set which contains all the name used by operator
123 std::unordered_set<std::string> used_tensor_name_set;
124
Eric Kunzee5e26762020-10-13 16:11:07 -0700125 // Maximum number of times to evalute a node before
126 // warning.
127 const int MAX_EVAL_COUNT = 10000;
128};
129}; // namespace TosaReference
130
131#endif