blob: 02ec54f1698076627c0893793d9e34921647901b [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Kevin Cheng3a478572021-01-22 17:21:02 -08002// Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "scatter_gather.h"
17#include "quant_util.h"
18
19using namespace TosaReference;
20using namespace Eigen;
21using namespace tosa;
22
Kevin Cheng77d0f762020-11-24 10:26:32 -080023template <DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -070024OpGather<Dtype>::OpGather(SubgraphTraverser* sgt_,
25 TosaAttributeBase* attribute_,
26 TosaQuantInfoBase* qinfo_,
27 uint64_t id_)
28 : GraphNode(sgt_, Op_GATHER, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070029{
30 setRequiredOperands(2, 1);
Eric Kunzee5e26762020-10-13 16:11:07 -070031}
32
Kevin Cheng77d0f762020-11-24 10:26:32 -080033template <DType Dtype>
34OpGather<Dtype>::~OpGather()
35{}
Eric Kunzee5e26762020-10-13 16:11:07 -070036
Kevin Cheng77d0f762020-11-24 10:26:32 -080037template <DType Dtype>
38int OpGather<Dtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -070039{
40 if (validateRequiredOperands())
41 return 1;
42
Kevin Cheng77d0f762020-11-24 10:26:32 -080043 if (inputs[0]->getRank() != 3)
Eric Kunzee5e26762020-10-13 16:11:07 -070044 {
Kevin Cheng77d0f762020-11-24 10:26:32 -080045 printNodeValidationError("OpGather: values needs to be rank 3 tensor");
46 return 1;
47 }
48
49 if (inputs[1]->getRank() != 2)
50 {
51 printNodeValidationError("OpGather: indices needs to be rank 2 tensor");
52 return 1;
53 }
54
55 if (outputs[0]->getRank() != 3)
56 {
57 printNodeValidationError("OpGather: output needs to be rank 3 tensor");
58 return 1;
59 }
60
61 K = inputs[0]->getShape()[1];
62 N = outputs[0]->getShape()[0];
63 W = outputs[0]->getShape()[1];
64 C = outputs[0]->getShape()[2];
65
66 if (N != inputs[0]->getShape()[0] || N != inputs[1]->getShape()[0])
67 {
68 printNodeValidationError("OpGather: dimension N mismatch");
69 return 1;
70 }
71
72 if (W != inputs[1]->getShape()[1])
73 {
74 printNodeValidationError("OpGather: dimension W mismatch");
75 return 1;
76 }
77
78 if (C != inputs[0]->getShape()[2])
79 {
80 printNodeValidationError("OpGather: dimension C mismatch");
Eric Kunzee5e26762020-10-13 16:11:07 -070081 return 1;
82 }
83
84 // output and input must be the same types
85 if (inputs[0]->matchType(*outputs[0]))
86 {
87 printNodeValidationError("Failure to match input and output type");
88 return 1;
89 }
90
Kevin Cheng77d0f762020-11-24 10:26:32 -080091 values = dynamic_cast<TosaReference::TensorTemplate<TValue>*>(inputs[0]);
92 indices = dynamic_cast<TosaReference::TensorTemplate<TIndex>*>(inputs[1]);
93 output = dynamic_cast<TosaReference::TensorTemplate<TOutput>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -070094
Kevin Cheng77d0f762020-11-24 10:26:32 -080095 ASSERT_MEM(values && indices && output);
Eric Kunzee5e26762020-10-13 16:11:07 -070096
97 return 0;
98}
99
Kevin Cheng77d0f762020-11-24 10:26:32 -0800100template <DType Dtype>
101int OpGather<Dtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700102{
Kevin Cheng77d0f762020-11-24 10:26:32 -0800103 for (int32_t n = 0; n < N; n++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700104 {
Kevin Cheng77d0f762020-11-24 10:26:32 -0800105 for (int32_t w = 0; w < W; w++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700106 {
Kevin Cheng77d0f762020-11-24 10:26:32 -0800107 int32_t k = this->indices->getTensor()(n, w);
Kevin Chengacb550f2021-06-29 15:32:19 -0700108 REQUIRE(k >= 0 && k < K, "OpGather: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K);
Kevin Cheng77d0f762020-11-24 10:26:32 -0800109 for (int32_t c = 0; c < C; c++)
110 {
111 EigenType value = this->values->getTensor()(n, k, c);
112 this->output->getTensor()(n, w, c) = value;
113 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700114 }
115 }
116
Kevin Cheng77d0f762020-11-24 10:26:32 -0800117 return GraphNode::eval();
118}
119
120template <DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700121OpScatter<Dtype>::OpScatter(SubgraphTraverser* sgt_,
122 TosaAttributeBase* attribute_,
123 TosaQuantInfoBase* qinfo_,
124 uint64_t id_)
125 : GraphNode(sgt_, Op_SCATTER, id_)
Kevin Cheng77d0f762020-11-24 10:26:32 -0800126{
127 setRequiredOperands(3, 1);
128}
129
130template <DType Dtype>
131OpScatter<Dtype>::~OpScatter()
132{}
133
134template <DType Dtype>
135int OpScatter<Dtype>::checkTensorAttributes()
136{
137 if (validateRequiredOperands())
138 return 1;
139
140 if (inputs[0]->getRank() != 3)
Eric Kunzee5e26762020-10-13 16:11:07 -0700141 {
Kevin Cheng77d0f762020-11-24 10:26:32 -0800142 printNodeValidationError("OpGather: values_in needs to be rank 3 tensor");
143 return 1;
144 }
145
146 if (inputs[1]->getRank() != 2)
147 {
148 printNodeValidationError("OpGather: indices needs to be rank 2 tensor");
149 return 1;
150 }
151
152 if (inputs[2]->getRank() != 3)
153 {
154 printNodeValidationError("OpGather: input needs to be rank 3 tensor");
155 return 1;
156 }
157
158 if (outputs[0]->getRank() != 3)
159 {
160 printNodeValidationError("OpGather: values_out needs to be rank 3 tensor");
161 return 1;
162 }
163
164 W = inputs[2]->getShape()[1];
165 N = outputs[0]->getShape()[0];
166 K = outputs[0]->getShape()[1];
167 C = outputs[0]->getShape()[2];
168
169 if (N != inputs[0]->getShape()[0] || N != inputs[1]->getShape()[0] || N != inputs[2]->getShape()[0])
170 {
171 printNodeValidationError("OpScatter: dimension N mismatch");
172 return 1;
173 }
174
175 if (W != inputs[1]->getShape()[1])
176 {
177 printNodeValidationError("OpGather: dimension W mismatch");
178 return 1;
179 }
180
181 if (C != inputs[0]->getShape()[2] || C != inputs[2]->getShape()[2])
182 {
183 printNodeValidationError("OpGather: dimension C mismatch");
184 return 1;
185 }
186
187 // output and input must be the same types
188 if (inputs[0]->matchType(*outputs[0]))
189 {
190 printNodeValidationError("Failure to match input and output type");
191 return 1;
192 }
193
194 values_in = dynamic_cast<TosaReference::TensorTemplate<TValue>*>(inputs[0]);
195 indices = dynamic_cast<TosaReference::TensorTemplate<TIndex>*>(inputs[1]);
196 input = dynamic_cast<TosaReference::TensorTemplate<TValue>*>(inputs[2]);
197 values_out = dynamic_cast<TosaReference::TensorTemplate<TOutput>*>(outputs[0]);
198
199 ASSERT_MEM(values_in && indices && input && values_out);
200
201 return 0;
202}
203
204template <DType Dtype>
205int OpScatter<Dtype>::eval()
206{
207 // Initializes the output tensor with the input value for values that are unchanged by the scatter operation.
208 this->values_out->getTensor() = this->values_in->getTensor();
209
210 for (int n = 0; n < N; n++)
211 {
212 for (int w = 0; w < W; w++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700213 {
Kevin Cheng77d0f762020-11-24 10:26:32 -0800214 int32_t k = this->indices->getTensor()(n, w);
Kevin Chengacb550f2021-06-29 15:32:19 -0700215 REQUIRE(k >= 0 && k < K, "OpScatter: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K);
Kevin Cheng77d0f762020-11-24 10:26:32 -0800216 for (int c = 0; c < C; c++)
217 {
218 EigenType value = this->input->getTensor()(n, w, c);
219 this->values_out->getTensor()(n, k, c) = value;
220 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700221 }
222 }
223
224 return GraphNode::eval();
225}
226
227// template explicit instantiation
Kevin Cheng3a478572021-01-22 17:21:02 -0800228DEF_INSTANTIATE_ONE_TYPE(OpGather, INT8);
Kevin Cheng77d0f762020-11-24 10:26:32 -0800229DEF_INSTANTIATE_ONE_TYPE(OpGather, INT16);
230DEF_INSTANTIATE_ONE_TYPE(OpGather, INT32);
231DEF_INSTANTIATE_ONE_TYPE(OpGather, FLOAT);
232
Kevin Cheng3a478572021-01-22 17:21:02 -0800233DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT8);
Kevin Cheng77d0f762020-11-24 10:26:32 -0800234DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT16);
235DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT32);
236DEF_INSTANTIATE_ONE_TYPE(OpScatter, FLOAT);