blob: 226d6e0c9ca488c96dbb98e44f810036755e11ca [file] [log] [blame]
Georgios Pinitasd9769582017-08-03 10:19:40 +01001/*
Georgios Pinitas6cb26ce2020-06-24 17:20:23 +01002 * Copyright (c) 2017-2020 ARM Limited.
Georgios Pinitasd9769582017-08-03 10:19:40 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Giorgio Arena04a8f8c2017-11-23 11:45:24 +000024#include "arm_compute/core/NEON/kernels/NEL2NormalizeLayerKernel.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010025
26#include "arm_compute/core/Error.h"
27#include "arm_compute/core/Helpers.h"
28#include "arm_compute/core/ITensor.h"
29#include "arm_compute/core/NEON/NEMath.h"
30#include "arm_compute/core/TensorInfo.h"
31#include "arm_compute/core/Utils.h"
32#include "arm_compute/core/Validate.h"
33#include "arm_compute/core/Window.h"
34
Michalis Spyrou2897e612018-11-20 18:38:29 +000035#include "arm_compute/core/NEON/wrapper/wrapper.h"
Georgios Pinitasd9769582017-08-03 10:19:40 +010036#include <arm_neon.h>
37#include <cmath>
38
Michalis Spyrou2897e612018-11-20 18:38:29 +000039namespace arm_compute
40{
Georgios Pinitasd9769582017-08-03 10:19:40 +010041namespace
42{
Manuel Bottini4b5c5882019-05-14 10:38:30 +010043constexpr int max_input_tensor_dim = 3;
44
Michalis Spyrou2897e612018-11-20 18:38:29 +000045template <typename T, int S>
Georgios Pinitasd9769582017-08-03 10:19:40 +010046void l2_normalize_X(const ITensor *in, const ITensor *sum, ITensor *out, float epsilon, const Window &window)
47{
Michalis Spyrou2897e612018-11-20 18:38:29 +000048 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
49
Georgios Pinitas6cb26ce2020-06-24 17:20:23 +010050 const int window_step_x = 16 / data_size_from_type(in->info()->data_type());
51 const auto window_start_x = static_cast<int>(window.x().start());
52 const auto window_end_x = static_cast<int>(window.x().end());
Georgios Pinitasd9769582017-08-03 10:19:40 +010053
Georgios Pinitas6cb26ce2020-06-24 17:20:23 +010054 Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
55 win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
Georgios Pinitasd9769582017-08-03 10:19:40 +010056
Georgios Pinitas6cb26ce2020-06-24 17:20:23 +010057 Iterator input_it(in, win_collapsed);
58 Iterator sum_it(sum, win_collapsed);
59 Iterator output_it(out, win_collapsed);
60
61 execute_window_loop(win_collapsed, [&](const Coordinates &)
Georgios Pinitasd9769582017-08-03 10:19:40 +010062 {
Georgios Pinitas6cb26ce2020-06-24 17:20:23 +010063 const auto in_ptr = reinterpret_cast<const T *>(input_it.ptr());
64 const auto out_ptr = reinterpret_cast<T *>(output_it.ptr());
Georgios Pinitasd9769582017-08-03 10:19:40 +010065
Georgios Pinitas6cb26ce2020-06-24 17:20:23 +010066 const T sum_value = *reinterpret_cast<const T *>(sum_it.ptr());
67 const T norm_value = static_cast<T>(1.f) / std::sqrt(std::max(sum_value, static_cast<T>(epsilon)));
68 const auto vec_norm_value = wrapper::vdup_n(norm_value, ExactTagType{});
Georgios Pinitasd9769582017-08-03 10:19:40 +010069
Georgios Pinitas6cb26ce2020-06-24 17:20:23 +010070 // Compute elements over vector steps
71 int x = window_start_x;
72 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Georgios Pinitasd9769582017-08-03 10:19:40 +010073 {
Georgios Pinitas6cb26ce2020-06-24 17:20:23 +010074 wrapper::vstore(out_ptr + x, wrapper::vmul(wrapper::vloadq(in_ptr + x), vec_norm_value));
75 }
Georgios Pinitasd9769582017-08-03 10:19:40 +010076
Georgios Pinitas6cb26ce2020-06-24 17:20:23 +010077 // Compute left-over elements
78 for(; x < window_end_x; ++x)
79 {
80 out_ptr[x] = in_ptr[x] * norm_value;
81 }
82 },
83 input_it, sum_it, output_it);
Georgios Pinitasd9769582017-08-03 10:19:40 +010084}
John Richardson73d4aef2018-05-08 14:34:33 +010085
Michalis Spyrou2897e612018-11-20 18:38:29 +000086template <typename T, int S>
Georgios Pinitas6cb26ce2020-06-24 17:20:23 +010087void l2_normalize_YZ(const ITensor *in, const ITensor *sum, ITensor *out, float epsilon, const Window &window, size_t axis)
Michalis Spyrou2897e612018-11-20 18:38:29 +000088{
Michalis Spyrou2897e612018-11-20 18:38:29 +000089 using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
90
Georgios Pinitas6cb26ce2020-06-24 17:20:23 +010091 const int window_step_x = 16 / data_size_from_type(in->info()->data_type());
92 const auto window_start_x = static_cast<int>(window.x().start());
93 const auto window_end_x = static_cast<int>(window.x().end());
Michalis Spyrou2897e612018-11-20 18:38:29 +000094
Georgios Pinitas6cb26ce2020-06-24 17:20:23 +010095 Window win = window;
96 win.set(Window::DimX, Window::Dimension(0, 1, 1));
Michalis Spyrou2897e612018-11-20 18:38:29 +000097
Georgios Pinitas6cb26ce2020-06-24 17:20:23 +010098 Window window_sum(win);
99 window_sum.set(axis, Window::Dimension(0, 0, 0));
100
101 Iterator input_it(in, win);
102 Iterator sum_it(sum, window_sum);
103 Iterator output_it(out, win);
104
105 const auto vec_eps = wrapper::vdup_n(static_cast<T>(epsilon), ExactTagType{});
106
107 execute_window_loop(win, [&](const Coordinates &)
Michalis Spyrou2897e612018-11-20 18:38:29 +0000108 {
Georgios Pinitas6cb26ce2020-06-24 17:20:23 +0100109 const auto in_ptr = reinterpret_cast<const T *>(input_it.ptr());
110 const auto sum_ptr = reinterpret_cast<const T *>(sum_it.ptr());
111 const auto out_ptr = reinterpret_cast<T *>(output_it.ptr());
Michalis Spyrou2897e612018-11-20 18:38:29 +0000112
Georgios Pinitas6cb26ce2020-06-24 17:20:23 +0100113 // Compute elements over vector steps
114 int x = window_start_x;
115 for(; x <= (window_end_x - window_step_x); x += window_step_x)
Michalis Spyrou2897e612018-11-20 18:38:29 +0000116 {
Georgios Pinitas6cb26ce2020-06-24 17:20:23 +0100117 const auto vec_norm_value = wrapper::vinvsqrt(wrapper::vmax(wrapper::vloadq(sum_ptr + x), vec_eps));
118 wrapper::vstore(out_ptr + x, wrapper::vmul(wrapper::vloadq(in_ptr + x), vec_norm_value));
119 }
Michalis Spyrou2897e612018-11-20 18:38:29 +0000120
Georgios Pinitas6cb26ce2020-06-24 17:20:23 +0100121 // Compute left-over elements
122 for(; x < window_end_x; ++x)
Michalis Spyrou2897e612018-11-20 18:38:29 +0000123 {
Georgios Pinitas6cb26ce2020-06-24 17:20:23 +0100124 const T norm_value = static_cast<T>(1.f) / std::sqrt(std::max(sum_ptr[x], static_cast<T>(epsilon)));
125 out_ptr[x] = in_ptr[x] * norm_value;
126 }
127 },
128 input_it, sum_it, output_it);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000129}
130
Manuel Bottini4b5c5882019-05-14 10:38:30 +0100131Status validate_arguments(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output, int axis, float epsilon)
John Richardson73d4aef2018-05-08 14:34:33 +0100132{
133 ARM_COMPUTE_UNUSED(epsilon);
134
Manuel Bottini4b5c5882019-05-14 10:38:30 +0100135 const uint32_t actual_axis = wrap_around(axis, max_input_tensor_dim);
John Richardson73d4aef2018-05-08 14:34:33 +0100136 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, sum, output);
137 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, sum);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000138 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
Manuel Bottini4b5c5882019-05-14 10:38:30 +0100139 ARM_COMPUTE_RETURN_ERROR_ON_MSG(actual_axis > 2, "Actual axis greater than 2 is not supported");
140 ARM_COMPUTE_RETURN_ERROR_ON_MSG(actual_axis >= TensorShape::num_max_dimensions, "Actual normalization axis greater than max number of dimensions");
John Richardson73d4aef2018-05-08 14:34:33 +0100141
142 // Reduce shape on axis
143 TensorShape sum_shape = input->tensor_shape();
Manuel Bottini4b5c5882019-05-14 10:38:30 +0100144 sum_shape.set(actual_axis, 1);
John Richardson73d4aef2018-05-08 14:34:33 +0100145 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(sum->tensor_shape(), sum_shape);
146
147 if(output->total_size() != 0)
148 {
149 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
150 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
151 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(input->tensor_shape(), output->tensor_shape());
Michalis Spyrou2897e612018-11-20 18:38:29 +0000152 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
John Richardson73d4aef2018-05-08 14:34:33 +0100153 }
154
155 return Status{};
156}
157
Georgios Pinitas6cb26ce2020-06-24 17:20:23 +0100158std::tuple<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output)
John Richardson73d4aef2018-05-08 14:34:33 +0100159{
Georgios Pinitas6cb26ce2020-06-24 17:20:23 +0100160 Window win = calculate_max_window(*input, Steps());
John Richardson73d4aef2018-05-08 14:34:33 +0100161
162 // Output auto initialization if not yet initialized
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100163 auto_init_if_empty(*output, input->tensor_shape(), 1, input->data_type());
John Richardson73d4aef2018-05-08 14:34:33 +0100164
Georgios Pinitas6cb26ce2020-06-24 17:20:23 +0100165 // NEL2NormalizeLayerKernel doesn't need padding so update_window_and_padding() can be skipped
166 Coordinates coord;
167 coord.set_num_dimensions(output->num_dimensions());
168 output->set_valid_region(ValidRegion(coord, output->tensor_shape()));
John Richardson73d4aef2018-05-08 14:34:33 +0100169
Georgios Pinitas6cb26ce2020-06-24 17:20:23 +0100170 return std::make_tuple(Status{}, win);
John Richardson73d4aef2018-05-08 14:34:33 +0100171}
Georgios Pinitasd9769582017-08-03 10:19:40 +0100172} // namespace
173
Giorgio Arena04a8f8c2017-11-23 11:45:24 +0000174NEL2NormalizeLayerKernel::NEL2NormalizeLayerKernel()
Manuel Bottini4b5c5882019-05-14 10:38:30 +0100175 : _input(nullptr), _sum(nullptr), _output(nullptr), _actual_axis(0), _epsilon(1e-12)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100176{
177}
178
Manuel Bottini4b5c5882019-05-14 10:38:30 +0100179void NEL2NormalizeLayerKernel::configure(const ITensor *input, const ITensor *sum, ITensor *output, int axis, float epsilon)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100180{
181 ARM_COMPUTE_ERROR_ON_NULLPTR(input, sum, output);
John Richardson73d4aef2018-05-08 14:34:33 +0100182 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), sum->info(), output->info(), axis, epsilon));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100183
Georgios Pinitas6cb26ce2020-06-24 17:20:23 +0100184 _input = input;
185 _sum = sum;
186 _output = output;
187 _actual_axis = wrap_around(axis, max_input_tensor_dim);
188 _epsilon = epsilon;
Georgios Pinitasd9769582017-08-03 10:19:40 +0100189
190 // Configure kernel window
Georgios Pinitas6cb26ce2020-06-24 17:20:23 +0100191 auto win_config = validate_and_configure_window(_input->info(), _output->info());
John Richardson73d4aef2018-05-08 14:34:33 +0100192 ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100193
John Richardson73d4aef2018-05-08 14:34:33 +0100194 INEKernel::configure(std::get<1>(win_config));
195}
Georgios Pinitasd9769582017-08-03 10:19:40 +0100196
Manuel Bottini4b5c5882019-05-14 10:38:30 +0100197Status NEL2NormalizeLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output, int axis, float epsilon)
John Richardson73d4aef2018-05-08 14:34:33 +0100198{
199 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, sum, output, axis, epsilon));
Georgios Pinitas6cb26ce2020-06-24 17:20:23 +0100200 ARM_COMPUTE_RETURN_ON_ERROR(std::get<0>(validate_and_configure_window(input->clone().get(), output->clone().get())));
Georgios Pinitasd9769582017-08-03 10:19:40 +0100201
John Richardson73d4aef2018-05-08 14:34:33 +0100202 return Status{};
Georgios Pinitasd9769582017-08-03 10:19:40 +0100203}
204
Giorgio Arena04a8f8c2017-11-23 11:45:24 +0000205void NEL2NormalizeLayerKernel::run(const Window &window, const ThreadInfo &info)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100206{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100207 ARM_COMPUTE_UNUSED(info);
Georgios Pinitasd9769582017-08-03 10:19:40 +0100208 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
209 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
210
Georgios Pinitas6cb26ce2020-06-24 17:20:23 +0100211 if(_actual_axis > 2)
Georgios Pinitasd9769582017-08-03 10:19:40 +0100212 {
Georgios Pinitas6cb26ce2020-06-24 17:20:23 +0100213 ARM_COMPUTE_ERROR("Unsupported normalization axis");
214 }
215
216 switch(_input->info()->data_type())
217 {
218 case DataType::F32:
219 (_actual_axis == Window::DimX) ? l2_normalize_X<float, 4>(_input, _sum, _output, _epsilon, window) : l2_normalize_YZ<float, 4>(_input, _sum, _output, _epsilon, window, _actual_axis);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000220 break;
Michalis Spyrou2897e612018-11-20 18:38:29 +0000221#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Georgios Pinitas6cb26ce2020-06-24 17:20:23 +0100222 case DataType::F16:
223 (_actual_axis == Window::DimX) ? l2_normalize_X<float16_t, 8>(_input, _sum, _output, _epsilon, window) : l2_normalize_YZ<float16_t, 8>(_input, _sum, _output, _epsilon, window, _actual_axis);
Michalis Spyrou2897e612018-11-20 18:38:29 +0000224 break;
Michalis Spyrou2897e612018-11-20 18:38:29 +0000225#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Georgios Pinitasd9769582017-08-03 10:19:40 +0100226 default:
Georgios Pinitas6cb26ce2020-06-24 17:20:23 +0100227 ARM_COMPUTE_ERROR("Not implemented");
Georgios Pinitasd9769582017-08-03 10:19:40 +0100228 }
229}
Michalis Spyrou2897e612018-11-20 18:38:29 +0000230} // namespace arm_compute