blob: 46a2d80d10167ce01d6d29bd33e62b1cb1a6beee [file] [log] [blame]
Michalis Spyrou55b3d122018-05-09 09:59:23 +01001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/CL/functions/CLWidthConcatenateLayer.h"
25
26#include "arm_compute/core/CL/ICLTensor.h"
27#include "arm_compute/core/Error.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/TensorInfo.h"
30#include "arm_compute/core/Types.h"
31#include "arm_compute/core/utils/misc/ShapeCalculator.h"
32#include "arm_compute/runtime/CL/CLScheduler.h"
33#include "support/ToolchainSupport.h"
34
35using namespace arm_compute;
36
37CLWidthConcatenateLayer::CLWidthConcatenateLayer() // NOLINT
38 : _concat_kernels_vector(),
Michele Di Giorgio27400b92018-11-01 13:44:05 +000039 _concat_x2_kernel(),
40 _concat_x4_kernel(),
Michalis Spyrou55b3d122018-05-09 09:59:23 +010041 _num_inputs(0)
42{
43}
44
45Status CLWidthConcatenateLayer::validate(const std::vector<ITensorInfo *> &inputs_vector, const ITensorInfo *output) // NOLINT
46{
Michele Di Giorgio27400b92018-11-01 13:44:05 +000047 const unsigned int num_inputs = inputs_vector.size();
48
Michalis Spyrou55b3d122018-05-09 09:59:23 +010049 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output);
Michele Di Giorgio27400b92018-11-01 13:44:05 +000050 ARM_COMPUTE_RETURN_ERROR_ON(num_inputs < 2);
Michalis Spyrou55b3d122018-05-09 09:59:23 +010051
52 // Output auto inizialitation if not yet initialized
53 TensorInfo tmp_output_info = *output->clone();
54 TensorShape output_shape = arm_compute::misc::shape_calculator::calculate_width_concatenate_shape(inputs_vector);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010055 auto_init_if_empty(tmp_output_info, output_shape, 1, inputs_vector[0]->data_type());
Michalis Spyrou55b3d122018-05-09 09:59:23 +010056
Michele Di Giorgio27400b92018-11-01 13:44:05 +000057 switch(num_inputs)
Michalis Spyrou55b3d122018-05-09 09:59:23 +010058 {
Michele Di Giorgio27400b92018-11-01 13:44:05 +000059 case 2:
60 // Validate WidthConcatenate2Tensors kernels if there are 2 inputs
61 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(inputs_vector[0], inputs_vector[1]);
62 ARM_COMPUTE_RETURN_ON_ERROR(CLWidthConcatenate2TensorsKernel::validate(inputs_vector[0], inputs_vector[1], &tmp_output_info));
63 break;
64 case 4:
65 // Validate WidthConcatenate4Tensors kernels if there are 4 inputs
66 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(inputs_vector[0], inputs_vector[1], inputs_vector[2], inputs_vector[3]);
67 ARM_COMPUTE_RETURN_ON_ERROR(CLWidthConcatenate4TensorsKernel::validate(inputs_vector[0], inputs_vector[1], inputs_vector[2], inputs_vector[3], &tmp_output_info));
68 break;
69 default:
70 unsigned int width_offset = 0;
71 // Validate generic case of WidthConcatenate kernel
72 for(const auto &input : inputs_vector)
73 {
74 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input);
75 ARM_COMPUTE_RETURN_ON_ERROR(CLWidthConcatenateLayerKernel::validate(input, width_offset, &tmp_output_info));
76 width_offset += input->dimension(0);
77 }
78 break;
Michalis Spyrou55b3d122018-05-09 09:59:23 +010079 }
80
81 return Status{};
82}
83
84void CLWidthConcatenateLayer::configure(std::vector<ICLTensor *> inputs_vector, ICLTensor *output) // NOLINT
85{
86 _num_inputs = inputs_vector.size();
87
88 std::vector<ITensorInfo *> inputs_vector_info;
89 for(unsigned int i = 0; i < _num_inputs; i++)
90 {
91 inputs_vector_info.emplace_back(inputs_vector.at(i)->info());
92 }
93 TensorShape output_shape = arm_compute::misc::shape_calculator::calculate_width_concatenate_shape(inputs_vector);
94
95 // Output auto inizialitation if not yet initialized
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010096 auto_init_if_empty(*output->info(), output_shape, 1, inputs_vector[0]->info()->data_type());
Michele Di Giorgio27400b92018-11-01 13:44:05 +000097
Michalis Spyrou55b3d122018-05-09 09:59:23 +010098 ARM_COMPUTE_ERROR_THROW_ON(CLWidthConcatenateLayer::validate(inputs_vector_info, output->info()));
99
Michele Di Giorgio27400b92018-11-01 13:44:05 +0000100 switch(_num_inputs)
Michalis Spyrou55b3d122018-05-09 09:59:23 +0100101 {
Michele Di Giorgio27400b92018-11-01 13:44:05 +0000102 case 2:
103 // Configure WidthConcatenate2Tensors kernel
104 _concat_x2_kernel.configure(inputs_vector.at(0), inputs_vector.at(1), output);
105 break;
106 case 4:
107 // Configure WidthConcatenate4Tensors kernel
108 _concat_x4_kernel.configure(inputs_vector.at(0), inputs_vector.at(1), inputs_vector.at(2), inputs_vector.at(3), output);
109 break;
110 default:
111 // Configure generic case WidthConcatenate kernels
112 _concat_kernels_vector = arm_compute::support::cpp14::make_unique<CLWidthConcatenateLayerKernel[]>(_num_inputs);
113
114 unsigned int width_offset = 0;
115 for(unsigned int i = 0; i < _num_inputs; ++i)
116 {
117 _concat_kernels_vector[i].configure(inputs_vector.at(i), width_offset, output);
118 width_offset += inputs_vector.at(i)->info()->dimension(0);
119 }
120 break;
Michalis Spyrou55b3d122018-05-09 09:59:23 +0100121 }
122}
123
124void CLWidthConcatenateLayer::run()
125{
126 cl::CommandQueue q = CLScheduler::get().queue();
127
Michele Di Giorgio27400b92018-11-01 13:44:05 +0000128 switch(_num_inputs)
Michalis Spyrou55b3d122018-05-09 09:59:23 +0100129 {
Michele Di Giorgio27400b92018-11-01 13:44:05 +0000130 case 2:
131 CLScheduler::get().enqueue(_concat_x2_kernel, true);
132 break;
133 case 4:
134 CLScheduler::get().enqueue(_concat_x4_kernel, true);
135 break;
136 default:
137 for(unsigned int i = 0; i < _num_inputs; ++i)
138 {
139 CLScheduler::get().enqueue(_concat_kernels_vector[i], true);
140 }
141 break;
Michalis Spyrou55b3d122018-05-09 09:59:23 +0100142 }
143}