blob: 56612a7703bb6896c33e925688531738eb1cdad3 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2016, 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEDepthConvertKernel.h"
25
26#include "arm_compute/core/Error.h"
27#include "arm_compute/core/Helpers.h"
28#include "arm_compute/core/ITensor.h"
29#include "arm_compute/core/NEON/NEFixedPoint.h"
30#include "arm_compute/core/TensorInfo.h"
31#include "arm_compute/core/Validate.h"
32
33#include <arm_neon.h>
34
35using namespace arm_compute;
36
37namespace arm_compute
38{
39class Coordinates;
40} // namespace arm_compute
41
42NEDepthConvertKernel::NEDepthConvertKernel()
43 : _policy(), _shift(0)
44{
45}
46
47void NEDepthConvertKernel::configure(const ITensor *input, ITensor *output, ConvertPolicy policy, uint32_t shift)
48{
49 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::U16, DataType::F32);
50 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::U16, DataType::U32, DataType::S32, DataType::F32);
51 ARM_COMPUTE_ERROR_ON(shift >= 8);
52 ARM_COMPUTE_ERROR_ON(input == output);
53 ARM_COMPUTE_ERROR_ON_MSG(input->info()->data_type() == output->info()->data_type(), "Input and output data_types must be different");
54
55 ARM_COMPUTE_ERROR_ON_MSG(input->info()->data_type() == DataType::QS8 && (output->info()->data_type() != DataType::F32),
56 "Only data_types supported [in] QS8 -> [out] F32");
57
58 ARM_COMPUTE_ERROR_ON_MSG(input->info()->data_type() == DataType::U8 && (output->info()->data_type() != DataType::S16 && output->info()->data_type() != DataType::U16
59 && output->info()->data_type() != DataType::S32),
60 "Only data_types supported [in] U8 -> [out] U16, S16, S32");
61
62 ARM_COMPUTE_ERROR_ON_MSG(input->info()->data_type() == DataType::U16 && (output->info()->data_type() != DataType::U8 && output->info()->data_type() != DataType::U32),
63 "Only data_types supported [in] U16 -> [out] U8, U32");
64
65 ARM_COMPUTE_ERROR_ON_MSG(input->info()->data_type() == DataType::S16 && (output->info()->data_type() != DataType::U8 && output->info()->data_type() != DataType::S32),
66 "Only data_types supported [in] S16 -> [out] U8, S32");
67
68 ARM_COMPUTE_ERROR_ON_MSG(input->info()->data_type() == DataType::F32 && (output->info()->data_type() != DataType::QS8),
69 "Only data_types supported [in] F32 -> [out] QS8");
70
71 _policy = policy;
72 _shift = shift;
73
74 constexpr unsigned int num_elems_processed_per_iteration = 16;
75 INESimpleKernel::configure(input, output, num_elems_processed_per_iteration);
76}
77
78void NEDepthConvertKernel::run(const Window &window)
79{
80 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
81 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INESimpleKernel::window(), window);
82 ARM_COMPUTE_ERROR_ON(nullptr == _input);
83 ARM_COMPUTE_ERROR_ON(nullptr == _output);
84 ARM_COMPUTE_ERROR_ON(_input == _output);
85
86 Iterator input(_input, window);
87 Iterator output(_output, window);
88
89 switch(_input->info()->data_type())
90 {
91 case DataType::QS8:
92 {
93 const int fixed_point_position = _input->info()->fixed_point_position();
94
95 switch(_output->info()->data_type())
96 {
97 case DataType::F32:
98 {
99 /* Up-conversion QS8 -> F32 */
100 execute_window_loop(window, [&](const Coordinates & id)
101 {
102 const int8x16_t texels_s8 = vld1q_s8(reinterpret_cast<const int8_t *>(input.ptr()));
103
104 float32x4x2_t texels_low = vcvt_f32_qs8(vget_low_s8(texels_s8), fixed_point_position);
105 float32x4x2_t texels_high = vcvt_f32_qs8(vget_high_s8(texels_s8), fixed_point_position);
106
107 vst1q_f32(reinterpret_cast<float *>(output.ptr()), texels_low.val[0]);
108 vst1q_f32(reinterpret_cast<float *>(output.ptr()) + 4, texels_low.val[1]);
109 vst1q_f32(reinterpret_cast<float *>(output.ptr()) + 8, texels_high.val[0]);
110 vst1q_f32(reinterpret_cast<float *>(output.ptr()) + 12, texels_high.val[1]);
111 },
112 input, output);
113 break;
114 }
115 default:
116 ARM_COMPUTE_ERROR("Output data type not supported");
117 }
118 break;
119 }
120 case DataType::U8:
121 {
122 const int16x8_t b = vdupq_n_s16(_shift);
123
124 switch(_output->info()->data_type())
125 {
126 case DataType::S16:
127 {
128 /* Up-conversion U8 -> S16 */
129 execute_window_loop(window, [&](const Coordinates & id)
130 {
131 const uint8x16_t texels_u8 = vld1q_u8(input.ptr());
132
133 const int16x8x2_t texels =
134 {
135 {
136 vshlq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(texels_u8))), b),
137 vshlq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(texels_u8))), b)
138 }
139 };
140
141 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), texels.val[0]);
142 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()) + 8, texels.val[1]);
143 },
144 input, output);
145 break;
146 }
147 case DataType::S32:
148 {
149 /* Up-conversion U8 -> S32 */
150 execute_window_loop(window, [&](const Coordinates & id)
151 {
152 const uint8x16_t texels_u8 = vld1q_u8(input.ptr());
153
154 const int16x8x2_t texels =
155 {
156 {
157 vshlq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(texels_u8))), b),
158 vshlq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(texels_u8))), b)
159 }
160 };
161
162 vst1q_s32(reinterpret_cast<int32_t *>(output.ptr()), vmovl_s16(vget_low_s16(texels.val[0])));
163 vst1q_s32(reinterpret_cast<int32_t *>(output.ptr()) + 4, vmovl_s16(vget_high_s16(texels.val[0])));
164 vst1q_s32(reinterpret_cast<int32_t *>(output.ptr()) + 8, vmovl_s16(vget_low_s16(texels.val[1])));
165 vst1q_s32(reinterpret_cast<int32_t *>(output.ptr()) + 12, vmovl_s16(vget_high_s16(texels.val[1])));
166 },
167 input, output);
168 break;
169 }
170 case DataType::U16:
171 {
172 /* Up-conversion U8 -> U16 */
173 execute_window_loop(window, [&](const Coordinates & id)
174 {
175 const uint8x16_t texels_u8 = vld1q_u8(input.ptr());
176
177 const uint16x8x2_t texels =
178 {
179 {
180 vshlq_u16(vmovl_u8(vget_low_u8(texels_u8)), b),
181 vshlq_u16(vmovl_u8(vget_high_u8(texels_u8)), b)
182 }
183 };
184
185 vst1q_u16(reinterpret_cast<uint16_t *>(output.ptr()), texels.val[0]);
186 vst1q_u16(reinterpret_cast<uint16_t *>(output.ptr()) + 8, texels.val[1]);
187 },
188 input, output);
189 break;
190 }
191 default:
192 ARM_COMPUTE_ERROR("Output data type not supported");
193 }
194 break;
195 }
196 case DataType::S16:
197 {
198 switch(_output->info()->data_type())
199 {
200 case DataType::U8:
201 {
202 const int16x8_t b = vdupq_n_s16(-static_cast<int16_t>(_shift));
203
204 /* Down-conversion S16 -> U8 */
205 if(ConvertPolicy::SATURATE == _policy)
206 {
207 execute_window_loop(window, [&](const Coordinates & id)
208 {
209 const int16x8x2_t texels =
210 {
211 {
212 vqshlq_s16(vld1q_s16(reinterpret_cast<int16_t *>(input.ptr())), b),
213 vqshlq_s16(vld1q_s16(reinterpret_cast<int16_t *>(input.ptr()) + 8), b)
214 }
215 };
216
217 vst1q_u8(output.ptr(), vcombine_u8(vqmovun_s16(texels.val[0]), vqmovun_s16(texels.val[1])));
218 },
219 input, output);
220 }
221 else
222 {
223 execute_window_loop(window, [&](const Coordinates & id)
224 {
225 const int16x8x2_t texels =
226 {
227 {
228 vshlq_s16(vld1q_s16(reinterpret_cast<int16_t *>(input.ptr())), b),
229 vshlq_s16(vld1q_s16(reinterpret_cast<int16_t *>(input.ptr()) + 8), b)
230 }
231 };
232
233 vst1q_u8(output.ptr(), vcombine_u8(vmovn_u16(vreinterpretq_u16_s16(texels.val[0])),
234 vmovn_u16(vreinterpretq_u16_s16(texels.val[1]))));
235 },
236 input, output);
237 }
238 break;
239 }
240 case DataType::S32:
241 {
242 const int32x4_t b = vdupq_n_s32(_shift);
243
244 /* Up-conversion S16 -> S32 */
245 execute_window_loop(window, [&](const Coordinates & id)
246 {
247 const int16x8x2_t texels =
248 {
249 {
250 vld1q_s16(reinterpret_cast<int16_t *>(input.ptr())),
251 vld1q_s16(reinterpret_cast<int16_t *>(input.ptr()) + 8)
252 }
253 };
254
255 const int32x4x4_t texels_s32 =
256 {
257 {
258 vshlq_s32(vmovl_s16(vget_low_s16(texels.val[0])), b),
259 vshlq_s32(vmovl_s16(vget_high_s16(texels.val[0])), b),
260 vshlq_s32(vmovl_s16(vget_low_s16(texels.val[1])), b),
261 vshlq_s32(vmovl_s16(vget_high_s16(texels.val[1])), b)
262 }
263 };
264
265 vst1q_s32(reinterpret_cast<int32_t *>(output.ptr()), texels_s32.val[0]);
266 vst1q_s32(reinterpret_cast<int32_t *>(output.ptr()) + 4, texels_s32.val[1]);
267 vst1q_s32(reinterpret_cast<int32_t *>(output.ptr()) + 8, texels_s32.val[2]);
268 vst1q_s32(reinterpret_cast<int32_t *>(output.ptr()) + 12, texels_s32.val[3]);
269 },
270 input, output);
271 break;
272 }
273 default:
274 ARM_COMPUTE_ERROR("Output data type not supported");
275 }
276 break;
277 }
278 case DataType::U16:
279 {
280 switch(_output->info()->data_type())
281 {
282 case DataType::U8:
283 {
284 const int16x8_t b = vdupq_n_s16(-static_cast<int16_t>(_shift));
285
286 /* Down-conversion U16 -> U8 */
287 if(ConvertPolicy::SATURATE == _policy)
288 {
289 execute_window_loop(window, [&](const Coordinates & id)
290 {
291 const uint16x8x2_t texels =
292 {
293 {
294 vqshlq_u16(vld1q_u16(reinterpret_cast<uint16_t *>(input.ptr())), b),
295 vqshlq_u16(vld1q_u16(reinterpret_cast<uint16_t *>(input.ptr()) + 8), b)
296 }
297 };
298
299 vst1q_u8(output.ptr(), vcombine_u8(vqmovn_u16(texels.val[0]), vqmovn_u16(texels.val[1])));
300 },
301 input, output);
302 }
303 else
304 {
305 execute_window_loop(window, [&](const Coordinates & id)
306 {
307 const uint16x8x2_t texels =
308 {
309 {
310 vshlq_u16(vld1q_u16(reinterpret_cast<uint16_t *>(input.ptr())), b),
311 vshlq_u16(vld1q_u16(reinterpret_cast<uint16_t *>(input.ptr()) + 8), b)
312 }
313 };
314
315 vst1q_u8(output.ptr(), vcombine_u8(vmovn_u16(texels.val[0]), vmovn_u16(texels.val[1])));
316 },
317 input, output);
318 }
319 break;
320 }
321 case DataType::U32:
322 {
323 const int32x4_t b = vdupq_n_s32(_shift);
324
325 /* Up-conversion U16 -> U32 */
326 execute_window_loop(window, [&](const Coordinates & id)
327 {
328 const uint16x8x2_t texels =
329 {
330 {
331 vld1q_u16(reinterpret_cast<uint16_t *>(input.ptr())),
332 vld1q_u16(reinterpret_cast<uint16_t *>(input.ptr()) + 8)
333 }
334 };
335
336 vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr()), vshlq_u32(vmovl_u16(vget_low_u16(texels.val[0])), b));
337 vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr()) + 4, vshlq_u32(vmovl_u16(vget_high_u16(texels.val[0])), b));
338 vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr()) + 8, vshlq_u32(vmovl_u16(vget_low_u16(texels.val[1])), b));
339 vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr()) + 12, vshlq_u32(vmovl_u16(vget_high_u16(texels.val[1])), b));
340 },
341 input, output);
342 break;
343 }
344 default:
345 ARM_COMPUTE_ERROR("Output data type not supported");
346 }
347 break;
348 }
349 case DataType::F32:
350 {
351 switch(_output->info()->data_type())
352 {
353 case DataType::QS8:
354 {
355 const int fixed_point_position = _output->info()->fixed_point_position();
356 /* Down-conversion F32 -> QS8 */
357 execute_window_loop(window, [&](const Coordinates & id)
358 {
359 const float32x4x4_t texels_f32 =
360 {
361 {
362 vld1q_f32(reinterpret_cast<const float *>(input.ptr())),
363 vld1q_f32(reinterpret_cast<const float *>(input.ptr()) + 4),
364 vld1q_f32(reinterpret_cast<const float *>(input.ptr()) + 8),
365 vld1q_f32(reinterpret_cast<const float *>(input.ptr()) + 12)
366 }
367 };
368
369 const qint8x16_t texels_s8 = vcvtq_qs8_f32(texels_f32, fixed_point_position);
370
371 vst1q_s8(reinterpret_cast<int8_t *>(output.ptr()), texels_s8);
372 },
373 input, output);
374 break;
375 }
376 default:
377 ARM_COMPUTE_ERROR("Output data type not supported");
378 }
379 break;
380 }
381 default:
382 ARM_COMPUTE_ERROR("Not supported");
383 }
384}