blob: 5dbc7bd79a4c0c85f509f6c30a0ff78fe5ffad22 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Jerry Ged6a04612023-12-15 22:45:39 +00002// Copyright (c) 2020-2024, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "type_conversion.h"
James Ward736fd1a2023-01-23 17:13:37 +000017#include "arith_util.h"
Won Jeon2c34b462024-02-06 18:37:00 +000018#include "float_utils.h"
Jerry Ge9c9c8da2023-07-19 23:08:16 +000019#include "half.hpp"
20#include "quant_util.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070021#include "template_types.h"
22#include <cmath>
23
24using namespace TosaReference;
25using namespace Eigen;
26using namespace tosa;
27
Won Jeon2c34b462024-02-06 18:37:00 +000028using fp16 = tosa::reference::internal::float_t<int16_t, 5, true, true, true>;
29using bf16 = tosa::reference::internal::float_t<int16_t, 8, true, true, true>;
30using fp32 = tosa::reference::internal::float_t<int32_t, 8, true, true, true>;
31using fp8e4m3 = tosa::reference::internal::float_t<int8_t, 4, true, true, false>;
32using fp8e5m2 = tosa::reference::internal::float_t<int8_t, 5, true, true, true>;
33
Tai Lya4d748b2023-03-28 22:06:56 +000034template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +000035OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -070036 : GraphNode(sgt_, Op_RESCALE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070037{
38 setRequiredOperands(1, 1);
Eric Kunzee5e26762020-10-13 16:11:07 -070039 INIT_ATTRIBUTE(Rescale);
40}
41
Tai Lya4d748b2023-03-28 22:06:56 +000042template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070043OpRescale<Rank, InDtype, OutDtype>::~OpRescale()
44{
45 if (attribute)
46 delete attribute;
47}
48
Tai Lya4d748b2023-03-28 22:06:56 +000049template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070050int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes()
51{
Jerry Gea793f462023-04-11 00:05:02 +000052 // Check Tosa Level
53 auto tosa_level = g_func_config.tosa_level;
54 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
55
Eric Kunzee5e26762020-10-13 16:11:07 -070056 if (validateRequiredOperands())
57 return 1;
58
Eric Kunzee5e26762020-10-13 16:11:07 -070059 // output and input must be the same rank and size
60 if (inputs[0]->matchRankSize(*outputs[0]))
61 {
62 printNodeValidationError("OpRescale: input and output rank/size must match");
63 return 1;
64 }
65
66 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
67 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
68
69 ASSERT_MEM(in && out);
70
Tai Lya4d748b2023-03-28 22:06:56 +000071 if ((InDtype != TOSA_REF_TYPE_INT8) && (InDtype != TOSA_REF_TYPE_UINT8) && (InDtype != TOSA_REF_TYPE_UINT16) &&
72 (attribute->input_zp() != 0))
Kevin Chengcc61be32021-10-14 17:09:57 -070073 {
Tai Lya4d748b2023-03-28 22:06:56 +000074 printNodeValidationError("OpRescale: Input TOSA_REF_TYPE not INT8/UINT8/UINT16 and zero point not 0");
Kevin Chengcc61be32021-10-14 17:09:57 -070075 return 1;
76 }
77
Tai Lya4d748b2023-03-28 22:06:56 +000078 if ((OutDtype != TOSA_REF_TYPE_INT8) && (OutDtype != TOSA_REF_TYPE_UINT8) && (OutDtype != TOSA_REF_TYPE_UINT16) &&
79 (attribute->output_zp() != 0))
Kevin Chengcc61be32021-10-14 17:09:57 -070080 {
Tai Lya4d748b2023-03-28 22:06:56 +000081 printNodeValidationError("OpRescale: Output TOSA_REF_TYPE not INT8/UINT8/UINT16 and zero point not 0");
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010082 return 1;
83 }
84
Tai Lya4d748b2023-03-28 22:06:56 +000085 if ((InDtype == TOSA_REF_TYPE_UINT16) && ((attribute->input_zp() != 0) && (attribute->input_zp() != 32768)))
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010086 {
Tai Lya4d748b2023-03-28 22:06:56 +000087 printNodeValidationError("OpRescale: Input TOSA_REF_TYPE UINT16 and zero point not 0 or 32768");
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010088 return 1;
89 }
90
Tai Lya4d748b2023-03-28 22:06:56 +000091 if ((OutDtype == TOSA_REF_TYPE_UINT16) && ((attribute->output_zp() != 0) && (attribute->output_zp() != 32768)))
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010092 {
Tai Lya4d748b2023-03-28 22:06:56 +000093 printNodeValidationError("OpRescale: Output TOSA_REF_TYPE UINT16 and zero point not 0 or 32768");
Kevin Chengcc61be32021-10-14 17:09:57 -070094 return 1;
95 }
96
Tai Lya4d748b2023-03-28 22:06:56 +000097 if (attribute->scale32() && (InDtype == TOSA_REF_TYPE_INT48))
Kevin Chengcc61be32021-10-14 17:09:57 -070098 {
99 printNodeValidationError("OpRescale: Scale set to true but input type is INT48");
100 return 1;
101 }
102
103 if ((!attribute->scale32()) && attribute->double_round())
104 {
105 printNodeValidationError("OpRescale: Scale set to false but double round set to true");
106 return 1;
107 }
108
Eric Kunzee5e26762020-10-13 16:11:07 -0700109 return 0;
110}
111
Eric Kunzea8098c02023-09-07 00:31:54 +0000112// helpers to convert types
113static int64_t zero_extend(int8_t val)
114{
115 uint8_t* rval = reinterpret_cast<uint8_t*>(&val);
116 return static_cast<int64_t>(*rval);
117}
118static int64_t zero_extend(int16_t val)
119{
120 uint16_t* rval = reinterpret_cast<uint16_t*>(&val);
121 return static_cast<int64_t>(*rval);
122}
123
Tai Lya4d748b2023-03-28 22:06:56 +0000124template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700125int OpRescale<Rank, InDtype, OutDtype>::eval()
126{
127 int32_t input_zp = attribute->input_zp();
128 int32_t output_zp = attribute->output_zp();
129 std::vector<int32_t> multiplier = attribute->multiplier();
130 std::vector<int32_t> shift = attribute->shift();
Kevin Cheng0f87c952021-03-18 17:41:39 -0700131 bool scale32 = attribute->scale32();
132 bool double_round = attribute->double_round();
133 bool per_channel = attribute->per_channel();
Eric Kunzea8098c02023-09-07 00:31:54 +0000134 bool input_unsigned = attribute->input_unsigned();
135 bool output_unsigned = attribute->output_unsigned();
Eric Kunzee5e26762020-10-13 16:11:07 -0700136
Eric Kunzee5e26762020-10-13 16:11:07 -0700137 // reshape [d0, d1, ..., dn] into [d0 * d1 ..., dn]
138 Eigen::array<Eigen::Index, 2> shape_2d;
139 shape_2d[0] = 1;
140 if (Rank > 0)
141 {
142 for (int i = 0; i < Rank - 1; i++)
143 {
144 shape_2d[0] *= this->in->getShape()[i];
145 }
146 shape_2d[1] = this->in->getShape()[Rank - 1];
147 }
148 else
149 {
150 shape_2d[1] = 1;
151 }
152 ETensor2<InEigenType> input_reshaped = this->in->getTensor().reshape(shape_2d);
153
154 ETensor2<OutEigenType> output_2d(shape_2d);
155
Eric Kunzee5e26762020-10-13 16:11:07 -0700156 if (per_channel)
157 {
158 ETensor2<InEigenType> curr_channel_slice_prescaled;
159 ETensor2<OutEigenType> curr_channel_slice_postscaled;
160 int32_t channel_multiplier, channel_shift;
161 Eigen::array<Eigen::Index, 2> begin, size;
162 size = Eigen::array<Eigen::Index, 2>({ shape_2d[0], 1 });
Kevin Chengacb550f2021-06-29 15:32:19 -0700163 try
Eric Kunzee5e26762020-10-13 16:11:07 -0700164 {
Kevin Chengacb550f2021-06-29 15:32:19 -0700165 for (int32_t i = 0; i < shape_2d[1]; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700166 {
Eric Kunzea8098c02023-09-07 00:31:54 +0000167 begin = Eigen::array<Eigen::Index, 2>({ 0, i });
168 curr_channel_slice_prescaled = input_reshaped.slice(begin, size);
169 channel_multiplier = multiplier[i];
170 channel_shift = shift[i];
171 curr_channel_slice_postscaled = curr_channel_slice_prescaled.unaryExpr(
172 [input_zp, output_zp, channel_multiplier, channel_shift, double_round, scale32, input_unsigned,
173 output_unsigned](InEigenType in_val) -> OutEigenType {
174 int64_t input_zp_shifted;
175 if (input_unsigned)
176 {
177 int64_t in_val64;
178 int64_t in_zp64;
179 switch (GetNumBits<InDtype>::value)
180 {
181 case 8:
182 in_val64 = zero_extend(static_cast<int8_t>(in_val));
183 in_zp64 = zero_extend(static_cast<int8_t>(input_zp));
184 break;
185 case 16:
186 in_val64 = zero_extend(static_cast<int16_t>(in_val));
187 in_zp64 = zero_extend(static_cast<int16_t>(input_zp));
188 break;
189 default:
190 in_val64 = static_cast<int64_t>(in_val);
191 in_zp64 = static_cast<int64_t>(input_zp);
192 break;
193 }
194 input_zp_shifted = in_val64 - in_zp64;
195 }
196 else
197 {
198 input_zp_shifted = in_val - input_zp;
199 }
Kevin Chengacb550f2021-06-29 15:32:19 -0700200 int32_t scaled;
201 if (scale32)
Eric Kunzea8098c02023-09-07 00:31:54 +0000202 scaled = TosaReference::QuantUtil::apply_scale_32(static_cast<int32_t>(input_zp_shifted),
203 channel_multiplier, channel_shift,
204 double_round);
Kevin Chengacb550f2021-06-29 15:32:19 -0700205 else
206 scaled = TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, channel_multiplier,
207 channel_shift);
Eric Kunzea8098c02023-09-07 00:31:54 +0000208 int64_t output_zp_extended;
209 if (output_unsigned)
210 {
211 switch (GetNumBits<OutDtype>::value)
212 {
213 case 8:
214 output_zp_extended = zero_extend(static_cast<int8_t>(output_zp));
215 break;
216 case 16:
217 output_zp_extended = zero_extend(static_cast<int16_t>(output_zp));
218 break;
219 default:
220 output_zp_extended = static_cast<int64_t>(output_zp);
221 break;
222 }
223 }
224 else
225 {
226 output_zp_extended = static_cast<int64_t>(output_zp);
227 }
228 int64_t res_in_64 = static_cast<int64_t>(scaled) + output_zp_extended;
Jeremy Johnsondf628d42023-01-10 14:40:54 +0000229 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
230 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min());
231 if (res_in_64 > i32_max_in_64 || res_in_64 < i32_min_in_64)
232 {
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000233 std::string desc = "scaling result [" + std::to_string(scaled) + "] plus output_zp [" +
234 std::to_string(output_zp) + "] not in i32 range";
Jeremy Johnsondf628d42023-01-10 14:40:54 +0000235 throw desc;
236 }
237 OutEigenType out_val = static_cast<OutEigenType>(res_in_64);
Kevin Chengacb550f2021-06-29 15:32:19 -0700238 out_val = std::max<OutEigenType>(out_val, QMin);
239 out_val = std::min<OutEigenType>(out_val, QMax);
240 return out_val;
241 });
242
243 for (int32_t j = 0; j < shape_2d[0]; j++)
244 {
245 output_2d(j, i) = curr_channel_slice_postscaled(j, 0);
246 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700247 }
248 }
Kevin Chengacb550f2021-06-29 15:32:19 -0700249 catch (std::string desc)
250 {
Jeremy Johnsondf628d42023-01-10 14:40:54 +0000251 REQUIRE(false, "OpRescale failure: %s.", desc.c_str());
Kevin Chengacb550f2021-06-29 15:32:19 -0700252 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700253 }
254 else
255 {
256 int32_t tensor_multiplier = multiplier[0];
257 int32_t tensor_shift = shift[0];
Kevin Chengacb550f2021-06-29 15:32:19 -0700258 try
259 {
Eric Kunzea8098c02023-09-07 00:31:54 +0000260 output_2d =
261 input_reshaped.unaryExpr([input_zp, output_zp, tensor_multiplier, tensor_shift, double_round, scale32,
262 input_unsigned, output_unsigned](InEigenType in_val) -> OutEigenType {
263 int64_t input_zp_shifted;
264 if (input_unsigned)
265 {
266 int64_t in_val64;
267 int64_t in_zp64;
268 switch (GetNumBits<InDtype>::value)
269 {
270 case 8:
271 in_val64 = zero_extend(static_cast<int8_t>(in_val));
272 in_zp64 = zero_extend(static_cast<int8_t>(input_zp));
273 break;
274 case 16:
275 in_val64 = zero_extend(static_cast<int16_t>(in_val));
276 in_zp64 = zero_extend(static_cast<int16_t>(input_zp));
277 break;
278 default:
279 in_val64 = static_cast<int64_t>(in_val);
280 in_zp64 = static_cast<int64_t>(input_zp);
281 break;
282 }
283 input_zp_shifted = in_val64 - in_zp64;
284 }
285 else
286 {
287 input_zp_shifted = in_val - input_zp;
288 }
289 int32_t scaled;
290 if (scale32)
291 scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier,
292 tensor_shift, double_round);
293 else
294 scaled =
295 TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, tensor_multiplier, tensor_shift);
Jeremy Johnsondf628d42023-01-10 14:40:54 +0000296
Eric Kunzea8098c02023-09-07 00:31:54 +0000297 int64_t output_zp_extended;
298 if (output_unsigned)
299 {
300 switch (GetNumBits<OutDtype>::value)
301 {
302 case 8:
303 output_zp_extended = zero_extend(static_cast<int8_t>(output_zp));
304 break;
305 case 16:
306 output_zp_extended = zero_extend(static_cast<int16_t>(output_zp));
307 break;
308 default:
309 output_zp_extended = static_cast<int64_t>(output_zp);
310 break;
311 }
312 }
313 else
314 {
315 output_zp_extended = static_cast<int64_t>(output_zp);
316 }
317 int64_t res_in_64 = static_cast<int64_t>(scaled) + output_zp_extended;
318 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
319 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min());
320 if (res_in_64 > i32_max_in_64 || res_in_64 < i32_min_in_64)
321 {
322 std::string desc = "scaling result [" + std::to_string(scaled) + "] plus output_zp [" +
323 std::to_string(output_zp) + "] not in i32 range";
324 throw desc;
325 }
326
327 OutEigenType out_val = static_cast<OutEigenType>(res_in_64);
328 out_val = std::max<OutEigenType>(out_val, QMin);
329 out_val = std::min<OutEigenType>(out_val, QMax);
330 return out_val;
331 });
Kevin Chengacb550f2021-06-29 15:32:19 -0700332 }
333 catch (std::string desc)
334 {
Jeremy Johnsondf628d42023-01-10 14:40:54 +0000335 REQUIRE(false, "OpRescale failure: %s.", desc.c_str());
Kevin Chengacb550f2021-06-29 15:32:19 -0700336 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700337 }
338
339 // reshape [d0 * d1 ..., dn] back to [d0, d1, ..., dn]
340 Eigen::array<Eigen::Index, Rank> output_shape;
341 for (int i = 0; i < Rank; i++)
342 {
343 output_shape[i] = this->out->getShape()[i];
344 }
345 this->out->getTensor() = output_2d.reshape(output_shape);
346
347 return GraphNode::eval();
348}
349
Tai Lya4d748b2023-03-28 22:06:56 +0000350template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000351OpCast<Rank, InDtype, OutDtype>::OpCast(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700352 : GraphNode(sgt_, Op_CAST, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700353{
354 setRequiredOperands(1, 1);
355 setRequiredRank(0, 6);
356}
357
Tai Lya4d748b2023-03-28 22:06:56 +0000358template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700359OpCast<Rank, InDtype, OutDtype>::~OpCast()
360{}
361
Tai Lya4d748b2023-03-28 22:06:56 +0000362template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700363int OpCast<Rank, InDtype, OutDtype>::checkTensorAttributes()
364{
Jerry Gea793f462023-04-11 00:05:02 +0000365 // Check Tosa Level
366 auto tosa_level = g_func_config.tosa_level;
367 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
368
Eric Kunzee5e26762020-10-13 16:11:07 -0700369 if (validateRequiredOperands())
370 return 1;
371
372 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
373 {
374 return 1;
375 }
376
377 // output and input must be the same rank and size
378 if (inputs[0]->matchRankSize(*outputs[0]))
379 {
380 printNodeValidationError("OpCast: input and output rank/size must match");
381 return 1;
382 }
383
384 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
385 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
386
387 ASSERT_MEM(in && out);
388
389 return 0;
390}
391
Tai Lya4d748b2023-03-28 22:06:56 +0000392template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700393int OpCast<Rank, InDtype, OutDtype>::eval()
394{
395 this->out->getTensor() = this->in->getTensor().unaryExpr(cast_helper.get_fcn());
396
397 return GraphNode::eval();
398}
399
Tai Lya4d748b2023-03-28 22:06:56 +0000400template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700401CastHelper<InDtype, OutDtype>::CastHelper()
402{
403 fcn = [](InEigenType in) -> OutEigenType {
404 OutEigenType out = (OutEigenType)in; // implicit sign_extend() if sizeof(out_t) >= sizeof(in_t)
Eric Kunzee5e26762020-10-13 16:11:07 -0700405 return out;
406 };
407}
408
Tai Lya4d748b2023-03-28 22:06:56 +0000409template <TOSA_REF_TYPE InDtype>
410CastHelper<InDtype, TOSA_REF_TYPE_BOOL>::CastHelper()
Eric Kunzee5e26762020-10-13 16:11:07 -0700411{
412 fcn = [](InEigenType in) -> bool { return (in != 0) ? true : false; };
413}
414
Tai Lya4d748b2023-03-28 22:06:56 +0000415template <TOSA_REF_TYPE OutDtype>
416CastHelper<TOSA_REF_TYPE_BOOL, OutDtype>::CastHelper()
Eric Kunzee5e26762020-10-13 16:11:07 -0700417{
418 fcn = [](bool in) -> OutEigenType {
419 OutEigenType out = in ? (OutEigenType)1 : (OutEigenType)0;
420 return out;
421 };
422}
423
Tai Lya4d748b2023-03-28 22:06:56 +0000424template <TOSA_REF_TYPE InDtype>
425CastHelper<InDtype, TOSA_REF_TYPE_FP16>::CastHelper()
James Ward8b390432022-08-12 20:48:56 +0100426{
James Ward736fd1a2023-01-23 17:13:37 +0000427 // Integer data converted to fp16 (stored as fp32)
James Ward8b390432022-08-12 20:48:56 +0100428 fcn = [](InEigenType in) -> float {
James Ward736fd1a2023-01-23 17:13:37 +0000429 half_float::half h = half_float::half(in);
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000430 float out = half_float::half_cast<float, half_float::half>(h);
James Ward736fd1a2023-01-23 17:13:37 +0000431 return out;
432 };
433}
434
Tai Lya4d748b2023-03-28 22:06:56 +0000435CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP16>::CastHelper()
James Ward736fd1a2023-01-23 17:13:37 +0000436{
437 // fp32 data converted to fp16 (stored as fp32)
438 fcn = [](float in) -> float {
Tai Lya4d748b2023-03-28 22:06:56 +0000439 float out = fpTrunc<TOSA_REF_TYPE_FP16>(in); // truncate required for conversion from higher precision
James Ward736fd1a2023-01-23 17:13:37 +0000440 return out;
441 };
442}
443
Tai Lya4d748b2023-03-28 22:06:56 +0000444template <TOSA_REF_TYPE InDtype>
445CastHelper<InDtype, TOSA_REF_TYPE_BF16>::CastHelper()
James Ward736fd1a2023-01-23 17:13:37 +0000446{
447 // Integer data converted to bf16 (stored as fp32)
448 fcn = [](InEigenType in) -> float {
449 float out = (float)in; // default cast to float is round_to_nearest_float()
450 return out;
451 };
452}
453
Tai Lya4d748b2023-03-28 22:06:56 +0000454CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_BF16>::CastHelper()
James Ward736fd1a2023-01-23 17:13:37 +0000455{
456 // fp32 data converted to bf16 (stored as fp32)
457 fcn = [](float in) -> float {
Tai Lya4d748b2023-03-28 22:06:56 +0000458 return fpTrunc<TOSA_REF_TYPE_BF16>(in); // truncate required for conversions from higher precision
James Ward8b390432022-08-12 20:48:56 +0100459 };
460}
461
Tai Lya4d748b2023-03-28 22:06:56 +0000462template <TOSA_REF_TYPE OutDtype>
463CastHelper<TOSA_REF_TYPE_FP16, OutDtype>::CastHelper()
James Ward8b390432022-08-12 20:48:56 +0100464{
James Ward736fd1a2023-01-23 17:13:37 +0000465 // fp16 data (stored as fp32) converted to integer
James Ward8b390432022-08-12 20:48:56 +0100466 fcn = [](float in) -> OutEigenType {
James Ward736fd1a2023-01-23 17:13:37 +0000467 // Cast from float representation back to half_float before rounding
468 half_float::half h = half_float::half(in);
Jerry Ged6a04612023-12-15 22:45:39 +0000469 if (h >= half_float::half(float(OutMax)))
470 return OutMax;
471
472 if (h <= half_float::half(float(OutMin)))
473 return OutMin;
474
475 h = std::rint(h);
476 OutEigenType out = half_float::half_cast<OutEigenType, half_float::half>(h);
477
James Ward8b390432022-08-12 20:48:56 +0100478 return out;
479 };
480}
481
Tai Lya4d748b2023-03-28 22:06:56 +0000482CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP32>::CastHelper()
James Ward736fd1a2023-01-23 17:13:37 +0000483{
484 // No-op since fp16 values treated internally as their fp32 representation
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000485 fcn = [](float in) -> OutEigenType { return in; };
James Ward736fd1a2023-01-23 17:13:37 +0000486}
487
Tai Lya4d748b2023-03-28 22:06:56 +0000488template <TOSA_REF_TYPE OutDtype>
489CastHelper<TOSA_REF_TYPE_BF16, OutDtype>::CastHelper()
James Ward736fd1a2023-01-23 17:13:37 +0000490{
491 // bf16 data (stored as fp32) converted to integer
492 fcn = [](float in) -> OutEigenType {
Jerry Ged6a04612023-12-15 22:45:39 +0000493 if (in >= float(OutMax))
494 return OutMax;
495
496 if (in <= float(OutMin))
497 return OutMin;
498
499 OutEigenType out = std::rint(in);
James Ward736fd1a2023-01-23 17:13:37 +0000500 return out;
501 };
502}
503
Tai Lya4d748b2023-03-28 22:06:56 +0000504CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP32>::CastHelper()
James Ward736fd1a2023-01-23 17:13:37 +0000505{
506 // No-op since bf16 values treated as truncated fp32 internally
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000507 fcn = [](InEigenType in) -> OutEigenType { return in; };
James Ward736fd1a2023-01-23 17:13:37 +0000508}
509
Tai Lya4d748b2023-03-28 22:06:56 +0000510template <TOSA_REF_TYPE InDtype>
511CastHelper<InDtype, TOSA_REF_TYPE_FP32>::CastHelper()
Eric Kunzee5e26762020-10-13 16:11:07 -0700512{
James Ward736fd1a2023-01-23 17:13:37 +0000513 // Integer data converted to fp32
Eric Kunzee5e26762020-10-13 16:11:07 -0700514 fcn = [](InEigenType in) -> float {
515 float out = (OutEigenType)in; // default cast to float is round_to_nearest_float()
516 return out;
517 };
518}
519
Tai Lya4d748b2023-03-28 22:06:56 +0000520template <TOSA_REF_TYPE OutDtype>
521CastHelper<TOSA_REF_TYPE_FP32, OutDtype>::CastHelper()
Eric Kunzee5e26762020-10-13 16:11:07 -0700522{
James Ward736fd1a2023-01-23 17:13:37 +0000523 // fp32 data converted to integer
Eric Kunzee5e26762020-10-13 16:11:07 -0700524 fcn = [](float in) -> OutEigenType {
Jerry Ge44827be2023-12-15 00:05:37 +0000525 if (in >= float(OutMax))
526 return OutMax;
527
528 if (in <= float(OutMin))
529 return OutMin;
530
Eric Kunze57bc0792023-01-25 10:05:51 -0800531 OutEigenType out = std::rint(in);
Eric Kunzee5e26762020-10-13 16:11:07 -0700532 return out;
533 };
534}
535
Tai Lya4d748b2023-03-28 22:06:56 +0000536template <TOSA_REF_TYPE OutDtype>
Won Jeon2c34b462024-02-06 18:37:00 +0000537CastHelper<TOSA_REF_TYPE_FP8E4M3, OutDtype>::CastHelper()
538{
539 // fp8e4m3 data (stored as fp32) converted to integer
540 fcn = [](float in) -> OutEigenType {
541 if (in >= float(OutMax))
542 return OutMax;
543 if (in <= float(OutMin))
544 return OutMin;
545
546 OutEigenType out = std::rint(in);
547 return out;
548 };
549}
550
551CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_FP16>::CastHelper()
552{
553 // fp8e4m3 data (stored as fp32) converted to fp16 (stored as fp32)
554 fcn = [](float in) -> float {
555 half_float::half h = half_float::half(in);
556 float out = half_float::half_cast<half_float::half, float>(h);
557 return out;
558 };
559}
560
561CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_BF16>::CastHelper()
562{
563 // fp8e4m3 data (stored as fp32) converted to bf16 (stored as fp32)
564 fcn = [](float in) -> float { return (float)in; };
565}
566
567CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_FP32>::CastHelper()
568{
569 // fp8e4m3 data (stored as fp32) converted to fp32
570 fcn = [](InEigenType in) -> OutEigenType { return in; };
571}
572
573template <TOSA_REF_TYPE OutDtype>
574CastHelper<TOSA_REF_TYPE_FP8E5M2, OutDtype>::CastHelper()
575{
576 // fp8e5m2 data (stored as fp32) converted to integer
577 fcn = [](float in) -> OutEigenType {
578 if (in >= float(OutMax))
579 return OutMax;
580 if (in <= float(OutMin))
581 return OutMin;
582
583 OutEigenType out = std::rint(in);
584 return out;
585 };
586}
587
588CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_FP16>::CastHelper()
589{
590 // fp8e5m2 data (stored as fp32) converted to fp16 (stored as fp32)
591 fcn = [](float in) -> float {
592 half_float::half h = half_float::half(in);
593 float out = half_float::half_cast<half_float::half, float>(h);
594 return out;
595 };
596}
597
598CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_BF16>::CastHelper()
599{
600 // fp8e5m2 data (stored as fp32) converted to bf16 (stored as fp32)
601 fcn = [](float in) -> float { return (float)in; };
602}
603
604CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_FP32>::CastHelper()
605{
606 // fp8e5m2 data (stored as fp32) converted to fp32
607 fcn = [](InEigenType in) -> OutEigenType { return in; };
608}
609
610template <TOSA_REF_TYPE InDtype>
611CastHelper<InDtype, TOSA_REF_TYPE_FP8E4M3>::CastHelper()
612{
613 // Integer data converted to fp8e4m3 (stored as fp32)
614 fcn = [](InEigenType in) -> float {
615 auto f = static_cast<fp32>(static_cast<fp8e4m3>(float(in)));
616 float out = static_cast<float>(f);
617 return out;
618 };
619}
620
621CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP8E4M3>::CastHelper()
622{
623 // fp16 data (stored as fp32) converted to fp8e4m3 (stored as fp32)
624 fcn = [](float in) -> float {
625 auto f = static_cast<fp32>(static_cast<fp8e4m3>(in));
626 float out = static_cast<float>(f);
627 return out;
628 };
629}
630
631CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP8E4M3>::CastHelper()
632{
633 // bf16 data (stored as fp32) converted to fp8e4m3 (stored as fp32)
634 fcn = [](float in) -> float {
635 auto f = static_cast<fp32>(static_cast<fp8e4m3>(in));
636 float out = static_cast<float>(f);
637 return out;
638 };
639}
640
641CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP8E4M3>::CastHelper()
642{
643 // fp32 data converted to fp8e4m3 (stored as fp32)
644 fcn = [](float in) -> float {
645 auto f = static_cast<fp32>(static_cast<fp8e4m3>(in));
646 float out = static_cast<float>(f);
647 return out;
648 };
649}
650
651template <TOSA_REF_TYPE InDtype>
652CastHelper<InDtype, TOSA_REF_TYPE_FP8E5M2>::CastHelper()
653{
654 // Integer data converted to fp8e5m2 (stored as fp32)
655 fcn = [](InEigenType in) -> float {
656 auto f = static_cast<fp32>(static_cast<fp8e5m2>(float(in)));
657 float out = static_cast<float>(f);
658 return out;
659 };
660}
661
662CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP8E5M2>::CastHelper()
663{
664 // fp16 data (stored as fp32) converted to fp8e5m2 (stored as fp32)
665 fcn = [](float in) -> float {
666 auto f = static_cast<fp32>(static_cast<fp8e5m2>(in));
667 float out = static_cast<float>(f);
668 return out;
669 };
670}
671
672CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP8E5M2>::CastHelper()
673{
674 // bf16 data (stored as fp32) converted to fp8e5m2 (stored as fp32)
675 fcn = [](float in) -> float {
676 auto f = static_cast<fp32>(static_cast<fp8e5m2>(in));
677 float out = static_cast<float>(f);
678 return out;
679 };
680}
681
682CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP8E5M2>::CastHelper()
683{
684 // fp32 data converted to fp8e5m2 (stored as fp32)
685 fcn = [](float in) -> float {
686 auto f = static_cast<fp32>(static_cast<fp8e5m2>(in));
687 float out = static_cast<float>(f);
688 return out;
689 };
690}
691
692template <TOSA_REF_TYPE OutDtype>
Tai Lya4d748b2023-03-28 22:06:56 +0000693CastHelper<TOSA_REF_TYPE_FP64, OutDtype>::CastHelper()
694{
695 switch (OutDtype)
696 {
697 case TOSA_REF_TYPE_INT8:
698 case TOSA_REF_TYPE_INT16:
699 case TOSA_REF_TYPE_INT32:
700 // fp64 data converted to integer
701 fcn = [](InEigenType in) -> OutEigenType {
Jerry Ged6a04612023-12-15 22:45:39 +0000702 if (in >= double(OutMax))
703 return OutMax;
704
705 if (in <= double(OutMin))
706 return OutMin;
707
Tai Lya4d748b2023-03-28 22:06:56 +0000708 OutEigenType out = std::rint(in);
Tai Lya4d748b2023-03-28 22:06:56 +0000709 return out;
710 };
711 break;
712 case TOSA_REF_TYPE_FP64:
713 // no op
714 fcn = [](InEigenType in) -> OutEigenType { return in; };
715 break;
716 default:
717 ASSERT_MSG(false, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(OutDtype));
718 }
719}
720
Eric Kunzee5e26762020-10-13 16:11:07 -0700721// template explicit instantiation
722DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT8);
723DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT16);
724DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT32);
725DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL);
726DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16);
727DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32);
James Ward8b390432022-08-12 20:48:56 +0100728DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP16);
James Ward736fd1a2023-01-23 17:13:37 +0000729DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100730DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700731DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL);
732DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8);
733DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32);
James Ward8b390432022-08-12 20:48:56 +0100734DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP16);
James Ward736fd1a2023-01-23 17:13:37 +0000735DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100736DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700737DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL);
738DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8);
739DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16);
James Ward8b390432022-08-12 20:48:56 +0100740DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP16);
James Ward736fd1a2023-01-23 17:13:37 +0000741DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100742DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP32);
James Ward8b390432022-08-12 20:48:56 +0100743DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8);
744DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16);
745DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32);
James Ward736fd1a2023-01-23 17:13:37 +0000746DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP32);
James Ward24dbc422022-10-19 12:20:31 +0100747DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT8);
748DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT16);
749DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT32);
James Ward736fd1a2023-01-23 17:13:37 +0000750DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP32);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100751DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT8);
752DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT16);
753DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32);
James Ward736fd1a2023-01-23 17:13:37 +0000754DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP16);
755DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, BF16);
Tai Lya4d748b2023-03-28 22:06:56 +0000756DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT8);
757DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT16);
758DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT32);
759DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP64);
760DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64);
761DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64);
762DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64);
Won Jeon2c34b462024-02-06 18:37:00 +0000763DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP8E4M3);
764DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP8E5M2);
765DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP16);
766DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, BF16);
767DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP32);
768DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP16);
769DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, BF16);
770DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP32);
771DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E4M3);
772DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E5M2);
773DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E4M3);
774DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E5M2);
Eric Kunzee5e26762020-10-13 16:11:07 -0700775
Kevin Cheng3a478572021-01-22 17:21:02 -0800776DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8);
777DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT16);
778DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT32);
779DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700780DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT16);
781DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800782DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700783DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT16);
784DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800785DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700786DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16);
787DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800788DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT8);
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100789DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT16);
790DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT16, INT16);
Kevin Cheng3a478572021-01-22 17:21:02 -0800791DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, UINT8);
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100792DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT8);
793DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT16);