blob: 879cd6ab59305e21dd60faf3fbcdd30ba1888b1c [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
2// Copyright (c) 2020, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef OPS_CONTROL_FLOW_H
17#define OPS_CONTROL_FLOW_H
18
19#include "graph_node.h"
20
21#define MAX_WHILE_LOOP_ITERATION 10000
22
23namespace TosaReference
24{
25class OpControlFlow : public GraphNode
26{
27public:
Kevin Chengacb550f2021-06-29 15:32:19 -070028 OpControlFlow(SubgraphTraverser* sgt_, TosaSerializationHandler* tsh_, Op op_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070029 ~OpControlFlow();
30
31 virtual int evalBlock(TosaSerializationBasicBlock* block,
32 std::vector<TosaReference::Tensor*>& block_inputs,
33 std::vector<TosaReference::Tensor*>& block_outputs);
34
35protected:
36 TosaSerializationHandler* tsh;
37};
38
39class OpCondIf : public OpControlFlow
40{
41public:
Kevin Chengacb550f2021-06-29 15:32:19 -070042 OpCondIf(SubgraphTraverser* sgt_, TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070043 virtual ~OpCondIf();
44
45 virtual int checkTensorAttributes();
46 virtual int eval();
47
48protected:
49 TosaCondIfAttribute* attribute;
50 TosaReference::Tensor0<bool>* cond;
51 TosaSerializationBasicBlock* then_block;
52 TosaSerializationBasicBlock* else_block;
53};
54
55class OpWhileLoop : public OpControlFlow
56{
57public:
Kevin Chengacb550f2021-06-29 15:32:19 -070058 OpWhileLoop(SubgraphTraverser* sgt_, TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -070059 virtual ~OpWhileLoop();
60
61 virtual int checkTensorAttributes();
62 virtual int eval();
63
64protected:
65 TosaWhileLoopAttribute* attribute;
66 TosaSerializationBasicBlock* cond_block;
67 TosaSerializationBasicBlock* body_block;
68};
69
70}; // namespace TosaReference
71
72#endif