blob: 1a7f58bb08e0716fd6b19f2d3b8223d518c757e9 [file] [log] [blame]
David Svantesson3b162e52023-03-28 14:13:32 +00001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
David Svantessoncd8b40d2023-05-02 13:05:36 +000024#if defined(__aarch64__)
David Svantesson3b162e52023-03-28 14:13:32 +000025
26#include "src/core/NEON/kernels/NEReorderKernel.h"
27#include "src/common/utils/Log.h"
28#include "src/core/NEON/kernels/arm_gemm/transform.hpp"
29#include "arm_compute/core/Helpers.h"
30#include "arm_compute/core/Validate.h"
31
32namespace arm_compute
33{
34
35void NEReorderKernel::run(const Window &window, const ThreadInfo &info)
36{
37 ARM_COMPUTE_UNUSED(info);
38 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
39 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
40 switch(_input->info()->data_type())
41 {
42 case DataType::F32:
43 {
44 const int ksize_rows_elements = _xmax * _ksize;
45 const int jump_rows = ksize_rows_elements * window.x().start();
46 const int k_start = window.x().start() * _ksize;
47 const int k_end = std::min(window.x().end() * _ksize, _kmax);
48 const int stride = _kmax;
49 if(k_start < k_end)
50 {
51
52 switch(_output_wf)
53 {
54 case WeightFormat::OHWIo4:
55 {
56 arm_gemm::Transform<4, 1, true, arm_gemm::VLType::None>(reinterpret_cast<float *>(_output->buffer()) + jump_rows, reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax);
57 break;
58 }
59#if defined(ARM_COMPUTE_ENABLE_SVE)
60 case WeightFormat::OHWIo8:
61 {
62 arm_gemm::Transform<1, 1, true, arm_gemm::VLType::SVE>(reinterpret_cast<float *>(_output->buffer()) + jump_rows, reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax);
63 break;
64 }
65#endif /* ARM_COMPUTE_ENABLE_SVE */
66 default:
67 {
68 ARM_COMPUTE_ERROR("Unsupported data type!");
69 break;
70 }
71 }
72 }
73 break;
74 }
75 default:
76 ARM_COMPUTE_ERROR("Unsupported data type!");
77 }
78}
79
80NEReorderKernel::NEReorderKernel()
81 : _input(nullptr), _output(nullptr), _ksize(0), _kmax(0), _xmax(0), _input_wf(WeightFormat::ANY), _output_wf(WeightFormat::ANY)
82{
83}
84
85void NEReorderKernel::configure(const ITensor *input, ITensor *output, arm_compute::WeightFormat input_wf, arm_compute::WeightFormat output_wf)
86{
87 ARM_COMPUTE_LOG_PARAMS(input, output, input_wf, output_wf);
88 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
89 ARM_COMPUTE_ERROR_THROW_ON(validate(input->info(), output->info(), input_wf, output_wf));
90
91 // Set variables
92 _input = input;
93 _output = output;
94 _input_wf = input_wf;
95 _output_wf = output_wf;
96
97 // Setting parameters for transform
98 auto dims = input->info()->num_dimensions();
99 switch(dims)
100 {
101 case 2:
102 {
103 _xmax = input->info()->dimension(0); // Number of columns in input matrix
104 _kmax = input->info()->dimension(1); // Number of rows in input matrix
105 break;
106 }
107 case 4:
108 {
109 _xmax = input->info()->dimension(2); // Number of columns in input matrix
110 _kmax = input->info()->dimension(3); // Number of rows in input matrix
111 break;
112 }
113 default:
114 {
115 ARM_COMPUTE_ERROR("Only 2 or 4 dimensions supported.");
116 }
117 }
118
119 // Configure kernel window
120 // Window size is set by rows / _ksize
121 Window win;
122 int window_size = 0;
123 switch(_output_wf)
124 {
125#if defined(ARM_COMPUTE_ENABLE_SVE)
126 case WeightFormat::OHWIo8:
127 {
128 _ksize = 8;
129 window_size = _kmax / _ksize;
130 break;
131 }
132#endif /* ARM_COMPUTE_ENABLE_SVE */
133 case WeightFormat::OHWIo4:
134 {
135 _ksize = 4;
136 window_size = _kmax / _ksize;
137 break;
138 }
139 default:
140 {
141 ARM_COMPUTE_ERROR("Unsupported weight format.");
142 break;
143 }
144 }
145 if(_kmax % _ksize != 0)
146 {
147 window_size += 1;
148 }
149
150 win.set(Window::DimX, Window::Dimension(0, window_size, 1));
151
152 INEKernel::configure(win);
153}
154
155Status NEReorderKernel::validate(const ITensorInfo *input, const ITensorInfo *output, arm_compute::WeightFormat input_wf, arm_compute::WeightFormat output_wf)
156{
157 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
158 ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::UNKNOWN);
159 if(output->tensor_shape().total_size() != 0)
160 {
161 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
162 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
163 // Only input WeightFormat OHWI supported
164 ARM_COMPUTE_RETURN_ERROR_ON(input_wf != arm_compute::WeightFormat::OHWI);
165 int input_x_dim;
166 int input_k_dim;
167 int output_x_dim;
168 int output_k_dim;
169 auto dims = output->num_dimensions();
170 switch(dims)
171 {
172 case 2:
173 {
174 input_x_dim = input->dimension(0); // Number of columns in input matrix
175 input_k_dim = input->dimension(1); // Number of rows in input matrix
176 output_x_dim = output->dimension(0); // Number of columns in output matrix
177 output_k_dim = output->dimension(1); // Number of rows in output matrix
178 break;
179 }
180 case 4:
181 {
182 input_x_dim = input->dimension(2); // Number of columns in input matrix
183 input_k_dim = input->dimension(3); // Number of rows in input matrix
184 output_x_dim = output->dimension(2); // Number of columns in output matrix
185 output_k_dim = output->dimension(3); // Number of rows in output matrix
186 break;
187 }
188 default:
189 {
190 ARM_COMPUTE_RETURN_ERROR_MSG("Only 2 or 4 dimensions supported.");
191 }
192 }
193
194 int ksize;
195 switch(output_wf)
196 {
197 case WeightFormat::OHWIo8:
198 {
199 ksize = 8;
200 break;
201 }
202 case WeightFormat::OHWIo4:
203 {
204 ksize = 4;
205 break;
206 }
207 default:
208 {
209 ARM_COMPUTE_RETURN_ERROR_MSG("Unsupported weight format.");
210 break;
211 }
212 }
213
214 // output k_dim needs to be same as input but multiple of ksize
215 int32_t rnd_up_input_kdim = arm_compute::ceil_to_multiple<int32_t, int32_t>(input_k_dim, ksize);
216 ARM_COMPUTE_RETURN_ERROR_ON(rnd_up_input_kdim != output_k_dim);
217 // output x_dim needs to be same as input
218 ARM_COMPUTE_RETURN_ERROR_ON(input_x_dim != output_x_dim);
219
220 }
221 return Status{};
222}
223
224} // namespace arm_compute
David Svantessoncd8b40d2023-05-02 13:05:36 +0000225
226#endif // defined(__aarch64__)