blob: 484f7681ca573d8a9d74619257a00044bd6cf580 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Jerry Ged6a04612023-12-15 22:45:39 +00002// Copyright (c) 2020-2024, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "type_conversion.h"
James Ward736fd1a2023-01-23 17:13:37 +000017#include "arith_util.h"
Jerry Ge9c9c8da2023-07-19 23:08:16 +000018#include "half.hpp"
19#include "quant_util.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070020#include "template_types.h"
21#include <cmath>
22
23using namespace TosaReference;
24using namespace Eigen;
25using namespace tosa;
26
Tai Lya4d748b2023-03-28 22:06:56 +000027template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +000028OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -070029 : GraphNode(sgt_, Op_RESCALE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070030{
31 setRequiredOperands(1, 1);
Eric Kunzee5e26762020-10-13 16:11:07 -070032 INIT_ATTRIBUTE(Rescale);
33}
34
Tai Lya4d748b2023-03-28 22:06:56 +000035template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070036OpRescale<Rank, InDtype, OutDtype>::~OpRescale()
37{
38 if (attribute)
39 delete attribute;
40}
41
Tai Lya4d748b2023-03-28 22:06:56 +000042template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070043int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes()
44{
Jerry Gea793f462023-04-11 00:05:02 +000045 // Check Tosa Level
46 auto tosa_level = g_func_config.tosa_level;
47 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
48
Eric Kunzee5e26762020-10-13 16:11:07 -070049 if (validateRequiredOperands())
50 return 1;
51
Eric Kunzee5e26762020-10-13 16:11:07 -070052 // output and input must be the same rank and size
53 if (inputs[0]->matchRankSize(*outputs[0]))
54 {
55 printNodeValidationError("OpRescale: input and output rank/size must match");
56 return 1;
57 }
58
59 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
60 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
61
62 ASSERT_MEM(in && out);
63
Tai Lya4d748b2023-03-28 22:06:56 +000064 if ((InDtype != TOSA_REF_TYPE_INT8) && (InDtype != TOSA_REF_TYPE_UINT8) && (InDtype != TOSA_REF_TYPE_UINT16) &&
65 (attribute->input_zp() != 0))
Kevin Chengcc61be32021-10-14 17:09:57 -070066 {
Tai Lya4d748b2023-03-28 22:06:56 +000067 printNodeValidationError("OpRescale: Input TOSA_REF_TYPE not INT8/UINT8/UINT16 and zero point not 0");
Kevin Chengcc61be32021-10-14 17:09:57 -070068 return 1;
69 }
70
Tai Lya4d748b2023-03-28 22:06:56 +000071 if ((OutDtype != TOSA_REF_TYPE_INT8) && (OutDtype != TOSA_REF_TYPE_UINT8) && (OutDtype != TOSA_REF_TYPE_UINT16) &&
72 (attribute->output_zp() != 0))
Kevin Chengcc61be32021-10-14 17:09:57 -070073 {
Tai Lya4d748b2023-03-28 22:06:56 +000074 printNodeValidationError("OpRescale: Output TOSA_REF_TYPE not INT8/UINT8/UINT16 and zero point not 0");
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010075 return 1;
76 }
77
Tai Lya4d748b2023-03-28 22:06:56 +000078 if ((InDtype == TOSA_REF_TYPE_UINT16) && ((attribute->input_zp() != 0) && (attribute->input_zp() != 32768)))
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010079 {
Tai Lya4d748b2023-03-28 22:06:56 +000080 printNodeValidationError("OpRescale: Input TOSA_REF_TYPE UINT16 and zero point not 0 or 32768");
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010081 return 1;
82 }
83
Tai Lya4d748b2023-03-28 22:06:56 +000084 if ((OutDtype == TOSA_REF_TYPE_UINT16) && ((attribute->output_zp() != 0) && (attribute->output_zp() != 32768)))
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010085 {
Tai Lya4d748b2023-03-28 22:06:56 +000086 printNodeValidationError("OpRescale: Output TOSA_REF_TYPE UINT16 and zero point not 0 or 32768");
Kevin Chengcc61be32021-10-14 17:09:57 -070087 return 1;
88 }
89
Tai Lya4d748b2023-03-28 22:06:56 +000090 if (attribute->scale32() && (InDtype == TOSA_REF_TYPE_INT48))
Kevin Chengcc61be32021-10-14 17:09:57 -070091 {
92 printNodeValidationError("OpRescale: Scale set to true but input type is INT48");
93 return 1;
94 }
95
96 if ((!attribute->scale32()) && attribute->double_round())
97 {
98 printNodeValidationError("OpRescale: Scale set to false but double round set to true");
99 return 1;
100 }
101
Eric Kunzee5e26762020-10-13 16:11:07 -0700102 return 0;
103}
104
Eric Kunzea8098c02023-09-07 00:31:54 +0000105// helpers to convert types
106static int64_t zero_extend(int8_t val)
107{
108 uint8_t* rval = reinterpret_cast<uint8_t*>(&val);
109 return static_cast<int64_t>(*rval);
110}
111static int64_t zero_extend(int16_t val)
112{
113 uint16_t* rval = reinterpret_cast<uint16_t*>(&val);
114 return static_cast<int64_t>(*rval);
115}
116
Tai Lya4d748b2023-03-28 22:06:56 +0000117template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700118int OpRescale<Rank, InDtype, OutDtype>::eval()
119{
120 int32_t input_zp = attribute->input_zp();
121 int32_t output_zp = attribute->output_zp();
122 std::vector<int32_t> multiplier = attribute->multiplier();
123 std::vector<int32_t> shift = attribute->shift();
Kevin Cheng0f87c952021-03-18 17:41:39 -0700124 bool scale32 = attribute->scale32();
125 bool double_round = attribute->double_round();
126 bool per_channel = attribute->per_channel();
Eric Kunzea8098c02023-09-07 00:31:54 +0000127 bool input_unsigned = attribute->input_unsigned();
128 bool output_unsigned = attribute->output_unsigned();
Eric Kunzee5e26762020-10-13 16:11:07 -0700129
Eric Kunzee5e26762020-10-13 16:11:07 -0700130 // reshape [d0, d1, ..., dn] into [d0 * d1 ..., dn]
131 Eigen::array<Eigen::Index, 2> shape_2d;
132 shape_2d[0] = 1;
133 if (Rank > 0)
134 {
135 for (int i = 0; i < Rank - 1; i++)
136 {
137 shape_2d[0] *= this->in->getShape()[i];
138 }
139 shape_2d[1] = this->in->getShape()[Rank - 1];
140 }
141 else
142 {
143 shape_2d[1] = 1;
144 }
145 ETensor2<InEigenType> input_reshaped = this->in->getTensor().reshape(shape_2d);
146
147 ETensor2<OutEigenType> output_2d(shape_2d);
148
Eric Kunzee5e26762020-10-13 16:11:07 -0700149 if (per_channel)
150 {
151 ETensor2<InEigenType> curr_channel_slice_prescaled;
152 ETensor2<OutEigenType> curr_channel_slice_postscaled;
153 int32_t channel_multiplier, channel_shift;
154 Eigen::array<Eigen::Index, 2> begin, size;
155 size = Eigen::array<Eigen::Index, 2>({ shape_2d[0], 1 });
Kevin Chengacb550f2021-06-29 15:32:19 -0700156 try
Eric Kunzee5e26762020-10-13 16:11:07 -0700157 {
Kevin Chengacb550f2021-06-29 15:32:19 -0700158 for (int32_t i = 0; i < shape_2d[1]; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700159 {
Eric Kunzea8098c02023-09-07 00:31:54 +0000160 begin = Eigen::array<Eigen::Index, 2>({ 0, i });
161 curr_channel_slice_prescaled = input_reshaped.slice(begin, size);
162 channel_multiplier = multiplier[i];
163 channel_shift = shift[i];
164 curr_channel_slice_postscaled = curr_channel_slice_prescaled.unaryExpr(
165 [input_zp, output_zp, channel_multiplier, channel_shift, double_round, scale32, input_unsigned,
166 output_unsigned](InEigenType in_val) -> OutEigenType {
167 int64_t input_zp_shifted;
168 if (input_unsigned)
169 {
170 int64_t in_val64;
171 int64_t in_zp64;
172 switch (GetNumBits<InDtype>::value)
173 {
174 case 8:
175 in_val64 = zero_extend(static_cast<int8_t>(in_val));
176 in_zp64 = zero_extend(static_cast<int8_t>(input_zp));
177 break;
178 case 16:
179 in_val64 = zero_extend(static_cast<int16_t>(in_val));
180 in_zp64 = zero_extend(static_cast<int16_t>(input_zp));
181 break;
182 default:
183 in_val64 = static_cast<int64_t>(in_val);
184 in_zp64 = static_cast<int64_t>(input_zp);
185 break;
186 }
187 input_zp_shifted = in_val64 - in_zp64;
188 }
189 else
190 {
191 input_zp_shifted = in_val - input_zp;
192 }
Kevin Chengacb550f2021-06-29 15:32:19 -0700193 int32_t scaled;
194 if (scale32)
Eric Kunzea8098c02023-09-07 00:31:54 +0000195 scaled = TosaReference::QuantUtil::apply_scale_32(static_cast<int32_t>(input_zp_shifted),
196 channel_multiplier, channel_shift,
197 double_round);
Kevin Chengacb550f2021-06-29 15:32:19 -0700198 else
199 scaled = TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, channel_multiplier,
200 channel_shift);
Eric Kunzea8098c02023-09-07 00:31:54 +0000201 int64_t output_zp_extended;
202 if (output_unsigned)
203 {
204 switch (GetNumBits<OutDtype>::value)
205 {
206 case 8:
207 output_zp_extended = zero_extend(static_cast<int8_t>(output_zp));
208 break;
209 case 16:
210 output_zp_extended = zero_extend(static_cast<int16_t>(output_zp));
211 break;
212 default:
213 output_zp_extended = static_cast<int64_t>(output_zp);
214 break;
215 }
216 }
217 else
218 {
219 output_zp_extended = static_cast<int64_t>(output_zp);
220 }
221 int64_t res_in_64 = static_cast<int64_t>(scaled) + output_zp_extended;
Jeremy Johnsondf628d42023-01-10 14:40:54 +0000222 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
223 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min());
224 if (res_in_64 > i32_max_in_64 || res_in_64 < i32_min_in_64)
225 {
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000226 std::string desc = "scaling result [" + std::to_string(scaled) + "] plus output_zp [" +
227 std::to_string(output_zp) + "] not in i32 range";
Jeremy Johnsondf628d42023-01-10 14:40:54 +0000228 throw desc;
229 }
230 OutEigenType out_val = static_cast<OutEigenType>(res_in_64);
Kevin Chengacb550f2021-06-29 15:32:19 -0700231 out_val = std::max<OutEigenType>(out_val, QMin);
232 out_val = std::min<OutEigenType>(out_val, QMax);
233 return out_val;
234 });
235
236 for (int32_t j = 0; j < shape_2d[0]; j++)
237 {
238 output_2d(j, i) = curr_channel_slice_postscaled(j, 0);
239 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700240 }
241 }
Kevin Chengacb550f2021-06-29 15:32:19 -0700242 catch (std::string desc)
243 {
Jeremy Johnsondf628d42023-01-10 14:40:54 +0000244 REQUIRE(false, "OpRescale failure: %s.", desc.c_str());
Kevin Chengacb550f2021-06-29 15:32:19 -0700245 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700246 }
247 else
248 {
249 int32_t tensor_multiplier = multiplier[0];
250 int32_t tensor_shift = shift[0];
Kevin Chengacb550f2021-06-29 15:32:19 -0700251 try
252 {
Eric Kunzea8098c02023-09-07 00:31:54 +0000253 output_2d =
254 input_reshaped.unaryExpr([input_zp, output_zp, tensor_multiplier, tensor_shift, double_round, scale32,
255 input_unsigned, output_unsigned](InEigenType in_val) -> OutEigenType {
256 int64_t input_zp_shifted;
257 if (input_unsigned)
258 {
259 int64_t in_val64;
260 int64_t in_zp64;
261 switch (GetNumBits<InDtype>::value)
262 {
263 case 8:
264 in_val64 = zero_extend(static_cast<int8_t>(in_val));
265 in_zp64 = zero_extend(static_cast<int8_t>(input_zp));
266 break;
267 case 16:
268 in_val64 = zero_extend(static_cast<int16_t>(in_val));
269 in_zp64 = zero_extend(static_cast<int16_t>(input_zp));
270 break;
271 default:
272 in_val64 = static_cast<int64_t>(in_val);
273 in_zp64 = static_cast<int64_t>(input_zp);
274 break;
275 }
276 input_zp_shifted = in_val64 - in_zp64;
277 }
278 else
279 {
280 input_zp_shifted = in_val - input_zp;
281 }
282 int32_t scaled;
283 if (scale32)
284 scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier,
285 tensor_shift, double_round);
286 else
287 scaled =
288 TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, tensor_multiplier, tensor_shift);
Jeremy Johnsondf628d42023-01-10 14:40:54 +0000289
Eric Kunzea8098c02023-09-07 00:31:54 +0000290 int64_t output_zp_extended;
291 if (output_unsigned)
292 {
293 switch (GetNumBits<OutDtype>::value)
294 {
295 case 8:
296 output_zp_extended = zero_extend(static_cast<int8_t>(output_zp));
297 break;
298 case 16:
299 output_zp_extended = zero_extend(static_cast<int16_t>(output_zp));
300 break;
301 default:
302 output_zp_extended = static_cast<int64_t>(output_zp);
303 break;
304 }
305 }
306 else
307 {
308 output_zp_extended = static_cast<int64_t>(output_zp);
309 }
310 int64_t res_in_64 = static_cast<int64_t>(scaled) + output_zp_extended;
311 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
312 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min());
313 if (res_in_64 > i32_max_in_64 || res_in_64 < i32_min_in_64)
314 {
315 std::string desc = "scaling result [" + std::to_string(scaled) + "] plus output_zp [" +
316 std::to_string(output_zp) + "] not in i32 range";
317 throw desc;
318 }
319
320 OutEigenType out_val = static_cast<OutEigenType>(res_in_64);
321 out_val = std::max<OutEigenType>(out_val, QMin);
322 out_val = std::min<OutEigenType>(out_val, QMax);
323 return out_val;
324 });
Kevin Chengacb550f2021-06-29 15:32:19 -0700325 }
326 catch (std::string desc)
327 {
Jeremy Johnsondf628d42023-01-10 14:40:54 +0000328 REQUIRE(false, "OpRescale failure: %s.", desc.c_str());
Kevin Chengacb550f2021-06-29 15:32:19 -0700329 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700330 }
331
332 // reshape [d0 * d1 ..., dn] back to [d0, d1, ..., dn]
333 Eigen::array<Eigen::Index, Rank> output_shape;
334 for (int i = 0; i < Rank; i++)
335 {
336 output_shape[i] = this->out->getShape()[i];
337 }
338 this->out->getTensor() = output_2d.reshape(output_shape);
339
340 return GraphNode::eval();
341}
342
Tai Lya4d748b2023-03-28 22:06:56 +0000343template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000344OpCast<Rank, InDtype, OutDtype>::OpCast(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700345 : GraphNode(sgt_, Op_CAST, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700346{
347 setRequiredOperands(1, 1);
348 setRequiredRank(0, 6);
349}
350
Tai Lya4d748b2023-03-28 22:06:56 +0000351template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700352OpCast<Rank, InDtype, OutDtype>::~OpCast()
353{}
354
Tai Lya4d748b2023-03-28 22:06:56 +0000355template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700356int OpCast<Rank, InDtype, OutDtype>::checkTensorAttributes()
357{
Jerry Gea793f462023-04-11 00:05:02 +0000358 // Check Tosa Level
359 auto tosa_level = g_func_config.tosa_level;
360 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
361
Eric Kunzee5e26762020-10-13 16:11:07 -0700362 if (validateRequiredOperands())
363 return 1;
364
365 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
366 {
367 return 1;
368 }
369
370 // output and input must be the same rank and size
371 if (inputs[0]->matchRankSize(*outputs[0]))
372 {
373 printNodeValidationError("OpCast: input and output rank/size must match");
374 return 1;
375 }
376
377 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
378 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
379
380 ASSERT_MEM(in && out);
381
382 return 0;
383}
384
Tai Lya4d748b2023-03-28 22:06:56 +0000385template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700386int OpCast<Rank, InDtype, OutDtype>::eval()
387{
388 this->out->getTensor() = this->in->getTensor().unaryExpr(cast_helper.get_fcn());
389
390 return GraphNode::eval();
391}
392
Tai Lya4d748b2023-03-28 22:06:56 +0000393template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700394CastHelper<InDtype, OutDtype>::CastHelper()
395{
396 fcn = [](InEigenType in) -> OutEigenType {
397 OutEigenType out = (OutEigenType)in; // implicit sign_extend() if sizeof(out_t) >= sizeof(in_t)
Eric Kunzee5e26762020-10-13 16:11:07 -0700398 return out;
399 };
400}
401
Tai Lya4d748b2023-03-28 22:06:56 +0000402template <TOSA_REF_TYPE InDtype>
403CastHelper<InDtype, TOSA_REF_TYPE_BOOL>::CastHelper()
Eric Kunzee5e26762020-10-13 16:11:07 -0700404{
405 fcn = [](InEigenType in) -> bool { return (in != 0) ? true : false; };
406}
407
Tai Lya4d748b2023-03-28 22:06:56 +0000408template <TOSA_REF_TYPE OutDtype>
409CastHelper<TOSA_REF_TYPE_BOOL, OutDtype>::CastHelper()
Eric Kunzee5e26762020-10-13 16:11:07 -0700410{
411 fcn = [](bool in) -> OutEigenType {
412 OutEigenType out = in ? (OutEigenType)1 : (OutEigenType)0;
413 return out;
414 };
415}
416
Tai Lya4d748b2023-03-28 22:06:56 +0000417template <TOSA_REF_TYPE InDtype>
418CastHelper<InDtype, TOSA_REF_TYPE_FP16>::CastHelper()
James Ward8b390432022-08-12 20:48:56 +0100419{
James Ward736fd1a2023-01-23 17:13:37 +0000420 // Integer data converted to fp16 (stored as fp32)
James Ward8b390432022-08-12 20:48:56 +0100421 fcn = [](InEigenType in) -> float {
James Ward736fd1a2023-01-23 17:13:37 +0000422 half_float::half h = half_float::half(in);
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000423 float out = half_float::half_cast<float, half_float::half>(h);
James Ward736fd1a2023-01-23 17:13:37 +0000424 return out;
425 };
426}
427
Tai Lya4d748b2023-03-28 22:06:56 +0000428CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP16>::CastHelper()
James Ward736fd1a2023-01-23 17:13:37 +0000429{
430 // fp32 data converted to fp16 (stored as fp32)
431 fcn = [](float in) -> float {
Tai Lya4d748b2023-03-28 22:06:56 +0000432 float out = fpTrunc<TOSA_REF_TYPE_FP16>(in); // truncate required for conversion from higher precision
James Ward736fd1a2023-01-23 17:13:37 +0000433 return out;
434 };
435}
436
Tai Lya4d748b2023-03-28 22:06:56 +0000437template <TOSA_REF_TYPE InDtype>
438CastHelper<InDtype, TOSA_REF_TYPE_BF16>::CastHelper()
James Ward736fd1a2023-01-23 17:13:37 +0000439{
440 // Integer data converted to bf16 (stored as fp32)
441 fcn = [](InEigenType in) -> float {
442 float out = (float)in; // default cast to float is round_to_nearest_float()
443 return out;
444 };
445}
446
Tai Lya4d748b2023-03-28 22:06:56 +0000447CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_BF16>::CastHelper()
James Ward736fd1a2023-01-23 17:13:37 +0000448{
449 // fp32 data converted to bf16 (stored as fp32)
450 fcn = [](float in) -> float {
Tai Lya4d748b2023-03-28 22:06:56 +0000451 return fpTrunc<TOSA_REF_TYPE_BF16>(in); // truncate required for conversions from higher precision
James Ward8b390432022-08-12 20:48:56 +0100452 };
453}
454
Tai Lya4d748b2023-03-28 22:06:56 +0000455template <TOSA_REF_TYPE OutDtype>
456CastHelper<TOSA_REF_TYPE_FP16, OutDtype>::CastHelper()
James Ward8b390432022-08-12 20:48:56 +0100457{
James Ward736fd1a2023-01-23 17:13:37 +0000458 // fp16 data (stored as fp32) converted to integer
James Ward8b390432022-08-12 20:48:56 +0100459 fcn = [](float in) -> OutEigenType {
James Ward736fd1a2023-01-23 17:13:37 +0000460 // Cast from float representation back to half_float before rounding
461 half_float::half h = half_float::half(in);
Jerry Ged6a04612023-12-15 22:45:39 +0000462 if (h >= half_float::half(float(OutMax)))
463 return OutMax;
464
465 if (h <= half_float::half(float(OutMin)))
466 return OutMin;
467
468 h = std::rint(h);
469 OutEigenType out = half_float::half_cast<OutEigenType, half_float::half>(h);
470
James Ward8b390432022-08-12 20:48:56 +0100471 return out;
472 };
473}
474
Tai Lya4d748b2023-03-28 22:06:56 +0000475CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP32>::CastHelper()
James Ward736fd1a2023-01-23 17:13:37 +0000476{
477 // No-op since fp16 values treated internally as their fp32 representation
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000478 fcn = [](float in) -> OutEigenType { return in; };
James Ward736fd1a2023-01-23 17:13:37 +0000479}
480
Tai Lya4d748b2023-03-28 22:06:56 +0000481template <TOSA_REF_TYPE OutDtype>
482CastHelper<TOSA_REF_TYPE_BF16, OutDtype>::CastHelper()
James Ward736fd1a2023-01-23 17:13:37 +0000483{
484 // bf16 data (stored as fp32) converted to integer
485 fcn = [](float in) -> OutEigenType {
Jerry Ged6a04612023-12-15 22:45:39 +0000486 if (in >= float(OutMax))
487 return OutMax;
488
489 if (in <= float(OutMin))
490 return OutMin;
491
492 OutEigenType out = std::rint(in);
James Ward736fd1a2023-01-23 17:13:37 +0000493 return out;
494 };
495}
496
Tai Lya4d748b2023-03-28 22:06:56 +0000497CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP32>::CastHelper()
James Ward736fd1a2023-01-23 17:13:37 +0000498{
499 // No-op since bf16 values treated as truncated fp32 internally
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000500 fcn = [](InEigenType in) -> OutEigenType { return in; };
James Ward736fd1a2023-01-23 17:13:37 +0000501}
502
Tai Lya4d748b2023-03-28 22:06:56 +0000503template <TOSA_REF_TYPE InDtype>
504CastHelper<InDtype, TOSA_REF_TYPE_FP32>::CastHelper()
Eric Kunzee5e26762020-10-13 16:11:07 -0700505{
James Ward736fd1a2023-01-23 17:13:37 +0000506 // Integer data converted to fp32
Eric Kunzee5e26762020-10-13 16:11:07 -0700507 fcn = [](InEigenType in) -> float {
508 float out = (OutEigenType)in; // default cast to float is round_to_nearest_float()
509 return out;
510 };
511}
512
Tai Lya4d748b2023-03-28 22:06:56 +0000513template <TOSA_REF_TYPE OutDtype>
514CastHelper<TOSA_REF_TYPE_FP32, OutDtype>::CastHelper()
Eric Kunzee5e26762020-10-13 16:11:07 -0700515{
James Ward736fd1a2023-01-23 17:13:37 +0000516 // fp32 data converted to integer
Eric Kunzee5e26762020-10-13 16:11:07 -0700517 fcn = [](float in) -> OutEigenType {
Jerry Ge44827be2023-12-15 00:05:37 +0000518 if (in >= float(OutMax))
519 return OutMax;
520
521 if (in <= float(OutMin))
522 return OutMin;
523
Eric Kunze57bc0792023-01-25 10:05:51 -0800524 OutEigenType out = std::rint(in);
Eric Kunzee5e26762020-10-13 16:11:07 -0700525 return out;
526 };
527}
528
Tai Lya4d748b2023-03-28 22:06:56 +0000529template <TOSA_REF_TYPE OutDtype>
530CastHelper<TOSA_REF_TYPE_FP64, OutDtype>::CastHelper()
531{
532 switch (OutDtype)
533 {
534 case TOSA_REF_TYPE_INT8:
535 case TOSA_REF_TYPE_INT16:
536 case TOSA_REF_TYPE_INT32:
537 // fp64 data converted to integer
538 fcn = [](InEigenType in) -> OutEigenType {
Jerry Ged6a04612023-12-15 22:45:39 +0000539 if (in >= double(OutMax))
540 return OutMax;
541
542 if (in <= double(OutMin))
543 return OutMin;
544
Tai Lya4d748b2023-03-28 22:06:56 +0000545 OutEigenType out = std::rint(in);
Tai Lya4d748b2023-03-28 22:06:56 +0000546 return out;
547 };
548 break;
549 case TOSA_REF_TYPE_FP64:
550 // no op
551 fcn = [](InEigenType in) -> OutEigenType { return in; };
552 break;
553 default:
554 ASSERT_MSG(false, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(OutDtype));
555 }
556}
557
Eric Kunzee5e26762020-10-13 16:11:07 -0700558// template explicit instantiation
559DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT8);
560DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT16);
561DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT32);
562DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL);
563DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16);
564DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32);
James Ward8b390432022-08-12 20:48:56 +0100565DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP16);
James Ward736fd1a2023-01-23 17:13:37 +0000566DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100567DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700568DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL);
569DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8);
570DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32);
James Ward8b390432022-08-12 20:48:56 +0100571DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP16);
James Ward736fd1a2023-01-23 17:13:37 +0000572DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100573DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700574DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL);
575DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8);
576DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16);
James Ward8b390432022-08-12 20:48:56 +0100577DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP16);
James Ward736fd1a2023-01-23 17:13:37 +0000578DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100579DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP32);
James Ward8b390432022-08-12 20:48:56 +0100580DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8);
581DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16);
582DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32);
James Ward736fd1a2023-01-23 17:13:37 +0000583DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP32);
James Ward24dbc422022-10-19 12:20:31 +0100584DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT8);
585DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT16);
586DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT32);
James Ward736fd1a2023-01-23 17:13:37 +0000587DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP32);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100588DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT8);
589DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT16);
590DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32);
James Ward736fd1a2023-01-23 17:13:37 +0000591DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP16);
592DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, BF16);
Tai Lya4d748b2023-03-28 22:06:56 +0000593DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT8);
594DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT16);
595DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT32);
596DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP64);
597DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64);
598DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64);
599DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -0700600
Kevin Cheng3a478572021-01-22 17:21:02 -0800601DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8);
602DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT16);
603DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT32);
604DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700605DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT16);
606DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800607DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700608DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT16);
609DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800610DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700611DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16);
612DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800613DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT8);
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100614DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT16);
615DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT16, INT16);
Kevin Cheng3a478572021-01-22 17:21:02 -0800616DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, UINT8);
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100617DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT8);
618DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT16);