Eric Kunze | e5e2676 | 2020-10-13 16:11:07 -0700 | [diff] [blame] | 1 | |
Jerry Ge | 9e94af8 | 2022-10-27 09:57:00 -0700 | [diff] [blame] | 2 | // Copyright (c) 2020-2023, ARM Limited. |
Eric Kunze | e5e2676 | 2020-10-13 16:11:07 -0700 | [diff] [blame] | 3 | // |
| 4 | // Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | // you may not use this file except in compliance with the License. |
| 6 | // You may obtain a copy of the License at |
| 7 | // |
| 8 | // http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | // |
| 10 | // Unless required by applicable law or agreed to in writing, software |
| 11 | // distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | // See the License for the specific language governing permissions and |
| 14 | // limitations under the License. |
| 15 | |
| 16 | #ifndef SUBGRAPH_TRAVERSER_H |
| 17 | #define SUBGRAPH_TRAVERSER_H |
| 18 | |
Eric Kunze | e5e2676 | 2020-10-13 16:11:07 -0700 | [diff] [blame] | 19 | #include "graph_node.h" |
Jerry Ge | 9c9c8da | 2023-07-19 23:08:16 +0000 | [diff] [blame] | 20 | #include "graph_status.h" |
Kevin Cheng | acb550f | 2021-06-29 15:32:19 -0700 | [diff] [blame] | 21 | #include "model_common.h" |
Eric Kunze | e5e2676 | 2020-10-13 16:11:07 -0700 | [diff] [blame] | 22 | #include "ops/op_factory.h" |
Kevin Cheng | acb550f | 2021-06-29 15:32:19 -0700 | [diff] [blame] | 23 | #include "tensor.h" |
Eric Kunze | e5e2676 | 2020-10-13 16:11:07 -0700 | [diff] [blame] | 24 | #include "tosa_serialization_handler.h" |
Kevin Cheng | cc61be3 | 2021-10-14 17:09:57 -0700 | [diff] [blame] | 25 | #include <unordered_set> |
Eric Kunze | e5e2676 | 2020-10-13 16:11:07 -0700 | [diff] [blame] | 26 | |
| 27 | namespace TosaReference |
| 28 | { |
| 29 | |
| 30 | class SubgraphTraverser |
| 31 | { |
| 32 | public: |
Jerry Ge | 9e94af8 | 2022-10-27 09:57:00 -0700 | [diff] [blame] | 33 | SubgraphTraverser(TosaSerializationBasicBlock* block, TosaSerializationHandler* tsh, SubgraphTraverser* parent_sgt); |
Eric Kunze | e5e2676 | 2020-10-13 16:11:07 -0700 | [diff] [blame] | 34 | ~SubgraphTraverser(); |
| 35 | |
| 36 | int initializeGraph(); |
| 37 | int isFullyEvaluated() const; |
| 38 | int evaluateNextNode(); |
| 39 | int evaluateAll(); |
| 40 | |
Kevin Cheng | acb550f | 2021-06-29 15:32:19 -0700 | [diff] [blame] | 41 | GraphStatus getGraphStatus() const |
| 42 | { |
| 43 | return graph_status; |
| 44 | } |
| 45 | void setGraphStatus(GraphStatus status) |
| 46 | { |
| 47 | graph_status = status; |
| 48 | } |
| 49 | |
Eric Kunze | e5e2676 | 2020-10-13 16:11:07 -0700 | [diff] [blame] | 50 | int linkTensorsAndNodes(); |
| 51 | int validateGraph(); |
Jerry Ge | e5cabbf | 2023-07-17 21:33:17 +0000 | [diff] [blame] | 52 | int allocateInputTensors(); |
| 53 | int allocateTensor(std::string name); |
Eric Kunze | e5e2676 | 2020-10-13 16:11:07 -0700 | [diff] [blame] | 54 | |
| 55 | int dumpGraph(FILE* out) const; |
| 56 | int dumpNextNodeList(FILE* out) const; |
| 57 | int clearAllNodeMarkings(); |
| 58 | |
Kevin Cheng | 5d00c69 | 2021-10-15 20:06:00 +0000 | [diff] [blame] | 59 | std::string getBlockName() const |
| 60 | { |
| 61 | return block->GetName(); |
| 62 | } |
Jerry Ge | 9e94af8 | 2022-10-27 09:57:00 -0700 | [diff] [blame] | 63 | std::string getRegionName() const |
| 64 | { |
| 65 | return block->GetRegionName(); |
| 66 | } |
Jerry Ge | 264f7fa | 2023-04-21 22:49:57 +0000 | [diff] [blame] | 67 | TosaSerializationHandler* getTsh() const |
| 68 | { |
| 69 | return tsh; |
| 70 | } |
Eric Kunze | e5e2676 | 2020-10-13 16:11:07 -0700 | [diff] [blame] | 71 | int getNumInputTensors() const; |
| 72 | Tensor* getInputTensor(const unsigned int idx) const; |
| 73 | Tensor* getInputTensorByName(const std::string name) const; |
| 74 | int getNumOutputTensors() const; |
| 75 | Tensor* getOutputTensor(const unsigned int idx) const; |
| 76 | Tensor* getOutputTensorByName(const std::string name) const; |
Tai Ly | cf84bc9 | 2023-09-07 20:49:09 +0000 | [diff] [blame] | 77 | int getNumVariableTensors() const; |
| 78 | Tensor* getVariableTensor(const unsigned int idx) const; |
| 79 | Tensor* getVariableTensorByName(const std::string name) const; |
| 80 | int registerVariableTensor(Tensor* tensor); |
Eric Kunze | e5e2676 | 2020-10-13 16:11:07 -0700 | [diff] [blame] | 81 | int addToNextNodeList(GraphNode*); |
| 82 | |
| 83 | private: |
Tai Ly | cf84bc9 | 2023-09-07 20:49:09 +0000 | [diff] [blame] | 84 | int addTensor(const TosaSerializationTensor* ts); |
Eric Kunze | e5e2676 | 2020-10-13 16:11:07 -0700 | [diff] [blame] | 85 | int addNode(GraphNode* cn); |
| 86 | |
| 87 | Tensor* findTensorByName(const std::string& name) const; |
| 88 | |
| 89 | GraphNode* getNextNode(); |
| 90 | |
Kevin Cheng | acb550f | 2021-06-29 15:32:19 -0700 | [diff] [blame] | 91 | GraphStatus graph_status; |
| 92 | |
Jerry Ge | 9e94af8 | 2022-10-27 09:57:00 -0700 | [diff] [blame] | 93 | // pointer to the parent subgraph traversal if exists |
| 94 | // e.g., Control Flow Ops will have nested blocks (subgraph traversals) |
| 95 | SubgraphTraverser* parent_sgt; |
| 96 | |
Eric Kunze | e5e2676 | 2020-10-13 16:11:07 -0700 | [diff] [blame] | 97 | // pointer to serialization library and corresponding basic block |
| 98 | TosaSerializationBasicBlock* block; |
| 99 | TosaSerializationHandler* tsh; |
| 100 | |
| 101 | // The definitive list of all tensors |
| 102 | std::vector<Tensor*> tensors; |
| 103 | |
| 104 | // The subset of tensors that are also input tensors |
| 105 | std::vector<Tensor*> inputTensors; |
| 106 | |
| 107 | // The subset of tensors that are also output tensors |
| 108 | std::vector<Tensor*> outputTensors; |
| 109 | |
Tai Ly | cf84bc9 | 2023-09-07 20:49:09 +0000 | [diff] [blame] | 110 | // The subset of tensors that are also variable tensors |
| 111 | std::vector<Tensor*> variableTensors; |
| 112 | |
Eric Kunze | e5e2676 | 2020-10-13 16:11:07 -0700 | [diff] [blame] | 113 | // The definitive list of all nodes in the graph |
| 114 | std::vector<GraphNode*> nodes; |
| 115 | |
| 116 | // The subset of node that have all of their input tensors ready, but |
| 117 | // have not yet been evaluated to produce their output tensors. |
| 118 | // With control flow, a node may appear on this list more than once during its |
| 119 | // lifetime, although the list itself should only contain unique nodes. |
| 120 | std::list<GraphNode*> nextNodeList; |
| 121 | |
Kevin Cheng | cc61be3 | 2021-10-14 17:09:57 -0700 | [diff] [blame] | 122 | // tensor name set which contains all the name used by operator |
| 123 | std::unordered_set<std::string> used_tensor_name_set; |
| 124 | |
Eric Kunze | e5e2676 | 2020-10-13 16:11:07 -0700 | [diff] [blame] | 125 | // Maximum number of times to evalute a node before |
| 126 | // warning. |
| 127 | const int MAX_EVAL_COUNT = 10000; |
| 128 | }; |
| 129 | }; // namespace TosaReference |
| 130 | |
| 131 | #endif |