blob: b087dd8c00059c300fba3f226120e4b1337d27d9 [file] [log] [blame]
Tai Ly8690a082023-12-18 20:40:24 +00001// Copyright (c) 2023-2024, ARM Limited.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "shape.h"
16
17using namespace TosaReference;
18using namespace Eigen;
19using namespace tosa;
20
21OpConstShape::OpConstShape(SubgraphTraverser* sgt_, uint64_t id_)
22 : GraphNode(sgt_, Op_CONST, id_)
23{
24 setRequiredOperands(0, 1);
25}
26
27OpConstShape::~OpConstShape()
28{}
29
30int OpConstShape::checkTensorAttributes()
31{
32 if (validateRequiredOperands())
33 return 1;
34
35 return 0;
36}
37
38int OpConstShape::eval()
39{
40 for (auto ct : getOutputs())
41 {
42 if (!ct->getIsValid())
43 {
44 std::string err = "Constant Shape tensor " + ct->getName() + " not correctly initialized";
45 printNodeValidationError(err.c_str());
46 return 1;
47 }
48 }
49
50 // Evaluation is trivial for constants
51 return GraphNode::eval();
52}
53
54OpConcatShape::OpConcatShape(SubgraphTraverser* sgt_, uint64_t id_)
55 : GraphNode(sgt_, Op_CONCAT_SHAPE, id_)
56{
57 setRequiredOperands(-1, 1);
58 setRequiredRank(1, 1);
59}
60
61OpConcatShape::~OpConcatShape()
62{}
63
64int OpConcatShape::checkTensorAttributes()
65{
66 if (validateRequiredOperands())
67 return 1;
68
69 if (inputs.empty())
70 {
71 printNodeValidationError("ConcatShape operator must have at least one input tensor");
72 return 1;
73 }
74
75 int32_t num_inputs = inputs.size();
76 int32_t elements_count = 0;
77 for (int32_t i = 0; i < num_inputs; i++)
78 {
79 if (validateRequiredRank(inputs[i]))
80 return 1;
81 ins.push_back(dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[i]));
82 elements_count += inputs[i]->getShape()[0];
83 }
84
85 ERROR_IF(elements_count != outputs[0]->getShape()[0],
86 "OpConcatShape: sum of input elements not equal to output number of elements");
87
88 num_dims = outputs[0]->getShape()[0];
89 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
90
91 return 0;
92}
93
94int OpConcatShape::eval()
95{
96 ETensor1<EigenType> out_tensor(num_dims);
97 int32_t out_idx = 0;
98 for (size_t i = 0; i < ins.size(); i++)
99 {
100 // all tosa.shape values are 1-d tensors
101 // interate in_idx in range of [0, rank of 1-d tensor]
102 for (int32_t in_idx = 0; in_idx < inputs[i]->getShape()[0]; in_idx++)
103 {
104 out_tensor(out_idx) = ins[i]->getTensor()(in_idx);
105 out_idx++;
106 }
107 }
108 out->getTensor() = out_tensor;
109 return GraphNode::eval();
110}
111
112ShapeBinaryNodeBase::ShapeBinaryNodeBase(SubgraphTraverser* sgt_, const Op& op_, uint64_t id_)
113 : GraphNode(sgt_, op_, id_)
114{
115 setRequiredOperands(2, 1);
116 setRequiredRank(1, 1);
117
118 fcn = [](EigenType a, EigenType b) -> EigenType { return EigenType(); };
119}
120
121ShapeBinaryNodeBase::~ShapeBinaryNodeBase()
122{}
123
124int ShapeBinaryNodeBase::checkTensorAttributes()
125{
126 if (validateRequiredOperands())
127 return 1;
128 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
129 return 1;
130
131 num_dims = outputs[0]->getShape()[0];
132
133 if (inputs[0]->getShape()[0] != num_dims)
134 {
135 std::string err = "Binary shape operators " + std::string(EnumNamesOp()[nodeType]) +
136 " lhs input and output rank/shape must match";
137 printNodeValidationError(err.c_str());
138 return 1;
139 }
140
141 if (inputs[1]->getShape()[0] != num_dims)
142 {
143 std::string err = "Binary shape operators " + std::string(EnumNamesOp()[nodeType]) +
144 " rhs input and output rank/shape must match";
145 printNodeValidationError(err.c_str());
146 return 1;
147 }
148
149 a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
150 b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
151 result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
152
153 ASSERT_MEM(a && b && result);
154
155 return 0;
156}
157
158int ShapeBinaryNodeBase::eval()
159{
160 auto ia = a->getTensor();
161 auto ib = b->getTensor();
162 ETensor1<EigenType> out_tens(num_dims);
163 for (int32_t i = 0; i < num_dims; i++)
164 {
165 EigenType lhs = ia(i);
166 EigenType rhs = ib(i);
167 out_tens(i) = (lhs < 0 || rhs < 0) ? static_cast<EigenType>(-1) : fcn(lhs, rhs);
168 }
169
170 result->getTensor() = out_tens;
171 return GraphNode::eval();
172}
173
174int OpAddShape::register_fcn()
175{
176 fcn = [](EigenType a, EigenType b) -> EigenType { return a + b; };
177 return 0;
178}
179
180int OpSubShape::register_fcn()
181{
182 fcn = [](EigenType a, EigenType b) -> EigenType { return a - b; };
183 return 0;
184}
185
186int OpMulShape::register_fcn()
187{
188 fcn = [](EigenType a, EigenType b) -> EigenType { return a * b; };
189 return 0;
190}
191
192int OpDivShape::register_fcn()
193{
194 fcn = [](EigenType a, EigenType b) -> EigenType {
195 return (b == static_cast<EigenType>(0)) ? static_cast<EigenType>(-1) : (a / b);
196 };
197 return 0;
198}