blob: 9034add2d51e8308ca79ab835a191561be2859fb [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
James Ward8b390432022-08-12 20:48:56 +01002// Copyright (c) 2020-2022, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "type_conversion.h"
17#include "quant_util.h"
James Ward736fd1a2023-01-23 17:13:37 +000018#include "arith_util.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070019#include "template_types.h"
20#include <cmath>
James Ward8b390432022-08-12 20:48:56 +010021#include "half.hpp"
Eric Kunzee5e26762020-10-13 16:11:07 -070022
23using namespace TosaReference;
24using namespace Eigen;
25using namespace tosa;
26
27template <int Rank, DType InDtype, DType OutDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -070028OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_,
29 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -070030 uint64_t id_)
31 : GraphNode(sgt_, Op_RESCALE, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070032{
33 setRequiredOperands(1, 1);
TatWai Chongfd629052022-07-25 04:01:58 +000034 setRequiredRank(0, 6);
Eric Kunzee5e26762020-10-13 16:11:07 -070035 INIT_ATTRIBUTE(Rescale);
36}
37
38template <int Rank, DType InDtype, DType OutDtype>
39OpRescale<Rank, InDtype, OutDtype>::~OpRescale()
40{
41 if (attribute)
42 delete attribute;
43}
44
45template <int Rank, DType InDtype, DType OutDtype>
46int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes()
47{
Jerry Gea793f462023-04-11 00:05:02 +000048 // Check Tosa Level
49 auto tosa_level = g_func_config.tosa_level;
50 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
51
Eric Kunzee5e26762020-10-13 16:11:07 -070052 if (validateRequiredOperands())
53 return 1;
54
55 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
56 {
57 return 1;
58 }
59
60 // output and input must be the same rank and size
61 if (inputs[0]->matchRankSize(*outputs[0]))
62 {
63 printNodeValidationError("OpRescale: input and output rank/size must match");
64 return 1;
65 }
66
67 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
68 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
69
70 ASSERT_MEM(in && out);
71
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010072 if ((InDtype != DType_INT8) && (InDtype != DType_UINT8) && (InDtype != DType_UINT16) && (attribute->input_zp() != 0))
Kevin Chengcc61be32021-10-14 17:09:57 -070073 {
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010074 printNodeValidationError("OpRescale: Input DType not INT8/UINT8/UINT16 and zero point not 0");
Kevin Chengcc61be32021-10-14 17:09:57 -070075 return 1;
76 }
77
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010078 if ((OutDtype != DType_INT8) && (OutDtype != DType_UINT8) && (OutDtype != DType_UINT16) && (attribute->output_zp() != 0))
Kevin Chengcc61be32021-10-14 17:09:57 -070079 {
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010080 printNodeValidationError("OpRescale: Output DType not INT8/UINT8/UINT16 and zero point not 0");
81 return 1;
82 }
83
84 if ((InDtype == DType_UINT16) && ((attribute->input_zp() != 0) && (attribute->input_zp() != 32768)))
85 {
86 printNodeValidationError("OpRescale: Input DType UINT16 and zero point not 0 or 32768");
87 return 1;
88 }
89
90 if ((OutDtype == DType_UINT16) && ((attribute->output_zp() != 0) && (attribute->output_zp() != 32768)))
91 {
92 printNodeValidationError("OpRescale: Output DType UINT16 and zero point not 0 or 32768");
Kevin Chengcc61be32021-10-14 17:09:57 -070093 return 1;
94 }
95
96 if (attribute->scale32() && (InDtype == DType_INT48))
97 {
98 printNodeValidationError("OpRescale: Scale set to true but input type is INT48");
99 return 1;
100 }
101
102 if ((!attribute->scale32()) && attribute->double_round())
103 {
104 printNodeValidationError("OpRescale: Scale set to false but double round set to true");
105 return 1;
106 }
107
Eric Kunzee5e26762020-10-13 16:11:07 -0700108 return 0;
109}
110
111template <int Rank, DType InDtype, DType OutDtype>
112int OpRescale<Rank, InDtype, OutDtype>::eval()
113{
114 int32_t input_zp = attribute->input_zp();
115 int32_t output_zp = attribute->output_zp();
116 std::vector<int32_t> multiplier = attribute->multiplier();
117 std::vector<int32_t> shift = attribute->shift();
Kevin Cheng0f87c952021-03-18 17:41:39 -0700118 bool scale32 = attribute->scale32();
119 bool double_round = attribute->double_round();
120 bool per_channel = attribute->per_channel();
Eric Kunzee5e26762020-10-13 16:11:07 -0700121
Eric Kunzee5e26762020-10-13 16:11:07 -0700122 // reshape [d0, d1, ..., dn] into [d0 * d1 ..., dn]
123 Eigen::array<Eigen::Index, 2> shape_2d;
124 shape_2d[0] = 1;
125 if (Rank > 0)
126 {
127 for (int i = 0; i < Rank - 1; i++)
128 {
129 shape_2d[0] *= this->in->getShape()[i];
130 }
131 shape_2d[1] = this->in->getShape()[Rank - 1];
132 }
133 else
134 {
135 shape_2d[1] = 1;
136 }
137 ETensor2<InEigenType> input_reshaped = this->in->getTensor().reshape(shape_2d);
138
139 ETensor2<OutEigenType> output_2d(shape_2d);
140
Eric Kunzee5e26762020-10-13 16:11:07 -0700141 if (per_channel)
142 {
143 ETensor2<InEigenType> curr_channel_slice_prescaled;
144 ETensor2<OutEigenType> curr_channel_slice_postscaled;
145 int32_t channel_multiplier, channel_shift;
146 Eigen::array<Eigen::Index, 2> begin, size;
147 size = Eigen::array<Eigen::Index, 2>({ shape_2d[0], 1 });
Kevin Chengacb550f2021-06-29 15:32:19 -0700148 try
Eric Kunzee5e26762020-10-13 16:11:07 -0700149 {
Kevin Chengacb550f2021-06-29 15:32:19 -0700150 for (int32_t i = 0; i < shape_2d[1]; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700151 {
Kevin Chengacb550f2021-06-29 15:32:19 -0700152 begin = Eigen::array<Eigen::Index, 2>({ 0, i });
153 curr_channel_slice_prescaled = input_reshaped.slice(begin, size);
154 channel_multiplier = multiplier[i];
155 channel_shift = shift[i];
156 curr_channel_slice_postscaled =
157 curr_channel_slice_prescaled.unaryExpr([input_zp, output_zp, channel_multiplier, channel_shift,
158 double_round, scale32](InEigenType in_val) -> OutEigenType {
159 InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
160 int32_t scaled;
161 if (scale32)
162 scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, channel_multiplier,
163 channel_shift, double_round);
164 else
165 scaled = TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, channel_multiplier,
166 channel_shift);
Jeremy Johnsondf628d42023-01-10 14:40:54 +0000167 int64_t res_in_64 = static_cast<int64_t>(scaled) + output_zp;
168 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
169 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min());
170 if (res_in_64 > i32_max_in_64 || res_in_64 < i32_min_in_64)
171 {
172 std::string desc = "scaling result [" + std::to_string(scaled) + "] plus output_zp [" + std::to_string(output_zp) + "] not in i32 range";
173 throw desc;
174 }
175 OutEigenType out_val = static_cast<OutEigenType>(res_in_64);
Kevin Chengacb550f2021-06-29 15:32:19 -0700176 out_val = std::max<OutEigenType>(out_val, QMin);
177 out_val = std::min<OutEigenType>(out_val, QMax);
178 return out_val;
179 });
180
181 for (int32_t j = 0; j < shape_2d[0]; j++)
182 {
183 output_2d(j, i) = curr_channel_slice_postscaled(j, 0);
184 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700185 }
186 }
Kevin Chengacb550f2021-06-29 15:32:19 -0700187 catch (std::string desc)
188 {
Jeremy Johnsondf628d42023-01-10 14:40:54 +0000189 REQUIRE(false, "OpRescale failure: %s.", desc.c_str());
Kevin Chengacb550f2021-06-29 15:32:19 -0700190 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700191 }
192 else
193 {
194 int32_t tensor_multiplier = multiplier[0];
195 int32_t tensor_shift = shift[0];
Kevin Chengacb550f2021-06-29 15:32:19 -0700196 try
197 {
198 output_2d = input_reshaped.unaryExpr([input_zp, output_zp, tensor_multiplier, tensor_shift, double_round,
199 scale32](InEigenType in_val) -> OutEigenType {
200 InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
201 int32_t scaled;
202 if (scale32)
203 scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier, tensor_shift,
204 double_round);
205 else
206 scaled =
207 TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, tensor_multiplier, tensor_shift);
Jeremy Johnsondf628d42023-01-10 14:40:54 +0000208 int64_t res_in_64 = static_cast<int64_t>(scaled) + output_zp;
209 int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
210 int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min());
211 if (res_in_64 > i32_max_in_64 || res_in_64 < i32_min_in_64)
212 {
213 std::string desc = "scaling result [" + std::to_string(scaled) + "] plus output_zp [" + std::to_string(output_zp) + "] not in i32 range";
214 throw desc;
215 }
216
217 OutEigenType out_val = static_cast<OutEigenType>(res_in_64);
Kevin Chengacb550f2021-06-29 15:32:19 -0700218 out_val = std::max<OutEigenType>(out_val, QMin);
219 out_val = std::min<OutEigenType>(out_val, QMax);
220 return out_val;
221 });
222 }
223 catch (std::string desc)
224 {
Jeremy Johnsondf628d42023-01-10 14:40:54 +0000225 REQUIRE(false, "OpRescale failure: %s.", desc.c_str());
Kevin Chengacb550f2021-06-29 15:32:19 -0700226 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700227 }
228
229 // reshape [d0 * d1 ..., dn] back to [d0, d1, ..., dn]
230 Eigen::array<Eigen::Index, Rank> output_shape;
231 for (int i = 0; i < Rank; i++)
232 {
233 output_shape[i] = this->out->getShape()[i];
234 }
235 this->out->getTensor() = output_2d.reshape(output_shape);
236
237 return GraphNode::eval();
238}
239
240template <int Rank, DType InDtype, DType OutDtype>
Kevin Chengacb550f2021-06-29 15:32:19 -0700241OpCast<Rank, InDtype, OutDtype>::OpCast(SubgraphTraverser* sgt_,
242 TosaAttributeBase* attribute_,
Kevin Chengacb550f2021-06-29 15:32:19 -0700243 uint64_t id_)
244 : GraphNode(sgt_, Op_CAST, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700245{
246 setRequiredOperands(1, 1);
247 setRequiredRank(0, 6);
248}
249
250template <int Rank, DType InDtype, DType OutDtype>
251OpCast<Rank, InDtype, OutDtype>::~OpCast()
252{}
253
254template <int Rank, DType InDtype, DType OutDtype>
255int OpCast<Rank, InDtype, OutDtype>::checkTensorAttributes()
256{
Jerry Gea793f462023-04-11 00:05:02 +0000257 // Check Tosa Level
258 auto tosa_level = g_func_config.tosa_level;
259 LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
260
Eric Kunzee5e26762020-10-13 16:11:07 -0700261 if (validateRequiredOperands())
262 return 1;
263
264 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
265 {
266 return 1;
267 }
268
269 // output and input must be the same rank and size
270 if (inputs[0]->matchRankSize(*outputs[0]))
271 {
272 printNodeValidationError("OpCast: input and output rank/size must match");
273 return 1;
274 }
275
276 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
277 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
278
279 ASSERT_MEM(in && out);
280
281 return 0;
282}
283
284template <int Rank, DType InDtype, DType OutDtype>
285int OpCast<Rank, InDtype, OutDtype>::eval()
286{
287 this->out->getTensor() = this->in->getTensor().unaryExpr(cast_helper.get_fcn());
288
289 return GraphNode::eval();
290}
291
292template <DType InDtype, DType OutDtype>
293CastHelper<InDtype, OutDtype>::CastHelper()
294{
295 fcn = [](InEigenType in) -> OutEigenType {
296 OutEigenType out = (OutEigenType)in; // implicit sign_extend() if sizeof(out_t) >= sizeof(in_t)
Eric Kunzee5e26762020-10-13 16:11:07 -0700297 return out;
298 };
299}
300
301template <DType InDtype>
302CastHelper<InDtype, DType_BOOL>::CastHelper()
303{
304 fcn = [](InEigenType in) -> bool { return (in != 0) ? true : false; };
305}
306
307template <DType OutDtype>
308CastHelper<DType_BOOL, OutDtype>::CastHelper()
309{
310 fcn = [](bool in) -> OutEigenType {
311 OutEigenType out = in ? (OutEigenType)1 : (OutEigenType)0;
312 return out;
313 };
314}
315
316template <DType InDtype>
James Ward8b390432022-08-12 20:48:56 +0100317CastHelper<InDtype, DType_FP16>::CastHelper()
318{
James Ward736fd1a2023-01-23 17:13:37 +0000319 // Integer data converted to fp16 (stored as fp32)
James Ward8b390432022-08-12 20:48:56 +0100320 fcn = [](InEigenType in) -> float {
James Ward736fd1a2023-01-23 17:13:37 +0000321 half_float::half h = half_float::half(in);
322 float out = half_float::half_cast<float, half_float::half>(h);
323 return out;
324 };
325}
326
327CastHelper<DType_FP32, DType_FP16>::CastHelper()
328{
329 // fp32 data converted to fp16 (stored as fp32)
330 fcn = [](float in) -> float {
331 float out = fpTrunc<DType_FP16>(in); // truncate required for conversion from higher precision
332 return out;
333 };
334}
335
336template <DType InDtype>
337CastHelper<InDtype, DType_BF16>::CastHelper()
338{
339 // Integer data converted to bf16 (stored as fp32)
340 fcn = [](InEigenType in) -> float {
341 float out = (float)in; // default cast to float is round_to_nearest_float()
342 return out;
343 };
344}
345
346CastHelper<DType_FP32, DType_BF16>::CastHelper()
347{
348 // fp32 data converted to bf16 (stored as fp32)
349 fcn = [](float in) -> float {
350 return fpTrunc<DType_BF16>(in); // truncate required for conversions from higher precision
James Ward8b390432022-08-12 20:48:56 +0100351 };
352}
353
354template <DType OutDtype>
355CastHelper<DType_FP16, OutDtype>::CastHelper()
356{
James Ward736fd1a2023-01-23 17:13:37 +0000357 // fp16 data (stored as fp32) converted to integer
James Ward8b390432022-08-12 20:48:56 +0100358 fcn = [](float in) -> OutEigenType {
James Ward736fd1a2023-01-23 17:13:37 +0000359 // Cast from float representation back to half_float before rounding
360 half_float::half h = half_float::half(in);
Eric Kunze57bc0792023-01-25 10:05:51 -0800361 h = std::rint(h);
James Ward736fd1a2023-01-23 17:13:37 +0000362 OutEigenType out = half_float::half_cast<OutEigenType, half_float::half>(h);
James Ward8b390432022-08-12 20:48:56 +0100363 out = std::max<OutEigenType>(out, OutMin);
364 out = std::min<OutEigenType>(out, OutMax);
365 return out;
366 };
367}
368
James Ward736fd1a2023-01-23 17:13:37 +0000369CastHelper<DType_FP16, DType_FP32>::CastHelper()
370{
371 // No-op since fp16 values treated internally as their fp32 representation
372 fcn = [](float in) -> OutEigenType {
373 return in;
374 };
375}
376
377template <DType OutDtype>
378CastHelper<DType_BF16, OutDtype>::CastHelper()
379{
380 // bf16 data (stored as fp32) converted to integer
381 fcn = [](float in) -> OutEigenType {
382 OutEigenType out = std::round(in);
383 out = std::max<OutEigenType>(out, OutMin);
384 out = std::min<OutEigenType>(out, OutMax);
385 return out;
386 };
387}
388
389CastHelper<DType_BF16, DType_FP32>::CastHelper()
390{
391 // No-op since bf16 values treated as truncated fp32 internally
392 fcn = [](InEigenType in) -> OutEigenType {
393 return in;
394 };
395}
396
James Ward8b390432022-08-12 20:48:56 +0100397template <DType InDtype>
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100398CastHelper<InDtype, DType_FP32>::CastHelper()
Eric Kunzee5e26762020-10-13 16:11:07 -0700399{
James Ward736fd1a2023-01-23 17:13:37 +0000400 // Integer data converted to fp32
Eric Kunzee5e26762020-10-13 16:11:07 -0700401 fcn = [](InEigenType in) -> float {
402 float out = (OutEigenType)in; // default cast to float is round_to_nearest_float()
403 return out;
404 };
405}
406
407template <DType OutDtype>
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100408CastHelper<DType_FP32, OutDtype>::CastHelper()
Eric Kunzee5e26762020-10-13 16:11:07 -0700409{
James Ward736fd1a2023-01-23 17:13:37 +0000410 // fp32 data converted to integer
Eric Kunzee5e26762020-10-13 16:11:07 -0700411 fcn = [](float in) -> OutEigenType {
Eric Kunze57bc0792023-01-25 10:05:51 -0800412 OutEigenType out = std::rint(in);
Eric Kunzee5e26762020-10-13 16:11:07 -0700413 out = std::max<OutEigenType>(out, OutMin);
414 out = std::min<OutEigenType>(out, OutMax);
415 return out;
416 };
417}
418
419// template explicit instantiation
420DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT8);
421DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT16);
422DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT32);
423DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL);
424DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16);
425DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32);
James Ward8b390432022-08-12 20:48:56 +0100426DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP16);
James Ward736fd1a2023-01-23 17:13:37 +0000427DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100428DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700429DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL);
430DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8);
431DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32);
James Ward8b390432022-08-12 20:48:56 +0100432DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP16);
James Ward736fd1a2023-01-23 17:13:37 +0000433DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100434DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP32);
Eric Kunzee5e26762020-10-13 16:11:07 -0700435DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL);
436DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8);
437DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16);
James Ward8b390432022-08-12 20:48:56 +0100438DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP16);
James Ward736fd1a2023-01-23 17:13:37 +0000439DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BF16);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100440DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP32);
James Ward8b390432022-08-12 20:48:56 +0100441DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8);
442DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16);
443DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32);
James Ward736fd1a2023-01-23 17:13:37 +0000444DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP32);
James Ward24dbc422022-10-19 12:20:31 +0100445DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT8);
446DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT16);
447DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT32);
James Ward736fd1a2023-01-23 17:13:37 +0000448DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP32);
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100449DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT8);
450DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT16);
451DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32);
James Ward736fd1a2023-01-23 17:13:37 +0000452DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP16);
453DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, BF16);
Eric Kunzee5e26762020-10-13 16:11:07 -0700454
Kevin Cheng3a478572021-01-22 17:21:02 -0800455DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8);
456DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT16);
457DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT32);
458DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700459DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT16);
460DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800461DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700462DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT16);
463DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800464DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700465DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16);
466DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800467DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT8);
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100468DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT16);
469DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT16, INT16);
Kevin Cheng3a478572021-01-22 17:21:02 -0800470DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, UINT8);
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100471DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT8);
472DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT16);