blob: 3a610eaab838f5b8da71aa669ba4f5ee908ae738 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Kevin Cheng3a478572021-01-22 17:21:02 -08002// Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "type_conversion.h"
17#include "quant_util.h"
18#include "template_types.h"
19#include <cmath>
20
21using namespace TosaReference;
22using namespace Eigen;
23using namespace tosa;
24
25template <int Rank, DType InDtype, DType OutDtype>
26OpRescale<Rank, InDtype, OutDtype>::OpRescale(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
27 : GraphNode(Op_RESCALE, id_)
28{
29 setRequiredOperands(1, 1);
30 setRequiredRank(0, 6);
31 INIT_ATTRIBUTE(Rescale);
32}
33
34template <int Rank, DType InDtype, DType OutDtype>
35OpRescale<Rank, InDtype, OutDtype>::~OpRescale()
36{
37 if (attribute)
38 delete attribute;
39}
40
41template <int Rank, DType InDtype, DType OutDtype>
42int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes()
43{
44 if (validateRequiredOperands())
45 return 1;
46
47 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
48 {
49 return 1;
50 }
51
52 // output and input must be the same rank and size
53 if (inputs[0]->matchRankSize(*outputs[0]))
54 {
55 printNodeValidationError("OpRescale: input and output rank/size must match");
56 return 1;
57 }
58
59 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
60 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
61
62 ASSERT_MEM(in && out);
63
64 return 0;
65}
66
67template <int Rank, DType InDtype, DType OutDtype>
68int OpRescale<Rank, InDtype, OutDtype>::eval()
69{
70 int32_t input_zp = attribute->input_zp();
71 int32_t output_zp = attribute->output_zp();
72 std::vector<int32_t> multiplier = attribute->multiplier();
73 std::vector<int32_t> shift = attribute->shift();
74 //bool scale32 = attribute->scale32();
75 bool double_round = attribute->double_round();
76 bool per_channel = attribute->per_channel();
77
Eric Kunzee5e26762020-10-13 16:11:07 -070078 // reshape [d0, d1, ..., dn] into [d0 * d1 ..., dn]
79 Eigen::array<Eigen::Index, 2> shape_2d;
80 shape_2d[0] = 1;
81 if (Rank > 0)
82 {
83 for (int i = 0; i < Rank - 1; i++)
84 {
85 shape_2d[0] *= this->in->getShape()[i];
86 }
87 shape_2d[1] = this->in->getShape()[Rank - 1];
88 }
89 else
90 {
91 shape_2d[1] = 1;
92 }
93 ETensor2<InEigenType> input_reshaped = this->in->getTensor().reshape(shape_2d);
94
95 ETensor2<OutEigenType> output_2d(shape_2d);
96
97 // TODO: pass scale32 in when 16-bit mode implemented
98 if (per_channel)
99 {
100 ETensor2<InEigenType> curr_channel_slice_prescaled;
101 ETensor2<OutEigenType> curr_channel_slice_postscaled;
102 int32_t channel_multiplier, channel_shift;
103 Eigen::array<Eigen::Index, 2> begin, size;
104 size = Eigen::array<Eigen::Index, 2>({ shape_2d[0], 1 });
105 for (int32_t i = 0; i < shape_2d[1]; i++)
106 {
107 begin = Eigen::array<Eigen::Index, 2>({ 0, i });
108 curr_channel_slice_prescaled = input_reshaped.slice(begin, size);
109 channel_multiplier = multiplier[i];
110 channel_shift = shift[i];
111 curr_channel_slice_postscaled =
112 curr_channel_slice_prescaled.unaryExpr([input_zp, output_zp, channel_multiplier, channel_shift,
113 double_round](InEigenType in_val) -> OutEigenType {
114 InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
Kevin Chengaee1fac2020-11-11 13:54:06 -0800115 int32_t scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, channel_multiplier,
116 channel_shift, double_round);
Eric Kunzee5e26762020-10-13 16:11:07 -0700117 OutEigenType out_val = (OutEigenType)(scaled + output_zp);
118 out_val = std::max<OutEigenType>(out_val, QMin);
119 out_val = std::min<OutEigenType>(out_val, QMax);
120 return out_val;
121 });
122
123 for (int32_t j = 0; j < shape_2d[0]; j++)
124 {
125 output_2d(j, i) = curr_channel_slice_postscaled(j, 0);
126 }
127 }
128 }
129 else
130 {
131 int32_t tensor_multiplier = multiplier[0];
132 int32_t tensor_shift = shift[0];
133 output_2d = input_reshaped.unaryExpr(
134 [input_zp, output_zp, tensor_multiplier, tensor_shift, double_round](InEigenType in_val) -> OutEigenType {
135 InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
Kevin Chengaee1fac2020-11-11 13:54:06 -0800136 int32_t scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier,
Kevin Cheng99bea142020-10-19 12:35:05 -0700137 tensor_shift, double_round);
Eric Kunzee5e26762020-10-13 16:11:07 -0700138 OutEigenType out_val = (OutEigenType)(scaled + output_zp);
139 out_val = std::max<OutEigenType>(out_val, QMin);
140 out_val = std::min<OutEigenType>(out_val, QMax);
141 return out_val;
142 });
143 }
144
145 // reshape [d0 * d1 ..., dn] back to [d0, d1, ..., dn]
146 Eigen::array<Eigen::Index, Rank> output_shape;
147 for (int i = 0; i < Rank; i++)
148 {
149 output_shape[i] = this->out->getShape()[i];
150 }
151 this->out->getTensor() = output_2d.reshape(output_shape);
152
153 return GraphNode::eval();
154}
155
156template <int Rank, DType InDtype, DType OutDtype>
157OpCast<Rank, InDtype, OutDtype>::OpCast(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
158 : GraphNode(Op_CAST, id_)
159{
160 setRequiredOperands(1, 1);
161 setRequiredRank(0, 6);
162}
163
164template <int Rank, DType InDtype, DType OutDtype>
165OpCast<Rank, InDtype, OutDtype>::~OpCast()
166{}
167
168template <int Rank, DType InDtype, DType OutDtype>
169int OpCast<Rank, InDtype, OutDtype>::checkTensorAttributes()
170{
171 if (validateRequiredOperands())
172 return 1;
173
174 if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
175 {
176 return 1;
177 }
178
179 // output and input must be the same rank and size
180 if (inputs[0]->matchRankSize(*outputs[0]))
181 {
182 printNodeValidationError("OpCast: input and output rank/size must match");
183 return 1;
184 }
185
186 in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
187 out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
188
189 ASSERT_MEM(in && out);
190
191 return 0;
192}
193
194template <int Rank, DType InDtype, DType OutDtype>
195int OpCast<Rank, InDtype, OutDtype>::eval()
196{
197 this->out->getTensor() = this->in->getTensor().unaryExpr(cast_helper.get_fcn());
198
199 return GraphNode::eval();
200}
201
202template <DType InDtype, DType OutDtype>
203CastHelper<InDtype, OutDtype>::CastHelper()
204{
205 fcn = [](InEigenType in) -> OutEigenType {
206 OutEigenType out = (OutEigenType)in; // implicit sign_extend() if sizeof(out_t) >= sizeof(in_t)
Eric Kunzee5e26762020-10-13 16:11:07 -0700207 return out;
208 };
209}
210
211template <DType InDtype>
212CastHelper<InDtype, DType_BOOL>::CastHelper()
213{
214 fcn = [](InEigenType in) -> bool { return (in != 0) ? true : false; };
215}
216
217template <DType OutDtype>
218CastHelper<DType_BOOL, OutDtype>::CastHelper()
219{
220 fcn = [](bool in) -> OutEigenType {
221 OutEigenType out = in ? (OutEigenType)1 : (OutEigenType)0;
222 return out;
223 };
224}
225
226template <DType InDtype>
227CastHelper<InDtype, DType_FLOAT>::CastHelper()
228{
229 fcn = [](InEigenType in) -> float {
230 float out = (OutEigenType)in; // default cast to float is round_to_nearest_float()
231 return out;
232 };
233}
234
235template <DType OutDtype>
236CastHelper<DType_FLOAT, OutDtype>::CastHelper()
237{
238 fcn = [](float in) -> OutEigenType {
239 OutEigenType out = std::round(in);
240 out = std::max<OutEigenType>(out, OutMin);
241 out = std::min<OutEigenType>(out, OutMax);
242 return out;
243 };
244}
245
246// template explicit instantiation
247DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT8);
248DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT16);
249DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT32);
250DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL);
251DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16);
252DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32);
253DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FLOAT);
254DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL);
255DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8);
256DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32);
257DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FLOAT);
258DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL);
259DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8);
260DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16);
261DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FLOAT);
262DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT8);
263DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT16);
264DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT32);
265
Kevin Cheng3a478572021-01-22 17:21:02 -0800266DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8);
267DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT16);
268DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT32);
269DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700270DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT16);
271DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800272DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700273DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT16);
274DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800275DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT8);
Eric Kunzee5e26762020-10-13 16:11:07 -0700276DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16);
277DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32);
Kevin Cheng3a478572021-01-22 17:21:02 -0800278DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT8);
279DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, UINT8);