blob: c37d6519bba351027260bf9a91ad5ede625b1a0f [file] [log] [blame]
Tracy Narine10403ec2023-11-28 11:55:08 +00001//
2// Copyright © 2024 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <armnn/Exceptions.hpp>
7
8#pragma once
9
10inline void CreateRescaleTosaOperator(const std::string& inputName,
11 const std::string& outputName,
John Mcloughlinceb44282024-04-23 16:47:04 +010012 const std::vector<int32_t>& multipliers,
13 const std::vector<int32_t>& shifts,
Tracy Narine10403ec2023-11-28 11:55:08 +000014 int32_t input_zp,
15 int32_t output_zp,
16 bool double_round,
17 bool scale32,
John Mcloughlinceb44282024-04-23 16:47:04 +010018 bool per_channel,
Teresa Charlince48d1d2024-04-24 13:30:58 +010019 TosaSerializationOperator** op)
Tracy Narine10403ec2023-11-28 11:55:08 +000020{
21 if (!op)
22 {
23 throw armnn::Exception("CreateRescaleTosaOperator: nullptr op");
24 }
25
Tracy Narine10403ec2023-11-28 11:55:08 +000026 TosaRescaleAttribute attribute(input_zp,
27 output_zp,
28 multipliers,
29 shifts,
30 scale32,
31 double_round,
John Mcloughlinceb44282024-04-23 16:47:04 +010032 per_channel,
Teresa Charlin571a4f72024-03-26 11:18:42 +000033 false, // input_unsigned
34 false); // output_unsigned
Tracy Narine10403ec2023-11-28 11:55:08 +000035
36 // op
37 *op = new TosaSerializationOperator(Op_RESCALE, Attribute_RescaleAttribute, &attribute, {inputName}, {outputName});
38 if (!(*op))
39 {
40 throw armnn::Exception("CreateRescaleTosaOperator: failed to created operator");
41 }
Tracy Narine10403ec2023-11-28 11:55:08 +000042}
43
44inline void CreateRescaleTosaOperator(const std::string& inputName,
45 const std::string& outputName,
John Mcloughlinceb44282024-04-23 16:47:04 +010046 int32_t scale_multiplier,
47 int32_t scale_shift,
48 int32_t input_zp,
49 int32_t output_zp,
50 bool double_round,
51 bool scale32,
52 bool per_channel,
Teresa Charlince48d1d2024-04-24 13:30:58 +010053 TosaSerializationOperator** op)
John Mcloughlinceb44282024-04-23 16:47:04 +010054{
55 const std::vector<int32_t> multipliers{scale_multiplier};
56 const std::vector<int32_t> shifts{scale_shift};
Teresa Charlince48d1d2024-04-24 13:30:58 +010057 CreateRescaleTosaOperator(inputName, outputName, multipliers, shifts,
58 input_zp, output_zp, double_round, scale32, per_channel, op);
John Mcloughlinceb44282024-04-23 16:47:04 +010059}
60
61/// The following is taken from mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp in the LLVM project
62/// From a scale value, generates multiplier and shift values where
63/// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that
64/// multiplier = mantissa*2^shift for 32-bit scaling.
65static void ComputeMultiplierAndShiftTosaScale32(double scale,
66 int32_t &multiplier,
67 int32_t &shift)
68{
69 const double mantissa = std::frexp(scale, &shift);
70 auto shiftedM = std::round(mantissa * (int64_t(1) << 31));
71
72 // Can't be greater than 1.0.
73 if (!(shiftedM <= (int64_t(1) << 31)))
74 {
75 throw armnn::Exception("Shifted mantissa exceeds 32 signed bits");
76 }
77
78 if (shiftedM == (int64_t(1) << 31))
79 {
80 shiftedM /= 2;
81 shift++;
82 }
83
84 // TOSA expects right shift to be positive, and embed (1 << 31) into right
85 // shift bits.
86 shift = (-shift) + 31;
87
88 if (!(shiftedM <= std::numeric_limits<int32_t>::max()))
89 {
90 throw armnn::Exception("Shifted mantissa exceeds 32-bit signed output type");
91 }
92
93 multiplier = static_cast<int32_t>(shiftedM);
94
95 // Shifting tops out at 62 bits. Right shift to make 62 bits the max.
96 // The limit of 62 on shift allows the shift to be decomposed as
97 // two right shifts of 31.
98 if (shift > 62)
99 {
100 // Shifting the multiplier by more than 32-bits is unnecessary.
101 multiplier = multiplier >> std::min<int32_t>(31, shift - 62);
102 shift = 62;
103 }
104}
105
106/// The following is taken from mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp in the LLVM project
107/// From a scale value, generates multiplier and shift values where
108/// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that
109/// multiplier = mantissa*2^shift for 16-bit scaling.
110static void ComputeMultiplierAndShiftTosaScale16(double scale,
111 int32_t &multiplier,
112 int32_t &shift)
113{
114 const double mantissa = std::frexp(scale, &shift);
115 auto shiftedM = std::round(mantissa * (int64_t(1) << 15));
116
117 // Can't be greater than 1.0.
118 if (!(shiftedM <= (int64_t(1) << 15)))
119 {
120 throw armnn::Exception("Shifted mantissa exceeds 16 signed bits");
121 }
122
123 if (shiftedM == (int64_t(1) << 15))
124 {
125 shiftedM /= 2;
126 shift++;
127 }
128
129 // TOSA expects right shift to be positive and embed (1 << 15) into right
130 // shift bits.
131 shift = (-shift) + 15;
132
133 if (!(shiftedM <= std::numeric_limits<int32_t>::max()))
134 {
135 throw armnn::Exception("Shifted mantissa exceeds 32-bit signed output type");
136 }
137
138 multiplier = static_cast<int32_t>(shiftedM);
139
140 // Shifting tops out at 62 bits. Right shift to make 62 bits the max.
141 // The limit of 62 on shift allows the shift to be decomposed as
142 // two right shifts of 31.
143 if (shift > 62)
144 {
145 // Shifting the multiplier by more than 31-bits is unnecessary.
146 multiplier = multiplier >> std::min<int32_t>(31, shift - 62);
147 shift = 62;
148 }
149}
150
151inline void CreateRescaleTosaOperator(const std::string& inputName,
152 const std::string& outputName,
Tracy Narine10403ec2023-11-28 11:55:08 +0000153 double scale,
154 int32_t input_zp,
155 int32_t output_zp,
156 bool double_round,
157 bool scale32,
Teresa Charlince48d1d2024-04-24 13:30:58 +0100158 TosaSerializationOperator** op)
Tracy Narine10403ec2023-11-28 11:55:08 +0000159{
Tracy Narine10403ec2023-11-28 11:55:08 +0000160 int32_t multiplier;
161 int32_t shift;
John Mcloughlinceb44282024-04-23 16:47:04 +0100162
163 if (scale32)
164 {
165 ComputeMultiplierAndShiftTosaScale32(scale, multiplier, shift);
166 }
167 else
168 {
169 ComputeMultiplierAndShiftTosaScale16(scale, multiplier, shift);
170 }
171
Teresa Charlince48d1d2024-04-24 13:30:58 +0100172 CreateRescaleTosaOperator(inputName, outputName, multiplier, shift,
173 input_zp, output_zp, double_round, scale32, false, op);
John Mcloughlinceb44282024-04-23 16:47:04 +0100174}
175
176inline void CreateRescaleTosaOperatorPerChannel(const std::string& inputName,
177 const std::string& outputName,
John Mcloughlinceb44282024-04-23 16:47:04 +0100178 int32_t input_zp,
179 int32_t output_zp,
180 bool double_round,
181 bool scale32,
182 double input_scale,
183 double output_scale,
184 const std::vector<float>& weight_scales,
Teresa Charlince48d1d2024-04-24 13:30:58 +0100185 TosaSerializationOperator** op)
John Mcloughlinceb44282024-04-23 16:47:04 +0100186{
187 std::vector<int32_t> op_tensor_multipliers;
188 std::vector<int32_t> op_tensor_shifts;
189 op_tensor_multipliers.reserve(weight_scales.size());
190 op_tensor_shifts.reserve(weight_scales.size());
191
192 for (const float& weight_scale : weight_scales)
193 {
194 double op_tensor_scale = (input_scale * weight_scale) / output_scale;
195 int32_t multiplier;
196 int32_t shift;
197
198 if (scale32)
199 {
200 ComputeMultiplierAndShiftTosaScale32(op_tensor_scale, multiplier, shift);
201 }
202 else
203 {
204 ComputeMultiplierAndShiftTosaScale16(op_tensor_scale, multiplier, shift);
205 }
206
207 op_tensor_multipliers.push_back(multiplier);
208 op_tensor_shifts.push_back(shift);
209 }
210
Teresa Charlince48d1d2024-04-24 13:30:58 +0100211 CreateRescaleTosaOperator(inputName, outputName, op_tensor_multipliers, op_tensor_shifts,
212 input_zp, output_zp, double_round, scale32, true, op);
Tracy Narine10403ec2023-11-28 11:55:08 +0000213}
214
215inline void CreateFromInt32RescaleTosaOperator(const std::string& inputName,
216 const std::string& outputName,
John Mcloughlinceb44282024-04-23 16:47:04 +0100217 double output_scale,
218 int32_t output_zp,
Teresa Charlince48d1d2024-04-24 13:30:58 +0100219 TosaSerializationOperator** op)
Tracy Narine10403ec2023-11-28 11:55:08 +0000220{
Teresa Charlince48d1d2024-04-24 13:30:58 +0100221 CreateRescaleTosaOperator(inputName, outputName, output_scale,
222 0, output_zp, true, true, op);
Tracy Narine10403ec2023-11-28 11:55:08 +0000223}