blob: 3aa43ec96e58aa316ffcaddca0eb9eff7c5586ce [file] [log] [blame]
Pablo Telloeb82fd22018-02-23 13:43:50 +00001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef __ARM_ASSEMBLY_HELPER_H__
25#define __ARM_ASSEMBLY_HELPER_H__
26
27#include "arm_compute/core/ITensor.h"
28#include "support/ToolchainSupport.h"
29
30#include "arm_compute/core/Helpers.h"
31#include "arm_compute/core/IAccessWindow.h"
32#include "arm_compute/core/Log.h"
33#include "arm_compute/core/NEON/kernels/assembly/NEGEMMAssemblyWrapper.h"
34#include "arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp"
35#include "arm_compute/core/TensorInfo.h"
36#include "arm_compute/core/Types.h"
37#include "arm_compute/core/Validate.h"
38#include "arm_compute/core/Window.h"
39#include "arm_compute/runtime/NEON/NEScheduler.h"
40
41namespace arm_compute
42{
Alex Gildayc357c472018-03-21 13:54:09 +000043/** Assembly kernel glue */
Pablo Telloeb82fd22018-02-23 13:43:50 +000044template <typename TypeInput, typename TypeOutput>
45class AssemblyKernelGlue final
46{
47public:
Alex Gildayc357c472018-03-21 13:54:09 +000048 /** Operator type */
Pablo Telloeb82fd22018-02-23 13:43:50 +000049 using TypeOperator = TypeInput;
Alex Gildayc357c472018-03-21 13:54:09 +000050 /** Result type */
51 using TypeResult = TypeOutput;
52 /** Default constructor. */
Pablo Telloeb82fd22018-02-23 13:43:50 +000053 AssemblyKernelGlue()
Georgios Pinitasb95e2102018-05-30 10:17:38 +010054 : _gemm_kernel_asm(nullptr), _optimised_kernel(nullptr), _a(nullptr), _b(nullptr), _d(nullptr), _workspace(nullptr), _pretranspose(nullptr)
Pablo Telloeb82fd22018-02-23 13:43:50 +000055 {
56 }
Alex Gildayc357c472018-03-21 13:54:09 +000057 /** Assembly Gemm */
Pablo Telloeb82fd22018-02-23 13:43:50 +000058 using AssemblyGemm = arm_gemm::GemmCommon<TypeInput, TypeOutput>;
59
Alex Gildayc357c472018-03-21 13:54:09 +000060 /** Prevent instances of this class from being copy constructed */
Pablo Telloeb82fd22018-02-23 13:43:50 +000061 const AssemblyKernelGlue<TypeInput, TypeOutput> &operator=(const AssemblyKernelGlue<TypeInput, TypeOutput> &) = delete;
Alex Gildayc357c472018-03-21 13:54:09 +000062 /** Prevent instances of this class from being copied */
Pablo Telloeb82fd22018-02-23 13:43:50 +000063 AssemblyKernelGlue(const AssemblyKernelGlue<TypeInput, TypeOutput> &) = delete;
64
Alex Gildayc357c472018-03-21 13:54:09 +000065 /** Assembly Gemm kernel */
Pablo Telloeb82fd22018-02-23 13:43:50 +000066 std::unique_ptr<AssemblyGemm> _gemm_kernel_asm;
Alex Gildayc357c472018-03-21 13:54:09 +000067 /** Optimised NEON kernel */
68 std::unique_ptr<INEKernel> _optimised_kernel;
69 /** Input A */
70 const ITensor *_a;
71 /** Input B */
72 const ITensor *_b;
73 /** Output */
74 ITensor *_d;
Georgios Pinitasb95e2102018-05-30 10:17:38 +010075 /** GEMM workspace */
76 ITensor *_workspace;
Georgios Pinitas932b5612018-05-03 13:44:35 +010077 /** Pre-transpose tensor */
78 ITensor *_pretranspose;
Pablo Telloeb82fd22018-02-23 13:43:50 +000079
80 /** Configures the arrays pointers and strides in the assembly kernel and executes the assembly kernel.
81 * The call to set_arrays is needed to deal with the input sizes containing batches (dims > 2)
82 */
83 inline void run()
84 {
85 const int lda = _a->info()->strides_in_bytes().y() / sizeof(TypeInput);
86 const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput);
87 const int ldd = _d->info()->strides_in_bytes().y() / sizeof(TypeOutput);
88
Michalis Spyroue2503892018-04-23 15:17:31 +010089 // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is
90 // the relevant multiple of the row stride.
91 const bool is_nhwc = _a->info()->data_layout() == DataLayout::NHWC;
92 const int stride_in_bytes_a = is_nhwc ? _a->info()->strides_in_bytes().y() * _d->info()->dimension(1) : _a->info()->strides_in_bytes().z();
93
94 const int batch_stride_a = stride_in_bytes_a / sizeof(TypeInput);
Michalis Spyroue7e96e02018-04-13 13:44:10 +010095 const int batch_stride_d = _d->info()->strides_in_bytes().z() / sizeof(TypeOutput);
Pablo Telloeb82fd22018-02-23 13:43:50 +000096
Michalis Spyroue7e96e02018-04-13 13:44:10 +010097 const int multi_stride_a = _a->info()->strides_in_bytes()[3] / sizeof(TypeInput);
98 const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput);
99 const int multi_stride_d = _d->info()->strides_in_bytes()[3] / sizeof(TypeOutput);
100
101 const auto in0_ptr = reinterpret_cast<const TypeInput *>(_a->buffer());
102 const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer());
103 auto out_ptr = reinterpret_cast<TypeOutput *>(_d->buffer());
104
Georgios Pinitasb95e2102018-05-30 10:17:38 +0100105 // Set workspace if needed
106 if(_workspace != nullptr)
107 {
108 _gemm_kernel_asm->set_working_space(reinterpret_cast<void *>(_workspace->buffer()));
109 }
110
111 // Set gemm parameters
Michalis Spyroue7e96e02018-04-13 13:44:10 +0100112 _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, in1_ptr, ldb, multi_stride_b, out_ptr, ldd, batch_stride_d, multi_stride_d);
Georgios Pinitasb95e2102018-05-30 10:17:38 +0100113
114 // Pretranspose B if required
Georgios Pinitas932b5612018-05-03 13:44:35 +0100115 if(_gemm_kernel_asm->B_pretranspose_required())
116 {
Georgios Pinitas28a84932018-05-09 20:31:35 +0100117 // Forcing 128-byte alignment (required by 32-bit kernels)
118 const unsigned int alignment = 128;
119 void *raw_ptr = reinterpret_cast<void *>(_pretranspose->buffer());
120 size_t space = _pretranspose->info()->total_size();
121 void *aligned_ptr = support::cpp11::align(alignment, _gemm_kernel_asm->get_B_pretransposed_array_size(), raw_ptr, space);
Georgios Pinitas932b5612018-05-03 13:44:35 +0100122 ARM_COMPUTE_ERROR_ON(_pretranspose == nullptr || _pretranspose->buffer() == nullptr);
Georgios Pinitas28a84932018-05-09 20:31:35 +0100123 _gemm_kernel_asm->pretranspose_B_array(aligned_ptr, in1_ptr, ldb, multi_stride_b);
Georgios Pinitasd8cde852018-05-08 18:58:19 +0100124 _b->mark_as_unused();
Georgios Pinitas932b5612018-05-03 13:44:35 +0100125 }
126
Georgios Pinitasb95e2102018-05-30 10:17:38 +0100127 // Schedule assembly kernel
Michalis Spyroue7e96e02018-04-13 13:44:10 +0100128 NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX);
Pablo Telloeb82fd22018-02-23 13:43:50 +0000129 }
130};
131
Alex Gildayc357c472018-03-21 13:54:09 +0000132/** Float 32 assembly kernel glue */
133using AssemblyKernelGlueF32 = AssemblyKernelGlue<float, float>;
134/** Uint 8 to Uint 32 kernel glue */
Pablo Telloeb82fd22018-02-23 13:43:50 +0000135using AssemblyKernelGlueU8U32 = AssemblyKernelGlue<uint8_t, uint32_t>;
Alex Gildayc357c472018-03-21 13:54:09 +0000136/** Int 8 to Int 32 kernel glue */
Pablo Telloeb82fd22018-02-23 13:43:50 +0000137using AssemblyKernelGlueS8S32 = AssemblyKernelGlue<int8_t, int32_t>;
138
Alex Gildayc357c472018-03-21 13:54:09 +0000139/** Allocate a workspace tensor.
140 *
141 * @param[in] workspace_size Size to allocate.
142 * @param[out] workspace Tensor to allocate.
143 * @param[in] memory_group Tensor memory group.
144 * @param[in] alignment Workspace memory alignment.
145 * @param[in] num_threads Number of workspace threads.
146 */
Georgios Pinitas932b5612018-05-03 13:44:35 +0100147inline void allocate_workspace(size_t workspace_size, Tensor &workspace, MemoryGroup *memory_group, size_t alignment, unsigned int num_threads)
Pablo Telloeb82fd22018-02-23 13:43:50 +0000148{
149 ARM_COMPUTE_ERROR_ON_MSG(workspace_size == 0, "size cannot be 0");
150 workspace.allocator()->init(TensorInfo(TensorShape{ (workspace_size + alignment - 1) * num_threads }, 1, DataType::S8));
Georgios Pinitasb95e2102018-05-30 10:17:38 +0100151 if(memory_group != nullptr)
152 {
153 memory_group->manage(&workspace);
154 }
Pablo Telloeb82fd22018-02-23 13:43:50 +0000155 workspace.allocator()->allocate();
156}
157
Alex Gildayc357c472018-03-21 13:54:09 +0000158/** Create a wrapper kernel.
159 *
Georgios Pinitas932b5612018-05-03 13:44:35 +0100160 * @param[in] a Input tensor A.
161 * @param[in] b Input tensor B.
162 * @param[out] d Output tensor.
163 * @param[in] alpha Alpha value.
164 * @param[in] beta Beta value.
165 * @param[in] pretranspose_hint Pre-transpose hint in case matrix b should be pre-transposed
166 * @param[out] workspace Workspace tensor
167 * @param[out] B_pretranspose Tensor to hold the pre-transposed B
168 * @param[in] memory_group Tensor memory group.
169 * @param[out] asm_glue Assembly glue kernel.
Alex Gildayc357c472018-03-21 13:54:09 +0000170 *
Pablo Tello7fad9b12018-03-14 17:55:27 +0000171 * @return the wrapper kernel.
Alex Gildayc357c472018-03-21 13:54:09 +0000172 */
Pablo Telloeb82fd22018-02-23 13:43:50 +0000173template <typename T>
Georgios Pinitas932b5612018-05-03 13:44:35 +0100174inline bool setup_assembly_kernel(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
175 Tensor &workspace, Tensor &B_pretranspose, MemoryGroup &memory_group, T &asm_glue)
Pablo Telloeb82fd22018-02-23 13:43:50 +0000176{
Pablo Tello7fad9b12018-03-14 17:55:27 +0000177 const CPUInfo &ci = NEScheduler::get().cpu_info();
178 const int M = d->info()->tensor_shape().y();
179 const int N = d->info()->tensor_shape().x();
180 const int K = a->info()->tensor_shape().x();
Michalis Spyroue2503892018-04-23 15:17:31 +0100181 const int batches = d->info()->tensor_shape().total_size_upper(2);
Michalis Spyroue7e96e02018-04-13 13:44:10 +0100182 const int multis = b->info()->tensor_shape().z();
Pablo Tello7fad9b12018-03-14 17:55:27 +0000183 unsigned int num_threads = NEScheduler::get().num_threads();
Michalis Spyroue7e96e02018-04-13 13:44:10 +0100184
Pablo Telloeb82fd22018-02-23 13:43:50 +0000185 // unique_ptr to a Gemm object
Pablo Tello7fad9b12018-03-14 17:55:27 +0000186 std::unique_ptr<typename T::AssemblyGemm>
Georgios Pinitas932b5612018-05-03 13:44:35 +0100187 asm_gemm(arm_gemm::gemm<typename T::TypeOperator, typename T::TypeResult>(ci, M, N, K, batches, multis, false, false, alpha, beta, num_threads, pretranspose_hint));
Pablo Telloeb82fd22018-02-23 13:43:50 +0000188 // arm_compute wrapper for the Gemm object (see above)
Pablo Tello7fad9b12018-03-14 17:55:27 +0000189 std::unique_ptr<NEGEMMAssemblyWrapper<typename T::AssemblyGemm>>
190 acl_gemm_wrapper = support::cpp14::make_unique<NEGEMMAssemblyWrapper<typename T::AssemblyGemm>>();
Pablo Telloeb82fd22018-02-23 13:43:50 +0000191 if(acl_gemm_wrapper != nullptr && asm_gemm != nullptr)
192 {
193 acl_gemm_wrapper->configure(asm_gemm.get());
194 const size_t workspace_size = asm_gemm->get_working_size();
195 if(workspace_size)
196 {
197 // Allocate workspace
Pablo Tello7fad9b12018-03-14 17:55:27 +0000198 const unsigned int alignment = 4096;
Georgios Pinitas932b5612018-05-03 13:44:35 +0100199 allocate_workspace(workspace_size, workspace, &memory_group, alignment, num_threads);
Georgios Pinitasb95e2102018-05-30 10:17:38 +0100200 asm_glue._workspace = &workspace;
Pablo Telloeb82fd22018-02-23 13:43:50 +0000201 }
Pablo Tello7fad9b12018-03-14 17:55:27 +0000202
203 //if we disable this code below in brackets then ConvLayer deadlocks when threads > 1 and
204 //the shapes are In=1x1x1024 Weights=1x1x1024x1001 Biases=1001 Out=1x1x1001
Pablo Telloeb82fd22018-02-23 13:43:50 +0000205 {
Pablo Tello7fad9b12018-03-14 17:55:27 +0000206 const unsigned int window_size = asm_gemm->get_window_size();
207 if(window_size < num_threads)
208 {
209 num_threads = window_size;
210 asm_gemm->set_nthreads(num_threads);
211 }
Pablo Telloeb82fd22018-02-23 13:43:50 +0000212 }
Pablo Tello7fad9b12018-03-14 17:55:27 +0000213
Georgios Pinitas932b5612018-05-03 13:44:35 +0100214 // Check for pre-transposed support
215 if(asm_gemm->B_pretranspose_required())
216 {
Georgios Pinitas28a84932018-05-09 20:31:35 +0100217 // Forcing 128-byte alignment (required by 32-bit kernels)
218 const unsigned int alignment = 128;
219 const size_t B_pretranspose_size = asm_gemm->get_B_pretransposed_array_size();
220 allocate_workspace(B_pretranspose_size, B_pretranspose, nullptr, alignment, 1);
Georgios Pinitas932b5612018-05-03 13:44:35 +0100221 ARM_COMPUTE_ERROR_ON_NULLPTR(B_pretranspose.buffer());
222 asm_glue._pretranspose = &B_pretranspose;
223 }
224
Pablo Telloeb82fd22018-02-23 13:43:50 +0000225 asm_glue._gemm_kernel_asm = std::move(asm_gemm);
226 asm_glue._optimised_kernel = std::move(acl_gemm_wrapper);
227 // We need to setup the ptrs in the run() method
228 asm_glue._a = a;
229 asm_glue._b = b;
230 asm_glue._d = d;
231 return true;
232 }
233 return false;
234}
235}
236#endif /* __ARM_ASSEMBLY_HELPER_H__ */