blob: 68ffb1fd4ec325dc187579b60267725c1ab87e26 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Tai Lya4d748b2023-03-28 22:06:56 +00002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "type_conversion.h"
17#include "quant_util.h"
James Ward736fd1a2023-01-23 17:13:37 +000018#include "arith_util.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070019#include "template_types.h"
20#include <cmath>
James Ward8b390432022-08-12 20:48:56 +010021#include "half.hpp"
Eric Kunzee5e26762020-10-13 16:11:07 -070022
23using namespace TosaReference;
24using namespace Eigen;
25using namespace tosa;
26
Tai Lya4d748b2023-03-28 22:06:56 +000027template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -070028OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_,
29 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -070030 uint64_t id_)
31 : GraphNode(sgt_, Op_RESCALE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070032{
33 setRequiredOperands(1, 1);
TatWai Chongfd629052022-07-25 04:01:58 +000034 setRequiredRank(0, 6);
Eric Kunzee5e26762020-10-13 16:11:07 -070035 INIT_ATTRIBUTE(Rescale);
36}
37
Tai Lya4d748b2023-03-28 22:06:56 +000038template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070039OpRescale<Rank, InDtype, OutDtype>::~OpRescale()
40{
41 if (attribute)
42 delete attribute;
43}
44
Tai Lya4d748b2023-03-28 22:06:56 +000045template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -070046int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes()
47{
Jerry Gea793f462023-04-11 00:05:02 +000048 // Check Tosa Level
49 auto tosa_level = g_func_config.tosa_level;
50 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
51
Eric Kunzee5e26762020-10-13 16:11:07 -070052 if (validateRequiredOperands())
53 return 1;
54
55 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
56 {
57 return 1;
58 }
59
60 // output and input must be the same rank and size
61 if (inputs[0]->matchRankSize(*outputs[0]))
62 {
63 printNodeValidationError("OpRescale: input and output rank/size must match");
64 return 1;
65 }
66
67 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
68 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
69
70 ASSERT_MEM(in && out);
71
Tai Lya4d748b2023-03-28 22:06:56 +000072 if ((InDtype != TOSA_REF_TYPE_INT8) && (InDtype != TOSA_REF_TYPE_UINT8) && (InDtype != TOSA_REF_TYPE_UINT16) &&
73 (attribute->input_zp() != 0))
Kevin Chengcc61be32021-10-14 17:09:57 -070074 {
Tai Lya4d748b2023-03-28 22:06:56 +000075 printNodeValidationError("OpRescale: Input TOSA_REF_TYPE not INT8/UINT8/UINT16 and zero point not 0");
Kevin Chengcc61be32021-10-14 17:09:57 -070076 return 1;
77 }
78
Tai Lya4d748b2023-03-28 22:06:56 +000079 if ((OutDtype != TOSA_REF_TYPE_INT8) && (OutDtype != TOSA_REF_TYPE_UINT8) && (OutDtype != TOSA_REF_TYPE_UINT16) &&
80 (attribute->output_zp() != 0))
Kevin Chengcc61be32021-10-14 17:09:57 -070081 {
Tai Lya4d748b2023-03-28 22:06:56 +000082 printNodeValidationError("OpRescale: Output TOSA_REF_TYPE not INT8/UINT8/UINT16 and zero point not 0");
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010083 return 1;
84 }
85
Tai Lya4d748b2023-03-28 22:06:56 +000086 if ((InDtype == TOSA_REF_TYPE_UINT16) && ((attribute->input_zp() != 0) && (attribute->input_zp() != 32768)))
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010087 {
Tai Lya4d748b2023-03-28 22:06:56 +000088 printNodeValidationError("OpRescale: Input TOSA_REF_TYPE UINT16 and zero point not 0 or 32768");
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010089 return 1;
90 }
91
Tai Lya4d748b2023-03-28 22:06:56 +000092 if ((OutDtype == TOSA_REF_TYPE_UINT16) && ((attribute->output_zp() != 0) && (attribute->output_zp() != 32768)))
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010093 {
Tai Lya4d748b2023-03-28 22:06:56 +000094 printNodeValidationError("OpRescale: Output TOSA_REF_TYPE UINT16 and zero point not 0 or 32768");
Kevin Chengcc61be32021-10-14 17:09:57 -070095 return 1;
96 }
97
Tai Lya4d748b2023-03-28 22:06:56 +000098 if (attribute->scale32() && (InDtype == TOSA_REF_TYPE_INT48))
Kevin Chengcc61be32021-10-14 17:09:57 -070099 {
100 printNodeValidationError("OpRescale: Scale set to true but input type is INT48");
101 return 1;
102 }
103
104 if ((!attribute->scale32()) && attribute->double_round())
105 {
106 printNodeValidationError("OpRescale: Scale set to false but double round set to true");
107 return 1;
108 }
109
Eric Kunzee5e26762020-10-13 16:11:07 -0700110 return 0;
111}
112
Tai Lya4d748b2023-03-28 22:06:56 +0000113template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700114int OpRescale<Rank, InDtype, OutDtype>::eval()
115{
116 int32_t input_zp = attribute->input_zp();
117 int32_t output_zp = attribute->output_zp();
118 std::vector<int32_t> multiplier = attribute->multiplier();
119 std::vector<int32_t> shift = attribute->shift();
Kevin Cheng0f87c952021-03-18 17:41:39 -0700120 bool scale32 = attribute->scale32();
121 bool double_round = attribute->double_round();
122 bool per_channel = attribute->per_channel();
Eric Kunzee5e26762020-10-13 16:11:07 -0700123
Eric Kunzee5e26762020-10-13 16:11:07 -0700124 // reshape [d0, d1, ..., dn] into [d0 * d1 ..., dn]
125 Eigen::array<Eigen::Index, 2> shape_2d;
126 shape_2d[0] = 1;
127 if (Rank > 0)
128 {
129 for (int i = 0; i < Rank - 1; i++)
130 {
131 shape_2d[0] *= this->in->getShape()[i];
132 }
133 shape_2d[1] = this->in->getShape()[Rank - 1];
134 }
135 else
136 {
137 shape_2d[1] = 1;
138 }
139 ETensor2<InEigenType> input_reshaped = this->in->getTensor().reshape(shape_2d);
140
141 ETensor2<OutEigenType> output_2d(shape_2d);
142
Eric Kunzee5e26762020-10-13 16:11:07 -0700143 if (per_channel)
144 {
145 ETensor2<InEigenType> curr_channel_slice_prescaled;
146 ETensor2<OutEigenType> curr_channel_slice_postscaled;
147 int32_t channel_multiplier, channel_shift;
148 Eigen::array<Eigen::Index, 2> begin, size;
149 size = Eigen::array<Eigen::Index, 2>({ shape_2d[0], 1 });
Kevin Chengacb550f2021-06-29 15:32:19 -0700150 try
Eric Kunzee5e26762020-10-13 16:11:07 -0700151 {
Kevin Chengacb550f2021-06-29 15:32:19 -0700152 for (int32_t i = 0; i < shape_2d[1]; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700153 {
Kevin Chengacb550f2021-06-29 15:32:19 -0700154 begin = Eigen::array<Eigen::Index, 2>({ 0, i });
155 curr_channel_slice_prescaled = input_reshaped.slice(begin, size);
156 channel_multiplier = multiplier[i];
157 channel_shift = shift[i];
158 curr_channel_slice_postscaled =
159 curr_channel_slice_prescaled.unaryExpr([input_zp, output_zp, channel_multiplier, channel_shift,
160 double_round, scale32](InEigenType in_val) -> OutEigenType {
161 InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
162 int32_t scaled;
163 if (scale32)
164 scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, channel_multiplier,
165 channel_shift, double_round);
166 else
167 scaled = TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, channel_multiplier,
168 channel_shift);
Jeremy Johnsondf628d42023-01-10 14:40:54 +0000169 int64_t res_in_64 = static_cast<int64_t>(scaled) + output_zp;
170 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
171 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min());
172 if (res_in_64 > i32_max_in_64 || res_in_64 < i32_min_in_64)
173 {
174 std::string desc = "scaling result [" + std::to_string(scaled) + "] plus output_zp [" + std::to_string(output_zp) + "] not in i32 range";
175 throw desc;
176 }
177 OutEigenType out_val = static_cast<OutEigenType>(res_in_64);
Kevin Chengacb550f2021-06-29 15:32:19 -0700178 out_val = std::max<OutEigenType>(out_val, QMin);
179 out_val = std::min<OutEigenType>(out_val, QMax);
180 return out_val;
181 });
182
183 for (int32_t j = 0; j < shape_2d[0]; j++)
184 {
185 output_2d(j, i) = curr_channel_slice_postscaled(j, 0);
186 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700187 }
188 }
Kevin Chengacb550f2021-06-29 15:32:19 -0700189 catch (std::string desc)
190 {
Jeremy Johnsondf628d42023-01-10 14:40:54 +0000191 REQUIRE(false, "OpRescale failure: %s.", desc.c_str());
Kevin Chengacb550f2021-06-29 15:32:19 -0700192 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 }
194 else
195 {
196 int32_t tensor_multiplier = multiplier[0];
197 int32_t tensor_shift = shift[0];
Kevin Chengacb550f2021-06-29 15:32:19 -0700198 try
199 {
200 output_2d = input_reshaped.unaryExpr([input_zp, output_zp, tensor_multiplier, tensor_shift, double_round,
201 scale32](InEigenType in_val) -> OutEigenType {
202 InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
203 int32_t scaled;
204 if (scale32)
205 scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier, tensor_shift,
206 double_round);
207 else
208 scaled =
209 TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, tensor_multiplier, tensor_shift);
Jeremy Johnsondf628d42023-01-10 14:40:54 +0000210 int64_t res_in_64 = static_cast<int64_t>(scaled) + output_zp;
211 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
212 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min());
213 if (res_in_64 > i32_max_in_64 || res_in_64 < i32_min_in_64)
214 {
215 std::string desc = "scaling result [" + std::to_string(scaled) + "] plus output_zp [" + std::to_string(output_zp) + "] not in i32 range";
216 throw desc;
217 }
218
219 OutEigenType out_val = static_cast<OutEigenType>(res_in_64);
Kevin Chengacb550f2021-06-29 15:32:19 -0700220 out_val = std::max<OutEigenType>(out_val, QMin);
221 out_val = std::min<OutEigenType>(out_val, QMax);
222 return out_val;
223 });
224 }
225 catch (std::string desc)
226 {
Jeremy Johnsondf628d42023-01-10 14:40:54 +0000227 REQUIRE(false, "OpRescale failure: %s.", desc.c_str());
Kevin Chengacb550f2021-06-29 15:32:19 -0700228 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700229 }
230
231 // reshape [d0 * d1 ..., dn] back to [d0, d1, ..., dn]
232 Eigen::array<Eigen::Index, Rank> output_shape;
233 for (int i = 0; i < Rank; i++)
234 {
235 output_shape[i] = this->out->getShape()[i];
236 }
237 this->out->getTensor() = output_2d.reshape(output_shape);
238
239 return GraphNode::eval();
240}
241
Tai Lya4d748b2023-03-28 22:06:56 +0000242template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700243OpCast<Rank, InDtype, OutDtype>::OpCast(SubgraphTraverser* sgt_,
244 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700245 uint64_t id_)
246 : GraphNode(sgt_, Op_CAST, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700247{
248 setRequiredOperands(1, 1);
249 setRequiredRank(0, 6);
250}
251
Tai Lya4d748b2023-03-28 22:06:56 +0000252template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700253OpCast<Rank, InDtype, OutDtype>::~OpCast()
254{}
255
Tai Lya4d748b2023-03-28 22:06:56 +0000256template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700257int OpCast<Rank, InDtype, OutDtype>::checkTensorAttributes()
258{
Jerry Gea793f462023-04-11 00:05:02 +0000259 // Check Tosa Level
260 auto tosa_level = g_func_config.tosa_level;
261 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
262
Eric Kunzee5e26762020-10-13 16:11:07 -0700263 if (validateRequiredOperands())
264 return 1;
265
266 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
267 {
268 return 1;
269 }
270
271 // output and input must be the same rank and size
272 if (inputs[0]->matchRankSize(*outputs[0]))
273 {
274 printNodeValidationError("OpCast: input and output rank/size must match");
275 return 1;
276 }
277
278 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
279 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
280
281 ASSERT_MEM(in && out);
282
283 return 0;
284}
285
Tai Lya4d748b2023-03-28 22:06:56 +0000286template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700287int OpCast<Rank, InDtype, OutDtype>::eval()
288{
289 this->out->getTensor() = this->in->getTensor().unaryExpr(cast_helper.get_fcn());
290
291 return GraphNode::eval();
292}
293
Tai Lya4d748b2023-03-28 22:06:56 +0000294template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
Eric Kunzee5e26762020-10-13 16:11:07 -0700295CastHelper<InDtype, OutDtype>::CastHelper()
296{
297 fcn = [](InEigenType in) -> OutEigenType {
298 OutEigenType out = (OutEigenType)in; // implicit sign_extend() if sizeof(out_t) >= sizeof(in_t)
Eric Kunzee5e26762020-10-13 16:11:07 -0700299 return out;
300 };
301}
302
Tai Lya4d748b2023-03-28 22:06:56 +0000303template <TOSA_REF_TYPE InDtype>
304CastHelper<InDtype, TOSA_REF_TYPE_BOOL>::CastHelper()
Eric Kunzee5e26762020-10-13 16:11:07 -0700305{
306 fcn = [](InEigenType in) -> bool { return (in != 0) ? true : false; };
307}
308
Tai Lya4d748b2023-03-28 22:06:56 +0000309template <TOSA_REF_TYPE OutDtype>
310CastHelper<TOSA_REF_TYPE_BOOL, OutDtype>::CastHelper()
Eric Kunzee5e26762020-10-13 16:11:07 -0700311{
312 fcn = [](bool in) -> OutEigenType {
313 OutEigenType out = in ? (OutEigenType)1 : (OutEigenType)0;
314 return out;
315 };
316}
317
Tai Lya4d748b2023-03-28 22:06:56 +0000318template <TOSA_REF_TYPE InDtype>
319CastHelper<InDtype, TOSA_REF_TYPE_FP16>::CastHelper()
James Ward8b390432022-08-12 20:48:56 +0100320{
James Ward736fd1a2023-01-23 17:13:37 +0000321 // Integer data converted to fp16 (stored as fp32)
James Ward8b390432022-08-12 20:48:56 +0100322 fcn = [](InEigenType in) -> float {
James Ward736fd1a2023-01-23 17:13:37 +0000323 half_float::half h = half_float::half(in);
324 float out = half_float::half_cast<float, half_float::half>(h);
325 return out;
326 };
327}
328
Tai Lya4d748b2023-03-28 22:06:56 +0000329CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP16>::CastHelper()
James Ward736fd1a2023-01-23 17:13:37 +0000330{
331 // fp32 data converted to fp16 (stored as fp32)
332 fcn = [](float in) -> float {
Tai Lya4d748b2023-03-28 22:06:56 +0000333 float out = fpTrunc<TOSA_REF_TYPE_FP16>(in); // truncate required for conversion from higher precision
James Ward736fd1a2023-01-23 17:13:37 +0000334 return out;
335 };
336}
337
Tai Lya4d748b2023-03-28 22:06:56 +0000338template <TOSA_REF_TYPE InDtype>
339CastHelper<InDtype, TOSA_REF_TYPE_BF16>::CastHelper()
James Ward736fd1a2023-01-23 17:13:37 +0000340{
341 // Integer data converted to bf16 (stored as fp32)
342 fcn = [](InEigenType in) -> float {
343 float out = (float)in; // default cast to float is round_to_nearest_float()
344 return out;
345 };
346}
347
Tai Lya4d748b2023-03-28 22:06:56 +0000348CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_BF16>::CastHelper()
James Ward736fd1a2023-01-23 17:13:37 +0000349{
350 // fp32 data converted to bf16 (stored as fp32)
351 fcn = [](float in) -> float {
Tai Lya4d748b2023-03-28 22:06:56 +0000352 return fpTrunc<TOSA_REF_TYPE_BF16>(in); // truncate required for conversions from higher precision
James Ward8b390432022-08-12 20:48:56 +0100353 };
354}
355
Tai Lya4d748b2023-03-28 22:06:56 +0000356template <TOSA_REF_TYPE OutDtype>
357CastHelper<TOSA_REF_TYPE_FP16, OutDtype>::CastHelper()
James Ward8b390432022-08-12 20:48:56 +0100358{
James Ward736fd1a2023-01-23 17:13:37 +0000359 // fp16 data (stored as fp32) converted to integer
James Ward8b390432022-08-12 20:48:56 +0100360 fcn = [](float in) -> OutEigenType {
James Ward736fd1a2023-01-23 17:13:37 +0000361 // Cast from float representation back to half_float before rounding
362 half_float::half h = half_float::half(in);
Eric Kunze57bc0792023-01-25 10:05:51 -0800363 h = std::rint(h);
James Ward736fd1a2023-01-23 17:13:37 +0000364 OutEigenType out = half_float::half_cast<OutEigenType, half_float::half>(h);
James Ward8b390432022-08-12 20:48:56 +0100365 out = std::max<OutEigenType>(out, OutMin);
366 out = std::min<OutEigenType>(out, OutMax);
367 return out;
368 };
369}
370
Tai Lya4d748b2023-03-28 22:06:56 +0000371CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP32>::CastHelper()
James Ward736fd1a2023-01-23 17:13:37 +0000372{
373 // No-op since fp16 values treated internally as their fp32 representation
374 fcn = [](float in) -> OutEigenType {
375 return in;
376 };
377}
378
Tai Lya4d748b2023-03-28 22:06:56 +0000379template <TOSA_REF_TYPE OutDtype>
380CastHelper<TOSA_REF_TYPE_BF16, OutDtype>::CastHelper()
James Ward736fd1a2023-01-23 17:13:37 +0000381{
382 // bf16 data (stored as fp32) converted to integer
383 fcn = [](float in) -> OutEigenType {
384 OutEigenType out = std::round(in);
385 out = std::max<OutEigenType>(out, OutMin);
386 out = std::min<OutEigenType>(out, OutMax);
387 return out;
388 };
389}
390
Tai Lya4d748b2023-03-28 22:06:56 +0000391CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP32>::CastHelper()
James Ward736fd1a2023-01-23 17:13:37 +0000392{
393 // No-op since bf16 values treated as truncated fp32 internally
394 fcn = [](InEigenType in) -> OutEigenType {
395 return in;
396 };
397}
398
Tai Lya4d748b2023-03-28 22:06:56 +0000399template <TOSA_REF_TYPE InDtype>
400CastHelper<InDtype, TOSA_REF_TYPE_FP32>::CastHelper()
Eric Kunzee5e26762020-10-13 16:11:07 -0700401{
James Ward736fd1a2023-01-23 17:13:37 +0000402 // Integer data converted to fp32
Eric Kunzee5e26762020-10-13 16:11:07 -0700403 fcn = [](InEigenType in) -> float {
404 float out = (OutEigenType)in; // default cast to float is round_to_nearest_float()
405 return out;
406 };
407}
408
Tai Lya4d748b2023-03-28 22:06:56 +0000409template <TOSA_REF_TYPE OutDtype>
410CastHelper<TOSA_REF_TYPE_FP32, OutDtype>::CastHelper()
Eric Kunzee5e26762020-10-13 16:11:07 -0700411{
James Ward736fd1a2023-01-23 17:13:37 +0000412 // fp32 data converted to integer
Eric Kunzee5e26762020-10-13 16:11:07 -0700413 fcn = [](float in) -> OutEigenType {
Eric Kunze57bc0792023-01-25 10:05:51 -0800414 OutEigenType out = std::rint(in);
Eric Kunzee5e26762020-10-13 16:11:07 -0700415 out = std::max<OutEigenType>(out, OutMin);
416 out = std::min<OutEigenType>(out, OutMax);
417 return out;
418 };
419}
420
Tai Lya4d748b2023-03-28 22:06:56 +0000421template <TOSA_REF_TYPE OutDtype>
422CastHelper<TOSA_REF_TYPE_FP64, OutDtype>::CastHelper()
423{
424 switch (OutDtype)
425 {
426 case TOSA_REF_TYPE_INT8:
427 case TOSA_REF_TYPE_INT16:
428 case TOSA_REF_TYPE_INT32:
429 // fp64 data converted to integer
430 fcn = [](InEigenType in) -> OutEigenType {
431 OutEigenType out = std::rint(in);
432 out = std::max<OutEigenType>(out, OutMin);
433 out = std::min<OutEigenType>(out, OutMax);
434 return out;
435 };
436 break;
437 case TOSA_REF_TYPE_FP64:
438 // no op
439 fcn = [](InEigenType in) -> OutEigenType { return in; };
440 break;
441 default:
442 ASSERT_MSG(false, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(OutDtype));
443 }
444}
445
Eric Kunzee5e26762020-10-13 16:11:07 -0700446// template explicit instantiation
447DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT8);
448DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT16);
449DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT32);
450DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL);
451DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16);
452DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32);
James Ward8b390432022-08-12 20:48:56 +0100453DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP16);
James Ward736fd1a2023-01-23 17:13:37 +0000454DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100455DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700456DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL);
457DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8);
458DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32);
James Ward8b390432022-08-12 20:48:56 +0100459DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP16);
James Ward736fd1a2023-01-23 17:13:37 +0000460DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100461DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700462DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL);
463DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8);
464DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16);
James Ward8b390432022-08-12 20:48:56 +0100465DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP16);
James Ward736fd1a2023-01-23 17:13:37 +0000466DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100467DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP32);
James Ward8b390432022-08-12 20:48:56 +0100468DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8);
469DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16);
470DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32);
James Ward736fd1a2023-01-23 17:13:37 +0000471DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP32);
James Ward24dbc422022-10-19 12:20:31 +0100472DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT8);
473DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT16);
474DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT32);
James Ward736fd1a2023-01-23 17:13:37 +0000475DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP32);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100476DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT8);
477DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT16);
478DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32);
James Ward736fd1a2023-01-23 17:13:37 +0000479DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP16);
480DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, BF16);
Tai Lya4d748b2023-03-28 22:06:56 +0000481DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT8);
482DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT16);
483DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT32);
484DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP64);
485DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64);
486DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64);
487DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64);
Eric Kunzee5e26762020-10-13 16:11:07 -0700488
Kevin Cheng3a478572021-01-22 17:21:02 -0800489DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8);
490DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT16);
491DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT32);
492DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700493DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT16);
494DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800495DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700496DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT16);
497DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800498DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700499DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16);
500DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800501DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT8);
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100502DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT16);
503DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT16, INT16);
Kevin Cheng3a478572021-01-22 17:21:02 -0800504DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, UINT8);
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100505DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT8);
506DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT16);