blob: e46ab3897f101c20b7993e0f3027120924ff01b5 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Kevin Cheng3a478572021-01-22 17:21:02 -08002// Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "type_conversion.h"
17#include "quant_util.h"
18#include "template_types.h"
19#include <cmath>
20
21using namespace TosaReference;
22using namespace Eigen;
23using namespace tosa;
24
25template <int Rank, DType InDtype, DType OutDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -070026OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_,
27 TosaAttributeBase* attribute_,
28 TosaQuantInfoBase* qinfo_,
29 uint64_t id_)
30 : GraphNode(sgt_, Op_RESCALE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070031{
32 setRequiredOperands(1, 1);
Kevin Chengcc61be32021-10-14 17:09:57 -070033 setRequiredRank(0, 4);
Eric Kunzee5e26762020-10-13 16:11:07 -070034 INIT_ATTRIBUTE(Rescale);
35}
36
37template <int Rank, DType InDtype, DType OutDtype>
38OpRescale<Rank, InDtype, OutDtype>::~OpRescale()
39{
40 if (attribute)
41 delete attribute;
42}
43
44template <int Rank, DType InDtype, DType OutDtype>
45int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes()
46{
47 if (validateRequiredOperands())
48 return 1;
49
50 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
51 {
52 return 1;
53 }
54
55 // output and input must be the same rank and size
56 if (inputs[0]->matchRankSize(*outputs[0]))
57 {
58 printNodeValidationError("OpRescale: input and output rank/size must match");
59 return 1;
60 }
61
62 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
63 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
64
65 ASSERT_MEM(in && out);
66
Kevin Chengcc61be32021-10-14 17:09:57 -070067 if ((InDtype != DType_INT8) && (InDtype != DType_UINT8) && (attribute->input_zp() != 0))
68 {
69 printNodeValidationError("OpRescale: Input DType not INT8/UINT8 and zero point not 0");
70 return 1;
71 }
72
73 if ((OutDtype != DType_INT8) && (OutDtype != DType_UINT8) && (attribute->output_zp() != 0))
74 {
75 printNodeValidationError("OpRescale: Output DType not INT8/UINT8 and zero point not 0");
76 return 1;
77 }
78
79 if (attribute->scale32() && (InDtype == DType_INT48))
80 {
81 printNodeValidationError("OpRescale: Scale set to true but input type is INT48");
82 return 1;
83 }
84
85 if ((!attribute->scale32()) && attribute->double_round())
86 {
87 printNodeValidationError("OpRescale: Scale set to false but double round set to true");
88 return 1;
89 }
90
Eric Kunzee5e26762020-10-13 16:11:07 -070091 return 0;
92}
93
94template <int Rank, DType InDtype, DType OutDtype>
95int OpRescale<Rank, InDtype, OutDtype>::eval()
96{
97 int32_t input_zp = attribute->input_zp();
98 int32_t output_zp = attribute->output_zp();
99 std::vector<int32_t> multiplier = attribute->multiplier();
100 std::vector<int32_t> shift = attribute->shift();
Kevin Cheng0f87c952021-03-18 17:41:39 -0700101 bool scale32 = attribute->scale32();
102 bool double_round = attribute->double_round();
103 bool per_channel = attribute->per_channel();
Eric Kunzee5e26762020-10-13 16:11:07 -0700104
Eric Kunzee5e26762020-10-13 16:11:07 -0700105 // reshape [d0, d1, ..., dn] into [d0 * d1 ..., dn]
106 Eigen::array<Eigen::Index, 2> shape_2d;
107 shape_2d[0] = 1;
108 if (Rank > 0)
109 {
110 for (int i = 0; i < Rank - 1; i++)
111 {
112 shape_2d[0] *= this->in->getShape()[i];
113 }
114 shape_2d[1] = this->in->getShape()[Rank - 1];
115 }
116 else
117 {
118 shape_2d[1] = 1;
119 }
120 ETensor2<InEigenType> input_reshaped = this->in->getTensor().reshape(shape_2d);
121
122 ETensor2<OutEigenType> output_2d(shape_2d);
123
Eric Kunzee5e26762020-10-13 16:11:07 -0700124 if (per_channel)
125 {
126 ETensor2<InEigenType> curr_channel_slice_prescaled;
127 ETensor2<OutEigenType> curr_channel_slice_postscaled;
128 int32_t channel_multiplier, channel_shift;
129 Eigen::array<Eigen::Index, 2> begin, size;
130 size = Eigen::array<Eigen::Index, 2>({ shape_2d[0], 1 });
Kevin Chengacb550f2021-06-29 15:32:19 -0700131 try
Eric Kunzee5e26762020-10-13 16:11:07 -0700132 {
Kevin Chengacb550f2021-06-29 15:32:19 -0700133 for (int32_t i = 0; i < shape_2d[1]; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700134 {
Kevin Chengacb550f2021-06-29 15:32:19 -0700135 begin = Eigen::array<Eigen::Index, 2>({ 0, i });
136 curr_channel_slice_prescaled = input_reshaped.slice(begin, size);
137 channel_multiplier = multiplier[i];
138 channel_shift = shift[i];
139 curr_channel_slice_postscaled =
140 curr_channel_slice_prescaled.unaryExpr([input_zp, output_zp, channel_multiplier, channel_shift,
141 double_round, scale32](InEigenType in_val) -> OutEigenType {
142 InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
143 int32_t scaled;
144 if (scale32)
145 scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, channel_multiplier,
146 channel_shift, double_round);
147 else
148 scaled = TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, channel_multiplier,
149 channel_shift);
150 OutEigenType out_val = (OutEigenType)(scaled + output_zp);
151 out_val = std::max<OutEigenType>(out_val, QMin);
152 out_val = std::min<OutEigenType>(out_val, QMax);
153 return out_val;
154 });
155
156 for (int32_t j = 0; j < shape_2d[0]; j++)
157 {
158 output_2d(j, i) = curr_channel_slice_postscaled(j, 0);
159 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700160 }
161 }
Kevin Chengacb550f2021-06-29 15:32:19 -0700162 catch (std::string desc)
163 {
164 REQUIRE(false, "OpRescale apply_scale_32/16() fails: %s.", desc.c_str());
165 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700166 }
167 else
168 {
169 int32_t tensor_multiplier = multiplier[0];
170 int32_t tensor_shift = shift[0];
Kevin Chengacb550f2021-06-29 15:32:19 -0700171 try
172 {
173 output_2d = input_reshaped.unaryExpr([input_zp, output_zp, tensor_multiplier, tensor_shift, double_round,
174 scale32](InEigenType in_val) -> OutEigenType {
175 InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
176 int32_t scaled;
177 if (scale32)
178 scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier, tensor_shift,
179 double_round);
180 else
181 scaled =
182 TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, tensor_multiplier, tensor_shift);
183 OutEigenType out_val = (OutEigenType)(scaled + output_zp);
184 out_val = std::max<OutEigenType>(out_val, QMin);
185 out_val = std::min<OutEigenType>(out_val, QMax);
186 return out_val;
187 });
188 }
189 catch (std::string desc)
190 {
191 REQUIRE(false, "OpRescale apply_scale_32/16() fails: %s.", desc.c_str());
192 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 }
194
195 // reshape [d0 * d1 ..., dn] back to [d0, d1, ..., dn]
196 Eigen::array<Eigen::Index, Rank> output_shape;
197 for (int i = 0; i < Rank; i++)
198 {
199 output_shape[i] = this->out->getShape()[i];
200 }
201 this->out->getTensor() = output_2d.reshape(output_shape);
202
203 return GraphNode::eval();
204}
205
206template <int Rank, DType InDtype, DType OutDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700207OpCast<Rank, InDtype, OutDtype>::OpCast(SubgraphTraverser* sgt_,
208 TosaAttributeBase* attribute_,
209 TosaQuantInfoBase* qinfo_,
210 uint64_t id_)
211 : GraphNode(sgt_, Op_CAST, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700212{
213 setRequiredOperands(1, 1);
214 setRequiredRank(0, 6);
215}
216
217template <int Rank, DType InDtype, DType OutDtype>
218OpCast<Rank, InDtype, OutDtype>::~OpCast()
219{}
220
221template <int Rank, DType InDtype, DType OutDtype>
222int OpCast<Rank, InDtype, OutDtype>::checkTensorAttributes()
223{
224 if (validateRequiredOperands())
225 return 1;
226
227 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
228 {
229 return 1;
230 }
231
232 // output and input must be the same rank and size
233 if (inputs[0]->matchRankSize(*outputs[0]))
234 {
235 printNodeValidationError("OpCast: input and output rank/size must match");
236 return 1;
237 }
238
239 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
240 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
241
242 ASSERT_MEM(in && out);
243
244 return 0;
245}
246
247template <int Rank, DType InDtype, DType OutDtype>
248int OpCast<Rank, InDtype, OutDtype>::eval()
249{
250 this->out->getTensor() = this->in->getTensor().unaryExpr(cast_helper.get_fcn());
251
252 return GraphNode::eval();
253}
254
255template <DType InDtype, DType OutDtype>
256CastHelper<InDtype, OutDtype>::CastHelper()
257{
258 fcn = [](InEigenType in) -> OutEigenType {
259 OutEigenType out = (OutEigenType)in; // implicit sign_extend() if sizeof(out_t) >= sizeof(in_t)
Eric Kunzee5e26762020-10-13 16:11:07 -0700260 return out;
261 };
262}
263
264template <DType InDtype>
265CastHelper<InDtype, DType_BOOL>::CastHelper()
266{
267 fcn = [](InEigenType in) -> bool { return (in != 0) ? true : false; };
268}
269
270template <DType OutDtype>
271CastHelper<DType_BOOL, OutDtype>::CastHelper()
272{
273 fcn = [](bool in) -> OutEigenType {
274 OutEigenType out = in ? (OutEigenType)1 : (OutEigenType)0;
275 return out;
276 };
277}
278
279template <DType InDtype>
280CastHelper<InDtype, DType_FLOAT>::CastHelper()
281{
282 fcn = [](InEigenType in) -> float {
283 float out = (OutEigenType)in; // default cast to float is round_to_nearest_float()
284 return out;
285 };
286}
287
288template <DType OutDtype>
289CastHelper<DType_FLOAT, OutDtype>::CastHelper()
290{
291 fcn = [](float in) -> OutEigenType {
292 OutEigenType out = std::round(in);
293 out = std::max<OutEigenType>(out, OutMin);
294 out = std::min<OutEigenType>(out, OutMax);
295 return out;
296 };
297}
298
299// template explicit instantiation
300DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT8);
301DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT16);
302DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT32);
303DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL);
304DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16);
305DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32);
306DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FLOAT);
307DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL);
308DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8);
309DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32);
310DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FLOAT);
311DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL);
312DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8);
313DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16);
314DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FLOAT);
315DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT8);
316DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT16);
317DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT32);
318
Kevin Cheng3a478572021-01-22 17:21:02 -0800319DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8);
320DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT16);
321DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT32);
322DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700323DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT16);
324DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800325DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700326DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT16);
327DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800328DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700329DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16);
330DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800331DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT8);
332DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, UINT8);