blob: e2d27cf94129db7a66cad13bc801bea062381f42 [file] [log] [blame]
Pablo Telloeb82fd22018-02-23 13:43:50 +00001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef __ARM_ASSEMBLY_HELPER_H__
25#define __ARM_ASSEMBLY_HELPER_H__
26
27#include "arm_compute/core/ITensor.h"
28#include "support/ToolchainSupport.h"
29
30#include "arm_compute/core/Helpers.h"
31#include "arm_compute/core/IAccessWindow.h"
32#include "arm_compute/core/Log.h"
33#include "arm_compute/core/NEON/kernels/assembly/NEGEMMAssemblyWrapper.h"
34#include "arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp"
35#include "arm_compute/core/TensorInfo.h"
36#include "arm_compute/core/Types.h"
37#include "arm_compute/core/Validate.h"
38#include "arm_compute/core/Window.h"
39#include "arm_compute/runtime/NEON/NEScheduler.h"
40
41namespace arm_compute
42{
Alex Gildayc357c472018-03-21 13:54:09 +000043/** Assembly kernel glue */
Pablo Telloeb82fd22018-02-23 13:43:50 +000044template <typename TypeInput, typename TypeOutput>
45class AssemblyKernelGlue final
46{
47public:
Alex Gildayc357c472018-03-21 13:54:09 +000048 /** Operator type */
Pablo Telloeb82fd22018-02-23 13:43:50 +000049 using TypeOperator = TypeInput;
Alex Gildayc357c472018-03-21 13:54:09 +000050 /** Result type */
51 using TypeResult = TypeOutput;
52 /** Default constructor. */
Pablo Telloeb82fd22018-02-23 13:43:50 +000053 AssemblyKernelGlue()
54 : _gemm_kernel_asm(nullptr), _optimised_kernel(nullptr), _a(nullptr), _b(nullptr), _d(nullptr)
55 {
56 }
Alex Gildayc357c472018-03-21 13:54:09 +000057 /** Assembly Gemm */
Pablo Telloeb82fd22018-02-23 13:43:50 +000058 using AssemblyGemm = arm_gemm::GemmCommon<TypeInput, TypeOutput>;
59
Alex Gildayc357c472018-03-21 13:54:09 +000060 /** Prevent instances of this class from being copy constructed */
Pablo Telloeb82fd22018-02-23 13:43:50 +000061 const AssemblyKernelGlue<TypeInput, TypeOutput> &operator=(const AssemblyKernelGlue<TypeInput, TypeOutput> &) = delete;
Alex Gildayc357c472018-03-21 13:54:09 +000062 /** Prevent instances of this class from being copied */
Pablo Telloeb82fd22018-02-23 13:43:50 +000063 AssemblyKernelGlue(const AssemblyKernelGlue<TypeInput, TypeOutput> &) = delete;
64
Alex Gildayc357c472018-03-21 13:54:09 +000065 /** Assembly Gemm kernel */
Pablo Telloeb82fd22018-02-23 13:43:50 +000066 std::unique_ptr<AssemblyGemm> _gemm_kernel_asm;
Alex Gildayc357c472018-03-21 13:54:09 +000067 /** Optimised NEON kernel */
68 std::unique_ptr<INEKernel> _optimised_kernel;
69 /** Input A */
70 const ITensor *_a;
71 /** Input B */
72 const ITensor *_b;
73 /** Output */
74 ITensor *_d;
Pablo Telloeb82fd22018-02-23 13:43:50 +000075
76 /** Configures the arrays pointers and strides in the assembly kernel and executes the assembly kernel.
77 * The call to set_arrays is needed to deal with the input sizes containing batches (dims > 2)
78 */
79 inline void run()
80 {
81 const int lda = _a->info()->strides_in_bytes().y() / sizeof(TypeInput);
82 const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput);
83 const int ldd = _d->info()->strides_in_bytes().y() / sizeof(TypeOutput);
84
85 // Configure kernel window
86 Window window = calculate_max_window(*_d->info());
87 const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer());
88
89 // Only iterate over batches
90 Window win(window);
91 win.set(0, Window::Dimension(0, 1, 1));
92 win.set(1, Window::Dimension(0, 1, 1));
93 Iterator in0(_a, window);
94 Iterator out(_d, window);
95 execute_window_loop(win, [&](const Coordinates &)
96 {
97 const auto in0_ptr = reinterpret_cast<const TypeInput *>(in0.ptr());
98 auto out_ptr = reinterpret_cast<TypeOutput *>(out.ptr());
99 _gemm_kernel_asm->set_arrays(in0_ptr, lda, in1_ptr, ldb, out_ptr, ldd);
100 NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX);
101 },
102 in0, out);
103 }
104};
105
Alex Gildayc357c472018-03-21 13:54:09 +0000106/** Float 32 assembly kernel glue */
107using AssemblyKernelGlueF32 = AssemblyKernelGlue<float, float>;
108/** Uint 8 to Uint 32 kernel glue */
Pablo Telloeb82fd22018-02-23 13:43:50 +0000109using AssemblyKernelGlueU8U32 = AssemblyKernelGlue<uint8_t, uint32_t>;
Alex Gildayc357c472018-03-21 13:54:09 +0000110/** Int 8 to Int 32 kernel glue */
Pablo Telloeb82fd22018-02-23 13:43:50 +0000111using AssemblyKernelGlueS8S32 = AssemblyKernelGlue<int8_t, int32_t>;
112
Alex Gildayc357c472018-03-21 13:54:09 +0000113/** Allocate a workspace tensor.
114 *
115 * @param[in] workspace_size Size to allocate.
116 * @param[out] workspace Tensor to allocate.
117 * @param[in] memory_group Tensor memory group.
118 * @param[in] alignment Workspace memory alignment.
119 * @param[in] num_threads Number of workspace threads.
120 */
Pablo Telloeb82fd22018-02-23 13:43:50 +0000121inline void allocate_workspace(size_t workspace_size, Tensor &workspace, MemoryGroup &memory_group, size_t alignment, unsigned int num_threads)
122{
123 ARM_COMPUTE_ERROR_ON_MSG(workspace_size == 0, "size cannot be 0");
124 workspace.allocator()->init(TensorInfo(TensorShape{ (workspace_size + alignment - 1) * num_threads }, 1, DataType::S8));
125 workspace.allocator()->allocate();
126}
127
Alex Gildayc357c472018-03-21 13:54:09 +0000128/** Create a wrapper kernel.
129 *
130 * @param[in] a Input tensor A.
131 * @param[in] b Input tensor B.
132 * @param[in] c (Optional) Input tensor C.
133 * @param[out] d Output tensor.
134 * @param[in] alpha Alpha value.
135 * @param[in] beta Beta value.
136 *
137 * @return the wrapper kernel.
138 */
Pablo Telloeb82fd22018-02-23 13:43:50 +0000139template <typename T>
140std::unique_ptr<NEGEMMAssemblyWrapper<T>> create_wrapper_kernel(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta)
141{
142 // rework this function, why are we checking data type and other things here ? should we create another function can_run_optimised_kernel() ?
143#if defined(__arm__)
144 if(NEScheduler::get().cpu_info().CPU == CPUTarget::ARMV7 && a->info()->data_type() == DataType::F32 && (c == nullptr || beta == 0.f))
145 {
146 return support::cpp14::make_unique<NEGEMMAssemblyWrapper<T>>();
147 }
148#elif defined(__aarch64__)
149 if(NEScheduler::get().cpu_info().CPU >= CPUTarget::ARMV8 && a->info()->data_type() == DataType::F32 && (c == nullptr || beta == 0.f))
150 {
151 return support::cpp14::make_unique<NEGEMMAssemblyWrapper<T>>();
152 }
153 else if(a->info()->data_type() == DataType::F16 && (c == nullptr || beta == 0.f))
154 {
155#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
156 return support::cpp14::make_unique<NEGEMMAssemblyWrapper<T>>();
157#else /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
158 ARM_COMPUTE_ERROR("Recompile the library with arch=arm64-v8.2-a to enable support for FP16.");
159#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
160 }
161#endif /* defined(__arm__) || defined(__aarch64__) */
162 return nullptr;
163}
164
Alex Gildayc357c472018-03-21 13:54:09 +0000165/** Setup assembly kernel.
166 *
167 * @param[in] a Input tensor A.
168 * @param[in] b Input tensor B.
169 * @param[in] c (Optional) Input tensor C.
170 * @param[in] d Output tensor.
171 * @param[in] alpha Alpha value.
172 * @param[in] beta Beta value.
173 * @param[out] workspace Workspace tensor
174 * @param[in] memory_group Tensor memory group.
175 * @param[out] asm_glue Assembly glue kernel.
176 *
177 * @return True if the assembly kernel is setup correctly.
178 */
Pablo Telloeb82fd22018-02-23 13:43:50 +0000179template <typename T>
180inline bool setup_assembly_kernel(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta,
181 Tensor &workspace, MemoryGroup &memory_group, T &asm_glue)
182{
183 const ::CPUInfo *ci = get_CPUInfo();
184 const int M = d->info()->tensor_shape().y();
185 const int N = d->info()->tensor_shape().x();
186 const int K = a->info()->tensor_shape().x();
187 unsigned int num_threads = NEScheduler::get().num_threads();
188 // unique_ptr to a Gemm object
189 std::unique_ptr<typename T::AssemblyGemm> asm_gemm(arm_gemm::gemm<typename T::TypeOperator, typename T::TypeResult>(*ci, M, N, K, false, false, alpha, beta, num_threads,
190 false));
191
192 // arm_compute wrapper for the Gemm object (see above)
193 std::unique_ptr<NEGEMMAssemblyWrapper<typename T::AssemblyGemm>> acl_gemm_wrapper = create_wrapper_kernel<typename T::AssemblyGemm>(a, b, c, d, alpha, beta);
194 if(acl_gemm_wrapper != nullptr && asm_gemm != nullptr)
195 {
196 acl_gemm_wrapper->configure(asm_gemm.get());
197 const size_t workspace_size = asm_gemm->get_working_size();
198 if(workspace_size)
199 {
200 // Allocate workspace
201 allocate_workspace(workspace_size, workspace, memory_group, 4096, num_threads);
202 asm_gemm->set_working_space(reinterpret_cast<typename T::TypeResult *>(workspace.buffer()));
203 }
204 const unsigned int window_size = asm_gemm->get_window_size();
205 if(window_size < num_threads)
206 {
207 num_threads = window_size;
208 asm_gemm->set_nthreads(num_threads);
209 }
210 asm_glue._gemm_kernel_asm = std::move(asm_gemm);
211 asm_glue._optimised_kernel = std::move(acl_gemm_wrapper);
212 // We need to setup the ptrs in the run() method
213 asm_glue._a = a;
214 asm_glue._b = b;
215 asm_glue._d = d;
216 return true;
217 }
218 return false;
219}
220}
221#endif /* __ARM_ASSEMBLY_HELPER_H__ */