blob: 39fee1bfa5abcc7086ad6460e032790f32824736 [file] [log] [blame]
Anthony Barbier71d9b572018-07-06 17:05:59 +01001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h"
25
Anthony Barbiereaefd002018-07-20 17:49:35 +010026#include "arm_compute/core/CPP/Validate.h"
Anthony Barbier3d677cc2018-07-23 16:42:59 +010027#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h"
28#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.h"
29#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h"
Anthony Barbierc8e84b52018-07-17 16:48:42 +010030#include "arm_compute/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.h"
Anthony Barbier71d9b572018-07-06 17:05:59 +010031#include "arm_compute/runtime/NEON/NEScheduler.h"
Anthony Barbierc8e84b52018-07-17 16:48:42 +010032#include "arm_compute/runtime/NEON/functions/NESimpleAssemblyFunction.h"
Anthony Barbier3d677cc2018-07-23 16:42:59 +010033#include "arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h"
Anthony Barbier71d9b572018-07-06 17:05:59 +010034
Anthony Barbiereaefd002018-07-20 17:49:35 +010035#include <arm_neon.h>
36
Anthony Barbierc8e84b52018-07-17 16:48:42 +010037namespace arm_compute
38{
Anthony Barbiereaefd002018-07-20 17:49:35 +010039namespace
Anthony Barbier71d9b572018-07-06 17:05:59 +010040{
Anthony Barbier3d677cc2018-07-23 16:42:59 +010041std::unique_ptr<IFunction> create_function_all_types(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
42 std::shared_ptr<IMemoryManager> memory_manager)
43
Anthony Barbiereaefd002018-07-20 17:49:35 +010044{
Anthony Barbier3d677cc2018-07-23 16:42:59 +010045 //Note: It's safe to not check for FP16 support because this was already checked in NEGEMMAssemblyDispatch::configure()
Anthony Barbierc8e84b52018-07-17 16:48:42 +010046 switch(method)
47 {
Anthony Barbier3d677cc2018-07-23 16:42:59 +010048 case arm_gemm::GemmMethod::GEMM_INTERLEAVED:
49 {
50 if(!pretranspose_hint)
51 {
52 return nullptr;
53 }
54 auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager);
55 function->configure(a, b, d, alpha, beta, pretranspose_hint);
56 return std::move(function);
57 }
58 default:
59 return nullptr;
60 }
61}
62
63template <typename TypeInput, typename TypeOutput>
64std::unique_ptr<IFunction> create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
65 std::shared_ptr<IMemoryManager> memory_manager)
66{
67 ARM_COMPUTE_UNUSED(method);
68 ARM_COMPUTE_UNUSED(a);
69 ARM_COMPUTE_UNUSED(b);
70 ARM_COMPUTE_UNUSED(d);
71 ARM_COMPUTE_UNUSED(alpha);
72 ARM_COMPUTE_UNUSED(beta);
73 ARM_COMPUTE_UNUSED(pretranspose_hint);
74 ARM_COMPUTE_UNUSED(memory_manager);
75 return nullptr;
76}
77
Anthony Barbierc8e84b52018-07-17 16:48:42 +010078#ifdef __aarch64__
Anthony Barbier3d677cc2018-07-23 16:42:59 +010079template <>
80std::unique_ptr<IFunction> create_function<int8_t, int32_t>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
81 std::shared_ptr<IMemoryManager> memory_manager)
82{
83 switch(method)
84 {
85 case arm_gemm::GemmMethod::GEMM_INTERLEAVED_DOT:
86 {
87 if(!pretranspose_hint)
88 {
89 return nullptr;
90 }
91 auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager);
92 function->configure(a, b, d, alpha, beta, pretranspose_hint, true /* use_dot */);
93 return std::move(function);
94 }
95 default:
96 return nullptr;
97 }
98 return nullptr;
99}
100
101template <>
102std::unique_ptr<IFunction> create_function<uint8_t, uint32_t>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
103 std::shared_ptr<IMemoryManager> memory_manager)
104{
105 switch(method)
106 {
107 case arm_gemm::GemmMethod::GEMM_INTERLEAVED_DOT:
108 {
109 if(!pretranspose_hint)
110 {
111 return nullptr;
112 }
113 auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager);
114 function->configure(a, b, d, alpha, beta, pretranspose_hint, true /* use_dot */);
115 return std::move(function);
116 }
117 default:
118 return nullptr;
119 }
120 return nullptr;
121}
122
123template <>
124std::unique_ptr<IFunction> create_function<float, float>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
125 std::shared_ptr<IMemoryManager> memory_manager)
126{
127 ARM_COMPUTE_UNUSED(pretranspose_hint);
128 ARM_COMPUTE_UNUSED(memory_manager);
129 switch(method)
130 {
Anthony Barbierc8e84b52018-07-17 16:48:42 +0100131 case arm_gemm::GemmMethod::GEMM_NATIVE:
132 {
133 auto kernel = support::cpp14::make_unique<NEGEMMNativeWrapperKernel<float, float>>();
134 kernel->configure(a, b, d, alpha, beta);
135 auto function = support::cpp14::make_unique<NESimpleAssemblyFunction>();
136 function->configure(std::move(kernel));
Anthony Barbiereaefd002018-07-20 17:49:35 +0100137 return std::move(function);
Anthony Barbierc8e84b52018-07-17 16:48:42 +0100138 }
Anthony Barbierc8e84b52018-07-17 16:48:42 +0100139 default:
Anthony Barbiereaefd002018-07-20 17:49:35 +0100140 return nullptr;
Anthony Barbierc8e84b52018-07-17 16:48:42 +0100141 }
142}
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100143#endif /* __aarch64__ */
Anthony Barbierc8e84b52018-07-17 16:48:42 +0100144
Anthony Barbiereaefd002018-07-20 17:49:35 +0100145/** Fallback in case ACL doesn't have a function */
Anthony Barbierc8e84b52018-07-17 16:48:42 +0100146template <typename TypeInput, typename TypeOutput>
Anthony Barbiereaefd002018-07-20 17:49:35 +0100147class Fallback : public NEGEMMAssemblyDispatch::IFallback
Anthony Barbierc8e84b52018-07-17 16:48:42 +0100148{
Anthony Barbiereaefd002018-07-20 17:49:35 +0100149public:
150 void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> &args, MemoryGroup &memory_group);
151 void run() override;
152 void prepare() override;
153 bool is_configured() const override;
Anthony Barbierc8e84b52018-07-17 16:48:42 +0100154
Anthony Barbiereaefd002018-07-20 17:49:35 +0100155private:
156 /** Allocate a workspace tensor.
157 *
158 * @param[in] workspace_size Size to allocate.
159 * @param[in] memory_group Tensor memory group.
160 * @param[in] alignment Workspace memory alignment.
161 */
Anthony Barbier20394d52018-08-02 11:29:09 +0100162 void allocate_workspace(size_t workspace_size, MemoryGroup &memory_group, size_t alignment);
Anthony Barbierc8e84b52018-07-17 16:48:42 +0100163
Anthony Barbiereaefd002018-07-20 17:49:35 +0100164 /** Assembly Gemm kernel */
165 std::unique_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{ nullptr };
166 /** Optimised NEON kernel */
167 std::unique_ptr<INEKernel> _optimised_kernel{ nullptr };
168 /** Input A */
169 const ITensor *_a
Anthony Barbierc8e84b52018-07-17 16:48:42 +0100170 {
Anthony Barbiereaefd002018-07-20 17:49:35 +0100171 nullptr
172 };
173 /** Input B */
174 const ITensor *_b
175 {
176 nullptr
177 };
178 /** Output */
179 ITensor *_d{ nullptr };
180 /** GEMM workspace */
181 Tensor _workspace{};
182 /** Pre-transpose tensor */
183 Tensor _pretranspose{};
184 /** Prepared flag */
185 bool _is_prepared{ false };
186};
Anthony Barbier71d9b572018-07-06 17:05:59 +0100187
188template <typename TypeInput, typename TypeOutput>
Anthony Barbiereaefd002018-07-20 17:49:35 +0100189void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> &args, MemoryGroup &memory_group)
Anthony Barbier71d9b572018-07-06 17:05:59 +0100190{
Anthony Barbierc8e84b52018-07-17 16:48:42 +0100191 _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput>(args, nullptr);
Anthony Barbier71d9b572018-07-06 17:05:59 +0100192 if(_gemm_kernel_asm == nullptr)
193 {
194 //configuration not supported: Leave function unconfigured:
195 return;
196 }
197
198 // arm_compute wrapper for the Gemm object (see above)
199 std::unique_ptr<NEGEMMAssemblyWrapperKernel<TypeInput, TypeOutput>> acl_gemm_wrapper = support::cpp14::make_unique<NEGEMMAssemblyWrapperKernel<TypeInput, TypeOutput>>();
200 ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr);
201 acl_gemm_wrapper->configure(_gemm_kernel_asm.get());
202 const size_t workspace_size = _gemm_kernel_asm->get_working_size();
203 if(workspace_size > 0)
204 {
205 // Allocate workspace
206 const unsigned int alignment = 4096;
Anthony Barbier20394d52018-08-02 11:29:09 +0100207 allocate_workspace(workspace_size, memory_group, alignment);
Anthony Barbier71d9b572018-07-06 17:05:59 +0100208 }
209
210 //if we disable this code below in brackets then ConvLayer deadlocks when threads > 1 and
211 //the shapes are In=1x1x1024 Weights=1x1x1024x1001 Biases=1001 Out=1x1x1001
212 {
Anthony Barbierc8e84b52018-07-17 16:48:42 +0100213 const int window_size = _gemm_kernel_asm->get_window_size();
214 if(window_size < args._maxthreads)
Anthony Barbier71d9b572018-07-06 17:05:59 +0100215 {
Anthony Barbierc8e84b52018-07-17 16:48:42 +0100216 _gemm_kernel_asm->set_nthreads(window_size);
Anthony Barbier71d9b572018-07-06 17:05:59 +0100217 }
218 }
219
220 _optimised_kernel = std::move(acl_gemm_wrapper);
221 _a = a;
222 _b = b;
223 _d = d;
224 // Check for pre-transposed support
225 if(_gemm_kernel_asm->B_pretranspose_required())
226 {
227 // Forcing 128-byte alignment (required by 32-bit kernels)
228 const unsigned int alignment = 128;
229 const size_t B_pretranspose_size = _gemm_kernel_asm->get_B_pretransposed_array_size();
230 _pretranspose.allocator()->init(TensorInfo(TensorShape{ (B_pretranspose_size + alignment /* FIXME: remove alignment after COMPMID-1088 */) }, 1, DataType::S8), alignment);
231 _pretranspose.allocator()->allocate();
232 ARM_COMPUTE_ERROR_ON_NULLPTR(_pretranspose.buffer());
233 }
234}
235
236template <typename TypeInput, typename TypeOutput>
Anthony Barbiereaefd002018-07-20 17:49:35 +0100237void Fallback<TypeInput, TypeOutput>::prepare()
Anthony Barbier71d9b572018-07-06 17:05:59 +0100238{
239 if(!_is_prepared)
240 {
241 // Pretranspose B if required
242 if(_gemm_kernel_asm->B_pretranspose_required())
243 {
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100244 ARM_COMPUTE_ERROR_ON(_pretranspose.buffer() == nullptr);
Anthony Barbier71d9b572018-07-06 17:05:59 +0100245 const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput);
Georgios Pinitaseb84d6b2018-07-27 18:28:10 +0100246 const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer() + _b->info()->offset_first_element_in_bytes());
Anthony Barbier71d9b572018-07-06 17:05:59 +0100247 const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput);
248
Anthony Barbier71d9b572018-07-06 17:05:59 +0100249 _gemm_kernel_asm->pretranspose_B_array(_pretranspose.buffer(), in1_ptr, ldb, multi_stride_b);
250 _b->mark_as_unused();
251 }
252
253 _is_prepared = true;
254 }
255}
256
257template <typename TypeInput, typename TypeOutput>
Anthony Barbier20394d52018-08-02 11:29:09 +0100258void Fallback<TypeInput, TypeOutput>::allocate_workspace(size_t workspace_size, MemoryGroup &memory_group, size_t alignment)
Anthony Barbier71d9b572018-07-06 17:05:59 +0100259{
260 ARM_COMPUTE_ERROR_ON_MSG(workspace_size == 0, "size cannot be 0");
261 _workspace.allocator()->init(TensorInfo(TensorShape{ (workspace_size + alignment /* FIXME: remove alignment after COMPMID-1088 */) }, 1, DataType::S8), alignment);
Anthony Barbier20394d52018-08-02 11:29:09 +0100262 memory_group.manage(&_workspace);
Anthony Barbier71d9b572018-07-06 17:05:59 +0100263 _workspace.allocator()->allocate();
264}
265
266template <typename TypeInput, typename TypeOutput>
Anthony Barbiereaefd002018-07-20 17:49:35 +0100267bool Fallback<TypeInput, TypeOutput>::is_configured() const
Anthony Barbier71d9b572018-07-06 17:05:59 +0100268{
269 return _optimised_kernel != nullptr;
270}
271
272template <typename TypeInput, typename TypeOutput>
Anthony Barbiereaefd002018-07-20 17:49:35 +0100273void Fallback<TypeInput, TypeOutput>::run()
Anthony Barbier71d9b572018-07-06 17:05:59 +0100274{
275 const int lda = _a->info()->strides_in_bytes().y() / sizeof(TypeInput);
Georgios Pinitas40ed6d82018-07-31 17:22:11 +0100276 int ldb = 0;
Anthony Barbier71d9b572018-07-06 17:05:59 +0100277 const int ldd = _d->info()->strides_in_bytes().y() / sizeof(TypeOutput);
278
279 // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is
280 // the relevant multiple of the row stride.
281 const bool is_nhwc = _a->info()->data_layout() == DataLayout::NHWC;
282 const int stride_in_bytes_a = is_nhwc ? _a->info()->strides_in_bytes().y() * _d->info()->dimension(1) : _a->info()->strides_in_bytes().z();
283
284 const int batch_stride_a = stride_in_bytes_a / sizeof(TypeInput);
285 const int batch_stride_d = _d->info()->strides_in_bytes().z() / sizeof(TypeOutput);
286
287 const int multi_stride_a = _a->info()->strides_in_bytes()[3] / sizeof(TypeInput);
Georgios Pinitas40ed6d82018-07-31 17:22:11 +0100288 int multi_stride_b = 0;
Anthony Barbier71d9b572018-07-06 17:05:59 +0100289 const int multi_stride_d = _d->info()->strides_in_bytes()[3] / sizeof(TypeOutput);
290
Georgios Pinitas40ed6d82018-07-31 17:22:11 +0100291 const auto in0_ptr = reinterpret_cast<const TypeInput *>(_a->buffer() + _a->info()->offset_first_element_in_bytes());
292 const TypeInput *in1_ptr = nullptr;
293 auto out_ptr = reinterpret_cast<TypeOutput *>(_d->buffer() + _d->info()->offset_first_element_in_bytes());
294
295 // Check if B is pre-tranposed and de-reference if not
296 if(!_gemm_kernel_asm->B_is_pretransposed())
297 {
298 ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput);
299 multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput);
300 in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer() + _b->info()->offset_first_element_in_bytes());
301 }
Anthony Barbier71d9b572018-07-06 17:05:59 +0100302
303 // Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads
304 if(_workspace.buffer() != nullptr)
305 {
306 _gemm_kernel_asm->set_working_space(reinterpret_cast<void *>(_workspace.buffer()));
307 const unsigned int window_size = _gemm_kernel_asm->get_window_size();
308 unsigned int num_threads = NEScheduler::get().num_threads();
309 if(window_size < num_threads)
310 {
311 num_threads = window_size;
312 _gemm_kernel_asm->set_nthreads(num_threads);
313 }
314 }
315
316 // Prepare assembly kernel
317 prepare();
318
319 // Set gemm parameters
320 _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, in1_ptr, ldb, multi_stride_b, out_ptr, ldd, batch_stride_d, multi_stride_d);
321
322 // Schedule assembly kernel
323 NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX);
324}
325
Anthony Barbiereaefd002018-07-20 17:49:35 +0100326template <typename TypeInput, typename TypeOutput>
327void create_function_or_arm_gemm(std::unique_ptr<IFunction> &acl_function, std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group, const ITensor *a, const ITensor *b,
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100328 ITensor *d, float alpha, float beta, bool pretranspose_hint, std::shared_ptr<IMemoryManager> memory_manager)
Anthony Barbiereaefd002018-07-20 17:49:35 +0100329{
330 INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d);
331 const CPUInfo &ci = NEScheduler::get().cpu_info();
332 unsigned int num_threads = NEScheduler::get().num_threads();
333
334 arm_gemm::GemmArgs<TypeOutput> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint);
335
336 //Try to create an ACL function:
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100337 acl_function = create_function_all_types(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint, memory_manager);
338 // If the type agnostic factory failed to create an ACL function, try the specialised one:
339 if(acl_function == nullptr)
340 {
341 acl_function = create_function<TypeInput, TypeOutput>(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint, memory_manager);
342 }
343 //If we still don't have an ACL function:
Anthony Barbiereaefd002018-07-20 17:49:35 +0100344 if(acl_function == nullptr)
345 {
346 //Fallback onto arm_gemm function if ACL doesn't support this method.
347 auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput>>();
348 fallback->configure(a, b, d, args, memory_group);
349 arm_gemm = std::move(fallback);
350 }
351}
352
353} //namespace
354
355NEGEMMAssemblyDispatch::NEGEMMAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager)
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100356 : _function(nullptr), _arm_gemm(nullptr), _memory_group(memory_manager), _memory_manager(memory_manager)
Anthony Barbiereaefd002018-07-20 17:49:35 +0100357{
358}
359
360Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, float alpha, float beta, bool pretranspose_hint)
361{
362 ARM_COMPUTE_UNUSED(alpha);
363 ARM_COMPUTE_UNUSED(beta);
364 ARM_COMPUTE_UNUSED(pretranspose_hint);
365 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(a, b, d);
366 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
367#ifndef __aarch64__
368 ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 || a->data_type() == DataType::S8 || a->data_type() == DataType::QASYMM8, "8bit integer types only supported for aarch64");
369#endif /* __aarch64__ */
370 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::F32, DataType::U8, DataType::QASYMM8, DataType::S8, DataType::F16);
371 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
372 ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F32 && d->data_type() != DataType::F32, "Only F32 output supported for F32 input");
373 ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F16 && d->data_type() != DataType::F16, "Only F16 output supported for F16 input");
Anthony Barbier90367492018-08-01 13:56:08 +0100374 ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input");
375 ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && d->data_type() != DataType::S32 && d->data_type() != DataType::U32, "Only U32/S32 output supported for QASYMM8 input");
Anthony Barbiereaefd002018-07-20 17:49:35 +0100376 ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input");
377 return Status{};
378}
379
380void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
381{
382 ARM_COMPUTE_ERROR_ON_NULLPTR(a);
383 ARM_COMPUTE_ERROR_ON_NULLPTR(b);
384 ARM_COMPUTE_ERROR_ON_NULLPTR(d);
385
386 //If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured()
387 if(!NEGEMMAssemblyDispatch::validate(a->info(), b->info(), d->info(), alpha, beta, pretranspose_hint))
388 {
389 return;
390 }
391
392 switch(a->info()->data_type())
393 {
394 case DataType::F32:
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100395 create_function_or_arm_gemm<float, float>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
Anthony Barbiereaefd002018-07-20 17:49:35 +0100396 break;
397#ifdef __aarch64__
398 case DataType::U8:
399 case DataType::QASYMM8:
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100400 create_function_or_arm_gemm<uint8_t, uint32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
Anthony Barbiereaefd002018-07-20 17:49:35 +0100401 break;
402 case DataType::S8:
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100403 create_function_or_arm_gemm<int8_t, int32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
Anthony Barbiereaefd002018-07-20 17:49:35 +0100404 break;
405#endif /* __aarch64__ */
406#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
407 case DataType::F16:
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100408 create_function_or_arm_gemm<float16_t, float16_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
Anthony Barbiereaefd002018-07-20 17:49:35 +0100409 break;
410#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
411 default:
412 break;
413 }
414}
415
416void NEGEMMAssemblyDispatch::prepare()
417{
418 if(_function != nullptr)
419 {
420 _function->prepare();
421 }
422 else
423 {
424 ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
425 _arm_gemm->prepare();
426 }
427}
428
429bool NEGEMMAssemblyDispatch::is_configured() const
430{
431 return (_arm_gemm != nullptr && _arm_gemm->is_configured()) || _function != nullptr;
432}
433
434void NEGEMMAssemblyDispatch::run()
435{
436 _memory_group.acquire();
437 if(_function != nullptr)
438 {
439 _function->run();
440 }
441 else
442 {
443 ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
444 _arm_gemm->run();
445 }
446 _memory_group.release();
447}
Anthony Barbier71d9b572018-07-06 17:05:59 +0100448} //namespace arm_compute