blob: 7929f14046feaf3d2dc2229cf7702786b277b812 [file] [log] [blame]
Kurtis Charnockec00da12019-11-29 11:42:30 +00001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2020 Arm Limited.
Kurtis Charnockec00da12019-11-29 11:42:30 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef ARM_COMPUTE_CPP_SPLIT_H
25#define ARM_COMPUTE_CPP_SPLIT_H
26
27#include "arm_compute/core/Error.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/TensorInfo.h"
30#include "arm_compute/core/Types.h"
31#include "arm_compute/core/utils/misc/ShapeCalculator.h"
32
33#include "support/ToolchainSupport.h"
34
35#include "arm_compute/runtime/IFunction.h"
36
37namespace arm_compute
38{
39/** Basic function to split a tensor along a given axis */
40template <typename SliceType, typename TensorInterfaceType = ITensor>
41class CPPSplit : public IFunction
42{
43public:
44 CPPSplit()
45 : _outputs_vector(), _slice_functions(), _num_outputs(0)
46 {
47 }
48 /** Static function to check if given info will lead to a valid configuration of @ref CPPSplit
49 *
Michele Di Giorgiof22f6722020-07-03 16:29:24 +010050 * @param[in] input The input tensor info. Data types supported: All.
51 * @param[in] outputs A vector containing the output tensors' info. Data types supported: same as @p input.
Kurtis Charnockec00da12019-11-29 11:42:30 +000052 * The output tensors should match the input tensor dimensions for all shape dimensions apart
53 * from the split dimension
54 * @param[in] axis Axis on which to split the input.
55 *
56 * @return a status
57 */
58 static Status validate(const ITensorInfo *input, const std::vector<ITensorInfo *> &outputs, unsigned int axis)
59 {
60 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input);
61 ARM_COMPUTE_RETURN_ERROR_ON(axis >= input->num_dimensions());
62 ARM_COMPUTE_RETURN_ERROR_ON(outputs.size() < 2);
63
64 // Get output shape
65 TensorShape output_shape{};
66 unsigned int total_output_shape_size = 0;
67
68 // Sum the output sizes and fall back to evenly-sized splits if any are zero
69 const bool using_split_shapes = std::none_of(outputs.begin(), outputs.end(), [&total_output_shape_size](ITensorInfo * info)
70 {
71 unsigned int output_shape_size = info->tensor_shape().total_size();
72 total_output_shape_size += output_shape_size;
73 return output_shape_size == 0;
74 });
75
76 if(using_split_shapes)
77 {
78 ARM_COMPUTE_RETURN_ERROR_ON(input->tensor_shape().total_size() != total_output_shape_size);
79 }
80 else
81 {
82 output_shape = arm_compute::misc::shape_calculator::compute_split_shape(input, axis, outputs.size());
83 ARM_COMPUTE_RETURN_ERROR_ON(output_shape.total_size() == 0);
84 }
85
86 // Validate output tensors
87 unsigned int axis_offset = 0;
88 for(const auto &output : outputs)
89 {
90 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output);
91 if(using_split_shapes)
92 {
93 output_shape = output->tensor_shape();
94 ARM_COMPUTE_RETURN_ERROR_ON(output_shape.total_size() == 0);
95 }
96
97 const size_t axis_split_step = output_shape[axis];
98
99 // Start/End coordinates
100 Coordinates start_coords;
101 Coordinates end_coords;
102 for(unsigned int d = 0; d < output_shape.num_dimensions(); ++d)
103 {
104 end_coords.set(d, -1);
105 }
106
107 // Output auto inizialitation if not yet initialized
108 TensorInfo tmp_output_info = *output->clone();
109 auto_init_if_empty(tmp_output_info, input->clone()->set_is_resizable(true).set_tensor_shape(output_shape));
110
111 // Update coordinate on axis
112 start_coords.set(axis, axis_offset);
113 end_coords.set(axis, axis_offset + axis_split_step);
114
115 ARM_COMPUTE_RETURN_ON_ERROR(SliceType::validate(input, output, start_coords, end_coords));
116 axis_offset += axis_split_step;
117 }
118
119 return Status{};
120 }
121
122 /** Initialise the kernel's input and outputs.
123 *
124 * @param[in] input The input tensor. Data types supported: All
125 * @param[out] outputs A vector containing the output tensors. Data types supported: Same as @p input.
126 * The output tensors should match the input tensor dimensions for all shape dimensions apart
127 * from the split dimension.
128 * @param[in] axis Axis on which to split the input.
129 */
130 void configure(const TensorInterfaceType *input, const std::vector<TensorInterfaceType *> &outputs, unsigned int axis)
131 {
132 // Create Slice functions
133 _num_outputs = outputs.size();
134 _slice_functions.resize(_num_outputs);
135
136 // Extract output tensor info
137 std::vector<ITensorInfo *> outputs_info;
138 for(auto &output : outputs)
139 {
140 ARM_COMPUTE_ERROR_ON_NULLPTR(output);
141 outputs_info.emplace_back(output->info());
142 }
143
144 // If any of the outputs have a zero size, fall-back to using evenly-sized output splits
145 const bool outputs_have_sizes = std::none_of(outputs_info.begin(), outputs_info.end(), [](ITensorInfo * info)
146 {
147 return info->tensor_shape().total_size() == 0;
148 });
149
150 // Validate
151 ARM_COMPUTE_ERROR_THROW_ON(CPPSplit::validate(input->info(), outputs_info, axis));
152
153 unsigned int axis_offset = 0;
154 unsigned int i = 0;
155
156 for(const auto &output_info : outputs_info)
157 {
158 // Get output shape
159 TensorShape output_shape = (outputs_have_sizes ?
160 output_info->tensor_shape() :
161 arm_compute::misc::shape_calculator::compute_split_shape(input->info(), axis, _num_outputs));
162
163 const size_t axis_split_step = output_shape[axis];
164
165 // Start/End coordinates
166 Coordinates start_coords;
167 Coordinates end_coords;
168
169 for(unsigned int d = 0; d < output_shape.num_dimensions(); ++d)
170 {
171 end_coords.set(d, -1);
172 }
173
174 // Update coordinate on axis
175 start_coords.set(axis, axis_offset);
176 end_coords.set(axis, axis_offset + axis_split_step);
177
178 // Configure slice function
179 _slice_functions[i].configure(input, outputs[i], start_coords, end_coords);
180
181 // Set valid region from shape
182 outputs[i]->info()->set_valid_region(ValidRegion(Coordinates(), output_shape));
183
184 // Update axis offset
185 axis_offset += axis_split_step;
186 ++i;
187 }
188 }
189
190protected:
191 std::vector<TensorInterfaceType *> _outputs_vector;
192 std::vector<SliceType> _slice_functions;
193 unsigned int _num_outputs;
194};
195
196} // namespace arm_compute
197#endif /* ARM_COMPUTE_CPP_SPLIT_H */