blob: 2ee4935be2d2ea3ec391ed2983e4ab338f282234 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
2// Copyright (c) 2020, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "data_nodes.h"
17
18using namespace TosaReference;
19using namespace Eigen;
20using namespace tosa;
21
22OpConst::OpConst(uint64_t id_)
23 : GraphNode(Op_CONST, id_)
24{
25 setRequiredOperands(0, 1);
26}
27
28OpConst::~OpConst()
29{}
30
31int OpConst::checkTensorAttributes()
32{
33 if (validateRequiredOperands())
34 return 1;
35
36 return 0;
37}
38
39int OpConst::eval()
40{
41 // Evaluation is trivial for constants
42 return GraphNode::eval();
43}
44
45OpPlaceholder::OpPlaceholder(uint64_t id_)
46 : GraphNode(Op_PLACEHOLDER, id_)
47{
48 setRequiredOperands(0, 1);
49}
50
51OpPlaceholder::~OpPlaceholder()
52{}
53
54int OpPlaceholder::checkTensorAttributes()
55{
56 if (validateRequiredOperands())
57 return 1;
58
59 return 0;
60}
61
62int OpPlaceholder::eval()
63{
64 // Evaluation is trivial for placeholders
65 return GraphNode::eval();
66}
67
68template <int Rank, DType Dtype>
69OpIdentity<Rank, Dtype>::OpIdentity(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
70 : GraphNode(Op_IDENTITY, id_)
71{
72 setRequiredOperands(1, 1);
73 setRequiredRank(0, 6);
74}
75
76template <int Rank, DType Dtype>
77OpIdentity<Rank, Dtype>::~OpIdentity()
78{}
79
80template <int Rank, DType Dtype>
81int OpIdentity<Rank, Dtype>::checkTensorAttributes()
82{
83
84 if (inputs.size() != outputs.size())
85 {
86 printNodeValidationError("Input and output tensor list lengths are not equal");
87 return 1;
88 }
89
90 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
91 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
92
93 if (in->matchRankTypeShape(*out))
94 {
95 printNodeValidationError("Input and output tensor rank, type, or shape do not match");
96 return 1;
97 }
98
99 return 0;
100}
101
102template <int Rank, DType Dtype>
103int OpIdentity<Rank, Dtype>::eval()
104{
105 out->getTensor() = in->getTensor();
106
107 return GraphNode::eval();
108}
109
110template <int Rank, DType Dtype>
111OpIdentityN<Rank, Dtype>::OpIdentityN(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
112 : GraphNode(Op_IDENTITYN, id_)
113{
114 setRequiredRank(0, 6);
115}
116
117template <int Rank, DType Dtype>
118OpIdentityN<Rank, Dtype>::~OpIdentityN()
119{}
120
121template <int Rank, DType Dtype>
122int OpIdentityN<Rank, Dtype>::checkTensorAttributes()
123{
124
125 if (inputs.size() != outputs.size())
126 {
127 printNodeValidationError("Input and output tensor list lengths are not equal");
128 return 1;
129 }
130
131 for (size_t i = 0; i < inputs.size(); i++)
132 {
133 ins.push_back(dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[i]));
134 outs.push_back(dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[i]));
135
136 if (ins[i]->matchRankTypeShape(*outs[i]))
137 {
138 printNodeValidationError("Input and output tensor rank, type, or shape do not match");
139 return 1;
140 }
141 }
142
143 return 0;
144}
145
146template <int Rank, DType Dtype>
147int OpIdentityN<Rank, Dtype>::eval()
148{
149 for (size_t i = 0; i < ins.size(); i++)
150 {
151 outs[i]->getTensor() = ins[i]->getTensor();
152 }
153
154 return GraphNode::eval();
155}
156
157// template explicit instantiation
158// note OpConst and OpPlaceholder are not templated
159
160DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FLOAT);
161DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, AINT8);
162DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8);
163DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16);
164DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32);
165DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL);
166
167DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, FLOAT);
168DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, AINT8);
169DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT8);
170DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT16);
171DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT32);
172DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, BOOL);