blob: 425dfc2660e9f640331d5c93b3239c3dc305ffee [file] [log] [blame]
Tai Ly8690a082023-12-18 20:40:24 +00001// Copyright (c) 2023-2024, ARM Limited.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "shape.h"
16
17using namespace TosaReference;
18using namespace Eigen;
19using namespace tosa;
20
21OpConstShape::OpConstShape(SubgraphTraverser* sgt_, uint64_t id_)
22 : GraphNode(sgt_, Op_CONST, id_)
23{
24 setRequiredOperands(0, 1);
25}
26
27OpConstShape::~OpConstShape()
28{}
29
30int OpConstShape::checkTensorAttributes()
31{
32 if (validateRequiredOperands())
33 return 1;
34
35 return 0;
36}
37
38int OpConstShape::eval()
39{
Jerry Ge12159fc2024-04-01 17:05:10 +000040 // set the shapeValue given the actual tensor value
41 using EigenType = typename GetEigenType<TOSA_REF_TYPE_SHAPE>::type;
42 auto out = dynamic_cast<TosaReference::TensorTemplate<Eigen::Tensor<EigenType, 1>>*>(this->getOutputs()[0]);
43
44 std::vector<int> shapeValue;
45 for (int i = 0; out != nullptr && i < out->getTensor().size(); ++i)
46 {
47 shapeValue.push_back(out->getTensor()(i));
48 }
49
50 this->getOutputs()[0]->setShapeValue(shapeValue);
51
Tai Ly8690a082023-12-18 20:40:24 +000052 for (auto ct : getOutputs())
53 {
54 if (!ct->getIsValid())
55 {
56 std::string err = "Constant Shape tensor " + ct->getName() + " not correctly initialized";
57 printNodeValidationError(err.c_str());
58 return 1;
59 }
60 }
61
62 // Evaluation is trivial for constants
63 return GraphNode::eval();
64}
65
66OpConcatShape::OpConcatShape(SubgraphTraverser* sgt_, uint64_t id_)
67 : GraphNode(sgt_, Op_CONCAT_SHAPE, id_)
68{
69 setRequiredOperands(-1, 1);
70 setRequiredRank(1, 1);
71}
72
73OpConcatShape::~OpConcatShape()
74{}
75
76int OpConcatShape::checkTensorAttributes()
77{
78 if (validateRequiredOperands())
79 return 1;
80
81 if (inputs.empty())
82 {
83 printNodeValidationError("ConcatShape operator must have at least one input tensor");
84 return 1;
85 }
86
87 int32_t num_inputs = inputs.size();
88 int32_t elements_count = 0;
89 for (int32_t i = 0; i < num_inputs; i++)
90 {
91 if (validateRequiredRank(inputs[i]))
92 return 1;
93 ins.push_back(dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[i]));
94 elements_count += inputs[i]->getShape()[0];
95 }
96
97 ERROR_IF(elements_count != outputs[0]->getShape()[0],
98 "OpConcatShape: sum of input elements not equal to output number of elements");
99
100 num_dims = outputs[0]->getShape()[0];
101 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
102
103 return 0;
104}
105
106int OpConcatShape::eval()
107{
108 ETensor1<EigenType> out_tensor(num_dims);
109 int32_t out_idx = 0;
110 for (size_t i = 0; i < ins.size(); i++)
111 {
112 // all tosa.shape values are 1-d tensors
113 // interate in_idx in range of [0, rank of 1-d tensor]
114 for (int32_t in_idx = 0; in_idx < inputs[i]->getShape()[0]; in_idx++)
115 {
116 out_tensor(out_idx) = ins[i]->getTensor()(in_idx);
117 out_idx++;
118 }
119 }
120 out->getTensor() = out_tensor;
Jerry Ge12159fc2024-04-01 17:05:10 +0000121
122 // set the shapeValue given the actual tensor value
123 std::vector<int> shapeValue;
124 for (int i = 0; i < out->getTensor().size(); ++i)
125 {
126 shapeValue.push_back(out->getTensor()(i));
127 }
128 this->getOutputs()[0]->setShapeValue(shapeValue);
129
Tai Ly8690a082023-12-18 20:40:24 +0000130 return GraphNode::eval();
131}
132
133ShapeBinaryNodeBase::ShapeBinaryNodeBase(SubgraphTraverser* sgt_, const Op& op_, uint64_t id_)
134 : GraphNode(sgt_, op_, id_)
135{
136 setRequiredOperands(2, 1);
137 setRequiredRank(1, 1);
138
139 fcn = [](EigenType a, EigenType b) -> EigenType { return EigenType(); };
140}
141
142ShapeBinaryNodeBase::~ShapeBinaryNodeBase()
143{}
144
145int ShapeBinaryNodeBase::checkTensorAttributes()
146{
147 if (validateRequiredOperands())
148 return 1;
149 if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
150 return 1;
151
152 num_dims = outputs[0]->getShape()[0];
153
154 if (inputs[0]->getShape()[0] != num_dims)
155 {
156 std::string err = "Binary shape operators " + std::string(EnumNamesOp()[nodeType]) +
157 " lhs input and output rank/shape must match";
158 printNodeValidationError(err.c_str());
159 return 1;
160 }
161
162 if (inputs[1]->getShape()[0] != num_dims)
163 {
164 std::string err = "Binary shape operators " + std::string(EnumNamesOp()[nodeType]) +
165 " rhs input and output rank/shape must match";
166 printNodeValidationError(err.c_str());
167 return 1;
168 }
169
170 a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
171 b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
172 result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
173
174 ASSERT_MEM(a && b && result);
175
176 return 0;
177}
178
179int ShapeBinaryNodeBase::eval()
180{
181 auto ia = a->getTensor();
182 auto ib = b->getTensor();
183 ETensor1<EigenType> out_tens(num_dims);
184 for (int32_t i = 0; i < num_dims; i++)
185 {
186 EigenType lhs = ia(i);
187 EigenType rhs = ib(i);
188 out_tens(i) = (lhs < 0 || rhs < 0) ? static_cast<EigenType>(-1) : fcn(lhs, rhs);
189 }
190
191 result->getTensor() = out_tens;
Jerry Ge12159fc2024-04-01 17:05:10 +0000192
193 // set the shapeValue given the actual tensor value
194 std::vector<int> shapeValue;
195 for (int i = 0; i < result->getTensor().size(); ++i)
196 {
197 shapeValue.push_back(result->getTensor()(i));
198 }
199 this->getOutputs()[0]->setShapeValue(shapeValue);
200
Tai Ly8690a082023-12-18 20:40:24 +0000201 return GraphNode::eval();
202}
203
204int OpAddShape::register_fcn()
205{
206 fcn = [](EigenType a, EigenType b) -> EigenType { return a + b; };
207 return 0;
208}
209
210int OpSubShape::register_fcn()
211{
212 fcn = [](EigenType a, EigenType b) -> EigenType { return a - b; };
213 return 0;
214}
215
216int OpMulShape::register_fcn()
217{
218 fcn = [](EigenType a, EigenType b) -> EigenType { return a * b; };
219 return 0;
220}
221
222int OpDivShape::register_fcn()
223{
224 fcn = [](EigenType a, EigenType b) -> EigenType {
225 return (b == static_cast<EigenType>(0)) ? static_cast<EigenType>(-1) : (a / b);
226 };
227 return 0;
228}