blob: 3c1a94df7414b878f820d89a4f9444a0db177d25 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2016, 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEDepthConvertKernel.h"
25
26#include "arm_compute/core/Error.h"
27#include "arm_compute/core/Helpers.h"
28#include "arm_compute/core/ITensor.h"
29#include "arm_compute/core/NEON/NEFixedPoint.h"
30#include "arm_compute/core/TensorInfo.h"
31#include "arm_compute/core/Validate.h"
32
33#include <arm_neon.h>
34
35using namespace arm_compute;
36
37namespace arm_compute
38{
39class Coordinates;
40} // namespace arm_compute
41
42NEDepthConvertKernel::NEDepthConvertKernel()
43 : _policy(), _shift(0)
44{
45}
46
47void NEDepthConvertKernel::configure(const ITensor *input, ITensor *output, ConvertPolicy policy, uint32_t shift)
48{
Georgios Pinitas21efeb42017-07-04 12:47:17 +010049 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::U16, DataType::QS16, DataType::F32);
50 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::U16, DataType::QS16, DataType::U32, DataType::S32, DataType::F32);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010051 ARM_COMPUTE_ERROR_ON(shift >= 8);
52 ARM_COMPUTE_ERROR_ON(input == output);
53 ARM_COMPUTE_ERROR_ON_MSG(input->info()->data_type() == output->info()->data_type(), "Input and output data_types must be different");
54
Anthony Barbier6ff3b192017-09-04 18:44:23 +010055 ARM_COMPUTE_ERROR_ON_MSG(input->info()->data_type() == DataType::U8 && (output->info()->data_type() != DataType::S16 && output->info()->data_type() != DataType::U16
56 && output->info()->data_type() != DataType::S32),
57 "Only data_types supported [in] U8 -> [out] U16, S16, S32");
58
Georgios Pinitas21efeb42017-07-04 12:47:17 +010059 ARM_COMPUTE_ERROR_ON_MSG(input->info()->data_type() == DataType::QS8 && output->info()->data_type() != DataType::F32,
60 "Only data_types supported [in] QS8 -> [out] F32");
61
Anthony Barbier6ff3b192017-09-04 18:44:23 +010062 ARM_COMPUTE_ERROR_ON_MSG(input->info()->data_type() == DataType::U16 && (output->info()->data_type() != DataType::U8 && output->info()->data_type() != DataType::U32),
63 "Only data_types supported [in] U16 -> [out] U8, U32");
64
65 ARM_COMPUTE_ERROR_ON_MSG(input->info()->data_type() == DataType::S16 && (output->info()->data_type() != DataType::U8 && output->info()->data_type() != DataType::S32),
66 "Only data_types supported [in] S16 -> [out] U8, S32");
67
Georgios Pinitas21efeb42017-07-04 12:47:17 +010068 ARM_COMPUTE_ERROR_ON_MSG(input->info()->data_type() == DataType::QS16 && output->info()->data_type() != DataType::F32,
69 "Only data_types supported [in] QS16 -> [out] F32");
70
71 ARM_COMPUTE_ERROR_ON_MSG(input->info()->data_type() == DataType::F32 && (output->info()->data_type() != DataType::QS8 && output->info()->data_type() != DataType::QS16),
72 "Only data_types supported [in] F32 -> [out] QS8, QS16");
73
74 // Auto initialize output shape if not initialized (We can only auto-configure the shape, datatype must be given)
75 set_shape_if_empty(*output->info(), input->info()->tensor_shape());
76
77 ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010078
79 _policy = policy;
80 _shift = shift;
81
82 constexpr unsigned int num_elems_processed_per_iteration = 16;
83 INESimpleKernel::configure(input, output, num_elems_processed_per_iteration);
84}
85
86void NEDepthConvertKernel::run(const Window &window)
87{
88 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
89 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INESimpleKernel::window(), window);
90 ARM_COMPUTE_ERROR_ON(nullptr == _input);
91 ARM_COMPUTE_ERROR_ON(nullptr == _output);
92 ARM_COMPUTE_ERROR_ON(_input == _output);
93
94 Iterator input(_input, window);
95 Iterator output(_output, window);
96
97 switch(_input->info()->data_type())
98 {
99 case DataType::QS8:
100 {
101 const int fixed_point_position = _input->info()->fixed_point_position();
102
103 switch(_output->info()->data_type())
104 {
105 case DataType::F32:
106 {
107 /* Up-conversion QS8 -> F32 */
108 execute_window_loop(window, [&](const Coordinates & id)
109 {
110 const int8x16_t texels_s8 = vld1q_s8(reinterpret_cast<const int8_t *>(input.ptr()));
111
112 float32x4x2_t texels_low = vcvt_f32_qs8(vget_low_s8(texels_s8), fixed_point_position);
113 float32x4x2_t texels_high = vcvt_f32_qs8(vget_high_s8(texels_s8), fixed_point_position);
114
115 vst1q_f32(reinterpret_cast<float *>(output.ptr()), texels_low.val[0]);
116 vst1q_f32(reinterpret_cast<float *>(output.ptr()) + 4, texels_low.val[1]);
117 vst1q_f32(reinterpret_cast<float *>(output.ptr()) + 8, texels_high.val[0]);
118 vst1q_f32(reinterpret_cast<float *>(output.ptr()) + 12, texels_high.val[1]);
119 },
120 input, output);
121 break;
122 }
123 default:
124 ARM_COMPUTE_ERROR("Output data type not supported");
125 }
126 break;
127 }
128 case DataType::U8:
129 {
130 const int16x8_t b = vdupq_n_s16(_shift);
131
132 switch(_output->info()->data_type())
133 {
134 case DataType::S16:
135 {
136 /* Up-conversion U8 -> S16 */
137 execute_window_loop(window, [&](const Coordinates & id)
138 {
139 const uint8x16_t texels_u8 = vld1q_u8(input.ptr());
140
141 const int16x8x2_t texels =
142 {
143 {
144 vshlq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(texels_u8))), b),
145 vshlq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(texels_u8))), b)
146 }
147 };
148
149 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), texels.val[0]);
150 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()) + 8, texels.val[1]);
151 },
152 input, output);
153 break;
154 }
155 case DataType::S32:
156 {
157 /* Up-conversion U8 -> S32 */
158 execute_window_loop(window, [&](const Coordinates & id)
159 {
160 const uint8x16_t texels_u8 = vld1q_u8(input.ptr());
161
162 const int16x8x2_t texels =
163 {
164 {
165 vshlq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(texels_u8))), b),
166 vshlq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(texels_u8))), b)
167 }
168 };
169
170 vst1q_s32(reinterpret_cast<int32_t *>(output.ptr()), vmovl_s16(vget_low_s16(texels.val[0])));
171 vst1q_s32(reinterpret_cast<int32_t *>(output.ptr()) + 4, vmovl_s16(vget_high_s16(texels.val[0])));
172 vst1q_s32(reinterpret_cast<int32_t *>(output.ptr()) + 8, vmovl_s16(vget_low_s16(texels.val[1])));
173 vst1q_s32(reinterpret_cast<int32_t *>(output.ptr()) + 12, vmovl_s16(vget_high_s16(texels.val[1])));
174 },
175 input, output);
176 break;
177 }
178 case DataType::U16:
179 {
180 /* Up-conversion U8 -> U16 */
181 execute_window_loop(window, [&](const Coordinates & id)
182 {
183 const uint8x16_t texels_u8 = vld1q_u8(input.ptr());
184
185 const uint16x8x2_t texels =
186 {
187 {
188 vshlq_u16(vmovl_u8(vget_low_u8(texels_u8)), b),
189 vshlq_u16(vmovl_u8(vget_high_u8(texels_u8)), b)
190 }
191 };
192
193 vst1q_u16(reinterpret_cast<uint16_t *>(output.ptr()), texels.val[0]);
194 vst1q_u16(reinterpret_cast<uint16_t *>(output.ptr()) + 8, texels.val[1]);
195 },
196 input, output);
197 break;
198 }
199 default:
200 ARM_COMPUTE_ERROR("Output data type not supported");
201 }
202 break;
203 }
204 case DataType::S16:
205 {
206 switch(_output->info()->data_type())
207 {
208 case DataType::U8:
209 {
210 const int16x8_t b = vdupq_n_s16(-static_cast<int16_t>(_shift));
211
212 /* Down-conversion S16 -> U8 */
213 if(ConvertPolicy::SATURATE == _policy)
214 {
215 execute_window_loop(window, [&](const Coordinates & id)
216 {
217 const int16x8x2_t texels =
218 {
219 {
220 vqshlq_s16(vld1q_s16(reinterpret_cast<int16_t *>(input.ptr())), b),
221 vqshlq_s16(vld1q_s16(reinterpret_cast<int16_t *>(input.ptr()) + 8), b)
222 }
223 };
224
225 vst1q_u8(output.ptr(), vcombine_u8(vqmovun_s16(texels.val[0]), vqmovun_s16(texels.val[1])));
226 },
227 input, output);
228 }
229 else
230 {
231 execute_window_loop(window, [&](const Coordinates & id)
232 {
233 const int16x8x2_t texels =
234 {
235 {
236 vshlq_s16(vld1q_s16(reinterpret_cast<int16_t *>(input.ptr())), b),
237 vshlq_s16(vld1q_s16(reinterpret_cast<int16_t *>(input.ptr()) + 8), b)
238 }
239 };
240
241 vst1q_u8(output.ptr(), vcombine_u8(vmovn_u16(vreinterpretq_u16_s16(texels.val[0])),
242 vmovn_u16(vreinterpretq_u16_s16(texels.val[1]))));
243 },
244 input, output);
245 }
246 break;
247 }
248 case DataType::S32:
249 {
250 const int32x4_t b = vdupq_n_s32(_shift);
251
252 /* Up-conversion S16 -> S32 */
253 execute_window_loop(window, [&](const Coordinates & id)
254 {
255 const int16x8x2_t texels =
256 {
257 {
258 vld1q_s16(reinterpret_cast<int16_t *>(input.ptr())),
259 vld1q_s16(reinterpret_cast<int16_t *>(input.ptr()) + 8)
260 }
261 };
262
263 const int32x4x4_t texels_s32 =
264 {
265 {
266 vshlq_s32(vmovl_s16(vget_low_s16(texels.val[0])), b),
267 vshlq_s32(vmovl_s16(vget_high_s16(texels.val[0])), b),
268 vshlq_s32(vmovl_s16(vget_low_s16(texels.val[1])), b),
269 vshlq_s32(vmovl_s16(vget_high_s16(texels.val[1])), b)
270 }
271 };
272
273 vst1q_s32(reinterpret_cast<int32_t *>(output.ptr()), texels_s32.val[0]);
274 vst1q_s32(reinterpret_cast<int32_t *>(output.ptr()) + 4, texels_s32.val[1]);
275 vst1q_s32(reinterpret_cast<int32_t *>(output.ptr()) + 8, texels_s32.val[2]);
276 vst1q_s32(reinterpret_cast<int32_t *>(output.ptr()) + 12, texels_s32.val[3]);
277 },
278 input, output);
279 break;
280 }
281 default:
282 ARM_COMPUTE_ERROR("Output data type not supported");
283 }
284 break;
285 }
286 case DataType::U16:
287 {
288 switch(_output->info()->data_type())
289 {
290 case DataType::U8:
291 {
292 const int16x8_t b = vdupq_n_s16(-static_cast<int16_t>(_shift));
293
294 /* Down-conversion U16 -> U8 */
295 if(ConvertPolicy::SATURATE == _policy)
296 {
297 execute_window_loop(window, [&](const Coordinates & id)
298 {
299 const uint16x8x2_t texels =
300 {
301 {
302 vqshlq_u16(vld1q_u16(reinterpret_cast<uint16_t *>(input.ptr())), b),
303 vqshlq_u16(vld1q_u16(reinterpret_cast<uint16_t *>(input.ptr()) + 8), b)
304 }
305 };
306
307 vst1q_u8(output.ptr(), vcombine_u8(vqmovn_u16(texels.val[0]), vqmovn_u16(texels.val[1])));
308 },
309 input, output);
310 }
311 else
312 {
313 execute_window_loop(window, [&](const Coordinates & id)
314 {
315 const uint16x8x2_t texels =
316 {
317 {
318 vshlq_u16(vld1q_u16(reinterpret_cast<uint16_t *>(input.ptr())), b),
319 vshlq_u16(vld1q_u16(reinterpret_cast<uint16_t *>(input.ptr()) + 8), b)
320 }
321 };
322
323 vst1q_u8(output.ptr(), vcombine_u8(vmovn_u16(texels.val[0]), vmovn_u16(texels.val[1])));
324 },
325 input, output);
326 }
327 break;
328 }
329 case DataType::U32:
330 {
331 const int32x4_t b = vdupq_n_s32(_shift);
332
333 /* Up-conversion U16 -> U32 */
334 execute_window_loop(window, [&](const Coordinates & id)
335 {
336 const uint16x8x2_t texels =
337 {
338 {
339 vld1q_u16(reinterpret_cast<uint16_t *>(input.ptr())),
340 vld1q_u16(reinterpret_cast<uint16_t *>(input.ptr()) + 8)
341 }
342 };
343
344 vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr()), vshlq_u32(vmovl_u16(vget_low_u16(texels.val[0])), b));
345 vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr()) + 4, vshlq_u32(vmovl_u16(vget_high_u16(texels.val[0])), b));
346 vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr()) + 8, vshlq_u32(vmovl_u16(vget_low_u16(texels.val[1])), b));
347 vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr()) + 12, vshlq_u32(vmovl_u16(vget_high_u16(texels.val[1])), b));
348 },
349 input, output);
350 break;
351 }
352 default:
353 ARM_COMPUTE_ERROR("Output data type not supported");
354 }
355 break;
356 }
Georgios Pinitas21efeb42017-07-04 12:47:17 +0100357 case DataType::QS16:
358 {
359 const int fixed_point_position = _input->info()->fixed_point_position();
360
361 switch(_output->info()->data_type())
362 {
363 case DataType::F32:
364 {
365 /* Up-conversion QS16 -> F32 */
366 execute_window_loop(window, [&](const Coordinates & id)
367 {
368 const int16x8x2_t texels =
369 {
370 {
371 vld1q_s16(reinterpret_cast<qint16_t *>(input.ptr())),
372 vld1q_s16(reinterpret_cast<qint16_t *>(input.ptr()) + 8)
373 }
374 };
375
376 vst1q_f32(reinterpret_cast<float *>(output.ptr()), vcvt_f32_qs16(vget_low_s16(texels.val[0]), fixed_point_position));
377 vst1q_f32(reinterpret_cast<float *>(output.ptr()) + 4, vcvt_f32_qs16(vget_high_s16(texels.val[0]), fixed_point_position));
378 vst1q_f32(reinterpret_cast<float *>(output.ptr()) + 8, vcvt_f32_qs16(vget_low_s16(texels.val[1]), fixed_point_position));
379 vst1q_f32(reinterpret_cast<float *>(output.ptr()) + 12, vcvt_f32_qs16(vget_high_s16(texels.val[1]), fixed_point_position));
380 },
381 input, output);
382 break;
383 }
384 default:
385 ARM_COMPUTE_ERROR("Output data type not supported");
386 }
387 break;
388 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100389 case DataType::F32:
390 {
391 switch(_output->info()->data_type())
392 {
393 case DataType::QS8:
394 {
395 const int fixed_point_position = _output->info()->fixed_point_position();
396 /* Down-conversion F32 -> QS8 */
397 execute_window_loop(window, [&](const Coordinates & id)
398 {
399 const float32x4x4_t texels_f32 =
400 {
401 {
402 vld1q_f32(reinterpret_cast<const float *>(input.ptr())),
403 vld1q_f32(reinterpret_cast<const float *>(input.ptr()) + 4),
404 vld1q_f32(reinterpret_cast<const float *>(input.ptr()) + 8),
405 vld1q_f32(reinterpret_cast<const float *>(input.ptr()) + 12)
406 }
407 };
408
Georgios Pinitas21efeb42017-07-04 12:47:17 +0100409 const qint8x16_t texels_s8 = vqcvtq_qs8_f32(texels_f32, fixed_point_position);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100410
411 vst1q_s8(reinterpret_cast<int8_t *>(output.ptr()), texels_s8);
412 },
413 input, output);
414 break;
415 }
Georgios Pinitas21efeb42017-07-04 12:47:17 +0100416 case DataType::QS16:
417 {
418 const int fixed_point_position = _output->info()->fixed_point_position();
419 /* Down-conversion F32 -> QS16 */
420 execute_window_loop(window, [&](const Coordinates & id)
421 {
422 const float32x4x2_t texels_f32_1 =
423 {
424 {
425 vld1q_f32(reinterpret_cast<const float *>(input.ptr())),
426 vld1q_f32(reinterpret_cast<const float *>(input.ptr()) + 4),
427 }
428 };
429 const float32x4x2_t texels_f32_2 =
430 {
431 {
432 vld1q_f32(reinterpret_cast<const float *>(input.ptr()) + 8),
433 vld1q_f32(reinterpret_cast<const float *>(input.ptr()) + 12)
434 }
435 };
436
437 vst1q_s16(reinterpret_cast<qint16_t *>(output.ptr()), vqcvtq_qs16_f32(texels_f32_1, fixed_point_position));
438 vst1q_s16(reinterpret_cast<qint16_t *>(output.ptr()) + 8, vqcvtq_qs16_f32(texels_f32_2, fixed_point_position));
439 },
440 input, output);
441 break;
442 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100443 default:
444 ARM_COMPUTE_ERROR("Output data type not supported");
445 }
446 break;
447 }
448 default:
449 ARM_COMPUTE_ERROR("Not supported");
450 }
451}