blob: 44becf5a4b1e6df9a3c861153107378411a35efa [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NETransposeKernel.h"
25
Gian Marco5420b282017-11-29 10:41:38 +000026#include "arm_compute/core/AccessWindowStatic.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010027#include "arm_compute/core/Error.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/ITensor.h"
Gian Marco5420b282017-11-29 10:41:38 +000030#include "arm_compute/core/Utils.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010031#include "arm_compute/core/Validate.h"
32
33#include <arm_neon.h>
34
35using namespace arm_compute;
36
37namespace arm_compute
38{
39class Coordinates;
40} // namespace arm_compute
41
42namespace
43{
44void transpose_8bit_elements(const ITensor *in, ITensor *out, const Window &window)
45{
46 Window window_out(window);
47 window_out.set(Window::DimX, Window::Dimension(0, 0, 0));
48 window_out.set(Window::DimY, Window::Dimension(0, 0, 0));
49
50 Iterator input(in, window);
51 Iterator output(out, window_out);
52
53 const size_t input_stride_in_bytes = in->info()->strides_in_bytes()[1];
54 const size_t output_stride_in_bytes = out->info()->strides_in_bytes()[1];
55
56 execute_window_loop(window, [&](const Coordinates & id)
57 {
58 const uint8x8_t row0 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + 0 * input_stride_in_bytes));
59 const uint8x8_t row1 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + 1 * input_stride_in_bytes));
60 const uint8x8_t row2 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + 2 * input_stride_in_bytes));
61 const uint8x8_t row3 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + 3 * input_stride_in_bytes));
62 const uint8x8_t row4 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + 4 * input_stride_in_bytes));
63 const uint8x8_t row5 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + 5 * input_stride_in_bytes));
64 const uint8x8_t row6 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + 6 * input_stride_in_bytes));
65 const uint8x8_t row7 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + 7 * input_stride_in_bytes));
66
67 // Transpose 2x2
68 const uint8x8x2_t k0_u8 = vtrn_u8(row0, row1);
69 const uint8x8x2_t k1_u8 = vtrn_u8(row2, row3);
70 const uint8x8x2_t k2_u8 = vtrn_u8(row4, row5);
71 const uint8x8x2_t k3_u8 = vtrn_u8(row6, row7);
72
73 // Transpose 4x4
74 const uint16x4x2_t k0_u16 = vtrn_u16(vreinterpret_u16_u8(k0_u8.val[0]), vreinterpret_u16_u8(k1_u8.val[0]));
75 const uint16x4x2_t k1_u16 = vtrn_u16(vreinterpret_u16_u8(k0_u8.val[1]), vreinterpret_u16_u8(k1_u8.val[1]));
76 const uint16x4x2_t k2_u16 = vtrn_u16(vreinterpret_u16_u8(k2_u8.val[0]), vreinterpret_u16_u8(k3_u8.val[0]));
77 const uint16x4x2_t k3_u16 = vtrn_u16(vreinterpret_u16_u8(k2_u8.val[1]), vreinterpret_u16_u8(k3_u8.val[1]));
78
79 // Transpose 8x8
80 const uint32x2x2_t k0_u32 = vtrn_u32(vreinterpret_u32_u16(k0_u16.val[0]), vreinterpret_u32_u16(k2_u16.val[0]));
81 const uint32x2x2_t k1_u32 = vtrn_u32(vreinterpret_u32_u16(k0_u16.val[1]), vreinterpret_u32_u16(k2_u16.val[1]));
82 const uint32x2x2_t k2_u32 = vtrn_u32(vreinterpret_u32_u16(k1_u16.val[0]), vreinterpret_u32_u16(k3_u16.val[0]));
83 const uint32x2x2_t k3_u32 = vtrn_u32(vreinterpret_u32_u16(k1_u16.val[1]), vreinterpret_u32_u16(k3_u16.val[1]));
84
85 // Compute destination address
86 const size_t dst_offset_in_bytes = id.y() * sizeof(uint8_t) + id.x() * output_stride_in_bytes;
87
88 vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k0_u32.val[0])));
89 vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k2_u32.val[0])));
90 vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k1_u32.val[0])));
91 vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k3_u32.val[0])));
92 vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 4 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k0_u32.val[1])));
93 vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 5 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k2_u32.val[1])));
94 vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 6 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k1_u32.val[1])));
95 vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 7 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k3_u32.val[1])));
96 },
97 input, output);
98}
99
100void transpose_16bit_elements(const ITensor *in, ITensor *out, const Window &window)
101{
102 Window window_out(window);
103 window_out.set(Window::DimX, Window::Dimension(0, 0, 0));
104 window_out.set(Window::DimY, Window::Dimension(0, 0, 0));
105
106 Iterator input(in, window);
107 Iterator output(out, window_out);
108
109 const size_t input_stride_in_bytes = in->info()->strides_in_bytes()[1];
110 const size_t output_stride_in_bytes = out->info()->strides_in_bytes()[1];
111
112 execute_window_loop(window, [&](const Coordinates & id)
113 {
114 const uint16x4_t row0 = vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 0 * input_stride_in_bytes));
115 const uint16x4_t row1 = vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 1 * input_stride_in_bytes));
116 const uint16x4_t row2 = vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 2 * input_stride_in_bytes));
117 const uint16x4_t row3 = vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 3 * input_stride_in_bytes));
118
119 // Transpose 2x2
120 const uint16x4x2_t k0_u16 = vtrn_u16(row0, row1);
121 const uint16x4x2_t k1_u16 = vtrn_u16(row2, row3);
122
123 // Transpose 4x4
124 const uint32x2x2_t k0_u32 = vtrn_u32(vreinterpret_u32_u16(k0_u16.val[0]), vreinterpret_u32_u16(k1_u16.val[0]));
125 const uint32x2x2_t k1_u32 = vtrn_u32(vreinterpret_u32_u16(k0_u16.val[1]), vreinterpret_u32_u16(k1_u16.val[1]));
126
127 // Compute destination address
128 const size_t dst_offset_in_bytes = id.y() * sizeof(uint16_t) + id.x() * output_stride_in_bytes;
129
130 vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes), vreinterpret_u16_u32(k0_u32.val[0]));
131 vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes), vreinterpret_u16_u32(k1_u32.val[0]));
132 vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes), vreinterpret_u16_u32(k0_u32.val[1]));
133 vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes), vreinterpret_u16_u32(k1_u32.val[1]));
134 },
135 input, output);
136}
137
138void transpose_32bit_elements(const ITensor *in, ITensor *out, const Window &window)
139{
140 Window window_out(window);
141 window_out.set(Window::DimX, Window::Dimension(0, 0, 0));
142 window_out.set(Window::DimY, Window::Dimension(0, 0, 0));
143
144 Iterator input(in, window);
145 Iterator output(out, window_out);
146
147 const size_t input_stride_in_bytes = in->info()->strides_in_bytes()[1];
148 const size_t output_stride_in_bytes = out->info()->strides_in_bytes()[1];
149
150 execute_window_loop(window, [&](const Coordinates & id)
151 {
152 const uint32x4_t row0 = vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 0 * input_stride_in_bytes));
153 const uint32x4_t row1 = vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 1 * input_stride_in_bytes));
154 const uint32x4_t row2 = vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 2 * input_stride_in_bytes));
155 const uint32x4_t row3 = vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 3 * input_stride_in_bytes));
156
157 // Transpose 2x2
158 const uint32x2x2_t k0_u32 = vtrn_u32(vget_low_u32(row0), vget_low_u32(row1));
159 const uint32x2x2_t k1_u32 = vtrn_u32(vget_high_u32(row2), vget_high_u32(row3));
160 const uint32x2x2_t k2_u32 = vtrn_u32(vget_high_u32(row0), vget_high_u32(row1));
161 const uint32x2x2_t k3_u32 = vtrn_u32(vget_low_u32(row2), vget_low_u32(row3));
162
163 // Compute destination address
164 const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + id.x() * output_stride_in_bytes;
165
166 // Swap block 01 with block 10 and store
167 vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes), vcombine_u32(k0_u32.val[0], k3_u32.val[0]));
168 vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes), vcombine_u32(k0_u32.val[1], k3_u32.val[1]));
169 vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes), vcombine_u32(k2_u32.val[0], k1_u32.val[0]));
170 vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes), vcombine_u32(k2_u32.val[1], k1_u32.val[1]));
171 },
172 input, output);
173}
174} // namespace
175
176NETransposeKernel::NETransposeKernel()
177 : _func(nullptr), _input(nullptr), _output(nullptr)
178{
179}
180
181void NETransposeKernel::configure(const ITensor *input, ITensor *output)
182{
Gian Marco Iodice2bbd9642017-07-04 16:46:32 +0100183 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QS8, DataType::U16, DataType::S16, DataType::QS16, DataType::U32, DataType::S32, DataType::F16,
184 DataType::F32);
Gian Marco Iodiceec8b45e2017-06-22 13:00:39 +0100185 ARM_COMPUTE_ERROR_ON_NULLPTR(output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100186
187 TensorShape output_shape{ input->info()->tensor_shape() };
188 const size_t w_out = input->info()->dimension(1);
189 const size_t h_out = input->info()->dimension(0);
190 output_shape.set(0, w_out);
191 output_shape.set(1, h_out);
192
193 // Output tensor auto inizialitation if not yet initialized
194 auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type(), input->info()->fixed_point_position());
195
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100196 ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape);
Gian Marco Iodiceec8b45e2017-06-22 13:00:39 +0100197 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
198 ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100199
200 _input = input;
201 _output = output;
202
203 unsigned int num_elems_processed_per_iteration = 0;
204
205 switch(input->info()->element_size())
206 {
207 case 1:
208 _func = &transpose_8bit_elements;
209 num_elems_processed_per_iteration = 8;
210 break;
211 case 2:
212 _func = &transpose_16bit_elements;
213 num_elems_processed_per_iteration = 4;
214 break;
215 case 4:
216 _func = &transpose_32bit_elements;
217 num_elems_processed_per_iteration = 4;
218 break;
219 default:
220 ARM_COMPUTE_ERROR("Element size not supported");
221 break;
222 }
223
224 // Configure kernel window
Gian Marco5420b282017-11-29 10:41:38 +0000225 Window win = calculate_max_window(*input->info(), Steps(num_elems_processed_per_iteration, num_elems_processed_per_iteration));
226
227 // TODO (COMPMID-708): Replace AccessWindowStatic with AccessWindowTranspose
228 AccessWindowStatic output_access(output->info(), 0, 0, ceil_to_multiple(output->info()->dimension(0), num_elems_processed_per_iteration), ceil_to_multiple(output->info()->dimension(1),
229 num_elems_processed_per_iteration));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100230
231 update_window_and_padding(win,
232 AccessWindowRectangle(input->info(), 0, 0, num_elems_processed_per_iteration, num_elems_processed_per_iteration),
233 output_access);
234
235 output_access.set_valid_region(win, input->info()->valid_region());
236
237 INEKernel::configure(win);
238}
239
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100240void NETransposeKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100241{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100242 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100243 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
244 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
245 ARM_COMPUTE_ERROR_ON(_func == nullptr);
246
247 (*_func)(_input, _output, window);
248}