blob: a4bce5da5a7ee04b15b22bcfddcff666e3de3b44 [file] [log] [blame]
Georgios Pinitas77589b52018-08-21 14:41:35 +01001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/utils/helpers/tensor_transform.h"
25
26namespace arm_compute
27{
28namespace helpers
29{
30namespace tensor_transform
31{
Georgios Pinitasc1a72452018-08-24 11:25:32 +010032Coordinates slice_absolute_end_coords(TensorShape input_shape, Coordinates ends)
33{
34 // Create end mask
35 int32_t end_mask = 0;
36 for(unsigned int i = 0; i < ends.num_dimensions(); ++i)
37 {
38 if(ends[i] < 0)
39 {
40 end_mask |= 1 << i;
41 }
42 }
43 // Get unit strides
44 const BiStrides unit_strides = strided_slice_strides(input_shape, BiStrides());
45
46 return strided_slice_absolute_end_coords(input_shape, Coordinates(), ends, unit_strides, end_mask);
47}
48
49TensorShape compute_slice_output_shape(TensorShape input_shape, Coordinates starts, Coordinates ends_abs)
50{
51 // Get unit strides
52 const BiStrides unit_strides = strided_slice_strides(input_shape, BiStrides());
53 return compute_strided_slice_output_shape(input_shape, starts, ends_abs, unit_strides);
54}
55
Georgios Pinitas77589b52018-08-21 14:41:35 +010056Coordinates strided_slice_absolute_start_coords(TensorShape input_shape, Coordinates starts, Coordinates strides, int32_t begin_mask)
57{
58 Coordinates starts_abs;
59 for(unsigned int i = 0; i < starts.num_dimensions(); ++i)
60 {
61 // Get start index
62 int start_i = starts[i];
63
64 // Reset in case of begin mask present
65 if((begin_mask & 1 << i) != 0)
66 {
67 start_i = strides[i] > 0 ? std::numeric_limits<int>::lowest() : std::numeric_limits<int>::max();
68 }
69
70 // Account negative start points
71 const int dim_size = input_shape[i];
72 if(start_i < 0)
73 {
74 start_i += dim_size;
75 }
76
77 // Final clamp
78 start_i = utility::clamp(start_i, 0, dim_size - 1);
79 starts_abs.set(i, start_i);
80 }
81
82 // Fill remaining
83 for(unsigned int i = starts_abs.num_dimensions(); i < input_shape.num_dimensions(); ++i)
84 {
85 starts_abs.set(i, 0);
86 }
87
88 return starts_abs;
89}
90
91Coordinates strided_slice_absolute_end_coords(TensorShape input_shape, Coordinates starts_abs, Coordinates ends, Coordinates strides,
92 int32_t end_mask, int32_t shrink_axis_mask)
93{
94 Coordinates ends_abs;
95 for(unsigned int i = 0; i < ends.num_dimensions(); ++i)
96 {
97 // Get end index
98 int stop_i = ends[i];
99
100 // Shrink dimension
101 if((shrink_axis_mask & (1 << i)) != 0)
102 {
103 stop_i = starts_abs[i] + 1;
104 }
105
106 // Reset in case of begin mask present
107 if((end_mask & 1 << i) != 0)
108 {
109 stop_i = (strides[i] > 0) ? std::numeric_limits<int>::max() : std::numeric_limits<int>::lowest();
110 }
111
112 // Account negative end points
113 const int dim_size = input_shape[i];
114 if(stop_i < 0)
115 {
116 stop_i += dim_size;
117 }
118
119 // Final clamp
120 stop_i = (strides[i] > 0) ? utility::clamp(stop_i, 0, dim_size) : utility::clamp(stop_i, -1, dim_size - 1);
121 ends_abs.set(i, stop_i);
122 }
123
124 // Fill remaining ends
125 for(unsigned int i = ends_abs.num_dimensions(); i < input_shape.num_dimensions(); ++i)
126 {
127 ends_abs.set(i, input_shape[i]);
128 }
129
130 return ends_abs;
131}
132
133Coordinates strided_slice_strides(TensorShape input_shape, Coordinates strides)
134{
135 for(unsigned int i = strides.num_dimensions(); i < input_shape.num_dimensions(); ++i)
136 {
137 strides.set(i, 1);
138 }
139 return strides;
140}
141
142TensorShape compute_strided_slice_output_shape(TensorShape input_shape, Coordinates starts_abs, Coordinates ends_abs, Coordinates final_strides)
143{
144 TensorShape output_shape = input_shape;
145 for(unsigned int i = 0; i < input_shape.num_dimensions(); ++i)
146 {
147 const int stride_i = final_strides[i];
148 const int range = ends_abs[i] - starts_abs[i];
149 if((range == 0) || // Zero range
150 (range < 0 && stride_i >= 0) || // Negative range with positive stride
151 (range > 0 && stride_i <= 0)) // Positive range with negative stride
152 {
153 output_shape.set(i, 0);
154 return output_shape;
155 }
156 else
157 {
158 int dim = range / stride_i + (range % stride_i != 0 ? 1 : 0);
159 output_shape.set(i, dim);
160 }
161 }
162 return output_shape;
163}
164} // namespace tensor_transform
165} // namespace helpers
166} // namespace arm_compute