blob: 26aaa2a9d51711d7ca7ec55dae7a1719d2c598d6 [file] [log] [blame]
Gian Marcoe75a02b2017-11-08 12:24:09 +00001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel.h"
25
Gian Marco6b77e912017-11-17 09:27:57 +000026#include "arm_compute/core/AccessWindowStatic.h"
Gian Marcoe75a02b2017-11-08 12:24:09 +000027#include "arm_compute/core/Error.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/ITensor.h"
30#include "arm_compute/core/Types.h"
Gian Marco6b77e912017-11-17 09:27:57 +000031#include "arm_compute/core/Utils.h"
Gian Marcoe75a02b2017-11-08 12:24:09 +000032#include "arm_compute/core/Validate.h"
33#include "arm_compute/core/Window.h"
34
35#include <arm_neon.h>
36#include <cstddef>
37#include <cstdint>
38
39using namespace arm_compute;
40
Gian Marco6b77e912017-11-17 09:27:57 +000041namespace
42{
43inline void scale_input(int32x4x4_t &in_s32, int32x4_t result_offset_s32, int32_t result_mult_int)
44{
45 // Add the offset terms to GEMM's result
46 in_s32.val[0] = vaddq_s32(in_s32.val[0], result_offset_s32);
47 in_s32.val[1] = vaddq_s32(in_s32.val[1], result_offset_s32);
48 in_s32.val[2] = vaddq_s32(in_s32.val[2], result_offset_s32);
49 in_s32.val[3] = vaddq_s32(in_s32.val[3], result_offset_s32);
50
51 // Multiply by result_mult_int
52 in_s32.val[0] = vmulq_n_s32(in_s32.val[0], result_mult_int);
53 in_s32.val[1] = vmulq_n_s32(in_s32.val[1], result_mult_int);
54 in_s32.val[2] = vmulq_n_s32(in_s32.val[2], result_mult_int);
55 in_s32.val[3] = vmulq_n_s32(in_s32.val[3], result_mult_int);
56}
57
58template <bool is_bounded_relu>
59inline uint8x16_t finalize_quantization(int32x4x4_t &in_s32, int32x4_t result_shift_s32, uint8x16_t min_u8, uint8x16_t max_u8)
60{
61 const static int32x4_t zero_s32 = vdupq_n_s32(0);
62
63 // Shift final result (negative value shift right)
64 in_s32.val[0] = vshlq_s32(in_s32.val[0], result_shift_s32);
65 in_s32.val[1] = vshlq_s32(in_s32.val[1], result_shift_s32);
66 in_s32.val[2] = vshlq_s32(in_s32.val[2], result_shift_s32);
67 in_s32.val[3] = vshlq_s32(in_s32.val[3], result_shift_s32);
68
69 // Saturate negative values
70 in_s32.val[0] = vmaxq_s32(in_s32.val[0], zero_s32);
71 in_s32.val[1] = vmaxq_s32(in_s32.val[1], zero_s32);
72 in_s32.val[2] = vmaxq_s32(in_s32.val[2], zero_s32);
73 in_s32.val[3] = vmaxq_s32(in_s32.val[3], zero_s32);
74
75 // Convert S32 to S16
76 const int16x8x2_t in_s16 =
77 {
78 {
79 vcombine_s16(vqmovn_s32(in_s32.val[0]), vqmovn_s32(in_s32.val[1])),
80 vcombine_s16(vqmovn_s32(in_s32.val[2]), vqmovn_s32(in_s32.val[3]))
81 }
82 };
83
84 // Convert S16 to U8
85 uint8x16_t out_u8 = vcombine_u8(vqmovun_s16(in_s16.val[0]), vqmovun_s16(in_s16.val[1]));
86
87 if(is_bounded_relu)
88 {
89 out_u8 = vmaxq_u8(out_u8, min_u8);
90 out_u8 = vminq_u8(out_u8, max_u8);
91 }
92
93 return out_u8;
94}
95} // namespace
96
Gian Marcoe75a02b2017-11-08 12:24:09 +000097namespace arm_compute
98{
99class Coordinates;
100} // namespace arm_compute
101
Gian Marco6b77e912017-11-17 09:27:57 +0000102template <bool is_bounded_relu>
103void NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel::run(const Window &window)
104{
105 const int32x4_t result_offset_s32 = vdupq_n_s32(_result_offset);
106 const int32x4_t result_shift_s32 = vdupq_n_s32(-_result_shift);
107 const uint8x16_t min_u8 = vdupq_n_u8(static_cast<uint8_t>(_min));
108 const uint8x16_t max_u8 = vdupq_n_u8(static_cast<uint8_t>(_max));
109
110 ARM_COMPUTE_UNUSED(min_u8);
111 ARM_COMPUTE_UNUSED(max_u8);
112
113 Iterator in(_input, window);
114 Iterator out(_output, window);
115
116 if(_bias != nullptr)
117 {
118 Window win_biases;
119 win_biases.set(Window::DimX, Window::Dimension(window.x().start(), window.x().end(), window.x().step()));
120 win_biases.set(Window::DimY, Window::Dimension(0, 1, 1));
121
122 Iterator bias(_bias, win_biases);
123 execute_window_loop(window, [&](const Coordinates & id)
124 {
125 int32x4x4_t in_s32 =
126 {
127 {
128 vld1q_s32(reinterpret_cast<const int32_t *>(in.ptr()) + 0),
129 vld1q_s32(reinterpret_cast<const int32_t *>(in.ptr()) + 4),
130 vld1q_s32(reinterpret_cast<const int32_t *>(in.ptr()) + 8),
131 vld1q_s32(reinterpret_cast<const int32_t *>(in.ptr()) + 12)
132 }
133 };
134
135 const int32x4x4_t bias_s32 =
136 {
137 {
138 vld1q_s32(reinterpret_cast<const int32_t *>(bias.ptr()) + 0),
139 vld1q_s32(reinterpret_cast<const int32_t *>(bias.ptr()) + 4),
140 vld1q_s32(reinterpret_cast<const int32_t *>(bias.ptr()) + 8),
141 vld1q_s32(reinterpret_cast<const int32_t *>(bias.ptr()) + 12)
142 }
143 };
144
145 // Add the offset terms to GEMM's result and multiply by result_mult_int
146 scale_input(in_s32, result_offset_s32, _result_mult_int);
147
148 // Add the bias to GEMM's result
149 in_s32.val[0] = vaddq_s32(in_s32.val[0], bias_s32.val[0]);
150 in_s32.val[1] = vaddq_s32(in_s32.val[1], bias_s32.val[1]);
151 in_s32.val[2] = vaddq_s32(in_s32.val[2], bias_s32.val[2]);
152 in_s32.val[3] = vaddq_s32(in_s32.val[3], bias_s32.val[3]);
153
154 vst1q_u8(out.ptr(), finalize_quantization<is_bounded_relu>(in_s32, result_shift_s32, min_u8, max_u8));
155 },
156 in, bias, out);
157 }
158 else
159 {
160 execute_window_loop(window, [&](const Coordinates & id)
161 {
162 int32x4x4_t in_s32 =
163 {
164 {
165 vld1q_s32(reinterpret_cast<const int32_t *>(in.ptr()) + 0),
166 vld1q_s32(reinterpret_cast<const int32_t *>(in.ptr()) + 4),
167 vld1q_s32(reinterpret_cast<const int32_t *>(in.ptr()) + 8),
168 vld1q_s32(reinterpret_cast<const int32_t *>(in.ptr()) + 12)
169 }
170 };
171
172 // Add the offset terms to GEMM's result and multiply by result_mult_int
173 scale_input(in_s32, result_offset_s32, _result_mult_int);
174
175 vst1q_u8(out.ptr(), finalize_quantization<is_bounded_relu>(in_s32, result_shift_s32, min_u8, max_u8));
176 },
177 in, out);
178 }
179}
180
Gian Marcoe75a02b2017-11-08 12:24:09 +0000181NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel::NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel()
Gian Marco6b77e912017-11-17 09:27:57 +0000182 : _func(nullptr), _input(nullptr), _bias(nullptr), _output(nullptr), _result_offset(0), _result_mult_int(0), _result_shift(0), _min(0), _max(0)
Gian Marcoe75a02b2017-11-08 12:24:09 +0000183{
184}
185
Gian Marco6b77e912017-11-17 09:27:57 +0000186void NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel::configure(const ITensor *input, const ITensor *bias, ITensor *output, int result_offset, int result_mult_int, int result_shift, int min, int max)
Gian Marcoe75a02b2017-11-08 12:24:09 +0000187{
188 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::S32);
189 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8);
Gian Marco6b77e912017-11-17 09:27:57 +0000190 ARM_COMPUTE_ERROR_ON(max > 255);
191 ARM_COMPUTE_ERROR_ON(min < 0 || min > max);
192
193 if(bias != nullptr)
194 {
195 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, bias);
196 ARM_COMPUTE_ERROR_ON(bias->info()->num_dimensions() > 1);
197 ARM_COMPUTE_ERROR_ON(input->info()->dimension(0) != bias->info()->dimension(0));
198 }
Gian Marcoe75a02b2017-11-08 12:24:09 +0000199
200 _input = input;
Gian Marco6b77e912017-11-17 09:27:57 +0000201 _bias = bias;
Gian Marcoe75a02b2017-11-08 12:24:09 +0000202 _output = output;
203 _result_offset = result_offset;
204 _result_mult_int = result_mult_int;
205 _result_shift = result_shift;
Gian Marco6b77e912017-11-17 09:27:57 +0000206 _min = min;
207 _max = max;
Gian Marcoe75a02b2017-11-08 12:24:09 +0000208
209 constexpr unsigned int num_elems_processed_per_iteration = 16;
210
211 // Configure kernel window
212 Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration));
213
214 AccessWindowHorizontal input_access(input->info(), 0, num_elems_processed_per_iteration);
215 AccessWindowHorizontal output_result_access(output->info(), 0, num_elems_processed_per_iteration);
216
217 update_window_and_padding(win,
218 input_access,
219 output_result_access);
220
Gian Marco6b77e912017-11-17 09:27:57 +0000221 if(bias != nullptr)
222 {
223 AccessWindowStatic bias_access(bias->info(), 0, 0, ceil_to_multiple(bias->info()->dimension(0), num_elems_processed_per_iteration), bias->info()->tensor_shape()[1]);
224
225 update_window_and_padding(win,
226 bias_access);
227 }
228
Gian Marcoe75a02b2017-11-08 12:24:09 +0000229 output_result_access.set_valid_region(win, ValidRegion(Coordinates(0, 0), output->info()->tensor_shape()));
230
231 INEKernel::configure(win);
Gian Marco6b77e912017-11-17 09:27:57 +0000232
233 const bool is_bounded_relu = ((min != max) && !(min == 0 && max == 255));
234
235 // Check if we need to clamp the result using min and max
236 _func = is_bounded_relu ? &NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel::run<true> : &NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel::run<false>;
Gian Marcoe75a02b2017-11-08 12:24:09 +0000237}
238
239void NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel::run(const Window &window, const ThreadInfo &info)
240{
241 ARM_COMPUTE_UNUSED(info);
242 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
243 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
244
Gian Marco6b77e912017-11-17 09:27:57 +0000245 (this->*_func)(window);
Gian Marcoe75a02b2017-11-08 12:24:09 +0000246}