blob: c801b097b5cc0f52d0864e8373f9066e5a6b3fcd [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Giorgio Arenac9fe9fc2021-10-06 12:54:29 +01002 * Copyright (c) 2016-2021 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/Helpers.h"
25
SiCong Li96209c72020-08-21 12:28:30 +010026namespace arm_compute
27{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010028ValidRegion calculate_valid_region_scale(const ITensorInfo &src_info,
29 const TensorShape &dst_shape,
30 InterpolationPolicy interpolate_policy,
31 SamplingPolicy sampling_policy,
32 bool border_undefined)
Diego Lopez Recas00854292018-02-22 13:08:01 +000033{
Georgios Pinitas393fa4c2018-05-08 15:54:53 +010034 const DataLayout data_layout = src_info.data_layout();
35 const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
36 const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
37
38 const float scale_x = static_cast<float>(dst_shape[idx_width]) / src_info.tensor_shape()[idx_width];
39 const float scale_y = static_cast<float>(dst_shape[idx_height]) / src_info.tensor_shape()[idx_height];
Diego Lopez Recas00854292018-02-22 13:08:01 +000040 const float sampling_point = (sampling_policy == SamplingPolicy::CENTER) ? 0.5f : 0.0f;
41
42 // Get input's valid region start and end points
Georgios Pinitas393fa4c2018-05-08 15:54:53 +010043 const int valid_start_in_x = src_info.valid_region().anchor[idx_width];
44 const int valid_start_in_y = src_info.valid_region().anchor[idx_height];
45 const int valid_end_in_x = src_info.valid_region().anchor[idx_width] + src_info.valid_region().shape[idx_width];
46 const int valid_end_in_y = src_info.valid_region().anchor[idx_height] + src_info.valid_region().shape[idx_height];
Diego Lopez Recas00854292018-02-22 13:08:01 +000047
48 // Initialize output's valid region start and end points
49 auto valid_start_out_x = static_cast<int>(valid_start_in_x * scale_x);
50 auto valid_start_out_y = static_cast<int>(valid_start_in_y * scale_y);
Georgios Pinitas393fa4c2018-05-08 15:54:53 +010051 auto valid_end_out_x = std::min<int>(std::ceil(valid_end_in_x * scale_x), dst_shape[idx_width]);
52 auto valid_end_out_y = std::min<int>(std::ceil(valid_end_in_y * scale_y), dst_shape[idx_height]);
Diego Lopez Recas00854292018-02-22 13:08:01 +000053
54 // Handle valid points in case of the bi-linear interpolation
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010055 if (border_undefined)
Diego Lopez Recas00854292018-02-22 13:08:01 +000056 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010057 switch (interpolate_policy)
Diego Lopez Recas00854292018-02-22 13:08:01 +000058 {
59 case InterpolationPolicy::NEAREST_NEIGHBOR:
60 {
61 // (start_out + sampling_point) >= (start_in * scale)
62 // start_out = ceil((start_in * scale) - sampling_point)
63 valid_start_out_x = std::ceil(valid_start_in_x * scale_x - sampling_point);
64 valid_start_out_y = std::ceil(valid_start_in_y * scale_y - sampling_point);
65
66 // (end_out - 1 + sampling_point) < (end_in * scale)
67 // end_out = ceil((end_in * scale) - sampling_point); // <-- ceil(x - 1) strictly less
68 valid_end_out_x = std::ceil(valid_end_in_x * scale_x - sampling_point);
69 valid_end_out_y = std::ceil(valid_end_in_y * scale_y - sampling_point);
70 break;
71 }
72 case InterpolationPolicy::BILINEAR:
73 {
74 // (start_out + sampling_point) >= ((start_in + sampling_point) * scale)
75 // start_out = ceil(((start_in + sampling_point) * scale) - sampling_point)
76 valid_start_out_x = std::ceil((valid_start_in_x + sampling_point) * scale_x - sampling_point);
77 valid_start_out_y = std::ceil((valid_start_in_y + sampling_point) * scale_y - sampling_point);
78
79 // (end_out - 1 + sampling_point) <= ((end_in - 1 + sampling_point) * scale)
80 // end_out = floor(((end_in - 1 + sampling_point) * scale) - sampling_point + 1)
81 valid_end_out_x = std::floor((valid_end_in_x - 1.f + sampling_point) * scale_x - sampling_point + 1.f);
82 valid_end_out_y = std::floor((valid_end_in_y - 1.f + sampling_point) * scale_y - sampling_point + 1.f);
83 break;
84 }
85 case InterpolationPolicy::AREA:
86 break;
87 default:
88 {
89 ARM_COMPUTE_ERROR("Invalid InterpolationPolicy");
90 break;
91 }
92 }
93 }
94
95 // Setup output valid region
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010096 ValidRegion valid_region{Coordinates(), dst_shape, dst_shape.num_dimensions()};
Diego Lopez Recas00854292018-02-22 13:08:01 +000097
Georgios Pinitas393fa4c2018-05-08 15:54:53 +010098 valid_region.anchor.set(idx_width, std::max(0, valid_start_out_x));
99 valid_region.anchor.set(idx_height, std::max(0, valid_start_out_y));
Diego Lopez Recas00854292018-02-22 13:08:01 +0000100
Georgios Pinitas393fa4c2018-05-08 15:54:53 +0100101 valid_region.shape.set(idx_width, std::min<size_t>(valid_end_out_x - valid_start_out_x, dst_shape[idx_width]));
102 valid_region.shape.set(idx_height, std::min<size_t>(valid_end_out_y - valid_start_out_y, dst_shape[idx_height]));
Diego Lopez Recas00854292018-02-22 13:08:01 +0000103
104 return valid_region;
SiCong Li96209c72020-08-21 12:28:30 +0100105}
Giorgio Arenac9fe9fc2021-10-06 12:54:29 +0100106
107const std::map<DataLayout, std::vector<DataLayoutDimension>> &get_layout_map()
108{
109 constexpr DataLayoutDimension W = DataLayoutDimension::WIDTH;
110 constexpr DataLayoutDimension H = DataLayoutDimension::HEIGHT;
111 constexpr DataLayoutDimension C = DataLayoutDimension::CHANNEL;
112 constexpr DataLayoutDimension D = DataLayoutDimension::DEPTH;
113 constexpr DataLayoutDimension N = DataLayoutDimension::BATCHES;
114
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100115 static const std::map<DataLayout, std::vector<DataLayoutDimension>> layout_map = {
116 {DataLayout::NDHWC, {C, W, H, D, N}},
117 {DataLayout::NCDHW, {W, H, D, C, N}},
118 {DataLayout::NHWC, {C, W, H, N}},
119 {DataLayout::NCHW, {W, H, C, N}}};
Giorgio Arenac9fe9fc2021-10-06 12:54:29 +0100120
121 return layout_map;
122}
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100123} // namespace arm_compute