blob: fe998a0e42f86e21131b81a89caf04471260b7e6 [file] [log] [blame]
Anthony Barbier3d677cc2018-07-23 16:42:59 +01001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#include "arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h"
26
27#include "arm_compute/core/ITensor.h"
Anthony Barbierac314c22018-09-11 17:49:10 +010028#include "arm_compute/core/NEON/kernels/assembly/Helpers.h"
Anthony Barbier3d677cc2018-07-23 16:42:59 +010029#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h"
30#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.h"
31#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h"
32#include "arm_compute/core/Utils.h"
33#include "arm_compute/runtime/NEON/NEScheduler.h"
34
Anthony Barbierff0bccf2018-11-30 10:42:40 +000035#include <atomic>
36#include <condition_variable>
37#include <mutex>
38
Anthony Barbier3d677cc2018-07-23 16:42:59 +010039namespace arm_compute
40{
Anthony Barbierff0bccf2018-11-30 10:42:40 +000041#ifndef NO_MULTI_THREADING
42class BufferManagerMultipleThreads final : public IBufferManager
43{
44public:
45 /** Number of buffers to ping pong between */
46 static constexpr unsigned int NUM_BUFFERS = 3;
47
48 explicit BufferManagerMultipleThreads(unsigned int max_num_users)
49 : _max_num_users(max_num_users)
50 {
51 }
52 unsigned int num_buffers() const override
53 {
54 return NUM_BUFFERS;
55 }
56 /* - Lock the requested index if it's free and return true if it needs reshaping.
57 * - Return false without acquiring the lock if the buffer at the index is already reshaped / being reshaped.
58 * - Block if the corresponding buffer for the given index is still being used by a different index.
59 */
60 bool lock_to_reshape_if_needed(unsigned int index) override
61 {
62 Buffer &buf = get_buffer_from_index(index);
63 while(true)
64 {
65 if(buf.index == index && buf.state != State::FREE)
66 {
67 //Another thread already is reshaping / has reshaped this block: nothing to do
68 return false;
69 }
70 else
71 {
72 std::unique_lock<std::mutex> lock(buf.mutex);
73 //If the buffer is free then lock it for reshaping:
74 if(buf.state == State::FREE)
75 {
76 buf.index = index;
77 buf.state = State::BEING_RESHAPED;
78 return true;
79 }
80 // Check again just in case it changed while we were acquiring the lock:
81 if(buf.index == index)
82 {
83 //Another thread is reshaping this block already, nothing to do
84 return false;
85 }
86 // buf.index != index: Buffer still being used by another block, need to wait
87 buf.sem.wait(lock);
88 }
89 }
90 }
91 /* Mark the buffer at the given index as reshaped and release the lock acquired via lock_to_reshape_if_needed() */
92 void mark_as_reshaped(unsigned int index) override
93 {
94 Buffer &buf = get_buffer_from_index(index);
95 {
96 std::lock_guard<std::mutex> lock(buf.mutex);
97 buf.users = _max_num_users;
98 buf.state = State::IN_USE;
99 }
100 buf.sem.notify_all();
101 }
102
103 /* Block until the buffer at the given index is reshaped */
104 void wait_for_reshaping(unsigned int index) override
105 {
106 Buffer &buf = get_buffer_from_index(index);
107 ARM_COMPUTE_ERROR_ON(buf.index != index); // Should have blocked in lock_to_reshape_if_needed()
108 // Check if it's already ready to use:
109 if(buf.state == State::IN_USE)
110 return;
111 std::unique_lock<std::mutex> lock(buf.mutex);
112 //Double check it didn't change while we were acquiring the lock:
113 if(buf.state == State::IN_USE)
114 return;
115 buf.sem.wait(lock);
116 }
117 /* Mark the buffer at the given index as not used by this thread anymore.
118 * Once all the threads have called this method then the buffer is marked as free again.
119 */
120 void mark_as_unused(unsigned int index) override
121 {
122 Buffer &buf = get_buffer_from_index(index);
123 ARM_COMPUTE_ERROR_ON(buf.index != index); // Should have blocked in lock_to_reshape_if_needed()
124 if(--buf.users == 0)
125 {
126 std::unique_lock<std::mutex> lock(buf.mutex);
127 buf.state = State::FREE;
128 lock.unlock();
129 buf.sem.notify_all();
130 }
131 }
132
133private:
134 enum class State
135 {
136 FREE,
137 BEING_RESHAPED,
138 IN_USE
139 };
140 struct Buffer
141 {
142 unsigned int index{};
143 std::atomic_uint users{};
144 State state{ State::FREE };
145 std::mutex mutex{};
146 std::condition_variable sem{};
147 } _buffers[NUM_BUFFERS];
148 Buffer &get_buffer_from_index(unsigned int index)
149 {
150 return _buffers[index % NUM_BUFFERS];
151 }
152 unsigned int _max_num_users;
153};
154#endif /* NO_MULTI_THREADING */
155
156class BufferManagerSingleThread : public IBufferManager
157{
158public:
159 unsigned int num_buffers() const override
160 {
161 return 1;
162 }
163 bool lock_to_reshape_if_needed(unsigned int index) override
164 {
165 return true;
166 }
167 void mark_as_reshaped(unsigned int index) override
168 {
169 }
170 void wait_for_reshaping(unsigned int index) override
171 {
172 }
173 void mark_as_unused(unsigned int index) override
174 {
175 }
176};
177
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100178NEGEMMInterleavedWrapper::NEGEMMInterleavedWrapper(std::shared_ptr<IMemoryManager> memory_manager)
179 : _memory_group(std::move(memory_manager))
180{
181}
182void NEGEMMInterleavedWrapper::run()
183{
184 prepare();
185
186 _memory_group.acquire();
Anthony Barbierac314c22018-09-11 17:49:10 +0100187 NEScheduler::get().run_tagged_workloads(_workloads, _tag.c_str());
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100188 _memory_group.release();
189}
190
191void NEGEMMInterleavedWrapper::prepare()
192{
193 if(!_is_prepared)
194 {
195 if(_pretranspose_b)
196 {
Georgios Pinitasca1250d2018-11-22 19:38:27 +0000197 _transformed_b.allocator()->allocate();
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100198 NEScheduler::get().schedule(_prepare_b.get(), Window::DimX);
199 _b->mark_as_unused();
200 }
201 else
202 {
203 _prepare_b->create_workloads(_b_workloads);
204 }
205 _transform_a->create_workloads(_a_workloads);
206 _matrix_multiply->create_workloads(_mm_workloads);
207
208 //Maximum number of workloads to create:
209 const unsigned int num_threads = NEScheduler::get().num_threads();
Gian Marco Iodicef2bd2612018-08-07 17:22:24 +0100210 const unsigned int max_iterations = num_threads == 1 ? 1 : num_threads;
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100211 //Maximum number of iterations the parameters allow:
212 const unsigned int num_iterations = _batch_window.num_iterations_total();
213 // Keep the smallest of the two:
214 const unsigned int num_windows = std::min(num_iterations, max_iterations);
215 const TensorShape window_shape = _batch_window.shape();
Anthony Barbierff0bccf2018-11-30 10:42:40 +0000216 const unsigned int num_x_blocks = _block_walker.num_iterations(Window::DimX);
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100217
218 // Create a 1D window to dynamically split the batch window:
219 Window win_1D;
220 win_1D.set(0, Window::Dimension(0, num_iterations));
221
222 // Create one workload for each sub-window:
223 for(unsigned int w = 0; w < num_windows; w++)
224 {
Anthony Barbierff0bccf2018-11-30 10:42:40 +0000225 Window win = win_1D.split_window(0, w, num_windows);
226 const Coordinates start_offset = index2coords(window_shape, win.x().start());
227 const Coordinates end_offset = index2coords(window_shape, win.x().end() - 1);
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100228
Anthony Barbierff0bccf2018-11-30 10:42:40 +0000229 if(_pretranspose_b)
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100230 {
Anthony Barbierff0bccf2018-11-30 10:42:40 +0000231 auto workload = [start_offset, end_offset, num_x_blocks, this](const ThreadInfo & info)
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100232 {
Anthony Barbierff0bccf2018-11-30 10:42:40 +0000233 //For each block of rows in "M"
234 auto workload_mm = this->_mm_workloads.begin();
235 for(auto workload_a = this->_a_workloads.begin(); workload_a != this->_a_workloads.end(); workload_a++)
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100236 {
Anthony Barbierff0bccf2018-11-30 10:42:40 +0000237 // Transform one k_block from A:
238 this->_transform_a->transform(*workload_a, info, this->_batch_window, start_offset, end_offset);
239 // Then perform the matrix multiplication for each x block along N:
240 for(unsigned int i = 0; i < num_x_blocks; i++)
241 {
242 ARM_COMPUTE_ERROR_ON(workload_mm == this->_mm_workloads.end());
243 this->_matrix_multiply->transform(*workload_mm++, info, this->_batch_window, start_offset, end_offset);
244 }
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100245 }
Anthony Barbierff0bccf2018-11-30 10:42:40 +0000246 };
247 _workloads.push_back(workload);
248 }
249 else
250 {
251 auto workload = [num_threads, start_offset, end_offset, num_x_blocks, this](const ThreadInfo & info)
252 {
253 //For each block of rows in "M"
254 auto workload_mm = this->_mm_workloads.begin();
255 unsigned int workload_b = 0;
256 //If there is only one thread then only reshape the B blocks as you need them:
257 unsigned int workload_b_next = num_threads == 1 ? this->_b_workloads.size() : 1;
258
259 for(auto workload_a = this->_a_workloads.begin(); workload_a != this->_a_workloads.end(); workload_a++)
260 {
261 // Transform one k_block from A:
262 this->_transform_a->transform(*workload_a, info, this->_batch_window, start_offset, end_offset);
263 // Then perform the matrix multiplication for each x block along N:
264 for(unsigned int i = 0; i < num_x_blocks; i++)
265 {
266 ARM_COMPUTE_ERROR_ON(workload_mm == this->_mm_workloads.end());
267 if(workload_b_next < this->_b_workloads.size())
268 {
269 //Lock on BufferManager: need to run it ?
270 if(this->_buffer_manager->lock_to_reshape_if_needed(workload_b_next))
271 {
272 this->_prepare_b->transform(this->_b_workloads[workload_b_next], info);
273 this->_buffer_manager->mark_as_reshaped(workload_b_next);
274 }
275 workload_b_next++;
276 }
277 ARM_COMPUTE_ERROR_ON(workload_b >= this->_b_workloads.size());
278 // Run if needed or wait
279 if(this->_buffer_manager->lock_to_reshape_if_needed(workload_b))
280 {
281 this->_prepare_b->transform(this->_b_workloads[workload_b], info);
282 this->_buffer_manager->mark_as_reshaped(workload_b);
283 }
284 this->_buffer_manager->wait_for_reshaping(workload_b);
285 this->_matrix_multiply->transform(*workload_mm++, info, this->_batch_window, start_offset, end_offset);
286 this->_buffer_manager->mark_as_unused(workload_b);
287 workload_b++;
288 }
289 }
290 };
291 _workloads.push_back(workload);
292 }
293 }
294 if(!_pretranspose_b && num_windows > 1 && num_windows % num_threads != 0)
295 {
296 //Make sure the number of workloads is a multiple of the number of threads to avoid dead locks:
297 for(unsigned int leftover = num_windows % num_threads; leftover != num_threads; leftover++)
298 {
299 auto workload = [this](const ThreadInfo & info)
300 {
301 unsigned int workload_b = 0;
302 //If there is only one thread then only reshape the B blocks as you need them:
303 unsigned int workload_b_next = 1;
304
305 for(unsigned int iteration = 0; iteration < this->_mm_workloads.size(); iteration++)
306 {
307 if(workload_b_next < this->_b_workloads.size())
308 {
309 //Lock on BufferManager: need to run it ?
310 if(this->_buffer_manager->lock_to_reshape_if_needed(workload_b_next))
311 {
312 this->_prepare_b->transform(this->_b_workloads[workload_b_next], info);
313 this->_buffer_manager->mark_as_reshaped(workload_b_next);
314 }
315 workload_b_next++;
316 }
317 ARM_COMPUTE_ERROR_ON(workload_b >= this->_b_workloads.size());
318 // Run if needed or wait
319 if(this->_buffer_manager->lock_to_reshape_if_needed(workload_b))
320 {
321 this->_prepare_b->transform(this->_b_workloads[workload_b], info);
322 this->_buffer_manager->mark_as_reshaped(workload_b);
323 }
324 this->_buffer_manager->wait_for_reshaping(workload_b);
325 this->_buffer_manager->mark_as_unused(workload_b);
326 workload_b++;
327 }
328 };
329 _workloads.push_back(workload);
330 }
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100331 }
332
333 _is_prepared = true;
334 }
335}
336
337namespace
338{
339// Factory to instantiate NEGEMMInterleavedPrepareBWrapperKernel:
340template <typename InputType, bool use_dot = false>
341std::unique_ptr<NEGEMMInterleavedPrepareBWrapperKernel> instantiate_prepareB(const ITensor *b, ITensor *transformed_b, const INEGEMMWrapperKernel::Params &params)
342{
343 auto prepare_b = support::cpp14::make_unique<NEGEMMInterleavedPrepareBWrapperKernelTemplate<InputType, use_dot>>();
344 prepare_b->configure(b, transformed_b, false, NEScheduler::get().cpu_info(), params);
345 return std::move(prepare_b);
346}
347
348// Factory to instantiate NEGEMMInterleavedTransformAWrapperTemplate:
349template <typename InputType, bool use_dot = false>
350std::unique_ptr<NEGEMMInterleavedTransformAWrapper> instantiate_transformA(const ITensor *a, ITensor *transformed_a, const Window &block_walker, const INEGEMMWrapperKernel::Params &params)
351{
352 auto transform_a = support::cpp14::make_unique<NEGEMMInterleavedTransformAWrapperTemplate<InputType, use_dot>>();
353 transform_a->configure(a, transformed_a, false, block_walker, params);
354 return std::move(transform_a);
355}
356
357// Factory to instantiate NEGEMMInterleavedTransformAWrapperTemplate:
358template <typename InputType, typename OutputType, bool use_dot = false>
359std::unique_ptr<NEGEMMInterleavedMatrixMultiplyWrapper> instantiate_matrix_multiply(const ITensor *transformed_a, const ITensor *transformed_b, ITensor *tmp_c, ITensor *c, const Window &block_walker,
360 const BlockSizes &block_sizes, const INEGEMMWrapperKernel::Params &params, bool pretranspose_b, float alpha, float beta)
361{
362 auto matrix_multiply = support::cpp14::make_unique<NEGEMMInterleavedMatrixMultiplyWrapperTemplate<InputType, OutputType, use_dot>>();
363 matrix_multiply->configure(transformed_a, transformed_b, tmp_c, c, block_walker, block_sizes, params, pretranspose_b, alpha, beta, NEScheduler::get().num_threads());
364 return std::move(matrix_multiply);
365}
366} // namespace
367
368void NEGEMMInterleavedWrapper::configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, bool pretranspose_b, bool use_dot)
369{
370 _params = INEGEMMWrapperKernel::extract_parameters(a, b, c);
371 _a = a;
372 _b = b;
373 _c = c;
374 _pretranspose_b = pretranspose_b;
375
376 DataType input_type = a->info()->data_type();
377
378 // Forcing 128-byte alignment (required by 32-bit kernels)
379 const unsigned int alignment = 128;
380 _transformed_b.allocator()->init(TensorInfo{}, alignment);
381 _tmp_c.allocator()->init(TensorInfo{}, alignment);
Anthony Barbierac314c22018-09-11 17:49:10 +0100382 _tag = "NEGEMMInterleaved_";
383 _tag += get_strategy_name(input_type, use_dot);
384
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100385 if(!_pretranspose_b)
386 {
Anthony Barbierff0bccf2018-11-30 10:42:40 +0000387 _block_sizes = calculate_block_sizes_from_data_type(NEScheduler::get().cpu_info(), _params.M, _params.N, _params.K, input_type, use_dot);
388 _batch_window.set(Window::DimX, Window::Dimension(0, ceil_to_multiple(_block_sizes.m_round, _block_sizes.strategy_out_height), _block_sizes.strategy_out_height));
389 _batch_window.set(Window::DimY, Window::Dimension(0, _params.batches));
390 // If the execution is single threaded or has only one window then the buffer manager only needs 1 buffer else we will use NUM_BUFFERS buffers and ping pong between them:
391 const unsigned int num_iterations = _batch_window.num_iterations_total();
392 if(NEScheduler::get().num_threads() == 1 || num_iterations == 1)
393 {
394 _buffer_manager = support::cpp14::make_unique<BufferManagerSingleThread>();
395 }
396 else
397 {
398#ifdef NO_MULTI_THREADING
399 ARM_COMPUTE_ERROR("Can't have more than 1 buffer without multiple threads");
400#else /* NO_MULTI_THREADING */
401 _buffer_manager = support::cpp14::make_unique<BufferManagerMultipleThreads>(NEScheduler::get().num_threads());
402#endif /* NO_MULTI_THREADING */
403 }
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100404 // If B is transposed at every iteration then transformed_B can be managed:
405 _memory_group.manage(&_transformed_b);
Anthony Barbierff0bccf2018-11-30 10:42:40 +0000406 auto_init_if_empty(*_transformed_b.info(), _b->info()->clone()->set_tensor_shape(TensorShape(_block_sizes.x_block * _block_sizes.k_block, _buffer_manager->num_buffers())));
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100407 }
Anthony Barbierac314c22018-09-11 17:49:10 +0100408 else
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100409 {
Anthony Barbierac314c22018-09-11 17:49:10 +0100410 _tag += "_preB";
Anthony Barbierff0bccf2018-11-30 10:42:40 +0000411 }
412 switch(input_type)
413 {
414 case DataType::F32:
415 _prepare_b = instantiate_prepareB<float>(_b, &_transformed_b, _params);
416 break;
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100417#ifdef __aarch64__
Anthony Barbierff0bccf2018-11-30 10:42:40 +0000418 case DataType::U8:
419 case DataType::QASYMM8:
420 if(use_dot)
421 {
422 _prepare_b = instantiate_prepareB<uint8_t, true>(_b, &_transformed_b, _params);
423 }
424 else
425 {
426 _prepare_b = instantiate_prepareB<uint8_t, false>(_b, &_transformed_b, _params);
427 }
428 break;
429 case DataType::S8:
430 if(use_dot)
431 {
432 _prepare_b = instantiate_prepareB<int8_t, true>(_b, &_transformed_b, _params);
433 }
434 else
435 {
436 _prepare_b = instantiate_prepareB<int8_t, false>(_b, &_transformed_b, _params);
437 }
438 break;
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100439#endif /* __aarch64__ */
440#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Anthony Barbierff0bccf2018-11-30 10:42:40 +0000441 case DataType::F16:
442 _prepare_b = instantiate_prepareB<__fp16>(_b, &_transformed_b, _params);
443 break;
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100444#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Anthony Barbierff0bccf2018-11-30 10:42:40 +0000445 default:
446 ARM_COMPUTE_ERROR("DataType not supported");
447 break;
448 }
449 ARM_COMPUTE_ERROR_ON(_prepare_b == nullptr);
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100450
Anthony Barbierff0bccf2018-11-30 10:42:40 +0000451 if(_pretranspose_b)
452 {
Anthony Barbierac314c22018-09-11 17:49:10 +0100453 _block_sizes = _prepare_b->block_sizes();
Anthony Barbierff0bccf2018-11-30 10:42:40 +0000454 _batch_window.set(Window::DimX, Window::Dimension(0, ceil_to_multiple(_block_sizes.m_round, _block_sizes.strategy_out_height), _block_sizes.strategy_out_height));
455 _batch_window.set(Window::DimY, Window::Dimension(0, _params.batches));
Anthony Barbierac314c22018-09-11 17:49:10 +0100456 }
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100457
458 _block_walker.set(Window::DimX, Window::Dimension(0, ceil_to_multiple(_params.N, _block_sizes.x_block), _block_sizes.x_block));
459 _block_walker.set(Window::DimY, Window::Dimension(0, ceil_to_multiple(_params.K, _block_sizes.k_block), _block_sizes.k_block));
460 _block_walker.set(Window::DimZ, Window::Dimension(0, _params.multis));
461
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100462 _transformed_a.allocator()->init(TensorInfo(TensorShape{ _block_sizes.k_block, _block_sizes.m_round, _params.batches }, 1, input_type), alignment);
463 _memory_group.manage(&_transformed_a);
464 _memory_group.manage(&_tmp_c);
465
466 switch(input_type)
467 {
468 case DataType::F32:
469 _transform_a = instantiate_transformA<float>(_a, &_transformed_a, _block_walker, _params);
470 _matrix_multiply = instantiate_matrix_multiply<float, float>(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, pretranspose_b, alpha, beta);
471 break;
472#ifdef __aarch64__
473 case DataType::U8:
474 case DataType::QASYMM8:
475 if(use_dot)
476 {
477 _transform_a = instantiate_transformA<uint8_t, true>(_a, &_transformed_a, _block_walker, _params);
478 _matrix_multiply = instantiate_matrix_multiply<uint8_t, uint32_t, true>(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, pretranspose_b, alpha, beta);
479 }
480 else
481 {
482 _transform_a = instantiate_transformA<uint8_t, false>(_a, &_transformed_a, _block_walker, _params);
483 _matrix_multiply = instantiate_matrix_multiply<uint8_t, uint32_t, false>(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, pretranspose_b, alpha, beta);
484 }
485 break;
486 case DataType::S8:
487 if(use_dot)
488 {
489 _transform_a = instantiate_transformA<int8_t, true>(_a, &_transformed_a, _block_walker, _params);
490 _matrix_multiply = instantiate_matrix_multiply<int8_t, int32_t, true>(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, pretranspose_b, alpha, beta);
491 }
492 else
493 {
494 _transform_a = instantiate_transformA<int8_t, false>(_a, &_transformed_a, _block_walker, _params);
495 _matrix_multiply = instantiate_matrix_multiply<int8_t, int32_t, false>(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, pretranspose_b, alpha, beta);
496 }
497 break;
498#endif /* __aarch64__ */
499#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
500 case DataType::F16:
501 _transform_a = instantiate_transformA<__fp16>(_a, &_transformed_a, _block_walker, _params);
502 _matrix_multiply = instantiate_matrix_multiply<__fp16, __fp16>(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, pretranspose_b, alpha, beta);
503 break;
504 break;
505#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
506 default:
507 break;
508 }
509 ARM_COMPUTE_ERROR_ON(_transform_a == nullptr);
510 ARM_COMPUTE_ERROR_ON(_matrix_multiply == nullptr);
511 _transformed_a.allocator()->allocate();
512 _tmp_c.allocator()->allocate();
Georgios Pinitasca1250d2018-11-22 19:38:27 +0000513 if(!_pretranspose_b)
514 {
515 _transformed_b.allocator()->allocate();
516 }
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100517}
518} // namespace arm_compute