blob: 2c9ad923aae882c9c1125e217ca02d3c4f182ee3 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Isabella Gottardie6630e42018-01-18 15:50:39 +00002 * Copyright (c) 2017-2018 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEWeightsReshapeKernel.h"
25
26#include "arm_compute/core/Dimensions.h"
27#include "arm_compute/core/Error.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/ITensor.h"
30#include "arm_compute/core/Types.h"
31#include "arm_compute/core/Validate.h"
32
33using namespace arm_compute;
34
35namespace
36{
Michalis Spyroue2503892018-04-23 15:17:31 +010037template <typename T, bool is_nhwc>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010038void weights_reshape(const ITensor *input, const ITensor *bias, ITensor *output, const Window &window)
39{
Michalis Spyroue2503892018-04-23 15:17:31 +010040 DataLayout data_layout = input->info()->data_layout();
41 const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
42 const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
43 const int idx_channel = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
44 const unsigned int kernel_size_x = input->info()->dimension(idx_width);
45 const unsigned int kernel_size_y = input->info()->dimension(idx_height);
46 const unsigned int kernel_depth = input->info()->dimension(idx_channel);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010047 const unsigned int input_stride_x = input->info()->strides_in_bytes().x();
48 const unsigned int input_stride_y = input->info()->strides_in_bytes().y();
49 const unsigned int input_stride_z = input->info()->strides_in_bytes().z();
50 const unsigned int output_stride_y = output->info()->strides_in_bytes().y();
51
52 // Create iterators
53 Iterator in(input, window);
54 execute_window_loop(window, [&](const Coordinates & id)
55 {
56 // Get column index
57 const int kernel_idx = id[3];
58 const int kernel_idz = id[4];
59
60 // Setup pointers
61 const uint8_t *tmp_input_ptr = in.ptr();
62 uint8_t *tmp_output_ptr = output->ptr_to_element(Coordinates(kernel_idx, 0, kernel_idz));
63 const uint8_t *curr_input_row_ptr = tmp_input_ptr;
64 const uint8_t *curr_input_depth_ptr = tmp_input_ptr;
65
66 // Linearize volume
67 for(unsigned int d = 0; d < kernel_depth; ++d)
68 {
Gian Marco Iodice7b06cde2017-06-21 08:54:02 +010069 for(unsigned int j = 0; j < kernel_size_y; ++j)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010070 {
Gian Marco Iodice7b06cde2017-06-21 08:54:02 +010071 for(unsigned int i = 0; i < kernel_size_x; ++i)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010072 {
73 *(reinterpret_cast<T *>(tmp_output_ptr)) = *(reinterpret_cast<const T *>(tmp_input_ptr));
Michalis Spyroue2503892018-04-23 15:17:31 +010074 tmp_input_ptr += is_nhwc ? input_stride_y : input_stride_x;
Anthony Barbier6ff3b192017-09-04 18:44:23 +010075 tmp_output_ptr += output_stride_y;
76 }
Michalis Spyroue2503892018-04-23 15:17:31 +010077 curr_input_row_ptr += is_nhwc ? input_stride_z : input_stride_y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +010078 tmp_input_ptr = curr_input_row_ptr;
79 }
Michalis Spyroue2503892018-04-23 15:17:31 +010080 curr_input_depth_ptr += is_nhwc ? input_stride_x : input_stride_z;
Anthony Barbier6ff3b192017-09-04 18:44:23 +010081 curr_input_row_ptr = curr_input_depth_ptr;
82 tmp_input_ptr = curr_input_depth_ptr;
83 }
84
85 // Add bias
86 if(bias != nullptr)
87 {
88 *(reinterpret_cast<T *>(tmp_output_ptr)) = *(reinterpret_cast<const T *>(bias->ptr_to_element(Coordinates(kernel_idx, kernel_idz))));
89 }
90 },
91 in);
92}
Giorgio Arena7c23ad02017-11-30 15:08:38 +000093
94TensorShape get_output_shape(const ITensorInfo *input, bool has_bias)
95{
96 TensorShape output_shape{ input->tensor_shape() };
97
98 output_shape.collapse(3);
99 const size_t tmp_dim = output_shape[0];
100 output_shape.set(0, output_shape[1]);
101 output_shape.set(1, tmp_dim + (has_bias ? 1 : 0));
102
103 return output_shape;
104}
105
106Status validate_arguments(const ITensorInfo *input, const ITensorInfo *biases, const ITensorInfo *output)
107{
Anthony Barbiereaefd002018-07-20 17:49:35 +0100108 //Note: ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input) is not needed here as this kernel doesn't use NEON FP16 instructions.
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100109 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
Giorgio Arena7c23ad02017-11-30 15:08:38 +0000110 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output);
111
112 if(biases != nullptr)
113 {
Isabella Gottardie6630e42018-01-18 15:50:39 +0000114 ARM_COMPUTE_RETURN_ERROR_ON(is_data_type_quantized_asymmetric(input->data_type()));
Giorgio Arena7c23ad02017-11-30 15:08:38 +0000115 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, biases);
Giorgio Arena7c23ad02017-11-30 15:08:38 +0000116 ARM_COMPUTE_RETURN_ERROR_ON((input->num_dimensions() == 4) && (biases->num_dimensions() != 1));
117 ARM_COMPUTE_RETURN_ERROR_ON((input->num_dimensions() == 5) && (biases->num_dimensions() != 2));
118 ARM_COMPUTE_RETURN_ERROR_ON((input->num_dimensions() == 4) && (biases->dimension(0) != input->tensor_shape()[3]));
119 ARM_COMPUTE_RETURN_ERROR_ON((input->num_dimensions() == 5) && (biases->dimension(0) != input->tensor_shape()[3] || biases->dimension(1) != input->tensor_shape()[4]));
120 }
121
122 // Checks performed when output is configured
123 if(output->total_size() != 0)
124 {
125 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), get_output_shape(input, biases != nullptr));
126 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
Giorgio Arena7c23ad02017-11-30 15:08:38 +0000127 }
128
129 return Status{};
130}
131
132std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output)
133{
134 Window window = calculate_max_window(*input, Steps());
135 window.set(Window::DimX, Window::Dimension(0, input->dimension(0), input->dimension(0)));
136 window.set(Window::DimY, Window::Dimension(0, input->dimension(1), input->dimension(1)));
137 window.set(Window::DimZ, Window::Dimension(0, input->dimension(2), input->dimension(2)));
138
139 // The NEConvolutionLayerWeightsReshapeKernel doesn't need padding so update_window_and_padding() can be skipped
140 output->set_valid_region(ValidRegion(Coordinates(), output->tensor_shape()));
141
142 return std::make_pair(Status{}, window);
143}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100144} // namespace
145
146NEWeightsReshapeKernel::NEWeightsReshapeKernel()
147 : _func(nullptr), _input(nullptr), _bias(nullptr), _output(nullptr)
148{
149}
150
151void NEWeightsReshapeKernel::configure(const ITensor *input, const ITensor *bias, ITensor *output)
152{
Giorgio Arena7c23ad02017-11-30 15:08:38 +0000153 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100154
Gian Marco Iodice5cb4c422017-06-23 10:38:25 +0100155 // Output tensor auto inizialitation if not yet initialized
Giorgio Arena7c23ad02017-11-30 15:08:38 +0000156 auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(get_output_shape(input->info(), (bias != nullptr))));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100157
Giorgio Arena7c23ad02017-11-30 15:08:38 +0000158 // Perform validation step
159 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(),
160 (bias != nullptr) ? bias->info() : nullptr,
161 output->info()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100162
163 _input = input;
164 _bias = bias;
165 _output = output;
166
Michalis Spyroue2503892018-04-23 15:17:31 +0100167 const DataLayout data_layout = input->info()->data_layout();
168 const bool is_nhwc = data_layout == DataLayout::NHWC;
169
Gian Marco Iodice2bbd9642017-07-04 16:46:32 +0100170 switch(_input->info()->element_size())
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100171 {
Gian Marco Iodice2bbd9642017-07-04 16:46:32 +0100172 case 4:
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100173 {
Michalis Spyroue2503892018-04-23 15:17:31 +0100174 _func = is_nhwc ? &weights_reshape<uint32_t, true> : &weights_reshape<uint32_t, false>;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100175 break;
176 }
Gian Marco Iodice2bbd9642017-07-04 16:46:32 +0100177 case 2:
Pablo Tello659abc02017-06-22 16:00:16 +0100178 {
Michalis Spyroue2503892018-04-23 15:17:31 +0100179 _func = is_nhwc ? &weights_reshape<uint16_t, true> : &weights_reshape<uint16_t, false>;
Pablo Tello659abc02017-06-22 16:00:16 +0100180 break;
181 }
Gian Marco Iodice2bbd9642017-07-04 16:46:32 +0100182 case 1:
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100183 {
Michalis Spyroue2503892018-04-23 15:17:31 +0100184 _func = is_nhwc ? &weights_reshape<uint8_t, true> : &weights_reshape<uint8_t, false>;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100185 break;
186 }
187 default:
188 {
Gian Marco Iodice2bbd9642017-07-04 16:46:32 +0100189 ARM_COMPUTE_ERROR_ON("Element size not supported");
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100190 break;
191 }
192 }
193
194 // Configure kernel
Giorgio Arena7c23ad02017-11-30 15:08:38 +0000195 auto win_config = validate_and_configure_window(input->info(), output->info());
196 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
197 INEKernel::configure(win_config.second);
198}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100199
Giorgio Arena7c23ad02017-11-30 15:08:38 +0000200Status NEWeightsReshapeKernel::validate(const ITensorInfo *input, const ITensorInfo *biases, const ITensorInfo *output)
201{
202 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, biases, output));
203 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get()).first);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100204
Giorgio Arena7c23ad02017-11-30 15:08:38 +0000205 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100206}
207
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100208void NEWeightsReshapeKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100209{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100210 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100211 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
212 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
213
214 (*_func)(_input, _bias, _output, window);
215}