blob: f4710fab8483d74c8f532f9ac24424d515142423 [file] [log] [blame]
Anthony Barbier71d9b572018-07-06 17:05:59 +01001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h"
25
Anthony Barbierc8e84b52018-07-17 16:48:42 +010026#include "arm_compute/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.h"
Anthony Barbier71d9b572018-07-06 17:05:59 +010027#include "arm_compute/runtime/NEON/NEScheduler.h"
Anthony Barbierc8e84b52018-07-17 16:48:42 +010028#include "arm_compute/runtime/NEON/functions/NESimpleAssemblyFunction.h"
Anthony Barbier71d9b572018-07-06 17:05:59 +010029
Anthony Barbierc8e84b52018-07-17 16:48:42 +010030namespace arm_compute
31{
Anthony Barbier71d9b572018-07-06 17:05:59 +010032template <typename TypeInput, typename TypeOutput>
33NEGEMMAssemblyDispatch<TypeInput, TypeOutput>::NEGEMMAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager)
34 : _function(nullptr), _arm_gemm(), _memory_group(std::move(memory_manager))
35{
36}
37
Anthony Barbierc8e84b52018-07-17 16:48:42 +010038template <>
39bool NEGEMMAssemblyDispatch<float, float>::create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
40{
41 ARM_COMPUTE_UNUSED(method);
42 ARM_COMPUTE_UNUSED(a);
43 ARM_COMPUTE_UNUSED(b);
44 ARM_COMPUTE_UNUSED(d);
45 ARM_COMPUTE_UNUSED(alpha);
46 ARM_COMPUTE_UNUSED(beta);
47 ARM_COMPUTE_UNUSED(pretranspose_hint);
48 switch(method)
49 {
50#ifdef __aarch64__
51 case arm_gemm::GemmMethod::GEMM_NATIVE:
52 {
53 auto kernel = support::cpp14::make_unique<NEGEMMNativeWrapperKernel<float, float>>();
54 kernel->configure(a, b, d, alpha, beta);
55 auto function = support::cpp14::make_unique<NESimpleAssemblyFunction>();
56 function->configure(std::move(kernel));
57 _function = std::move(function);
58 return true;
59 }
60#endif /* __aarch64__ */
61 default:
62 return false;
63 }
64}
65
66template <typename TypeInput, typename TypeOutput>
67bool NEGEMMAssemblyDispatch<TypeInput, TypeOutput>::create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
68{
69 ARM_COMPUTE_UNUSED(method);
70 ARM_COMPUTE_UNUSED(a);
71 ARM_COMPUTE_UNUSED(b);
72 ARM_COMPUTE_UNUSED(d);
73 ARM_COMPUTE_UNUSED(alpha);
74 ARM_COMPUTE_UNUSED(beta);
75 ARM_COMPUTE_UNUSED(pretranspose_hint);
76 return false;
77}
78
Anthony Barbier71d9b572018-07-06 17:05:59 +010079template <typename TypeInput, typename TypeOutput>
80void NEGEMMAssemblyDispatch<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
81{
Anthony Barbierc8e84b52018-07-17 16:48:42 +010082 INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d);
83 const CPUInfo &ci = NEScheduler::get().cpu_info();
84 unsigned int num_threads = NEScheduler::get().num_threads();
85
86 arm_gemm::GemmArgs<TypeOutput> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint);
87
88 //Try to create an ACL function:
89 if(!create_function(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint))
90 {
91 //Fallback onto arm_gemm function if ACL doesn't support this method.
92 _arm_gemm.configure(a, b, d, args, _memory_group);
93 }
Anthony Barbier71d9b572018-07-06 17:05:59 +010094}
95
96template <typename TypeInput, typename TypeOutput>
97void NEGEMMAssemblyDispatch<TypeInput, TypeOutput>::prepare()
98{
99 if(_function != nullptr)
100 {
101 _function->prepare();
102 }
103 else
104 {
105 _arm_gemm.prepare();
106 }
107}
108
109template <typename TypeInput, typename TypeOutput>
110bool NEGEMMAssemblyDispatch<TypeInput, TypeOutput>::is_configured() const
111{
112 return _arm_gemm.is_configured() || _function != nullptr;
113}
114
115template <typename TypeInput, typename TypeOutput>
116void NEGEMMAssemblyDispatch<TypeInput, TypeOutput>::run()
117{
118 _memory_group.acquire();
119 if(_function != nullptr)
120 {
121 _function->run();
122 }
123 else
124 {
125 _arm_gemm.run();
126 }
127 _memory_group.release();
128}
129
130#ifndef __aarch64__
131template <>
Anthony Barbierc8e84b52018-07-17 16:48:42 +0100132void NEGEMMAssemblyDispatch<uint8_t, uint32_t>::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
Anthony Barbier71d9b572018-07-06 17:05:59 +0100133{
134 // arm_gemm::gemm for 8bit only exists for aarch64
135 ARM_COMPUTE_UNUSED(a);
136 ARM_COMPUTE_UNUSED(b);
137 ARM_COMPUTE_UNUSED(d);
138 ARM_COMPUTE_UNUSED(alpha);
139 ARM_COMPUTE_UNUSED(beta);
140 ARM_COMPUTE_UNUSED(pretranspose_hint);
Anthony Barbierc8e84b52018-07-17 16:48:42 +0100141 ARM_COMPUTE_ERROR("Not supported for this architecture");
Anthony Barbier71d9b572018-07-06 17:05:59 +0100142}
143
144template <>
Anthony Barbierc8e84b52018-07-17 16:48:42 +0100145void NEGEMMAssemblyDispatch<int8_t, int32_t>::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
Anthony Barbier71d9b572018-07-06 17:05:59 +0100146{
147 // arm_gemm::gemm for 8bit only exists for aarch64
148 ARM_COMPUTE_UNUSED(a);
149 ARM_COMPUTE_UNUSED(b);
150 ARM_COMPUTE_UNUSED(d);
151 ARM_COMPUTE_UNUSED(alpha);
152 ARM_COMPUTE_UNUSED(beta);
153 ARM_COMPUTE_UNUSED(pretranspose_hint);
Anthony Barbierc8e84b52018-07-17 16:48:42 +0100154 ARM_COMPUTE_ERROR("Not supported for this architecture");
Anthony Barbier71d9b572018-07-06 17:05:59 +0100155}
Anthony Barbier67fd4e82018-07-13 09:09:46 +0100156
Anthony Barbierc8e84b52018-07-17 16:48:42 +0100157template <>
158void NEGEMMAssemblyDispatch<uint8_t, uint32_t>::Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<uint32_t> &args, MemoryGroup &memory_group)
159{
160 // arm_gemm::gemm for 8bit only exists for aarch64
161 ARM_COMPUTE_UNUSED(a);
162 ARM_COMPUTE_UNUSED(b);
163 ARM_COMPUTE_UNUSED(d);
164 ARM_COMPUTE_UNUSED(args);
165 ARM_COMPUTE_UNUSED(memory_group);
166 ARM_COMPUTE_ERROR("Not supported for this architecture");
167}
168
169template <>
170void NEGEMMAssemblyDispatch<int8_t, int32_t>::Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<int32_t> &args, MemoryGroup &memory_group)
171{
172 // arm_gemm::gemm for 8bit only exists for aarch64
173 ARM_COMPUTE_UNUSED(a);
174 ARM_COMPUTE_UNUSED(b);
175 ARM_COMPUTE_UNUSED(d);
176 ARM_COMPUTE_UNUSED(args);
177 ARM_COMPUTE_UNUSED(memory_group);
178 ARM_COMPUTE_ERROR("Not supported for this architecture");
179}
Anthony Barbier71d9b572018-07-06 17:05:59 +0100180#endif // aarch64
181template <typename TypeInput, typename TypeOutput>
Anthony Barbierc8e84b52018-07-17 16:48:42 +0100182void NEGEMMAssemblyDispatch<TypeInput, TypeOutput>::Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> &args, MemoryGroup &memory_group)
Anthony Barbier71d9b572018-07-06 17:05:59 +0100183{
Anthony Barbierc8e84b52018-07-17 16:48:42 +0100184 _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput>(args, nullptr);
Anthony Barbier71d9b572018-07-06 17:05:59 +0100185 if(_gemm_kernel_asm == nullptr)
186 {
187 //configuration not supported: Leave function unconfigured:
188 return;
189 }
190
191 // arm_compute wrapper for the Gemm object (see above)
192 std::unique_ptr<NEGEMMAssemblyWrapperKernel<TypeInput, TypeOutput>> acl_gemm_wrapper = support::cpp14::make_unique<NEGEMMAssemblyWrapperKernel<TypeInput, TypeOutput>>();
193 ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr);
194 acl_gemm_wrapper->configure(_gemm_kernel_asm.get());
195 const size_t workspace_size = _gemm_kernel_asm->get_working_size();
196 if(workspace_size > 0)
197 {
198 // Allocate workspace
199 const unsigned int alignment = 4096;
200 //FIXME: is memory_group ever null ?
201 allocate_workspace(workspace_size, &memory_group, alignment);
202 }
203
204 //if we disable this code below in brackets then ConvLayer deadlocks when threads > 1 and
205 //the shapes are In=1x1x1024 Weights=1x1x1024x1001 Biases=1001 Out=1x1x1001
206 {
Anthony Barbierc8e84b52018-07-17 16:48:42 +0100207 const int window_size = _gemm_kernel_asm->get_window_size();
208 if(window_size < args._maxthreads)
Anthony Barbier71d9b572018-07-06 17:05:59 +0100209 {
Anthony Barbierc8e84b52018-07-17 16:48:42 +0100210 _gemm_kernel_asm->set_nthreads(window_size);
Anthony Barbier71d9b572018-07-06 17:05:59 +0100211 }
212 }
213
214 _optimised_kernel = std::move(acl_gemm_wrapper);
215 _a = a;
216 _b = b;
217 _d = d;
218 // Check for pre-transposed support
219 if(_gemm_kernel_asm->B_pretranspose_required())
220 {
221 // Forcing 128-byte alignment (required by 32-bit kernels)
222 const unsigned int alignment = 128;
223 const size_t B_pretranspose_size = _gemm_kernel_asm->get_B_pretransposed_array_size();
224 _pretranspose.allocator()->init(TensorInfo(TensorShape{ (B_pretranspose_size + alignment /* FIXME: remove alignment after COMPMID-1088 */) }, 1, DataType::S8), alignment);
225 _pretranspose.allocator()->allocate();
226 ARM_COMPUTE_ERROR_ON_NULLPTR(_pretranspose.buffer());
227 }
228}
229
230template <typename TypeInput, typename TypeOutput>
231void NEGEMMAssemblyDispatch<TypeInput, TypeOutput>::Fallback::prepare()
232{
233 if(!_is_prepared)
234 {
235 // Pretranspose B if required
236 if(_gemm_kernel_asm->B_pretranspose_required())
237 {
238 const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput);
239 const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer());
240 const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput);
241
242 ARM_COMPUTE_ERROR_ON(_pretranspose.buffer() == nullptr);
243 _gemm_kernel_asm->pretranspose_B_array(_pretranspose.buffer(), in1_ptr, ldb, multi_stride_b);
244 _b->mark_as_unused();
245 }
246
247 _is_prepared = true;
248 }
249}
250
251template <typename TypeInput, typename TypeOutput>
252void NEGEMMAssemblyDispatch<TypeInput, TypeOutput>::Fallback::allocate_workspace(size_t workspace_size, MemoryGroup *memory_group, size_t alignment)
253{
254 ARM_COMPUTE_ERROR_ON_MSG(workspace_size == 0, "size cannot be 0");
255 _workspace.allocator()->init(TensorInfo(TensorShape{ (workspace_size + alignment /* FIXME: remove alignment after COMPMID-1088 */) }, 1, DataType::S8), alignment);
256 if(memory_group != nullptr)
257 {
258 memory_group->manage(&_workspace);
259 }
260 _workspace.allocator()->allocate();
261}
262
263template <typename TypeInput, typename TypeOutput>
264bool NEGEMMAssemblyDispatch<TypeInput, TypeOutput>::Fallback::is_configured() const
265{
266 return _optimised_kernel != nullptr;
267}
268
269template <typename TypeInput, typename TypeOutput>
270void NEGEMMAssemblyDispatch<TypeInput, TypeOutput>::Fallback::run()
271{
272 const int lda = _a->info()->strides_in_bytes().y() / sizeof(TypeInput);
273 const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput);
274 const int ldd = _d->info()->strides_in_bytes().y() / sizeof(TypeOutput);
275
276 // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is
277 // the relevant multiple of the row stride.
278 const bool is_nhwc = _a->info()->data_layout() == DataLayout::NHWC;
279 const int stride_in_bytes_a = is_nhwc ? _a->info()->strides_in_bytes().y() * _d->info()->dimension(1) : _a->info()->strides_in_bytes().z();
280
281 const int batch_stride_a = stride_in_bytes_a / sizeof(TypeInput);
282 const int batch_stride_d = _d->info()->strides_in_bytes().z() / sizeof(TypeOutput);
283
284 const int multi_stride_a = _a->info()->strides_in_bytes()[3] / sizeof(TypeInput);
285 const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput);
286 const int multi_stride_d = _d->info()->strides_in_bytes()[3] / sizeof(TypeOutput);
287
288 const auto in0_ptr = reinterpret_cast<const TypeInput *>(_a->buffer());
289 const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer());
290 auto out_ptr = reinterpret_cast<TypeOutput *>(_d->buffer());
291
292 // Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads
293 if(_workspace.buffer() != nullptr)
294 {
295 _gemm_kernel_asm->set_working_space(reinterpret_cast<void *>(_workspace.buffer()));
296 const unsigned int window_size = _gemm_kernel_asm->get_window_size();
297 unsigned int num_threads = NEScheduler::get().num_threads();
298 if(window_size < num_threads)
299 {
300 num_threads = window_size;
301 _gemm_kernel_asm->set_nthreads(num_threads);
302 }
303 }
304
305 // Prepare assembly kernel
306 prepare();
307
308 // Set gemm parameters
309 _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, in1_ptr, ldb, multi_stride_b, out_ptr, ldd, batch_stride_d, multi_stride_d);
310
311 // Schedule assembly kernel
312 NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX);
313}
314
Anthony Barbier71d9b572018-07-06 17:05:59 +0100315template class NEGEMMAssemblyDispatch<float, float>;
316template class NEGEMMAssemblyDispatch<uint8_t, uint32_t>;
317template class NEGEMMAssemblyDispatch<int8_t, int32_t>;
318} //namespace arm_compute