blob: 478b7765ede361f7599c4ce27e3f9b1071c3e763 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Kevin Cheng3a478572021-01-22 17:21:02 -08002// Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "scatter_gather.h"
17#include "quant_util.h"
18
19using namespace TosaReference;
20using namespace Eigen;
21using namespace tosa;
22
Kevin Cheng77d0f762020-11-24 10:26:32 -080023template <DType Dtype>
24OpGather<Dtype>::OpGather(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070025 : GraphNode(Op_GATHER, id_)
26{
27 setRequiredOperands(2, 1);
Eric Kunzee5e26762020-10-13 16:11:07 -070028}
29
Kevin Cheng77d0f762020-11-24 10:26:32 -080030template <DType Dtype>
31OpGather<Dtype>::~OpGather()
32{}
Eric Kunzee5e26762020-10-13 16:11:07 -070033
Kevin Cheng77d0f762020-11-24 10:26:32 -080034template <DType Dtype>
35int OpGather<Dtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -070036{
37 if (validateRequiredOperands())
38 return 1;
39
Kevin Cheng77d0f762020-11-24 10:26:32 -080040 if (inputs[0]->getRank() != 3)
Eric Kunzee5e26762020-10-13 16:11:07 -070041 {
Kevin Cheng77d0f762020-11-24 10:26:32 -080042 printNodeValidationError("OpGather: values needs to be rank 3 tensor");
43 return 1;
44 }
45
46 if (inputs[1]->getRank() != 2)
47 {
48 printNodeValidationError("OpGather: indices needs to be rank 2 tensor");
49 return 1;
50 }
51
52 if (outputs[0]->getRank() != 3)
53 {
54 printNodeValidationError("OpGather: output needs to be rank 3 tensor");
55 return 1;
56 }
57
58 K = inputs[0]->getShape()[1];
59 N = outputs[0]->getShape()[0];
60 W = outputs[0]->getShape()[1];
61 C = outputs[0]->getShape()[2];
62
63 if (N != inputs[0]->getShape()[0] || N != inputs[1]->getShape()[0])
64 {
65 printNodeValidationError("OpGather: dimension N mismatch");
66 return 1;
67 }
68
69 if (W != inputs[1]->getShape()[1])
70 {
71 printNodeValidationError("OpGather: dimension W mismatch");
72 return 1;
73 }
74
75 if (C != inputs[0]->getShape()[2])
76 {
77 printNodeValidationError("OpGather: dimension C mismatch");
Eric Kunzee5e26762020-10-13 16:11:07 -070078 return 1;
79 }
80
81 // output and input must be the same types
82 if (inputs[0]->matchType(*outputs[0]))
83 {
84 printNodeValidationError("Failure to match input and output type");
85 return 1;
86 }
87
Kevin Cheng77d0f762020-11-24 10:26:32 -080088 values = dynamic_cast<TosaReference::TensorTemplate<TValue>*>(inputs[0]);
89 indices = dynamic_cast<TosaReference::TensorTemplate<TIndex>*>(inputs[1]);
90 output = dynamic_cast<TosaReference::TensorTemplate<TOutput>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -070091
Kevin Cheng77d0f762020-11-24 10:26:32 -080092 ASSERT_MEM(values && indices && output);
Eric Kunzee5e26762020-10-13 16:11:07 -070093
94 return 0;
95}
96
Kevin Cheng77d0f762020-11-24 10:26:32 -080097template <DType Dtype>
98int OpGather<Dtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -070099{
Kevin Cheng77d0f762020-11-24 10:26:32 -0800100 for (int32_t n = 0; n < N; n++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700101 {
Kevin Cheng77d0f762020-11-24 10:26:32 -0800102 for (int32_t w = 0; w < W; w++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700103 {
Kevin Cheng77d0f762020-11-24 10:26:32 -0800104 int32_t k = this->indices->getTensor()(n, w);
105 ASSERT_MSG_NODE(k >= 0 && k < K, "OpGather: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K);
106 for (int32_t c = 0; c < C; c++)
107 {
108 EigenType value = this->values->getTensor()(n, k, c);
109 this->output->getTensor()(n, w, c) = value;
110 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700111 }
112 }
113
Kevin Cheng77d0f762020-11-24 10:26:32 -0800114 return GraphNode::eval();
115}
116
117template <DType Dtype>
118OpScatter<Dtype>::OpScatter(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
119 : GraphNode(Op_SCATTER, id_)
120{
121 setRequiredOperands(3, 1);
122}
123
124template <DType Dtype>
125OpScatter<Dtype>::~OpScatter()
126{}
127
128template <DType Dtype>
129int OpScatter<Dtype>::checkTensorAttributes()
130{
131 if (validateRequiredOperands())
132 return 1;
133
134 if (inputs[0]->getRank() != 3)
Eric Kunzee5e26762020-10-13 16:11:07 -0700135 {
Kevin Cheng77d0f762020-11-24 10:26:32 -0800136 printNodeValidationError("OpGather: values_in needs to be rank 3 tensor");
137 return 1;
138 }
139
140 if (inputs[1]->getRank() != 2)
141 {
142 printNodeValidationError("OpGather: indices needs to be rank 2 tensor");
143 return 1;
144 }
145
146 if (inputs[2]->getRank() != 3)
147 {
148 printNodeValidationError("OpGather: input needs to be rank 3 tensor");
149 return 1;
150 }
151
152 if (outputs[0]->getRank() != 3)
153 {
154 printNodeValidationError("OpGather: values_out needs to be rank 3 tensor");
155 return 1;
156 }
157
158 W = inputs[2]->getShape()[1];
159 N = outputs[0]->getShape()[0];
160 K = outputs[0]->getShape()[1];
161 C = outputs[0]->getShape()[2];
162
163 if (N != inputs[0]->getShape()[0] || N != inputs[1]->getShape()[0] || N != inputs[2]->getShape()[0])
164 {
165 printNodeValidationError("OpScatter: dimension N mismatch");
166 return 1;
167 }
168
169 if (W != inputs[1]->getShape()[1])
170 {
171 printNodeValidationError("OpGather: dimension W mismatch");
172 return 1;
173 }
174
175 if (C != inputs[0]->getShape()[2] || C != inputs[2]->getShape()[2])
176 {
177 printNodeValidationError("OpGather: dimension C mismatch");
178 return 1;
179 }
180
181 // output and input must be the same types
182 if (inputs[0]->matchType(*outputs[0]))
183 {
184 printNodeValidationError("Failure to match input and output type");
185 return 1;
186 }
187
188 values_in = dynamic_cast<TosaReference::TensorTemplate<TValue>*>(inputs[0]);
189 indices = dynamic_cast<TosaReference::TensorTemplate<TIndex>*>(inputs[1]);
190 input = dynamic_cast<TosaReference::TensorTemplate<TValue>*>(inputs[2]);
191 values_out = dynamic_cast<TosaReference::TensorTemplate<TOutput>*>(outputs[0]);
192
193 ASSERT_MEM(values_in && indices && input && values_out);
194
195 return 0;
196}
197
198template <DType Dtype>
199int OpScatter<Dtype>::eval()
200{
201 // Initializes the output tensor with the input value for values that are unchanged by the scatter operation.
202 this->values_out->getTensor() = this->values_in->getTensor();
203
204 for (int n = 0; n < N; n++)
205 {
206 for (int w = 0; w < W; w++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700207 {
Kevin Cheng77d0f762020-11-24 10:26:32 -0800208 int32_t k = this->indices->getTensor()(n, w);
209 ASSERT_MSG_NODE(k >= 0 && k < K, "OpScatter: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K);
210 for (int c = 0; c < C; c++)
211 {
212 EigenType value = this->input->getTensor()(n, w, c);
213 this->values_out->getTensor()(n, k, c) = value;
214 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700215 }
216 }
217
218 return GraphNode::eval();
219}
220
221// template explicit instantiation
Kevin Cheng3a478572021-01-22 17:21:02 -0800222DEF_INSTANTIATE_ONE_TYPE(OpGather, INT8);
Kevin Cheng77d0f762020-11-24 10:26:32 -0800223DEF_INSTANTIATE_ONE_TYPE(OpGather, INT16);
224DEF_INSTANTIATE_ONE_TYPE(OpGather, INT32);
225DEF_INSTANTIATE_ONE_TYPE(OpGather, FLOAT);
226
Kevin Cheng3a478572021-01-22 17:21:02 -0800227DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT8);
Kevin Cheng77d0f762020-11-24 10:26:32 -0800228DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT16);
229DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT32);
230DEF_INSTANTIATE_ONE_TYPE(OpScatter, FLOAT);