blob: f51c38c9fa6e715e021b034772b0fbb44f93c363 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
James Ward8b390432022-08-12 20:48:56 +01002// Copyright (c) 2020-2022, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "type_conversion.h"
17#include "quant_util.h"
18#include "template_types.h"
19#include <cmath>
James Ward8b390432022-08-12 20:48:56 +010020#include "half.hpp"
Eric Kunzee5e26762020-10-13 16:11:07 -070021
22using namespace TosaReference;
23using namespace Eigen;
24using namespace tosa;
25
26template <int Rank, DType InDtype, DType OutDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -070027OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_,
28 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -070029 uint64_t id_)
30 : GraphNode(sgt_, Op_RESCALE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070031{
32 setRequiredOperands(1, 1);
TatWai Chongfd629052022-07-25 04:01:58 +000033 setRequiredRank(0, 6);
Eric Kunzee5e26762020-10-13 16:11:07 -070034 INIT_ATTRIBUTE(Rescale);
35}
36
37template <int Rank, DType InDtype, DType OutDtype>
38OpRescale<Rank, InDtype, OutDtype>::~OpRescale()
39{
40 if (attribute)
41 delete attribute;
42}
43
44template <int Rank, DType InDtype, DType OutDtype>
45int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes()
46{
47 if (validateRequiredOperands())
48 return 1;
49
50 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
51 {
52 return 1;
53 }
54
55 // output and input must be the same rank and size
56 if (inputs[0]->matchRankSize(*outputs[0]))
57 {
58 printNodeValidationError("OpRescale: input and output rank/size must match");
59 return 1;
60 }
61
62 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
63 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
64
65 ASSERT_MEM(in && out);
66
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010067 if ((InDtype != DType_INT8) && (InDtype != DType_UINT8) && (InDtype != DType_UINT16) && (attribute->input_zp() != 0))
Kevin Chengcc61be32021-10-14 17:09:57 -070068 {
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010069 printNodeValidationError("OpRescale: Input DType not INT8/UINT8/UINT16 and zero point not 0");
Kevin Chengcc61be32021-10-14 17:09:57 -070070 return 1;
71 }
72
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010073 if ((OutDtype != DType_INT8) && (OutDtype != DType_UINT8) && (OutDtype != DType_UINT16) && (attribute->output_zp() != 0))
Kevin Chengcc61be32021-10-14 17:09:57 -070074 {
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010075 printNodeValidationError("OpRescale: Output DType not INT8/UINT8/UINT16 and zero point not 0");
76 return 1;
77 }
78
79 if ((InDtype == DType_UINT16) && ((attribute->input_zp() != 0) && (attribute->input_zp() != 32768)))
80 {
81 printNodeValidationError("OpRescale: Input DType UINT16 and zero point not 0 or 32768");
82 return 1;
83 }
84
85 if ((OutDtype == DType_UINT16) && ((attribute->output_zp() != 0) && (attribute->output_zp() != 32768)))
86 {
87 printNodeValidationError("OpRescale: Output DType UINT16 and zero point not 0 or 32768");
Kevin Chengcc61be32021-10-14 17:09:57 -070088 return 1;
89 }
90
91 if (attribute->scale32() && (InDtype == DType_INT48))
92 {
93 printNodeValidationError("OpRescale: Scale set to true but input type is INT48");
94 return 1;
95 }
96
97 if ((!attribute->scale32()) && attribute->double_round())
98 {
99 printNodeValidationError("OpRescale: Scale set to false but double round set to true");
100 return 1;
101 }
102
Eric Kunzee5e26762020-10-13 16:11:07 -0700103 return 0;
104}
105
106template <int Rank, DType InDtype, DType OutDtype>
107int OpRescale<Rank, InDtype, OutDtype>::eval()
108{
109 int32_t input_zp = attribute->input_zp();
110 int32_t output_zp = attribute->output_zp();
111 std::vector<int32_t> multiplier = attribute->multiplier();
112 std::vector<int32_t> shift = attribute->shift();
Kevin Cheng0f87c952021-03-18 17:41:39 -0700113 bool scale32 = attribute->scale32();
114 bool double_round = attribute->double_round();
115 bool per_channel = attribute->per_channel();
Eric Kunzee5e26762020-10-13 16:11:07 -0700116
Eric Kunzee5e26762020-10-13 16:11:07 -0700117 // reshape [d0, d1, ..., dn] into [d0 * d1 ..., dn]
118 Eigen::array<Eigen::Index, 2> shape_2d;
119 shape_2d[0] = 1;
120 if (Rank > 0)
121 {
122 for (int i = 0; i < Rank - 1; i++)
123 {
124 shape_2d[0] *= this->in->getShape()[i];
125 }
126 shape_2d[1] = this->in->getShape()[Rank - 1];
127 }
128 else
129 {
130 shape_2d[1] = 1;
131 }
132 ETensor2<InEigenType> input_reshaped = this->in->getTensor().reshape(shape_2d);
133
134 ETensor2<OutEigenType> output_2d(shape_2d);
135
Eric Kunzee5e26762020-10-13 16:11:07 -0700136 if (per_channel)
137 {
138 ETensor2<InEigenType> curr_channel_slice_prescaled;
139 ETensor2<OutEigenType> curr_channel_slice_postscaled;
140 int32_t channel_multiplier, channel_shift;
141 Eigen::array<Eigen::Index, 2> begin, size;
142 size = Eigen::array<Eigen::Index, 2>({ shape_2d[0], 1 });
Kevin Chengacb550f2021-06-29 15:32:19 -0700143 try
Eric Kunzee5e26762020-10-13 16:11:07 -0700144 {
Kevin Chengacb550f2021-06-29 15:32:19 -0700145 for (int32_t i = 0; i < shape_2d[1]; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700146 {
Kevin Chengacb550f2021-06-29 15:32:19 -0700147 begin = Eigen::array<Eigen::Index, 2>({ 0, i });
148 curr_channel_slice_prescaled = input_reshaped.slice(begin, size);
149 channel_multiplier = multiplier[i];
150 channel_shift = shift[i];
151 curr_channel_slice_postscaled =
152 curr_channel_slice_prescaled.unaryExpr([input_zp, output_zp, channel_multiplier, channel_shift,
153 double_round, scale32](InEigenType in_val) -> OutEigenType {
154 InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
155 int32_t scaled;
156 if (scale32)
157 scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, channel_multiplier,
158 channel_shift, double_round);
159 else
160 scaled = TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, channel_multiplier,
161 channel_shift);
162 OutEigenType out_val = (OutEigenType)(scaled + output_zp);
163 out_val = std::max<OutEigenType>(out_val, QMin);
164 out_val = std::min<OutEigenType>(out_val, QMax);
165 return out_val;
166 });
167
168 for (int32_t j = 0; j < shape_2d[0]; j++)
169 {
170 output_2d(j, i) = curr_channel_slice_postscaled(j, 0);
171 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700172 }
173 }
Kevin Chengacb550f2021-06-29 15:32:19 -0700174 catch (std::string desc)
175 {
176 REQUIRE(false, "OpRescale apply_scale_32/16() fails: %s.", desc.c_str());
177 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700178 }
179 else
180 {
181 int32_t tensor_multiplier = multiplier[0];
182 int32_t tensor_shift = shift[0];
Kevin Chengacb550f2021-06-29 15:32:19 -0700183 try
184 {
185 output_2d = input_reshaped.unaryExpr([input_zp, output_zp, tensor_multiplier, tensor_shift, double_round,
186 scale32](InEigenType in_val) -> OutEigenType {
187 InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
188 int32_t scaled;
189 if (scale32)
190 scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier, tensor_shift,
191 double_round);
192 else
193 scaled =
194 TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, tensor_multiplier, tensor_shift);
195 OutEigenType out_val = (OutEigenType)(scaled + output_zp);
196 out_val = std::max<OutEigenType>(out_val, QMin);
197 out_val = std::min<OutEigenType>(out_val, QMax);
198 return out_val;
199 });
200 }
201 catch (std::string desc)
202 {
203 REQUIRE(false, "OpRescale apply_scale_32/16() fails: %s.", desc.c_str());
204 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700205 }
206
207 // reshape [d0 * d1 ..., dn] back to [d0, d1, ..., dn]
208 Eigen::array<Eigen::Index, Rank> output_shape;
209 for (int i = 0; i < Rank; i++)
210 {
211 output_shape[i] = this->out->getShape()[i];
212 }
213 this->out->getTensor() = output_2d.reshape(output_shape);
214
215 return GraphNode::eval();
216}
217
218template <int Rank, DType InDtype, DType OutDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700219OpCast<Rank, InDtype, OutDtype>::OpCast(SubgraphTraverser* sgt_,
220 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700221 uint64_t id_)
222 : GraphNode(sgt_, Op_CAST, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700223{
224 setRequiredOperands(1, 1);
225 setRequiredRank(0, 6);
226}
227
228template <int Rank, DType InDtype, DType OutDtype>
229OpCast<Rank, InDtype, OutDtype>::~OpCast()
230{}
231
232template <int Rank, DType InDtype, DType OutDtype>
233int OpCast<Rank, InDtype, OutDtype>::checkTensorAttributes()
234{
235 if (validateRequiredOperands())
236 return 1;
237
238 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
239 {
240 return 1;
241 }
242
243 // output and input must be the same rank and size
244 if (inputs[0]->matchRankSize(*outputs[0]))
245 {
246 printNodeValidationError("OpCast: input and output rank/size must match");
247 return 1;
248 }
249
250 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
251 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
252
253 ASSERT_MEM(in && out);
254
255 return 0;
256}
257
258template <int Rank, DType InDtype, DType OutDtype>
259int OpCast<Rank, InDtype, OutDtype>::eval()
260{
261 this->out->getTensor() = this->in->getTensor().unaryExpr(cast_helper.get_fcn());
262
263 return GraphNode::eval();
264}
265
266template <DType InDtype, DType OutDtype>
267CastHelper<InDtype, OutDtype>::CastHelper()
268{
269 fcn = [](InEigenType in) -> OutEigenType {
270 OutEigenType out = (OutEigenType)in; // implicit sign_extend() if sizeof(out_t) >= sizeof(in_t)
Eric Kunzee5e26762020-10-13 16:11:07 -0700271 return out;
272 };
273}
274
275template <DType InDtype>
276CastHelper<InDtype, DType_BOOL>::CastHelper()
277{
278 fcn = [](InEigenType in) -> bool { return (in != 0) ? true : false; };
279}
280
281template <DType OutDtype>
282CastHelper<DType_BOOL, OutDtype>::CastHelper()
283{
284 fcn = [](bool in) -> OutEigenType {
285 OutEigenType out = in ? (OutEigenType)1 : (OutEigenType)0;
286 return out;
287 };
288}
289
290template <DType InDtype>
James Ward8b390432022-08-12 20:48:56 +0100291CastHelper<InDtype, DType_FP16>::CastHelper()
292{
293 fcn = [](InEigenType in) -> float {
294 half_float::half out = half_float::half_cast<half_float::half, InEigenType>(in); // Cast to half_float
295 return half_float::half_cast<float, half_float::half>(out); // Cast to float (underlying FP16 EigenType)
296 };
297}
298
299template <DType OutDtype>
300CastHelper<DType_FP16, OutDtype>::CastHelper()
301{
302 // Assuming InEigenType = float.
303 fcn = [](float in) -> OutEigenType {
304 // Perform initial rounding in half-precision then cast back to float
305 half_float::half h = half_float::half_cast<half_float::half, float>(in);
306 h = std::round(h);
307 OutEigenType out = half_float::half_cast<float, half_float::half>(h);
308 out = std::max<OutEigenType>(out, OutMin);
309 out = std::min<OutEigenType>(out, OutMax);
310 return out;
311 };
312}
313
314template <DType InDtype>
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100315CastHelper<InDtype, DType_FP32>::CastHelper()
Eric Kunzee5e26762020-10-13 16:11:07 -0700316{
317 fcn = [](InEigenType in) -> float {
318 float out = (OutEigenType)in; // default cast to float is round_to_nearest_float()
319 return out;
320 };
321}
322
323template <DType OutDtype>
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100324CastHelper<DType_FP32, OutDtype>::CastHelper()
Eric Kunzee5e26762020-10-13 16:11:07 -0700325{
326 fcn = [](float in) -> OutEigenType {
327 OutEigenType out = std::round(in);
328 out = std::max<OutEigenType>(out, OutMin);
329 out = std::min<OutEigenType>(out, OutMax);
330 return out;
331 };
332}
333
334// template explicit instantiation
335DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT8);
336DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT16);
337DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT32);
338DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL);
339DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16);
340DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32);
James Ward8b390432022-08-12 20:48:56 +0100341DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100342DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700343DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL);
344DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8);
345DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32);
James Ward8b390432022-08-12 20:48:56 +0100346DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100347DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700348DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL);
349DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8);
350DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16);
James Ward8b390432022-08-12 20:48:56 +0100351DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100352DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP32);
James Ward8b390432022-08-12 20:48:56 +0100353DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8);
354DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16);
355DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100356DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT8);
357DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT16);
358DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700359
Kevin Cheng3a478572021-01-22 17:21:02 -0800360DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8);
361DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT16);
362DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT32);
363DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700364DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT16);
365DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800366DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700367DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT16);
368DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800369DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700370DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16);
371DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800372DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT8);
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100373DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT16);
374DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT16, INT16);
Kevin Cheng3a478572021-01-22 17:21:02 -0800375DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, UINT8);
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100376DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT8);
377DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT16);