blob: e60fe80e0fd9d0f1822dabb82eb19f65e9a3d0ea [file] [log] [blame]
Anthony Barbier71d9b572018-07-06 17:05:59 +01001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h"
25
Anthony Barbiereaefd002018-07-20 17:49:35 +010026#include "arm_compute/core/CPP/Validate.h"
Anthony Barbierc8e84b52018-07-17 16:48:42 +010027#include "arm_compute/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.h"
Anthony Barbier71d9b572018-07-06 17:05:59 +010028#include "arm_compute/runtime/NEON/NEScheduler.h"
Anthony Barbierc8e84b52018-07-17 16:48:42 +010029#include "arm_compute/runtime/NEON/functions/NESimpleAssemblyFunction.h"
Anthony Barbier71d9b572018-07-06 17:05:59 +010030
Anthony Barbiereaefd002018-07-20 17:49:35 +010031#include <arm_neon.h>
32
Anthony Barbierc8e84b52018-07-17 16:48:42 +010033namespace arm_compute
34{
Anthony Barbiereaefd002018-07-20 17:49:35 +010035namespace
Anthony Barbier71d9b572018-07-06 17:05:59 +010036{
Anthony Barbiereaefd002018-07-20 17:49:35 +010037template <typename TypeInput, typename TypeOutput>
38std::unique_ptr<IFunction> create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
39{
40 ARM_COMPUTE_UNUSED(method);
41 ARM_COMPUTE_UNUSED(a);
42 ARM_COMPUTE_UNUSED(b);
43 ARM_COMPUTE_UNUSED(d);
44 ARM_COMPUTE_UNUSED(alpha);
45 ARM_COMPUTE_UNUSED(beta);
46 ARM_COMPUTE_UNUSED(pretranspose_hint);
47 return nullptr;
Anthony Barbier71d9b572018-07-06 17:05:59 +010048}
Anthony Barbierc8e84b52018-07-17 16:48:42 +010049template <>
Anthony Barbiereaefd002018-07-20 17:49:35 +010050std::unique_ptr<IFunction> create_function<float, float>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
Anthony Barbierc8e84b52018-07-17 16:48:42 +010051{
52 ARM_COMPUTE_UNUSED(method);
53 ARM_COMPUTE_UNUSED(a);
54 ARM_COMPUTE_UNUSED(b);
55 ARM_COMPUTE_UNUSED(d);
56 ARM_COMPUTE_UNUSED(alpha);
57 ARM_COMPUTE_UNUSED(beta);
58 ARM_COMPUTE_UNUSED(pretranspose_hint);
59 switch(method)
60 {
61#ifdef __aarch64__
62 case arm_gemm::GemmMethod::GEMM_NATIVE:
63 {
64 auto kernel = support::cpp14::make_unique<NEGEMMNativeWrapperKernel<float, float>>();
65 kernel->configure(a, b, d, alpha, beta);
66 auto function = support::cpp14::make_unique<NESimpleAssemblyFunction>();
67 function->configure(std::move(kernel));
Anthony Barbiereaefd002018-07-20 17:49:35 +010068 return std::move(function);
Anthony Barbierc8e84b52018-07-17 16:48:42 +010069 }
70#endif /* __aarch64__ */
71 default:
Anthony Barbiereaefd002018-07-20 17:49:35 +010072 return nullptr;
Anthony Barbierc8e84b52018-07-17 16:48:42 +010073 }
74}
75
Anthony Barbiereaefd002018-07-20 17:49:35 +010076/** Fallback in case ACL doesn't have a function */
Anthony Barbierc8e84b52018-07-17 16:48:42 +010077template <typename TypeInput, typename TypeOutput>
Anthony Barbiereaefd002018-07-20 17:49:35 +010078class Fallback : public NEGEMMAssemblyDispatch::IFallback
Anthony Barbierc8e84b52018-07-17 16:48:42 +010079{
Anthony Barbiereaefd002018-07-20 17:49:35 +010080public:
81 void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> &args, MemoryGroup &memory_group);
82 void run() override;
83 void prepare() override;
84 bool is_configured() const override;
Anthony Barbierc8e84b52018-07-17 16:48:42 +010085
Anthony Barbiereaefd002018-07-20 17:49:35 +010086private:
87 /** Allocate a workspace tensor.
88 *
89 * @param[in] workspace_size Size to allocate.
90 * @param[in] memory_group Tensor memory group.
91 * @param[in] alignment Workspace memory alignment.
92 */
93 void allocate_workspace(size_t workspace_size, MemoryGroup *memory_group, size_t alignment);
Anthony Barbierc8e84b52018-07-17 16:48:42 +010094
Anthony Barbiereaefd002018-07-20 17:49:35 +010095 /** Assembly Gemm kernel */
96 std::unique_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{ nullptr };
97 /** Optimised NEON kernel */
98 std::unique_ptr<INEKernel> _optimised_kernel{ nullptr };
99 /** Input A */
100 const ITensor *_a
Anthony Barbierc8e84b52018-07-17 16:48:42 +0100101 {
Anthony Barbiereaefd002018-07-20 17:49:35 +0100102 nullptr
103 };
104 /** Input B */
105 const ITensor *_b
106 {
107 nullptr
108 };
109 /** Output */
110 ITensor *_d{ nullptr };
111 /** GEMM workspace */
112 Tensor _workspace{};
113 /** Pre-transpose tensor */
114 Tensor _pretranspose{};
115 /** Prepared flag */
116 bool _is_prepared{ false };
117};
Anthony Barbier71d9b572018-07-06 17:05:59 +0100118
119template <typename TypeInput, typename TypeOutput>
Anthony Barbiereaefd002018-07-20 17:49:35 +0100120void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> &args, MemoryGroup &memory_group)
Anthony Barbier71d9b572018-07-06 17:05:59 +0100121{
Anthony Barbierc8e84b52018-07-17 16:48:42 +0100122 _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput>(args, nullptr);
Anthony Barbier71d9b572018-07-06 17:05:59 +0100123 if(_gemm_kernel_asm == nullptr)
124 {
125 //configuration not supported: Leave function unconfigured:
126 return;
127 }
128
129 // arm_compute wrapper for the Gemm object (see above)
130 std::unique_ptr<NEGEMMAssemblyWrapperKernel<TypeInput, TypeOutput>> acl_gemm_wrapper = support::cpp14::make_unique<NEGEMMAssemblyWrapperKernel<TypeInput, TypeOutput>>();
131 ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr);
132 acl_gemm_wrapper->configure(_gemm_kernel_asm.get());
133 const size_t workspace_size = _gemm_kernel_asm->get_working_size();
134 if(workspace_size > 0)
135 {
136 // Allocate workspace
137 const unsigned int alignment = 4096;
138 //FIXME: is memory_group ever null ?
139 allocate_workspace(workspace_size, &memory_group, alignment);
140 }
141
142 //if we disable this code below in brackets then ConvLayer deadlocks when threads > 1 and
143 //the shapes are In=1x1x1024 Weights=1x1x1024x1001 Biases=1001 Out=1x1x1001
144 {
Anthony Barbierc8e84b52018-07-17 16:48:42 +0100145 const int window_size = _gemm_kernel_asm->get_window_size();
146 if(window_size < args._maxthreads)
Anthony Barbier71d9b572018-07-06 17:05:59 +0100147 {
Anthony Barbierc8e84b52018-07-17 16:48:42 +0100148 _gemm_kernel_asm->set_nthreads(window_size);
Anthony Barbier71d9b572018-07-06 17:05:59 +0100149 }
150 }
151
152 _optimised_kernel = std::move(acl_gemm_wrapper);
153 _a = a;
154 _b = b;
155 _d = d;
156 // Check for pre-transposed support
157 if(_gemm_kernel_asm->B_pretranspose_required())
158 {
159 // Forcing 128-byte alignment (required by 32-bit kernels)
160 const unsigned int alignment = 128;
161 const size_t B_pretranspose_size = _gemm_kernel_asm->get_B_pretransposed_array_size();
162 _pretranspose.allocator()->init(TensorInfo(TensorShape{ (B_pretranspose_size + alignment /* FIXME: remove alignment after COMPMID-1088 */) }, 1, DataType::S8), alignment);
163 _pretranspose.allocator()->allocate();
164 ARM_COMPUTE_ERROR_ON_NULLPTR(_pretranspose.buffer());
165 }
166}
167
168template <typename TypeInput, typename TypeOutput>
Anthony Barbiereaefd002018-07-20 17:49:35 +0100169void Fallback<TypeInput, TypeOutput>::prepare()
Anthony Barbier71d9b572018-07-06 17:05:59 +0100170{
171 if(!_is_prepared)
172 {
173 // Pretranspose B if required
174 if(_gemm_kernel_asm->B_pretranspose_required())
175 {
176 const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput);
177 const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer());
178 const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput);
179
180 ARM_COMPUTE_ERROR_ON(_pretranspose.buffer() == nullptr);
181 _gemm_kernel_asm->pretranspose_B_array(_pretranspose.buffer(), in1_ptr, ldb, multi_stride_b);
182 _b->mark_as_unused();
183 }
184
185 _is_prepared = true;
186 }
187}
188
189template <typename TypeInput, typename TypeOutput>
Anthony Barbiereaefd002018-07-20 17:49:35 +0100190void Fallback<TypeInput, TypeOutput>::allocate_workspace(size_t workspace_size, MemoryGroup *memory_group, size_t alignment)
Anthony Barbier71d9b572018-07-06 17:05:59 +0100191{
192 ARM_COMPUTE_ERROR_ON_MSG(workspace_size == 0, "size cannot be 0");
193 _workspace.allocator()->init(TensorInfo(TensorShape{ (workspace_size + alignment /* FIXME: remove alignment after COMPMID-1088 */) }, 1, DataType::S8), alignment);
194 if(memory_group != nullptr)
195 {
196 memory_group->manage(&_workspace);
197 }
198 _workspace.allocator()->allocate();
199}
200
201template <typename TypeInput, typename TypeOutput>
Anthony Barbiereaefd002018-07-20 17:49:35 +0100202bool Fallback<TypeInput, TypeOutput>::is_configured() const
Anthony Barbier71d9b572018-07-06 17:05:59 +0100203{
204 return _optimised_kernel != nullptr;
205}
206
207template <typename TypeInput, typename TypeOutput>
Anthony Barbiereaefd002018-07-20 17:49:35 +0100208void Fallback<TypeInput, TypeOutput>::run()
Anthony Barbier71d9b572018-07-06 17:05:59 +0100209{
210 const int lda = _a->info()->strides_in_bytes().y() / sizeof(TypeInput);
211 const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput);
212 const int ldd = _d->info()->strides_in_bytes().y() / sizeof(TypeOutput);
213
214 // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is
215 // the relevant multiple of the row stride.
216 const bool is_nhwc = _a->info()->data_layout() == DataLayout::NHWC;
217 const int stride_in_bytes_a = is_nhwc ? _a->info()->strides_in_bytes().y() * _d->info()->dimension(1) : _a->info()->strides_in_bytes().z();
218
219 const int batch_stride_a = stride_in_bytes_a / sizeof(TypeInput);
220 const int batch_stride_d = _d->info()->strides_in_bytes().z() / sizeof(TypeOutput);
221
222 const int multi_stride_a = _a->info()->strides_in_bytes()[3] / sizeof(TypeInput);
223 const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput);
224 const int multi_stride_d = _d->info()->strides_in_bytes()[3] / sizeof(TypeOutput);
225
226 const auto in0_ptr = reinterpret_cast<const TypeInput *>(_a->buffer());
227 const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer());
228 auto out_ptr = reinterpret_cast<TypeOutput *>(_d->buffer());
229
230 // Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads
231 if(_workspace.buffer() != nullptr)
232 {
233 _gemm_kernel_asm->set_working_space(reinterpret_cast<void *>(_workspace.buffer()));
234 const unsigned int window_size = _gemm_kernel_asm->get_window_size();
235 unsigned int num_threads = NEScheduler::get().num_threads();
236 if(window_size < num_threads)
237 {
238 num_threads = window_size;
239 _gemm_kernel_asm->set_nthreads(num_threads);
240 }
241 }
242
243 // Prepare assembly kernel
244 prepare();
245
246 // Set gemm parameters
247 _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, in1_ptr, ldb, multi_stride_b, out_ptr, ldd, batch_stride_d, multi_stride_d);
248
249 // Schedule assembly kernel
250 NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX);
251}
252
Anthony Barbiereaefd002018-07-20 17:49:35 +0100253template <typename TypeInput, typename TypeOutput>
254void create_function_or_arm_gemm(std::unique_ptr<IFunction> &acl_function, std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group, const ITensor *a, const ITensor *b,
255 ITensor *d, float alpha, float beta, bool pretranspose_hint)
256{
257 INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d);
258 const CPUInfo &ci = NEScheduler::get().cpu_info();
259 unsigned int num_threads = NEScheduler::get().num_threads();
260
261 arm_gemm::GemmArgs<TypeOutput> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint);
262
263 //Try to create an ACL function:
264 acl_function = create_function<TypeInput, TypeOutput>(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint);
265 if(acl_function == nullptr)
266 {
267 //Fallback onto arm_gemm function if ACL doesn't support this method.
268 auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput>>();
269 fallback->configure(a, b, d, args, memory_group);
270 arm_gemm = std::move(fallback);
271 }
272}
273
274} //namespace
275
276NEGEMMAssemblyDispatch::NEGEMMAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager)
277 : _function(nullptr), _arm_gemm(nullptr), _memory_group(std::move(memory_manager))
278{
279}
280
281Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, float alpha, float beta, bool pretranspose_hint)
282{
283 ARM_COMPUTE_UNUSED(alpha);
284 ARM_COMPUTE_UNUSED(beta);
285 ARM_COMPUTE_UNUSED(pretranspose_hint);
286 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(a, b, d);
287 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
288#ifndef __aarch64__
289 ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 || a->data_type() == DataType::S8 || a->data_type() == DataType::QASYMM8, "8bit integer types only supported for aarch64");
290#endif /* __aarch64__ */
291 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::F32, DataType::U8, DataType::QASYMM8, DataType::S8, DataType::F16);
292 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
293 ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F32 && d->data_type() != DataType::F32, "Only F32 output supported for F32 input");
294 ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F16 && d->data_type() != DataType::F16, "Only F16 output supported for F16 input");
295 ARM_COMPUTE_RETURN_ERROR_ON_MSG((a->data_type() == DataType::QASYMM8 || a->data_type() == DataType::U8) && d->data_type() != DataType::U32, "Only U32 output supported for U8 / QASYMM8 input");
296 ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input");
297 return Status{};
298}
299
300void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
301{
302 ARM_COMPUTE_ERROR_ON_NULLPTR(a);
303 ARM_COMPUTE_ERROR_ON_NULLPTR(b);
304 ARM_COMPUTE_ERROR_ON_NULLPTR(d);
305
306 //If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured()
307 if(!NEGEMMAssemblyDispatch::validate(a->info(), b->info(), d->info(), alpha, beta, pretranspose_hint))
308 {
309 return;
310 }
311
312 switch(a->info()->data_type())
313 {
314 case DataType::F32:
315 create_function_or_arm_gemm<float, float>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint);
316 break;
317#ifdef __aarch64__
318 case DataType::U8:
319 case DataType::QASYMM8:
320 create_function_or_arm_gemm<uint8_t, uint32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint);
321 break;
322 case DataType::S8:
323 create_function_or_arm_gemm<int8_t, int32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint);
324 break;
325#endif /* __aarch64__ */
326#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
327 case DataType::F16:
328 create_function_or_arm_gemm<float16_t, float16_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint);
329 break;
330#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
331 default:
332 break;
333 }
334}
335
336void NEGEMMAssemblyDispatch::prepare()
337{
338 if(_function != nullptr)
339 {
340 _function->prepare();
341 }
342 else
343 {
344 ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
345 _arm_gemm->prepare();
346 }
347}
348
349bool NEGEMMAssemblyDispatch::is_configured() const
350{
351 return (_arm_gemm != nullptr && _arm_gemm->is_configured()) || _function != nullptr;
352}
353
354void NEGEMMAssemblyDispatch::run()
355{
356 _memory_group.acquire();
357 if(_function != nullptr)
358 {
359 _function->run();
360 }
361 else
362 {
363 ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
364 _arm_gemm->run();
365 }
366 _memory_group.release();
367}
Anthony Barbier71d9b572018-07-06 17:05:59 +0100368} //namespace arm_compute