blob: 470e9220ae8d72dda93fa7811abbd703b85d9161 [file] [log] [blame]
Anthony Barbier71d9b572018-07-06 17:05:59 +01001/*
Georgios Pinitas7cd26d42019-01-09 18:35:17 +00002 * Copyright (c) 2018-2019 ARM Limited.
Anthony Barbier71d9b572018-07-06 17:05:59 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h"
25
Anthony Barbiereaefd002018-07-20 17:49:35 +010026#include "arm_compute/core/CPP/Validate.h"
Anthony Barbierc8e84b52018-07-17 16:48:42 +010027#include "arm_compute/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.h"
Anthony Barbier71d9b572018-07-06 17:05:59 +010028#include "arm_compute/runtime/NEON/NEScheduler.h"
Anthony Barbierc8e84b52018-07-17 16:48:42 +010029#include "arm_compute/runtime/NEON/functions/NESimpleAssemblyFunction.h"
Anthony Barbier3d677cc2018-07-23 16:42:59 +010030#include "arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h"
Anthony Barbier71d9b572018-07-06 17:05:59 +010031
Anthony Barbiereaefd002018-07-20 17:49:35 +010032#include <arm_neon.h>
33
Anthony Barbierc8e84b52018-07-17 16:48:42 +010034namespace arm_compute
35{
Anthony Barbiereaefd002018-07-20 17:49:35 +010036namespace
Anthony Barbier71d9b572018-07-06 17:05:59 +010037{
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000038std::unique_ptr<IFunction> create_function_all_types(arm_gemm::KernelDescription gemm_kernel_info,
39 const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
Anthony Barbier3d677cc2018-07-23 16:42:59 +010040 std::shared_ptr<IMemoryManager> memory_manager)
41
Anthony Barbiereaefd002018-07-20 17:49:35 +010042{
Anthony Barbier3d677cc2018-07-23 16:42:59 +010043 //Note: It's safe to not check for FP16 support because this was already checked in NEGEMMAssemblyDispatch::configure()
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000044 switch(gemm_kernel_info.method)
Anthony Barbierc8e84b52018-07-17 16:48:42 +010045 {
Anthony Barbier3d677cc2018-07-23 16:42:59 +010046 case arm_gemm::GemmMethod::GEMM_INTERLEAVED:
47 {
48 if(!pretranspose_hint)
49 {
50 return nullptr;
51 }
52 auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager);
53 function->configure(a, b, d, alpha, beta, pretranspose_hint);
54 return std::move(function);
55 }
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000056#if defined(__aarch64__)
Anthony Barbierc8e84b52018-07-17 16:48:42 +010057 case arm_gemm::GemmMethod::GEMM_NATIVE:
58 {
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000059 if(gemm_kernel_info.name.find("sgemm_native_16x4") != std::string::npos)
60 {
61 auto kernel = support::cpp14::make_unique<NEGEMMNativeWrapperKernel<float, float>>();
62 kernel->configure(a, b, d, alpha, beta);
63 auto function = support::cpp14::make_unique<NESimpleAssemblyFunction>();
64 function->configure(std::move(kernel));
65 return std::move(function);
66 }
67 return nullptr;
Anthony Barbierc8e84b52018-07-17 16:48:42 +010068 }
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000069#endif // defined(__aarch64__)
Anthony Barbierc8e84b52018-07-17 16:48:42 +010070 default:
Anthony Barbiereaefd002018-07-20 17:49:35 +010071 return nullptr;
Anthony Barbierc8e84b52018-07-17 16:48:42 +010072 }
73}
74
Anthony Barbiereaefd002018-07-20 17:49:35 +010075/** Fallback in case ACL doesn't have a function */
Anthony Barbierc8e84b52018-07-17 16:48:42 +010076template <typename TypeInput, typename TypeOutput>
Anthony Barbiereaefd002018-07-20 17:49:35 +010077class Fallback : public NEGEMMAssemblyDispatch::IFallback
Anthony Barbierc8e84b52018-07-17 16:48:42 +010078{
Anthony Barbiereaefd002018-07-20 17:49:35 +010079public:
Georgios Pinitas3dbfd232019-01-30 17:17:16 +000080 /** Initialise the functions's input and output.
81 *
82 * @param[in] a Input tensor containing the Matrix A.
83 * @param[in] b Input tensor containing the Matrix B.
84 * @param[out] d Output tensor to store the result of matrix multiplication.
85 * @param[in] args Matrix multiplication information.
86 * @param[in] memory_group Memory group to be used by the function.
87 */
88 void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> args, MemoryGroup &memory_group);
89
90 // Inherited methods overridden:
Anthony Barbiereaefd002018-07-20 17:49:35 +010091 void run() override;
92 void prepare() override;
93 bool is_configured() const override;
Anthony Barbierc8e84b52018-07-17 16:48:42 +010094
Anthony Barbiereaefd002018-07-20 17:49:35 +010095private:
96 /** Allocate a workspace tensor.
97 *
98 * @param[in] workspace_size Size to allocate.
99 * @param[in] memory_group Tensor memory group.
100 * @param[in] alignment Workspace memory alignment.
101 */
Anthony Barbier20394d52018-08-02 11:29:09 +0100102 void allocate_workspace(size_t workspace_size, MemoryGroup &memory_group, size_t alignment);
Anthony Barbierc8e84b52018-07-17 16:48:42 +0100103
Anthony Barbiereaefd002018-07-20 17:49:35 +0100104 /** Assembly Gemm kernel */
105 std::unique_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{ nullptr };
106 /** Optimised NEON kernel */
107 std::unique_ptr<INEKernel> _optimised_kernel{ nullptr };
108 /** Input A */
109 const ITensor *_a
Anthony Barbierc8e84b52018-07-17 16:48:42 +0100110 {
Anthony Barbiereaefd002018-07-20 17:49:35 +0100111 nullptr
112 };
113 /** Input B */
114 const ITensor *_b
115 {
116 nullptr
117 };
118 /** Output */
119 ITensor *_d{ nullptr };
120 /** GEMM workspace */
121 Tensor _workspace{};
122 /** Pre-transpose tensor */
123 Tensor _pretranspose{};
124 /** Prepared flag */
125 bool _is_prepared{ false };
126};
Anthony Barbier71d9b572018-07-06 17:05:59 +0100127
128template <typename TypeInput, typename TypeOutput>
Georgios Pinitas3dbfd232019-01-30 17:17:16 +0000129void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> args, MemoryGroup &memory_group)
Anthony Barbier71d9b572018-07-06 17:05:59 +0100130{
Georgios Pinitas3dbfd232019-01-30 17:17:16 +0000131 arm_gemm::GemmConfig gemm_cfg;
132 const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args);
133 if(gemm_kernel_info.method != arm_gemm::GemmMethod::GEMV_BATCHED)
134 {
135 gemm_cfg.filter = gemm_kernel_info.name;
136 args._cfg = &gemm_cfg;
137 }
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000138 _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput>(args);
Anthony Barbier71d9b572018-07-06 17:05:59 +0100139 if(_gemm_kernel_asm == nullptr)
140 {
141 //configuration not supported: Leave function unconfigured:
142 return;
143 }
144
145 // arm_compute wrapper for the Gemm object (see above)
146 std::unique_ptr<NEGEMMAssemblyWrapperKernel<TypeInput, TypeOutput>> acl_gemm_wrapper = support::cpp14::make_unique<NEGEMMAssemblyWrapperKernel<TypeInput, TypeOutput>>();
147 ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr);
Georgios Pinitas3dbfd232019-01-30 17:17:16 +0000148 acl_gemm_wrapper->configure(_gemm_kernel_asm.get(), gemm_cfg.filter);
Anthony Barbier71d9b572018-07-06 17:05:59 +0100149 const size_t workspace_size = _gemm_kernel_asm->get_working_size();
150 if(workspace_size > 0)
151 {
152 // Allocate workspace
153 const unsigned int alignment = 4096;
Anthony Barbier20394d52018-08-02 11:29:09 +0100154 allocate_workspace(workspace_size, memory_group, alignment);
Anthony Barbier71d9b572018-07-06 17:05:59 +0100155 }
156
157 //if we disable this code below in brackets then ConvLayer deadlocks when threads > 1 and
158 //the shapes are In=1x1x1024 Weights=1x1x1024x1001 Biases=1001 Out=1x1x1001
159 {
Anthony Barbierc8e84b52018-07-17 16:48:42 +0100160 const int window_size = _gemm_kernel_asm->get_window_size();
161 if(window_size < args._maxthreads)
Anthony Barbier71d9b572018-07-06 17:05:59 +0100162 {
Anthony Barbierc8e84b52018-07-17 16:48:42 +0100163 _gemm_kernel_asm->set_nthreads(window_size);
Anthony Barbier71d9b572018-07-06 17:05:59 +0100164 }
165 }
166
167 _optimised_kernel = std::move(acl_gemm_wrapper);
168 _a = a;
169 _b = b;
170 _d = d;
171 // Check for pre-transposed support
172 if(_gemm_kernel_asm->B_pretranspose_required())
173 {
174 // Forcing 128-byte alignment (required by 32-bit kernels)
175 const unsigned int alignment = 128;
176 const size_t B_pretranspose_size = _gemm_kernel_asm->get_B_pretransposed_array_size();
177 _pretranspose.allocator()->init(TensorInfo(TensorShape{ (B_pretranspose_size + alignment /* FIXME: remove alignment after COMPMID-1088 */) }, 1, DataType::S8), alignment);
Anthony Barbier71d9b572018-07-06 17:05:59 +0100178 }
179}
180
181template <typename TypeInput, typename TypeOutput>
Anthony Barbiereaefd002018-07-20 17:49:35 +0100182void Fallback<TypeInput, TypeOutput>::prepare()
Anthony Barbier71d9b572018-07-06 17:05:59 +0100183{
184 if(!_is_prepared)
185 {
186 // Pretranspose B if required
187 if(_gemm_kernel_asm->B_pretranspose_required())
188 {
Georgios Pinitasca1250d2018-11-22 19:38:27 +0000189 _pretranspose.allocator()->allocate();
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100190 ARM_COMPUTE_ERROR_ON(_pretranspose.buffer() == nullptr);
Anthony Barbier71d9b572018-07-06 17:05:59 +0100191 const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput);
Georgios Pinitaseb84d6b2018-07-27 18:28:10 +0100192 const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer() + _b->info()->offset_first_element_in_bytes());
Anthony Barbier71d9b572018-07-06 17:05:59 +0100193 const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput);
194
Anthony Barbier71d9b572018-07-06 17:05:59 +0100195 _gemm_kernel_asm->pretranspose_B_array(_pretranspose.buffer(), in1_ptr, ldb, multi_stride_b);
196 _b->mark_as_unused();
197 }
198
199 _is_prepared = true;
200 }
201}
202
203template <typename TypeInput, typename TypeOutput>
Anthony Barbier20394d52018-08-02 11:29:09 +0100204void Fallback<TypeInput, TypeOutput>::allocate_workspace(size_t workspace_size, MemoryGroup &memory_group, size_t alignment)
Anthony Barbier71d9b572018-07-06 17:05:59 +0100205{
206 ARM_COMPUTE_ERROR_ON_MSG(workspace_size == 0, "size cannot be 0");
207 _workspace.allocator()->init(TensorInfo(TensorShape{ (workspace_size + alignment /* FIXME: remove alignment after COMPMID-1088 */) }, 1, DataType::S8), alignment);
Anthony Barbier20394d52018-08-02 11:29:09 +0100208 memory_group.manage(&_workspace);
Anthony Barbier71d9b572018-07-06 17:05:59 +0100209 _workspace.allocator()->allocate();
210}
211
212template <typename TypeInput, typename TypeOutput>
Anthony Barbiereaefd002018-07-20 17:49:35 +0100213bool Fallback<TypeInput, TypeOutput>::is_configured() const
Anthony Barbier71d9b572018-07-06 17:05:59 +0100214{
215 return _optimised_kernel != nullptr;
216}
217
218template <typename TypeInput, typename TypeOutput>
Anthony Barbiereaefd002018-07-20 17:49:35 +0100219void Fallback<TypeInput, TypeOutput>::run()
Anthony Barbier71d9b572018-07-06 17:05:59 +0100220{
221 const int lda = _a->info()->strides_in_bytes().y() / sizeof(TypeInput);
Georgios Pinitas40ed6d82018-07-31 17:22:11 +0100222 int ldb = 0;
Anthony Barbier71d9b572018-07-06 17:05:59 +0100223 const int ldd = _d->info()->strides_in_bytes().y() / sizeof(TypeOutput);
224
225 // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is
226 // the relevant multiple of the row stride.
227 const bool is_nhwc = _a->info()->data_layout() == DataLayout::NHWC;
228 const int stride_in_bytes_a = is_nhwc ? _a->info()->strides_in_bytes().y() * _d->info()->dimension(1) : _a->info()->strides_in_bytes().z();
229
230 const int batch_stride_a = stride_in_bytes_a / sizeof(TypeInput);
231 const int batch_stride_d = _d->info()->strides_in_bytes().z() / sizeof(TypeOutput);
232
233 const int multi_stride_a = _a->info()->strides_in_bytes()[3] / sizeof(TypeInput);
Georgios Pinitas40ed6d82018-07-31 17:22:11 +0100234 int multi_stride_b = 0;
Anthony Barbier71d9b572018-07-06 17:05:59 +0100235 const int multi_stride_d = _d->info()->strides_in_bytes()[3] / sizeof(TypeOutput);
236
Georgios Pinitas40ed6d82018-07-31 17:22:11 +0100237 const auto in0_ptr = reinterpret_cast<const TypeInput *>(_a->buffer() + _a->info()->offset_first_element_in_bytes());
238 const TypeInput *in1_ptr = nullptr;
239 auto out_ptr = reinterpret_cast<TypeOutput *>(_d->buffer() + _d->info()->offset_first_element_in_bytes());
240
241 // Check if B is pre-tranposed and de-reference if not
242 if(!_gemm_kernel_asm->B_is_pretransposed())
243 {
244 ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput);
245 multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput);
246 in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer() + _b->info()->offset_first_element_in_bytes());
247 }
Anthony Barbier71d9b572018-07-06 17:05:59 +0100248
249 // Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads
250 if(_workspace.buffer() != nullptr)
251 {
252 _gemm_kernel_asm->set_working_space(reinterpret_cast<void *>(_workspace.buffer()));
253 const unsigned int window_size = _gemm_kernel_asm->get_window_size();
254 unsigned int num_threads = NEScheduler::get().num_threads();
255 if(window_size < num_threads)
256 {
257 num_threads = window_size;
258 _gemm_kernel_asm->set_nthreads(num_threads);
259 }
260 }
261
262 // Prepare assembly kernel
263 prepare();
264
265 // Set gemm parameters
266 _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, in1_ptr, ldb, multi_stride_b, out_ptr, ldd, batch_stride_d, multi_stride_d);
267
268 // Schedule assembly kernel
269 NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX);
270}
271
Anthony Barbiereaefd002018-07-20 17:49:35 +0100272template <typename TypeInput, typename TypeOutput>
273void create_function_or_arm_gemm(std::unique_ptr<IFunction> &acl_function, std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group, const ITensor *a, const ITensor *b,
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100274 ITensor *d, float alpha, float beta, bool pretranspose_hint, std::shared_ptr<IMemoryManager> memory_manager)
Anthony Barbiereaefd002018-07-20 17:49:35 +0100275{
276 INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d);
277 const CPUInfo &ci = NEScheduler::get().cpu_info();
278 unsigned int num_threads = NEScheduler::get().num_threads();
279
280 arm_gemm::GemmArgs<TypeOutput> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint);
281
282 //Try to create an ACL function:
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000283 acl_function = create_function_all_types(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint, std::move(memory_manager));
284
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100285 //If we still don't have an ACL function:
Anthony Barbiereaefd002018-07-20 17:49:35 +0100286 if(acl_function == nullptr)
287 {
288 //Fallback onto arm_gemm function if ACL doesn't support this method.
289 auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput>>();
290 fallback->configure(a, b, d, args, memory_group);
291 arm_gemm = std::move(fallback);
292 }
293}
294
295} //namespace
296
297NEGEMMAssemblyDispatch::NEGEMMAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager)
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100298 : _function(nullptr), _arm_gemm(nullptr), _memory_group(memory_manager), _memory_manager(memory_manager)
Anthony Barbiereaefd002018-07-20 17:49:35 +0100299{
300}
301
302Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, float alpha, float beta, bool pretranspose_hint)
303{
304 ARM_COMPUTE_UNUSED(alpha);
305 ARM_COMPUTE_UNUSED(beta);
306 ARM_COMPUTE_UNUSED(pretranspose_hint);
307 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(a, b, d);
308 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
309#ifndef __aarch64__
310 ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 || a->data_type() == DataType::S8 || a->data_type() == DataType::QASYMM8, "8bit integer types only supported for aarch64");
311#endif /* __aarch64__ */
312 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::F32, DataType::U8, DataType::QASYMM8, DataType::S8, DataType::F16);
313 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
314 ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F32 && d->data_type() != DataType::F32, "Only F32 output supported for F32 input");
315 ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F16 && d->data_type() != DataType::F16, "Only F16 output supported for F16 input");
Anthony Barbier90367492018-08-01 13:56:08 +0100316 ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input");
317 ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && d->data_type() != DataType::S32 && d->data_type() != DataType::U32, "Only U32/S32 output supported for QASYMM8 input");
Anthony Barbiereaefd002018-07-20 17:49:35 +0100318 ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input");
319 return Status{};
320}
321
322void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
323{
324 ARM_COMPUTE_ERROR_ON_NULLPTR(a);
325 ARM_COMPUTE_ERROR_ON_NULLPTR(b);
326 ARM_COMPUTE_ERROR_ON_NULLPTR(d);
327
328 //If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured()
329 if(!NEGEMMAssemblyDispatch::validate(a->info(), b->info(), d->info(), alpha, beta, pretranspose_hint))
330 {
331 return;
332 }
333
334 switch(a->info()->data_type())
335 {
336 case DataType::F32:
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100337 create_function_or_arm_gemm<float, float>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
Anthony Barbiereaefd002018-07-20 17:49:35 +0100338 break;
339#ifdef __aarch64__
340 case DataType::U8:
341 case DataType::QASYMM8:
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100342 create_function_or_arm_gemm<uint8_t, uint32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
Anthony Barbiereaefd002018-07-20 17:49:35 +0100343 break;
344 case DataType::S8:
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100345 create_function_or_arm_gemm<int8_t, int32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
Anthony Barbiereaefd002018-07-20 17:49:35 +0100346 break;
347#endif /* __aarch64__ */
348#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
349 case DataType::F16:
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100350 create_function_or_arm_gemm<float16_t, float16_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
Anthony Barbiereaefd002018-07-20 17:49:35 +0100351 break;
352#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
353 default:
354 break;
355 }
356}
357
358void NEGEMMAssemblyDispatch::prepare()
359{
360 if(_function != nullptr)
361 {
362 _function->prepare();
363 }
364 else
365 {
366 ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
367 _arm_gemm->prepare();
368 }
369}
370
371bool NEGEMMAssemblyDispatch::is_configured() const
372{
373 return (_arm_gemm != nullptr && _arm_gemm->is_configured()) || _function != nullptr;
374}
375
376void NEGEMMAssemblyDispatch::run()
377{
378 _memory_group.acquire();
379 if(_function != nullptr)
380 {
381 _function->run();
382 }
383 else
384 {
385 ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
386 _arm_gemm->run();
387 }
388 _memory_group.release();
389}
Anthony Barbier71d9b572018-07-06 17:05:59 +0100390} //namespace arm_compute