blob: 9464fd95dd23da1cc2c3224d2fdf7b0d626bb40e [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Tai Lya4d748b2023-03-28 22:06:56 +00002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "type_conversion.h"
James Ward736fd1a2023-01-23 17:13:37 +000017#include "arith_util.h"
Jerry Ge9c9c8da2023-07-19 23:08:16 +000018#include "half.hpp"
19#include "quant_util.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070020#include "template_types.h"
21#include <cmath>
22
23using namespace TosaReference;
24using namespace Eigen;
25using namespace tosa;
26
Tai Lya4d748b2023-03-28 22:06:56 +000027template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +000028OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -070029 : GraphNode(sgt_, Op_RESCALE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070030{
31 setRequiredOperands(1, 1);
Eric Kunzee5e26762020-10-13 16:11:07 -070032 INIT_ATTRIBUTE(Rescale);
33}
34
Tai Lya4d748b2023-03-28 22:06:56 +000035template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070036OpRescale<Rank, InDtype, OutDtype>::~OpRescale()
37{
38 if (attribute)
39 delete attribute;
40}
41
Tai Lya4d748b2023-03-28 22:06:56 +000042template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070043int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes()
44{
Jerry Gea793f462023-04-11 00:05:02 +000045 // Check Tosa Level
46 auto tosa_level = g_func_config.tosa_level;
47 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
48
Eric Kunzee5e26762020-10-13 16:11:07 -070049 if (validateRequiredOperands())
50 return 1;
51
Eric Kunzee5e26762020-10-13 16:11:07 -070052 // output and input must be the same rank and size
53 if (inputs[0]->matchRankSize(*outputs[0]))
54 {
55 printNodeValidationError("OpRescale: input and output rank/size must match");
56 return 1;
57 }
58
59 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
60 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
61
62 ASSERT_MEM(in && out);
63
Tai Lya4d748b2023-03-28 22:06:56 +000064 if ((InDtype != TOSA_REF_TYPE_INT8) && (InDtype != TOSA_REF_TYPE_UINT8) && (InDtype != TOSA_REF_TYPE_UINT16) &&
65 (attribute->input_zp() != 0))
Kevin Chengcc61be32021-10-14 17:09:57 -070066 {
Tai Lya4d748b2023-03-28 22:06:56 +000067 printNodeValidationError("OpRescale: Input TOSA_REF_TYPE not INT8/UINT8/UINT16 and zero point not 0");
Kevin Chengcc61be32021-10-14 17:09:57 -070068 return 1;
69 }
70
Tai Lya4d748b2023-03-28 22:06:56 +000071 if ((OutDtype != TOSA_REF_TYPE_INT8) && (OutDtype != TOSA_REF_TYPE_UINT8) && (OutDtype != TOSA_REF_TYPE_UINT16) &&
72 (attribute->output_zp() != 0))
Kevin Chengcc61be32021-10-14 17:09:57 -070073 {
Tai Lya4d748b2023-03-28 22:06:56 +000074 printNodeValidationError("OpRescale: Output TOSA_REF_TYPE not INT8/UINT8/UINT16 and zero point not 0");
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010075 return 1;
76 }
77
Tai Lya4d748b2023-03-28 22:06:56 +000078 if ((InDtype == TOSA_REF_TYPE_UINT16) && ((attribute->input_zp() != 0) && (attribute->input_zp() != 32768)))
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010079 {
Tai Lya4d748b2023-03-28 22:06:56 +000080 printNodeValidationError("OpRescale: Input TOSA_REF_TYPE UINT16 and zero point not 0 or 32768");
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010081 return 1;
82 }
83
Tai Lya4d748b2023-03-28 22:06:56 +000084 if ((OutDtype == TOSA_REF_TYPE_UINT16) && ((attribute->output_zp() != 0) && (attribute->output_zp() != 32768)))
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010085 {
Tai Lya4d748b2023-03-28 22:06:56 +000086 printNodeValidationError("OpRescale: Output TOSA_REF_TYPE UINT16 and zero point not 0 or 32768");
Kevin Chengcc61be32021-10-14 17:09:57 -070087 return 1;
88 }
89
Tai Lya4d748b2023-03-28 22:06:56 +000090 if (attribute->scale32() && (InDtype == TOSA_REF_TYPE_INT48))
Kevin Chengcc61be32021-10-14 17:09:57 -070091 {
92 printNodeValidationError("OpRescale: Scale set to true but input type is INT48");
93 return 1;
94 }
95
96 if ((!attribute->scale32()) && attribute->double_round())
97 {
98 printNodeValidationError("OpRescale: Scale set to false but double round set to true");
99 return 1;
100 }
101
Eric Kunzee5e26762020-10-13 16:11:07 -0700102 return 0;
103}
104
Tai Lya4d748b2023-03-28 22:06:56 +0000105template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700106int OpRescale<Rank, InDtype, OutDtype>::eval()
107{
108 int32_t input_zp = attribute->input_zp();
109 int32_t output_zp = attribute->output_zp();
110 std::vector<int32_t> multiplier = attribute->multiplier();
111 std::vector<int32_t> shift = attribute->shift();
Kevin Cheng0f87c952021-03-18 17:41:39 -0700112 bool scale32 = attribute->scale32();
113 bool double_round = attribute->double_round();
114 bool per_channel = attribute->per_channel();
Eric Kunzee5e26762020-10-13 16:11:07 -0700115
Eric Kunzee5e26762020-10-13 16:11:07 -0700116 // reshape [d0, d1, ..., dn] into [d0 * d1 ..., dn]
117 Eigen::array<Eigen::Index, 2> shape_2d;
118 shape_2d[0] = 1;
119 if (Rank > 0)
120 {
121 for (int i = 0; i < Rank - 1; i++)
122 {
123 shape_2d[0] *= this->in->getShape()[i];
124 }
125 shape_2d[1] = this->in->getShape()[Rank - 1];
126 }
127 else
128 {
129 shape_2d[1] = 1;
130 }
131 ETensor2<InEigenType> input_reshaped = this->in->getTensor().reshape(shape_2d);
132
133 ETensor2<OutEigenType> output_2d(shape_2d);
134
Eric Kunzee5e26762020-10-13 16:11:07 -0700135 if (per_channel)
136 {
137 ETensor2<InEigenType> curr_channel_slice_prescaled;
138 ETensor2<OutEigenType> curr_channel_slice_postscaled;
139 int32_t channel_multiplier, channel_shift;
140 Eigen::array<Eigen::Index, 2> begin, size;
141 size = Eigen::array<Eigen::Index, 2>({ shape_2d[0], 1 });
Kevin Chengacb550f2021-06-29 15:32:19 -0700142 try
Eric Kunzee5e26762020-10-13 16:11:07 -0700143 {
Kevin Chengacb550f2021-06-29 15:32:19 -0700144 for (int32_t i = 0; i < shape_2d[1]; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700145 {
Kevin Chengacb550f2021-06-29 15:32:19 -0700146 begin = Eigen::array<Eigen::Index, 2>({ 0, i });
147 curr_channel_slice_prescaled = input_reshaped.slice(begin, size);
148 channel_multiplier = multiplier[i];
149 channel_shift = shift[i];
150 curr_channel_slice_postscaled =
151 curr_channel_slice_prescaled.unaryExpr([input_zp, output_zp, channel_multiplier, channel_shift,
152 double_round, scale32](InEigenType in_val) -> OutEigenType {
153 InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
154 int32_t scaled;
155 if (scale32)
156 scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, channel_multiplier,
157 channel_shift, double_round);
158 else
159 scaled = TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, channel_multiplier,
160 channel_shift);
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000161 int64_t res_in_64 = static_cast<int64_t>(scaled) + output_zp;
Jeremy Johnsondf628d42023-01-10 14:40:54 +0000162 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
163 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min());
164 if (res_in_64 > i32_max_in_64 || res_in_64 < i32_min_in_64)
165 {
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000166 std::string desc = "scaling result [" + std::to_string(scaled) + "] plus output_zp [" +
167 std::to_string(output_zp) + "] not in i32 range";
Jeremy Johnsondf628d42023-01-10 14:40:54 +0000168 throw desc;
169 }
170 OutEigenType out_val = static_cast<OutEigenType>(res_in_64);
Kevin Chengacb550f2021-06-29 15:32:19 -0700171 out_val = std::max<OutEigenType>(out_val, QMin);
172 out_val = std::min<OutEigenType>(out_val, QMax);
173 return out_val;
174 });
175
176 for (int32_t j = 0; j < shape_2d[0]; j++)
177 {
178 output_2d(j, i) = curr_channel_slice_postscaled(j, 0);
179 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700180 }
181 }
Kevin Chengacb550f2021-06-29 15:32:19 -0700182 catch (std::string desc)
183 {
Jeremy Johnsondf628d42023-01-10 14:40:54 +0000184 REQUIRE(false, "OpRescale failure: %s.", desc.c_str());
Kevin Chengacb550f2021-06-29 15:32:19 -0700185 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700186 }
187 else
188 {
189 int32_t tensor_multiplier = multiplier[0];
190 int32_t tensor_shift = shift[0];
Kevin Chengacb550f2021-06-29 15:32:19 -0700191 try
192 {
193 output_2d = input_reshaped.unaryExpr([input_zp, output_zp, tensor_multiplier, tensor_shift, double_round,
194 scale32](InEigenType in_val) -> OutEigenType {
195 InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
196 int32_t scaled;
197 if (scale32)
198 scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier, tensor_shift,
199 double_round);
200 else
201 scaled =
202 TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, tensor_multiplier, tensor_shift);
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000203 int64_t res_in_64 = static_cast<int64_t>(scaled) + output_zp;
Jeremy Johnsondf628d42023-01-10 14:40:54 +0000204 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
205 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min());
206 if (res_in_64 > i32_max_in_64 || res_in_64 < i32_min_in_64)
207 {
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000208 std::string desc = "scaling result [" + std::to_string(scaled) + "] plus output_zp [" +
209 std::to_string(output_zp) + "] not in i32 range";
Jeremy Johnsondf628d42023-01-10 14:40:54 +0000210 throw desc;
211 }
212
213 OutEigenType out_val = static_cast<OutEigenType>(res_in_64);
Kevin Chengacb550f2021-06-29 15:32:19 -0700214 out_val = std::max<OutEigenType>(out_val, QMin);
215 out_val = std::min<OutEigenType>(out_val, QMax);
216 return out_val;
217 });
218 }
219 catch (std::string desc)
220 {
Jeremy Johnsondf628d42023-01-10 14:40:54 +0000221 REQUIRE(false, "OpRescale failure: %s.", desc.c_str());
Kevin Chengacb550f2021-06-29 15:32:19 -0700222 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700223 }
224
225 // reshape [d0 * d1 ..., dn] back to [d0, d1, ..., dn]
226 Eigen::array<Eigen::Index, Rank> output_shape;
227 for (int i = 0; i < Rank; i++)
228 {
229 output_shape[i] = this->out->getShape()[i];
230 }
231 this->out->getTensor() = output_2d.reshape(output_shape);
232
233 return GraphNode::eval();
234}
235
Tai Lya4d748b2023-03-28 22:06:56 +0000236template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000237OpCast<Rank, InDtype, OutDtype>::OpCast(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
Kevin Chengacb550f2021-06-29 15:32:19 -0700238 : GraphNode(sgt_, Op_CAST, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700239{
240 setRequiredOperands(1, 1);
241 setRequiredRank(0, 6);
242}
243
Tai Lya4d748b2023-03-28 22:06:56 +0000244template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700245OpCast<Rank, InDtype, OutDtype>::~OpCast()
246{}
247
Tai Lya4d748b2023-03-28 22:06:56 +0000248template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700249int OpCast<Rank, InDtype, OutDtype>::checkTensorAttributes()
250{
Jerry Gea793f462023-04-11 00:05:02 +0000251 // Check Tosa Level
252 auto tosa_level = g_func_config.tosa_level;
253 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
254
Eric Kunzee5e26762020-10-13 16:11:07 -0700255 if (validateRequiredOperands())
256 return 1;
257
258 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
259 {
260 return 1;
261 }
262
263 // output and input must be the same rank and size
264 if (inputs[0]->matchRankSize(*outputs[0]))
265 {
266 printNodeValidationError("OpCast: input and output rank/size must match");
267 return 1;
268 }
269
270 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
271 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
272
273 ASSERT_MEM(in && out);
274
275 return 0;
276}
277
Tai Lya4d748b2023-03-28 22:06:56 +0000278template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700279int OpCast<Rank, InDtype, OutDtype>::eval()
280{
281 this->out->getTensor() = this->in->getTensor().unaryExpr(cast_helper.get_fcn());
282
283 return GraphNode::eval();
284}
285
Tai Lya4d748b2023-03-28 22:06:56 +0000286template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700287CastHelper<InDtype, OutDtype>::CastHelper()
288{
289 fcn = [](InEigenType in) -> OutEigenType {
290 OutEigenType out = (OutEigenType)in; // implicit sign_extend() if sizeof(out_t) >= sizeof(in_t)
Eric Kunzee5e26762020-10-13 16:11:07 -0700291 return out;
292 };
293}
294
Tai Lya4d748b2023-03-28 22:06:56 +0000295template <TOSA_REF_TYPE InDtype>
296CastHelper<InDtype, TOSA_REF_TYPE_BOOL>::CastHelper()
Eric Kunzee5e26762020-10-13 16:11:07 -0700297{
298 fcn = [](InEigenType in) -> bool { return (in != 0) ? true : false; };
299}
300
Tai Lya4d748b2023-03-28 22:06:56 +0000301template <TOSA_REF_TYPE OutDtype>
302CastHelper<TOSA_REF_TYPE_BOOL, OutDtype>::CastHelper()
Eric Kunzee5e26762020-10-13 16:11:07 -0700303{
304 fcn = [](bool in) -> OutEigenType {
305 OutEigenType out = in ? (OutEigenType)1 : (OutEigenType)0;
306 return out;
307 };
308}
309
Tai Lya4d748b2023-03-28 22:06:56 +0000310template <TOSA_REF_TYPE InDtype>
311CastHelper<InDtype, TOSA_REF_TYPE_FP16>::CastHelper()
James Ward8b390432022-08-12 20:48:56 +0100312{
James Ward736fd1a2023-01-23 17:13:37 +0000313 // Integer data converted to fp16 (stored as fp32)
James Ward8b390432022-08-12 20:48:56 +0100314 fcn = [](InEigenType in) -> float {
James Ward736fd1a2023-01-23 17:13:37 +0000315 half_float::half h = half_float::half(in);
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000316 float out = half_float::half_cast<float, half_float::half>(h);
James Ward736fd1a2023-01-23 17:13:37 +0000317 return out;
318 };
319}
320
Tai Lya4d748b2023-03-28 22:06:56 +0000321CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP16>::CastHelper()
James Ward736fd1a2023-01-23 17:13:37 +0000322{
323 // fp32 data converted to fp16 (stored as fp32)
324 fcn = [](float in) -> float {
Tai Lya4d748b2023-03-28 22:06:56 +0000325 float out = fpTrunc<TOSA_REF_TYPE_FP16>(in); // truncate required for conversion from higher precision
James Ward736fd1a2023-01-23 17:13:37 +0000326 return out;
327 };
328}
329
Tai Lya4d748b2023-03-28 22:06:56 +0000330template <TOSA_REF_TYPE InDtype>
331CastHelper<InDtype, TOSA_REF_TYPE_BF16>::CastHelper()
James Ward736fd1a2023-01-23 17:13:37 +0000332{
333 // Integer data converted to bf16 (stored as fp32)
334 fcn = [](InEigenType in) -> float {
335 float out = (float)in; // default cast to float is round_to_nearest_float()
336 return out;
337 };
338}
339
Tai Lya4d748b2023-03-28 22:06:56 +0000340CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_BF16>::CastHelper()
James Ward736fd1a2023-01-23 17:13:37 +0000341{
342 // fp32 data converted to bf16 (stored as fp32)
343 fcn = [](float in) -> float {
Tai Lya4d748b2023-03-28 22:06:56 +0000344 return fpTrunc<TOSA_REF_TYPE_BF16>(in); // truncate required for conversions from higher precision
James Ward8b390432022-08-12 20:48:56 +0100345 };
346}
347
Tai Lya4d748b2023-03-28 22:06:56 +0000348template <TOSA_REF_TYPE OutDtype>
349CastHelper<TOSA_REF_TYPE_FP16, OutDtype>::CastHelper()
James Ward8b390432022-08-12 20:48:56 +0100350{
James Ward736fd1a2023-01-23 17:13:37 +0000351 // fp16 data (stored as fp32) converted to integer
James Ward8b390432022-08-12 20:48:56 +0100352 fcn = [](float in) -> OutEigenType {
James Ward736fd1a2023-01-23 17:13:37 +0000353 // Cast from float representation back to half_float before rounding
354 half_float::half h = half_float::half(in);
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000355 h = std::rint(h);
356 OutEigenType out = half_float::half_cast<OutEigenType, half_float::half>(h);
357 out = std::max<OutEigenType>(out, OutMin);
358 out = std::min<OutEigenType>(out, OutMax);
James Ward8b390432022-08-12 20:48:56 +0100359 return out;
360 };
361}
362
Tai Lya4d748b2023-03-28 22:06:56 +0000363CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP32>::CastHelper()
James Ward736fd1a2023-01-23 17:13:37 +0000364{
365 // No-op since fp16 values treated internally as their fp32 representation
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000366 fcn = [](float in) -> OutEigenType { return in; };
James Ward736fd1a2023-01-23 17:13:37 +0000367}
368
Tai Lya4d748b2023-03-28 22:06:56 +0000369template <TOSA_REF_TYPE OutDtype>
370CastHelper<TOSA_REF_TYPE_BF16, OutDtype>::CastHelper()
James Ward736fd1a2023-01-23 17:13:37 +0000371{
372 // bf16 data (stored as fp32) converted to integer
373 fcn = [](float in) -> OutEigenType {
374 OutEigenType out = std::round(in);
375 out = std::max<OutEigenType>(out, OutMin);
376 out = std::min<OutEigenType>(out, OutMax);
377 return out;
378 };
379}
380
Tai Lya4d748b2023-03-28 22:06:56 +0000381CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP32>::CastHelper()
James Ward736fd1a2023-01-23 17:13:37 +0000382{
383 // No-op since bf16 values treated as truncated fp32 internally
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000384 fcn = [](InEigenType in) -> OutEigenType { return in; };
James Ward736fd1a2023-01-23 17:13:37 +0000385}
386
Tai Lya4d748b2023-03-28 22:06:56 +0000387template <TOSA_REF_TYPE InDtype>
388CastHelper<InDtype, TOSA_REF_TYPE_FP32>::CastHelper()
Eric Kunzee5e26762020-10-13 16:11:07 -0700389{
James Ward736fd1a2023-01-23 17:13:37 +0000390 // Integer data converted to fp32
Eric Kunzee5e26762020-10-13 16:11:07 -0700391 fcn = [](InEigenType in) -> float {
392 float out = (OutEigenType)in; // default cast to float is round_to_nearest_float()
393 return out;
394 };
395}
396
Tai Lya4d748b2023-03-28 22:06:56 +0000397template <TOSA_REF_TYPE OutDtype>
398CastHelper<TOSA_REF_TYPE_FP32, OutDtype>::CastHelper()
Eric Kunzee5e26762020-10-13 16:11:07 -0700399{
James Ward736fd1a2023-01-23 17:13:37 +0000400 // fp32 data converted to integer
Eric Kunzee5e26762020-10-13 16:11:07 -0700401 fcn = [](float in) -> OutEigenType {
Eric Kunze57bc0792023-01-25 10:05:51 -0800402 OutEigenType out = std::rint(in);
Eric Kunzee5e26762020-10-13 16:11:07 -0700403 out = std::max<OutEigenType>(out, OutMin);
404 out = std::min<OutEigenType>(out, OutMax);
405 return out;
406 };
407}
408
Tai Lya4d748b2023-03-28 22:06:56 +0000409template <TOSA_REF_TYPE OutDtype>
410CastHelper<TOSA_REF_TYPE_FP64, OutDtype>::CastHelper()
411{
412 switch (OutDtype)
413 {
414 case TOSA_REF_TYPE_INT8:
415 case TOSA_REF_TYPE_INT16:
416 case TOSA_REF_TYPE_INT32:
417 // fp64 data converted to integer
418 fcn = [](InEigenType in) -> OutEigenType {
419 OutEigenType out = std::rint(in);
420 out = std::max<OutEigenType>(out, OutMin);
421 out = std::min<OutEigenType>(out, OutMax);
422 return out;
423 };
424 break;
425 case TOSA_REF_TYPE_FP64:
426 // no op
427 fcn = [](InEigenType in) -> OutEigenType { return in; };
428 break;
429 default:
430 ASSERT_MSG(false, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(OutDtype));
431 }
432}
433
Eric Kunzee5e26762020-10-13 16:11:07 -0700434// template explicit instantiation
435DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT8);
436DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT16);
437DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT32);
438DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL);
439DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16);
440DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32);
James Ward8b390432022-08-12 20:48:56 +0100441DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP16);
James Ward736fd1a2023-01-23 17:13:37 +0000442DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100443DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700444DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL);
445DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8);
446DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32);
James Ward8b390432022-08-12 20:48:56 +0100447DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP16);
James Ward736fd1a2023-01-23 17:13:37 +0000448DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100449DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700450DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL);
451DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8);
452DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16);
James Ward8b390432022-08-12 20:48:56 +0100453DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP16);
James Ward736fd1a2023-01-23 17:13:37 +0000454DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100455DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP32);
James Ward8b390432022-08-12 20:48:56 +0100456DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8);
457DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16);
458DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32);
James Ward736fd1a2023-01-23 17:13:37 +0000459DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP32);
James Ward24dbc422022-10-19 12:20:31 +0100460DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT8);
461DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT16);
462DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT32);
James Ward736fd1a2023-01-23 17:13:37 +0000463DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP32);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100464DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT8);
465DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT16);
466DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32);
James Ward736fd1a2023-01-23 17:13:37 +0000467DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP16);
468DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, BF16);
Tai Lya4d748b2023-03-28 22:06:56 +0000469DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT8);
470DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT16);
471DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT32);
472DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP64);
473DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64);
474DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64);
475DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -0700476
Kevin Cheng3a478572021-01-22 17:21:02 -0800477DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8);
478DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT16);
479DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT32);
480DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700481DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT16);
482DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800483DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700484DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT16);
485DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800486DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700487DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16);
488DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800489DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT8);
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100490DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT16);
491DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT16, INT16);
Kevin Cheng3a478572021-01-22 17:21:02 -0800492DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, UINT8);
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100493DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT8);
494DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT16);