blob: bcd8ce5d2a23f8006d1dbe2af00d336c8d7211f2 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
James Ward8b390432022-08-12 20:48:56 +01002// Copyright (c) 2020-2022, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "scatter_gather.h"
17#include "quant_util.h"
18
19using namespace TosaReference;
20using namespace Eigen;
21using namespace tosa;
22
Kevin Cheng77d0f762020-11-24 10:26:32 -080023template <DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -070024OpGather<Dtype>::OpGather(SubgraphTraverser* sgt_,
25 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -070026 uint64_t id_)
27 : GraphNode(sgt_, Op_GATHER, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070028{
29 setRequiredOperands(2, 1);
Eric Kunzee5e26762020-10-13 16:11:07 -070030}
31
Kevin Cheng77d0f762020-11-24 10:26:32 -080032template <DType Dtype>
33OpGather<Dtype>::~OpGather()
34{}
Eric Kunzee5e26762020-10-13 16:11:07 -070035
Kevin Cheng77d0f762020-11-24 10:26:32 -080036template <DType Dtype>
37int OpGather<Dtype>::checkTensorAttributes()
Eric Kunzee5e26762020-10-13 16:11:07 -070038{
39 if (validateRequiredOperands())
40 return 1;
41
Kevin Cheng77d0f762020-11-24 10:26:32 -080042 if (inputs[0]->getRank() != 3)
Eric Kunzee5e26762020-10-13 16:11:07 -070043 {
Kevin Cheng77d0f762020-11-24 10:26:32 -080044 printNodeValidationError("OpGather: values needs to be rank 3 tensor");
45 return 1;
46 }
47
48 if (inputs[1]->getRank() != 2)
49 {
50 printNodeValidationError("OpGather: indices needs to be rank 2 tensor");
51 return 1;
52 }
53
54 if (outputs[0]->getRank() != 3)
55 {
56 printNodeValidationError("OpGather: output needs to be rank 3 tensor");
57 return 1;
58 }
59
60 K = inputs[0]->getShape()[1];
61 N = outputs[0]->getShape()[0];
62 W = outputs[0]->getShape()[1];
63 C = outputs[0]->getShape()[2];
64
65 if (N != inputs[0]->getShape()[0] || N != inputs[1]->getShape()[0])
66 {
67 printNodeValidationError("OpGather: dimension N mismatch");
68 return 1;
69 }
70
71 if (W != inputs[1]->getShape()[1])
72 {
73 printNodeValidationError("OpGather: dimension W mismatch");
74 return 1;
75 }
76
77 if (C != inputs[0]->getShape()[2])
78 {
79 printNodeValidationError("OpGather: dimension C mismatch");
Eric Kunzee5e26762020-10-13 16:11:07 -070080 return 1;
81 }
82
83 // output and input must be the same types
84 if (inputs[0]->matchType(*outputs[0]))
85 {
86 printNodeValidationError("Failure to match input and output type");
87 return 1;
88 }
89
Kevin Cheng77d0f762020-11-24 10:26:32 -080090 values = dynamic_cast<TosaReference::TensorTemplate<TValue>*>(inputs[0]);
91 indices = dynamic_cast<TosaReference::TensorTemplate<TIndex>*>(inputs[1]);
92 output = dynamic_cast<TosaReference::TensorTemplate<TOutput>*>(outputs[0]);
Eric Kunzee5e26762020-10-13 16:11:07 -070093
Kevin Cheng77d0f762020-11-24 10:26:32 -080094 ASSERT_MEM(values && indices && output);
Eric Kunzee5e26762020-10-13 16:11:07 -070095
96 return 0;
97}
98
Kevin Cheng77d0f762020-11-24 10:26:32 -080099template <DType Dtype>
100int OpGather<Dtype>::eval()
Eric Kunzee5e26762020-10-13 16:11:07 -0700101{
Kevin Cheng77d0f762020-11-24 10:26:32 -0800102 for (int32_t n = 0; n < N; n++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700103 {
Kevin Cheng77d0f762020-11-24 10:26:32 -0800104 for (int32_t w = 0; w < W; w++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700105 {
Kevin Cheng77d0f762020-11-24 10:26:32 -0800106 int32_t k = this->indices->getTensor()(n, w);
Kevin Chengacb550f2021-06-29 15:32:19 -0700107 REQUIRE(k >= 0 && k < K, "OpGather: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K);
Kevin Cheng77d0f762020-11-24 10:26:32 -0800108 for (int32_t c = 0; c < C; c++)
109 {
110 EigenType value = this->values->getTensor()(n, k, c);
111 this->output->getTensor()(n, w, c) = value;
112 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700113 }
114 }
115
Kevin Cheng77d0f762020-11-24 10:26:32 -0800116 return GraphNode::eval();
117}
118
119template <DType Dtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700120OpScatter<Dtype>::OpScatter(SubgraphTraverser* sgt_,
121 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700122 uint64_t id_)
123 : GraphNode(sgt_, Op_SCATTER, id_)
Kevin Cheng77d0f762020-11-24 10:26:32 -0800124{
125 setRequiredOperands(3, 1);
126}
127
128template <DType Dtype>
129OpScatter<Dtype>::~OpScatter()
130{}
131
132template <DType Dtype>
133int OpScatter<Dtype>::checkTensorAttributes()
134{
135 if (validateRequiredOperands())
136 return 1;
137
138 if (inputs[0]->getRank() != 3)
Eric Kunzee5e26762020-10-13 16:11:07 -0700139 {
Kevin Cheng77d0f762020-11-24 10:26:32 -0800140 printNodeValidationError("OpGather: values_in needs to be rank 3 tensor");
141 return 1;
142 }
143
144 if (inputs[1]->getRank() != 2)
145 {
146 printNodeValidationError("OpGather: indices needs to be rank 2 tensor");
147 return 1;
148 }
149
150 if (inputs[2]->getRank() != 3)
151 {
152 printNodeValidationError("OpGather: input needs to be rank 3 tensor");
153 return 1;
154 }
155
156 if (outputs[0]->getRank() != 3)
157 {
158 printNodeValidationError("OpGather: values_out needs to be rank 3 tensor");
159 return 1;
160 }
161
162 W = inputs[2]->getShape()[1];
163 N = outputs[0]->getShape()[0];
164 K = outputs[0]->getShape()[1];
165 C = outputs[0]->getShape()[2];
166
167 if (N != inputs[0]->getShape()[0] || N != inputs[1]->getShape()[0] || N != inputs[2]->getShape()[0])
168 {
169 printNodeValidationError("OpScatter: dimension N mismatch");
170 return 1;
171 }
172
173 if (W != inputs[1]->getShape()[1])
174 {
175 printNodeValidationError("OpGather: dimension W mismatch");
176 return 1;
177 }
178
179 if (C != inputs[0]->getShape()[2] || C != inputs[2]->getShape()[2])
180 {
181 printNodeValidationError("OpGather: dimension C mismatch");
182 return 1;
183 }
184
185 // output and input must be the same types
186 if (inputs[0]->matchType(*outputs[0]))
187 {
188 printNodeValidationError("Failure to match input and output type");
189 return 1;
190 }
191
192 values_in = dynamic_cast<TosaReference::TensorTemplate<TValue>*>(inputs[0]);
193 indices = dynamic_cast<TosaReference::TensorTemplate<TIndex>*>(inputs[1]);
194 input = dynamic_cast<TosaReference::TensorTemplate<TValue>*>(inputs[2]);
195 values_out = dynamic_cast<TosaReference::TensorTemplate<TOutput>*>(outputs[0]);
196
197 ASSERT_MEM(values_in && indices && input && values_out);
198
199 return 0;
200}
201
202template <DType Dtype>
203int OpScatter<Dtype>::eval()
204{
205 // Initializes the output tensor with the input value for values that are unchanged by the scatter operation.
206 this->values_out->getTensor() = this->values_in->getTensor();
207
208 for (int n = 0; n < N; n++)
209 {
210 for (int w = 0; w < W; w++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700211 {
Kevin Cheng77d0f762020-11-24 10:26:32 -0800212 int32_t k = this->indices->getTensor()(n, w);
Kevin Chengacb550f2021-06-29 15:32:19 -0700213 REQUIRE(k >= 0 && k < K, "OpScatter: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K);
Kevin Cheng77d0f762020-11-24 10:26:32 -0800214 for (int c = 0; c < C; c++)
215 {
216 EigenType value = this->input->getTensor()(n, w, c);
217 this->values_out->getTensor()(n, k, c) = value;
218 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700219 }
220 }
221
222 return GraphNode::eval();
223}
224
225// template explicit instantiation
Kevin Cheng3a478572021-01-22 17:21:02 -0800226DEF_INSTANTIATE_ONE_TYPE(OpGather, INT8);
Kevin Cheng77d0f762020-11-24 10:26:32 -0800227DEF_INSTANTIATE_ONE_TYPE(OpGather, INT16);
228DEF_INSTANTIATE_ONE_TYPE(OpGather, INT32);
James Ward8b390432022-08-12 20:48:56 +0100229DEF_INSTANTIATE_ONE_TYPE(OpGather, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100230DEF_INSTANTIATE_ONE_TYPE(OpGather, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100231DEF_INSTANTIATE_ONE_TYPE(OpGather, FP32);
Kevin Cheng77d0f762020-11-24 10:26:32 -0800232
Kevin Cheng3a478572021-01-22 17:21:02 -0800233DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT8);
Kevin Cheng77d0f762020-11-24 10:26:32 -0800234DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT16);
235DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT32);
James Ward8b390432022-08-12 20:48:56 +0100236DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP16);
James Ward24dbc422022-10-19 12:20:31 +0100237DEF_INSTANTIATE_ONE_TYPE(OpScatter, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100238DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP32);