blob: 915ea754318674091717acf76a1c3af9fc54e51c [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2017-2020 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEIm2ColKernel.h"
25
26#include "arm_compute/core/Error.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010027#include "arm_compute/core/Helpers.h"
28#include "arm_compute/core/ITensor.h"
Gian Marco Iodice13edbff2017-06-26 17:20:16 +010029#include "arm_compute/core/Size2D.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010030#include "arm_compute/core/TensorInfo.h"
31#include "arm_compute/core/Types.h"
32#include "arm_compute/core/Validate.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010033#include "src/core/CPP/Validate.h"
34#include "src/core/helpers/AutoConfiguration.h"
35#include "src/core/helpers/WindowHelpers.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010036
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +000037#include "arm_compute/core/utils/misc/ShapeCalculator.h"
38
Anthony Barbier6ff3b192017-09-04 18:44:23 +010039#include <arm_neon.h>
40#include <cstddef>
41#include <cstdint>
42#include <cstring>
43#include <tuple>
44
45using namespace arm_compute;
Giorgio Arena368e6352018-08-20 15:06:07 +010046using namespace misc::shape_calculator;
Anthony Barbier6ff3b192017-09-04 18:44:23 +010047
48namespace
49{
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +000050Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info,
Giorgio Arena368e6352018-08-20 15:06:07 +010051 bool has_bias, const Size2D &dilation, unsigned int num_groups)
Georgios Pinitasd912fd82017-11-27 21:00:13 +000052{
Anthony Barbiereaefd002018-07-20 17:49:35 +010053 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
Georgios Pinitasc7b183a2020-03-06 18:12:09 +000054 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::BFLOAT16, DataType::F16, DataType::F32);
Georgios Pinitas6e1791b2019-12-02 19:01:25 +000055 ARM_COMPUTE_RETURN_ERROR_ON(is_data_type_quantized(input->data_type()) && has_bias);
Alex Gilday7da29b62018-03-23 14:16:00 +000056 ARM_COMPUTE_RETURN_ERROR_ON((dilation.x() < 1) || (dilation.y() < 1));
Giorgio Arena0f170392018-07-18 16:13:12 +010057 ARM_COMPUTE_RETURN_ERROR_ON_MSG(num_groups > 1, "Number of groups greater than one are not supported on NEON");
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +000058
Gian Marco Iodice215b4ea2018-06-28 16:29:29 +010059 if(output->total_size() > 0)
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +000060 {
Giorgio Arena368e6352018-08-20 15:06:07 +010061 TensorInfo expected_output = output->clone()->set_tensor_shape(compute_im2col_conv_shape(input, kernel_dims, conv_info, has_bias, dilation, false));
Gian Marco Iodice215b4ea2018-06-28 16:29:29 +010062 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&expected_output, output);
63 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
Isabella Gottardi0a1090a2019-02-14 18:07:36 +000064 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
Gian Marco Iodice215b4ea2018-06-28 16:29:29 +010065 }
Giorgio Arena156fcf32018-03-09 15:30:43 +000066
Georgios Pinitas631c41a2017-12-06 11:53:03 +000067 return Status{};
Georgios Pinitasd912fd82017-11-27 21:00:13 +000068}
69
Giorgio Arena368e6352018-08-20 15:06:07 +010070std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info,
71 bool has_bias, const Size2D &dilation)
72{
73 const unsigned int width_idx = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::WIDTH);
74 const unsigned int height_idx = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT);
75 const unsigned int channel_idx = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL);
76
77 std::pair<unsigned int, unsigned int> convolved_dims = scaled_dimensions(input->dimension(width_idx), input->dimension(height_idx),
78 kernel_dims.width, kernel_dims.height,
79 conv_info, dilation);
80
81 // Output tensor auto initialization if not yet initialized
82 auto_init_if_empty(*output, input->clone()->set_tensor_shape(compute_im2col_conv_shape(input, kernel_dims, conv_info, has_bias, dilation, false)));
83
84 Window win = calculate_max_window(*input, Steps());
85 win.set(width_idx, Window::Dimension(0, convolved_dims.first, 1));
86 win.set(height_idx, Window::Dimension(0, convolved_dims.second, 1));
87 win.set(channel_idx, Window::Dimension(0, 1, 1));
88
89 // The NEIm2ColKernel doesn't need padding so update_window_and_padding() can be skipped
90 output->set_valid_region(ValidRegion(Coordinates(), output->tensor_shape()));
91
92 return std::make_pair(Status{}, win);
93}
94
Anthony Barbier6ff3b192017-09-04 18:44:23 +010095template <typename T, bool has_pads>
Giorgio Arenafb629082018-08-20 18:03:27 +010096inline void linearize_volume_nchw(const uint8_t *const in_ptr,
97 T *out_ptr,
98 bool has_bias,
99 int top_left_x,
100 int top_left_y,
101 int kernel_width,
102 int kernel_height,
103 int kernel_depth,
104 int input_w,
105 int input_h,
106 int input_stride_x,
107 int input_stride_y,
108 int input_stride_z,
109 int pad_value,
110 int dilation_x,
111 int dilation_y)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100112{
Gian Marco Iodice13edbff2017-06-26 17:20:16 +0100113 const int kernel_size2 = kernel_width * kernel_height;
Alex Gilday7da29b62018-03-23 14:16:00 +0000114 const int x_e = top_left_x + kernel_width * dilation_x;
115 const int y_e = top_left_y + kernel_height * dilation_y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100116
117 // Linearize volume
118 int d = 0;
119 // This for loop linearize a volume with 3 slices. This allows:
120 // 1) to reduce the iterations of the outer for loop "d"
121 // 2) to have an optimized im2col for the first convolution layer where usually we have 3 IFMs
122 for(; d <= (kernel_depth - 3); d += 3)
123 {
Alex Gilday7da29b62018-03-23 14:16:00 +0000124 for(int y = top_left_y; y < y_e; y += dilation_y)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100125 {
126 if((y < 0 || y >= input_h) && has_pads)
127 {
Isabella Gottardie6630e42018-01-18 15:50:39 +0000128 // All the values will be the offset (will be zeros when not quantized)
Alex Gilday7da29b62018-03-23 14:16:00 +0000129 for(int x = top_left_x; x < x_e; x += dilation_x, ++out_ptr)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100130 {
Isabella Gottardie6630e42018-01-18 15:50:39 +0000131 *(out_ptr + 0 * kernel_size2) = pad_value;
132 *(out_ptr + 1 * kernel_size2) = pad_value;
133 *(out_ptr + 2 * kernel_size2) = pad_value;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100134 }
135 }
136 else
137 {
Alex Gilday7da29b62018-03-23 14:16:00 +0000138 for(int x = top_left_x; x < x_e; x += dilation_x, ++out_ptr)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100139 {
140 if((x < 0 || x >= input_w) && has_pads)
141 {
Isabella Gottardie6630e42018-01-18 15:50:39 +0000142 *(out_ptr + 0 * kernel_size2) = pad_value;
143 *(out_ptr + 1 * kernel_size2) = pad_value;
144 *(out_ptr + 2 * kernel_size2) = pad_value;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100145 }
146 else
147 {
148 *(out_ptr + 0 * kernel_size2) = *(reinterpret_cast<const T *>(in_ptr + ((d + 0) * input_stride_z + y * input_stride_y + x * input_stride_x)));
149 *(out_ptr + 1 * kernel_size2) = *(reinterpret_cast<const T *>(in_ptr + ((d + 1) * input_stride_z + y * input_stride_y + x * input_stride_x)));
150 *(out_ptr + 2 * kernel_size2) = *(reinterpret_cast<const T *>(in_ptr + ((d + 2) * input_stride_z + y * input_stride_y + x * input_stride_x)));
151 }
152 }
153 }
154 }
155 out_ptr += 2 * kernel_size2;
156 }
157
158 // Left over
159 for(; d < kernel_depth; d++)
160 {
Alex Gilday7da29b62018-03-23 14:16:00 +0000161 for(int y = top_left_y; y < y_e; y += dilation_y)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100162 {
163 if((y < 0 || y >= input_h) && has_pads)
164 {
Isabella Gottardie6630e42018-01-18 15:50:39 +0000165 // All the values will be the offset (will be zeros when not quantized)
Georgios Pinitas8a14b2c2020-09-04 20:20:56 +0100166 memset(static_cast<void *>(out_ptr), pad_value, kernel_width * sizeof(T));
Gian Marco Iodice13edbff2017-06-26 17:20:16 +0100167 out_ptr += kernel_width;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100168 }
169 else
170 {
Alex Gilday7da29b62018-03-23 14:16:00 +0000171 for(int x = top_left_x; x < x_e; x += dilation_x, ++out_ptr)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100172 {
173 if((x < 0 || x >= input_w) && has_pads)
174 {
Isabella Gottardie6630e42018-01-18 15:50:39 +0000175 *out_ptr = pad_value;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100176 }
177 else
178 {
179 *out_ptr = *(reinterpret_cast<const T *>(in_ptr + (d * input_stride_z + y * input_stride_y + x * input_stride_x)));
180 }
181 }
182 }
183 }
184 }
185
186 // Append 1 if the convolution layer has biases
187 if(has_bias)
188 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100189 *out_ptr = static_cast<T>(1);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100190 }
191}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100192
193template <typename T, bool has_pads>
Giorgio Arenafb629082018-08-20 18:03:27 +0100194inline void linearize_volume_nhwc(const uint8_t *const in_ptr,
195 T *out_ptr,
196 bool has_bias,
197 int start_x,
198 int start_y,
199 int kernel_width,
200 int kernel_height,
201 int input_w,
202 int input_h,
203 int input_c,
204 int input_stride_y,
205 int input_stride_z,
206 int pad_value,
207 int dilation_x,
208 int dilation_y)
209{
Georgios Pinitas75bde5e2019-06-07 11:52:01 +0100210 const int end_x = start_x + kernel_width * dilation_x;
211 const int end_y = start_y + kernel_height * dilation_y;
212 const int pad_quant = kernel_width * input_c;
213 const int element_size = static_cast<int>(sizeof(T));
214 if((start_y >= 0) && (end_y < input_h) && (start_x >= 0) && (end_x < input_w) && (dilation_x == 1) && (input_stride_y == input_c * element_size))
Giorgio Arenafb629082018-08-20 18:03:27 +0100215 {
Vidhya Sudhan Loganathan642680a2019-04-02 09:40:08 +0100216 for(int y = start_y; y < end_y; y += dilation_y)
Giorgio Arenafb629082018-08-20 18:03:27 +0100217 {
Vidhya Sudhan Loganathan642680a2019-04-02 09:40:08 +0100218 //optimized for no dilation and no boundary pixels
Georgios Pinitas75bde5e2019-06-07 11:52:01 +0100219 memcpy(out_ptr, reinterpret_cast<const T *>(in_ptr + (y * input_stride_z + start_x * input_stride_y)), input_c * kernel_width * element_size);
Vidhya Sudhan Loganathan642680a2019-04-02 09:40:08 +0100220 out_ptr += input_c * kernel_width;
Giorgio Arenafb629082018-08-20 18:03:27 +0100221 }
Vidhya Sudhan Loganathan642680a2019-04-02 09:40:08 +0100222 }
223 else
224 {
225 for(int y = start_y; y < end_y; y += dilation_y)
Giorgio Arenafb629082018-08-20 18:03:27 +0100226 {
Vidhya Sudhan Loganathan642680a2019-04-02 09:40:08 +0100227 if(y < 0 || y >= input_h)
Giorgio Arenafb629082018-08-20 18:03:27 +0100228 {
Georgios Pinitas8a14b2c2020-09-04 20:20:56 +0100229 memset(static_cast<void *>(out_ptr), pad_value, pad_quant * element_size);
Vidhya Sudhan Loganathan642680a2019-04-02 09:40:08 +0100230 out_ptr += pad_quant;
231 }
Georgios Pinitas75bde5e2019-06-07 11:52:01 +0100232 else if(dilation_x > 1 || start_x < 0 || end_x >= input_w || input_stride_y != input_c * element_size)
Vidhya Sudhan Loganathan642680a2019-04-02 09:40:08 +0100233 {
234 for(int x = start_x; x < end_x; x += dilation_x)
Giorgio Arenafb629082018-08-20 18:03:27 +0100235 {
Vidhya Sudhan Loganathan642680a2019-04-02 09:40:08 +0100236 if(x < 0 || x >= input_w)
237 {
Georgios Pinitas8a14b2c2020-09-04 20:20:56 +0100238 memset(static_cast<void *>(out_ptr), pad_value, input_c * element_size);
Vidhya Sudhan Loganathan642680a2019-04-02 09:40:08 +0100239 out_ptr += input_c;
240 }
241 else
242 {
Georgios Pinitas75bde5e2019-06-07 11:52:01 +0100243 memcpy(out_ptr, reinterpret_cast<const T *>(in_ptr + (y * input_stride_z + x * input_stride_y)), input_c * element_size);
Vidhya Sudhan Loganathan642680a2019-04-02 09:40:08 +0100244 out_ptr += input_c;
245 }
Giorgio Arenafb629082018-08-20 18:03:27 +0100246 }
Vidhya Sudhan Loganathan642680a2019-04-02 09:40:08 +0100247 }
248 else
249 {
250 //optimized for no dilation and no boundary pixels
Georgios Pinitas75bde5e2019-06-07 11:52:01 +0100251 memcpy(out_ptr, reinterpret_cast<const T *>(in_ptr + (y * input_stride_z + start_x * input_stride_y)), input_c * kernel_width * element_size);
Vidhya Sudhan Loganathan642680a2019-04-02 09:40:08 +0100252 out_ptr += input_c * kernel_width;
Giorgio Arenafb629082018-08-20 18:03:27 +0100253 }
254 }
255 }
Giorgio Arenafb629082018-08-20 18:03:27 +0100256 // Append 1 if the convolution layer has biases
257 if(has_bias)
258 {
259 *out_ptr = static_cast<T>(1);
260 }
261}
262} // namespace
263
264template <typename T, bool has_pads, bool is_nchw>
Giorgio Arena368e6352018-08-20 15:06:07 +0100265void NEIm2ColKernel::run_im2col(const Window &window)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100266{
267 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
268 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
269
Georgios Pinitas329b4d62020-01-15 17:48:20 +0000270 const unsigned int width_idx = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH);
271 const unsigned int height_idx = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT);
272 const unsigned int channel_idx = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::CHANNEL);
Giorgio Arena156fcf32018-03-09 15:30:43 +0000273
Giorgio Arena156fcf32018-03-09 15:30:43 +0000274 const int input_w = _input->info()->dimension(width_idx);
275 const int input_h = _input->info()->dimension(height_idx);
Giorgio Arenafb629082018-08-20 18:03:27 +0100276 const int input_c = _input->info()->dimension(channel_idx);
277 const int input_stride_x = _input->info()->strides_in_bytes().x();
278 const int input_stride_y = _input->info()->strides_in_bytes().y();
279 const int input_stride_z = _input->info()->strides_in_bytes().z();
280 const int pad_left = _conv_info.pad_left();
281 const int pad_top = _conv_info.pad_top();
282 const int stride_x = _conv_info.stride().first;
283 const int stride_y = _conv_info.stride().second;
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100284 const int pad_value = is_data_type_quantized(_input->info()->data_type()) ? _input->info()->quantization_info().uniform().offset : 0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100285
Giorgio Arenaf485a102018-04-20 16:06:21 +0100286 Window window_in_out(window);
287 // The first three dimensions of the input and output are increased by the inner loops
288 window_in_out.set(Window::DimX, Window::Dimension(0, 0, 0));
289 window_in_out.set(Window::DimY, Window::Dimension(0, 0, 0));
290 window_in_out.set(Window::DimZ, Window::Dimension(0, 0, 0));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100291
292 // Create iterators
Giorgio Arenaf485a102018-04-20 16:06:21 +0100293 Iterator in(_input, window_in_out);
294 Iterator out(_output, window_in_out);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100295
296 execute_window_loop(window, [&](const Coordinates & id)
297 {
Giorgio Arenafb629082018-08-20 18:03:27 +0100298 const int start_w = id[width_idx] * stride_x - pad_left;
299 const int start_h = id[height_idx] * stride_y - pad_top;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100300
301 // Get pointers
302 const uint8_t *const input_ptr = in.ptr();
Giorgio Arenaf485a102018-04-20 16:06:21 +0100303 auto output_ptr = reinterpret_cast<T *>(out.ptr() + (id[width_idx] + id[height_idx] * _convolved_dims.first) * _output->info()->strides_in_bytes().y());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100304
305 // Linearize volume
Giorgio Arenafb629082018-08-20 18:03:27 +0100306 if(is_nchw)
307 {
308 linearize_volume_nchw<T, has_pads>(input_ptr,
309 output_ptr,
310 _has_bias,
311 start_w,
312 start_h,
313 _kernel_width,
314 _kernel_height,
315 input_c,
316 input_w,
317 input_h,
318 input_stride_x,
319 input_stride_y,
320 input_stride_z,
321 pad_value,
322 _dilation.x(),
323 _dilation.y());
324 }
325 else
326 {
327 linearize_volume_nhwc<T, has_pads>(input_ptr,
328 output_ptr,
329 _has_bias,
330 start_w,
331 start_h,
332 _kernel_width,
333 _kernel_height,
334 input_w,
335 input_h,
336 input_c,
337 input_stride_y,
338 input_stride_z,
339 pad_value,
340 _dilation.x(),
341 _dilation.y());
342 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100343 },
344 in, out);
345}
346
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100347NEIm2ColKernel::NEIm2ColKernel()
Georgios Pinitas329b4d62020-01-15 17:48:20 +0000348 : _func(), _input(nullptr), _output(nullptr), _convolved_dims(), _conv_info(), _kernel_width(0), _kernel_height(0), _has_bias(false), _dilation(1U, 1U), _data_layout(DataLayout::UNKNOWN)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100349{
350}
351
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +0000352void NEIm2ColKernel::configure(const ITensor *input, ITensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info,
Giorgio Arena368e6352018-08-20 15:06:07 +0100353 bool has_bias, const Size2D &dilation, unsigned int num_groups)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100354{
Georgios Pinitasd912fd82017-11-27 21:00:13 +0000355 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Giorgio Arena368e6352018-08-20 15:06:07 +0100356 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), kernel_dims, conv_info, has_bias, dilation, num_groups));
Giorgio Arena0f170392018-07-18 16:13:12 +0100357 ARM_COMPUTE_UNUSED(num_groups);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100358
Georgios Pinitas329b4d62020-01-15 17:48:20 +0000359 _data_layout = input->info()->data_layout();
360 const unsigned int width_idx = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH);
361 const unsigned int height_idx = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT);
Giorgio Arena156fcf32018-03-09 15:30:43 +0000362
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100363 _input = input;
364 _output = output;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100365 _conv_info = conv_info;
Gian Marco Iodice13edbff2017-06-26 17:20:16 +0100366 _kernel_width = kernel_dims.width;
Alex Gilday7da29b62018-03-23 14:16:00 +0000367 _kernel_height = kernel_dims.height;
368 _dilation = dilation;
Giorgio Arena156fcf32018-03-09 15:30:43 +0000369 _convolved_dims = scaled_dimensions(input->info()->dimension(width_idx), input->info()->dimension(height_idx),
Gian Marco Iodice13edbff2017-06-26 17:20:16 +0100370 _kernel_width, _kernel_height,
Alex Gilday7da29b62018-03-23 14:16:00 +0000371 _conv_info, _dilation);
Gian Marco Iodice13edbff2017-06-26 17:20:16 +0100372 _has_bias = has_bias;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100373
Georgios Pinitas329b4d62020-01-15 17:48:20 +0000374 if(_data_layout == DataLayout::NCHW)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100375 {
Giorgio Arenafb629082018-08-20 18:03:27 +0100376 switch(_input->info()->data_type())
377 {
378 case DataType::F32:
379 _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_im2col<float, false, true> : &NEIm2ColKernel::run_im2col<float, true, true>;
380 break;
Georgios Pinitasc7b183a2020-03-06 18:12:09 +0000381#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
382 case DataType::BFLOAT16:
383 _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_im2col<bfloat16, false, true> : &NEIm2ColKernel::run_im2col<bfloat16, true, true>;
384 break;
385#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000386#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arenafb629082018-08-20 18:03:27 +0100387 case DataType::F16:
388 _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_im2col<float16_t, false, true> : &NEIm2ColKernel::run_im2col<float16_t, true, true>;
389 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000390#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Georgios Pinitas6e1791b2019-12-02 19:01:25 +0000391 case DataType::QASYMM8_SIGNED:
Giorgio Arenafb629082018-08-20 18:03:27 +0100392 case DataType::QASYMM8:
393 _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_im2col<qasymm8_t, false, true> : &NEIm2ColKernel::run_im2col<qasymm8_t, true, true>;
394 break;
395 default:
396 ARM_COMPUTE_ERROR("Data type not supported");
397 break;
398 }
399 }
400 else
401 {
402 switch(_input->info()->data_type())
403 {
404 case DataType::F32:
405 _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_im2col<float, false, false> : &NEIm2ColKernel::run_im2col<float, true, false>;
406 break;
Georgios Pinitasc7b183a2020-03-06 18:12:09 +0000407#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
408 case DataType::BFLOAT16:
409 _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_im2col<bfloat16, false, false> : &NEIm2ColKernel::run_im2col<bfloat16, true, false>;
410 break;
411#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
Giorgio Arenafb629082018-08-20 18:03:27 +0100412#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
413 case DataType::F16:
414 _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_im2col<float16_t, false, false> : &NEIm2ColKernel::run_im2col<float16_t, true, false>;
415 break;
416#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
417 case DataType::QASYMM8:
Georgios Pinitas6e1791b2019-12-02 19:01:25 +0000418 _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_im2col<uint8_t, false, false> : &NEIm2ColKernel::run_im2col<qasymm8_t, true, false>;
419 break;
420 case DataType::QASYMM8_SIGNED:
421 _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_im2col<int8_t, false, false> : &NEIm2ColKernel::run_im2col<qasymm8_t, true, false>;
Giorgio Arenafb629082018-08-20 18:03:27 +0100422 break;
423 default:
424 ARM_COMPUTE_ERROR("Data type not supported");
425 break;
426 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100427 }
428
Giorgio Arena368e6352018-08-20 15:06:07 +0100429 // Configure kernel window
430 auto win_config = validate_and_configure_window(input->info(), output->info(), kernel_dims, conv_info, has_bias, dilation);
431 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
432 INEKernel::configure(win_config.second);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100433}
434
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +0000435Status NEIm2ColKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info,
Giorgio Arena368e6352018-08-20 15:06:07 +0100436 bool has_bias, const Size2D &dilation, unsigned int num_groups)
Georgios Pinitasd912fd82017-11-27 21:00:13 +0000437{
Giorgio Arena368e6352018-08-20 15:06:07 +0100438 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, kernel_dims, conv_info, has_bias, dilation, num_groups));
439 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get(), kernel_dims, conv_info, has_bias, dilation).first);
Georgios Pinitas631c41a2017-12-06 11:53:03 +0000440 return Status{};
Georgios Pinitasd912fd82017-11-27 21:00:13 +0000441}
442
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100443void NEIm2ColKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100444{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100445 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100446 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
447 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
448
449 (this->*_func)(window);
450}