blob: 543b00807db087884a1fea42d75365947762fcdf [file] [log] [blame]
// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SUBGRAPH_TRAVERSER_H
#define SUBGRAPH_TRAVERSER_H
#include "graph_status.h"
#include "graph_node.h"
#include "model_common.h"
#include "ops/op_factory.h"
#include "tensor.h"
#include "tosa_serialization_handler.h"
#include <unordered_set>
namespace TosaReference
{
class SubgraphTraverser
{
public:
SubgraphTraverser(TosaSerializationBasicBlock* block, TosaSerializationHandler* tsh, SubgraphTraverser* parent_sgt);
~SubgraphTraverser();
int initializeGraph();
int isFullyEvaluated() const;
int evaluateNextNode();
int evaluateAll();
GraphStatus getGraphStatus() const
{
return graph_status;
}
void setGraphStatus(GraphStatus status)
{
graph_status = status;
}
int linkTensorsAndNodes();
int validateGraph();
int allocateTensor();
int dumpGraph(FILE* out) const;
int dumpNextNodeList(FILE* out) const;
int clearAllNodeMarkings();
std::string getBlockName() const
{
return block->GetName();
}
std::string getRegionName() const
{
return block->GetRegionName();
}
int getNumInputTensors() const;
Tensor* getInputTensor(const unsigned int idx) const;
Tensor* getInputTensorByName(const std::string name) const;
int getNumOutputTensors() const;
Tensor* getOutputTensor(const unsigned int idx) const;
Tensor* getOutputTensorByName(const std::string name) const;
int addToNextNodeList(GraphNode*);
private:
int addTensor(Tensor* ct);
int addNode(GraphNode* cn);
Tensor* findTensorByName(const std::string& name) const;
GraphNode* getNextNode();
GraphStatus graph_status;
// pointer to the parent subgraph traversal if exists
// e.g., Control Flow Ops will have nested blocks (subgraph traversals)
SubgraphTraverser* parent_sgt;
// pointer to serialization library and corresponding basic block
TosaSerializationBasicBlock* block;
TosaSerializationHandler* tsh;
// The definitive list of all tensors
std::vector<Tensor*> tensors;
// The subset of tensors that are also input tensors
std::vector<Tensor*> inputTensors;
// The subset of tensors that are also output tensors
std::vector<Tensor*> outputTensors;
// The definitive list of all nodes in the graph
std::vector<GraphNode*> nodes;
// The subset of node that have all of their input tensors ready, but
// have not yet been evaluated to produce their output tensors.
// With control flow, a node may appear on this list more than once during its
// lifetime, although the list itself should only contain unique nodes.
std::list<GraphNode*> nextNodeList;
// tensor name set which contains all the name used by operator
std::unordered_set<std::string> used_tensor_name_set;
// Maximum number of times to evalute a node before
// warning.
const int MAX_EVAL_COUNT = 10000;
};
}; // namespace TosaReference
#endif