blob: a990e9068e75fe73f3dca901211511cda6677559 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NETransposeKernel.h"
25
26#include "arm_compute/core/AccessWindowTranspose.h"
27#include "arm_compute/core/Error.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/ITensor.h"
30#include "arm_compute/core/Validate.h"
31
32#include <arm_neon.h>
33
34using namespace arm_compute;
35
36namespace arm_compute
37{
38class Coordinates;
39} // namespace arm_compute
40
41namespace
42{
43void transpose_8bit_elements(const ITensor *in, ITensor *out, const Window &window)
44{
45 Window window_out(window);
46 window_out.set(Window::DimX, Window::Dimension(0, 0, 0));
47 window_out.set(Window::DimY, Window::Dimension(0, 0, 0));
48
49 Iterator input(in, window);
50 Iterator output(out, window_out);
51
52 const size_t input_stride_in_bytes = in->info()->strides_in_bytes()[1];
53 const size_t output_stride_in_bytes = out->info()->strides_in_bytes()[1];
54
55 execute_window_loop(window, [&](const Coordinates & id)
56 {
57 const uint8x8_t row0 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + 0 * input_stride_in_bytes));
58 const uint8x8_t row1 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + 1 * input_stride_in_bytes));
59 const uint8x8_t row2 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + 2 * input_stride_in_bytes));
60 const uint8x8_t row3 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + 3 * input_stride_in_bytes));
61 const uint8x8_t row4 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + 4 * input_stride_in_bytes));
62 const uint8x8_t row5 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + 5 * input_stride_in_bytes));
63 const uint8x8_t row6 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + 6 * input_stride_in_bytes));
64 const uint8x8_t row7 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + 7 * input_stride_in_bytes));
65
66 // Transpose 2x2
67 const uint8x8x2_t k0_u8 = vtrn_u8(row0, row1);
68 const uint8x8x2_t k1_u8 = vtrn_u8(row2, row3);
69 const uint8x8x2_t k2_u8 = vtrn_u8(row4, row5);
70 const uint8x8x2_t k3_u8 = vtrn_u8(row6, row7);
71
72 // Transpose 4x4
73 const uint16x4x2_t k0_u16 = vtrn_u16(vreinterpret_u16_u8(k0_u8.val[0]), vreinterpret_u16_u8(k1_u8.val[0]));
74 const uint16x4x2_t k1_u16 = vtrn_u16(vreinterpret_u16_u8(k0_u8.val[1]), vreinterpret_u16_u8(k1_u8.val[1]));
75 const uint16x4x2_t k2_u16 = vtrn_u16(vreinterpret_u16_u8(k2_u8.val[0]), vreinterpret_u16_u8(k3_u8.val[0]));
76 const uint16x4x2_t k3_u16 = vtrn_u16(vreinterpret_u16_u8(k2_u8.val[1]), vreinterpret_u16_u8(k3_u8.val[1]));
77
78 // Transpose 8x8
79 const uint32x2x2_t k0_u32 = vtrn_u32(vreinterpret_u32_u16(k0_u16.val[0]), vreinterpret_u32_u16(k2_u16.val[0]));
80 const uint32x2x2_t k1_u32 = vtrn_u32(vreinterpret_u32_u16(k0_u16.val[1]), vreinterpret_u32_u16(k2_u16.val[1]));
81 const uint32x2x2_t k2_u32 = vtrn_u32(vreinterpret_u32_u16(k1_u16.val[0]), vreinterpret_u32_u16(k3_u16.val[0]));
82 const uint32x2x2_t k3_u32 = vtrn_u32(vreinterpret_u32_u16(k1_u16.val[1]), vreinterpret_u32_u16(k3_u16.val[1]));
83
84 // Compute destination address
85 const size_t dst_offset_in_bytes = id.y() * sizeof(uint8_t) + id.x() * output_stride_in_bytes;
86
87 vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k0_u32.val[0])));
88 vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k2_u32.val[0])));
89 vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k1_u32.val[0])));
90 vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k3_u32.val[0])));
91 vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 4 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k0_u32.val[1])));
92 vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 5 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k2_u32.val[1])));
93 vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 6 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k1_u32.val[1])));
94 vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 7 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k3_u32.val[1])));
95 },
96 input, output);
97}
98
99void transpose_16bit_elements(const ITensor *in, ITensor *out, const Window &window)
100{
101 Window window_out(window);
102 window_out.set(Window::DimX, Window::Dimension(0, 0, 0));
103 window_out.set(Window::DimY, Window::Dimension(0, 0, 0));
104
105 Iterator input(in, window);
106 Iterator output(out, window_out);
107
108 const size_t input_stride_in_bytes = in->info()->strides_in_bytes()[1];
109 const size_t output_stride_in_bytes = out->info()->strides_in_bytes()[1];
110
111 execute_window_loop(window, [&](const Coordinates & id)
112 {
113 const uint16x4_t row0 = vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 0 * input_stride_in_bytes));
114 const uint16x4_t row1 = vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 1 * input_stride_in_bytes));
115 const uint16x4_t row2 = vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 2 * input_stride_in_bytes));
116 const uint16x4_t row3 = vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 3 * input_stride_in_bytes));
117
118 // Transpose 2x2
119 const uint16x4x2_t k0_u16 = vtrn_u16(row0, row1);
120 const uint16x4x2_t k1_u16 = vtrn_u16(row2, row3);
121
122 // Transpose 4x4
123 const uint32x2x2_t k0_u32 = vtrn_u32(vreinterpret_u32_u16(k0_u16.val[0]), vreinterpret_u32_u16(k1_u16.val[0]));
124 const uint32x2x2_t k1_u32 = vtrn_u32(vreinterpret_u32_u16(k0_u16.val[1]), vreinterpret_u32_u16(k1_u16.val[1]));
125
126 // Compute destination address
127 const size_t dst_offset_in_bytes = id.y() * sizeof(uint16_t) + id.x() * output_stride_in_bytes;
128
129 vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes), vreinterpret_u16_u32(k0_u32.val[0]));
130 vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes), vreinterpret_u16_u32(k1_u32.val[0]));
131 vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes), vreinterpret_u16_u32(k0_u32.val[1]));
132 vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes), vreinterpret_u16_u32(k1_u32.val[1]));
133 },
134 input, output);
135}
136
137void transpose_32bit_elements(const ITensor *in, ITensor *out, const Window &window)
138{
139 Window window_out(window);
140 window_out.set(Window::DimX, Window::Dimension(0, 0, 0));
141 window_out.set(Window::DimY, Window::Dimension(0, 0, 0));
142
143 Iterator input(in, window);
144 Iterator output(out, window_out);
145
146 const size_t input_stride_in_bytes = in->info()->strides_in_bytes()[1];
147 const size_t output_stride_in_bytes = out->info()->strides_in_bytes()[1];
148
149 execute_window_loop(window, [&](const Coordinates & id)
150 {
151 const uint32x4_t row0 = vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 0 * input_stride_in_bytes));
152 const uint32x4_t row1 = vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 1 * input_stride_in_bytes));
153 const uint32x4_t row2 = vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 2 * input_stride_in_bytes));
154 const uint32x4_t row3 = vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 3 * input_stride_in_bytes));
155
156 // Transpose 2x2
157 const uint32x2x2_t k0_u32 = vtrn_u32(vget_low_u32(row0), vget_low_u32(row1));
158 const uint32x2x2_t k1_u32 = vtrn_u32(vget_high_u32(row2), vget_high_u32(row3));
159 const uint32x2x2_t k2_u32 = vtrn_u32(vget_high_u32(row0), vget_high_u32(row1));
160 const uint32x2x2_t k3_u32 = vtrn_u32(vget_low_u32(row2), vget_low_u32(row3));
161
162 // Compute destination address
163 const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + id.x() * output_stride_in_bytes;
164
165 // Swap block 01 with block 10 and store
166 vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes), vcombine_u32(k0_u32.val[0], k3_u32.val[0]));
167 vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes), vcombine_u32(k0_u32.val[1], k3_u32.val[1]));
168 vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes), vcombine_u32(k2_u32.val[0], k1_u32.val[0]));
169 vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes), vcombine_u32(k2_u32.val[1], k1_u32.val[1]));
170 },
171 input, output);
172}
173} // namespace
174
175NETransposeKernel::NETransposeKernel()
176 : _func(nullptr), _input(nullptr), _output(nullptr)
177{
178}
179
180void NETransposeKernel::configure(const ITensor *input, ITensor *output)
181{
182 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QS8, DataType::U16, DataType::S16, DataType::U32, DataType::S32, DataType::F16, DataType::F32);
Gian Marco Iodiceec8b45e2017-06-22 13:00:39 +0100183 ARM_COMPUTE_ERROR_ON_NULLPTR(output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100184
185 TensorShape output_shape{ input->info()->tensor_shape() };
186 const size_t w_out = input->info()->dimension(1);
187 const size_t h_out = input->info()->dimension(0);
188 output_shape.set(0, w_out);
189 output_shape.set(1, h_out);
190
191 // Output tensor auto inizialitation if not yet initialized
192 auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type(), input->info()->fixed_point_position());
193
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100194 ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape);
Gian Marco Iodiceec8b45e2017-06-22 13:00:39 +0100195 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
196 ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100197
198 _input = input;
199 _output = output;
200
201 unsigned int num_elems_processed_per_iteration = 0;
202
203 switch(input->info()->element_size())
204 {
205 case 1:
206 _func = &transpose_8bit_elements;
207 num_elems_processed_per_iteration = 8;
208 break;
209 case 2:
210 _func = &transpose_16bit_elements;
211 num_elems_processed_per_iteration = 4;
212 break;
213 case 4:
214 _func = &transpose_32bit_elements;
215 num_elems_processed_per_iteration = 4;
216 break;
217 default:
218 ARM_COMPUTE_ERROR("Element size not supported");
219 break;
220 }
221
222 // Configure kernel window
223 Window win = calculate_max_window(*input->info(), Steps(num_elems_processed_per_iteration, num_elems_processed_per_iteration));
224 AccessWindowTranspose output_access(output->info(), 0, 0, num_elems_processed_per_iteration, num_elems_processed_per_iteration);
225
226 update_window_and_padding(win,
227 AccessWindowRectangle(input->info(), 0, 0, num_elems_processed_per_iteration, num_elems_processed_per_iteration),
228 output_access);
229
230 output_access.set_valid_region(win, input->info()->valid_region());
231
232 INEKernel::configure(win);
233}
234
235void NETransposeKernel::run(const Window &window)
236{
237 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
238 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
239 ARM_COMPUTE_ERROR_ON(_func == nullptr);
240
241 (*_func)(_input, _output, window);
242}