blob: 7fc57f3c0260760a4a7163e3620ae266c5260146 [file] [log] [blame]
Anthony Barbier3d677cc2018-07-23 16:42:59 +01001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.h"
26
27#include "NEGEMMInterleavedStrategies.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/ITensor.h"
30#include "arm_compute/core/Utils.h"
31#include "arm_compute/core/Validate.h"
32
33namespace arm_compute
34{
35namespace
36{
37// Call the lambda function for each workload generated by the passed window.
Anthony Barbierff0bccf2018-11-30 10:42:40 +000038template <typename To, bool use_dot, bool use_buffer_manager, typename Lambda>
Anthony Barbier3d677cc2018-07-23 16:42:59 +010039void for_each_element_in_window(const Window &window, const ITensor *b, ITensor *transformed_b, unsigned int N, unsigned int K, Lambda &&lambda)
40{
Anthony Barbierff0bccf2018-11-30 10:42:40 +000041 using strategy = typename Kernel<To, use_dot>::strategy;
42 unsigned int wl_index = 0;
43 unsigned int num_buffers = 0, reshaped_block_size = 0;
44
45 if(use_buffer_manager)
46 {
47 num_buffers = transformed_b->info()->tensor_shape()[1];
48 reshaped_block_size = transformed_b->info()->strides_in_bytes().y();
49 }
Anthony Barbier3d677cc2018-07-23 16:42:59 +010050
51 unsigned int offset_transformed_b = transformed_b->info()->offset_first_element_in_bytes();
52 execute_window_loop(window, [&](const Coordinates & coordinates)
53 {
54 const unsigned int x0 = coordinates.x();
55 const unsigned int k0 = coordinates.y();
56 const unsigned int multi = coordinates.z();
57
58 const unsigned int offset_b = b->info()->offset_element_in_bytes(Coordinates(0, 0, multi));
59 const unsigned int xmax = std::min(x0 + window.x().step(), N);
60 const unsigned int kmax = std::min(k0 + window.y().step(), K);
61
62 /* Figure out the size of each block. */
63 unsigned int x_size = (xmax - x0);
64 unsigned int k_size = (kmax - k0);
65
66 /* Round sizes up as needed. */
67 x_size = ceil_to_multiple(x_size, strategy::out_width());
68 k_size = ceil_to_multiple(k_size, strategy::k_unroll());
69
70 lambda(PrepareBWorkload(offset_b, offset_transformed_b, x0, xmax, k0, kmax));
71
72 //Each workload represents one block:
Anthony Barbierff0bccf2018-11-30 10:42:40 +000073 if(use_buffer_manager)
74 {
75 // Rotate through the BufferManager's buffers:
76 wl_index++;
77 offset_transformed_b = (wl_index % num_buffers) * reshaped_block_size;
78 }
79 else
80 {
81 offset_transformed_b += (x_size * k_size * sizeof(To));
82 }
Anthony Barbier3d677cc2018-07-23 16:42:59 +010083 });
84}
85
86// Calculate the size of transformed_b:
87template <typename To, bool use_dot>
Anthony Barbier08a45172018-11-30 17:20:26 +000088unsigned int get_B_pretransposed_array_size(unsigned int N, unsigned int K, const BlockSizes &bs, unsigned int multis)
Anthony Barbier3d677cc2018-07-23 16:42:59 +010089{
90 using strategy = typename Kernel<To, use_dot>::strategy;
91
92 // How many full blocks do N / K contain ?
93 size_t num_full_k = K / bs.k_block;
94 size_t num_full_x = N / bs.x_block;
95
96 ARM_COMPUTE_ERROR_ON(bs.x_block % strategy::out_width() != 0);
97 ARM_COMPUTE_ERROR_ON(bs.k_block % strategy::k_unroll() != 0);
98
99 size_t normal_x_size = bs.x_block;
100 size_t normal_k_size = bs.k_block;
101
102 // Round up the leftovers to be a multiple of the strategy processing size:
103 size_t left_over_x_size = ceil_to_multiple(N % bs.x_block, strategy::out_width());
104 size_t left_over_k_size = ceil_to_multiple(K % bs.k_block, strategy::k_unroll());
105
106 // Calculate the total size of the buffer:
107 size_t total = num_full_k * normal_k_size * (num_full_x * normal_x_size + left_over_x_size);
108 total += left_over_k_size * (left_over_x_size + num_full_x * normal_x_size);
Anthony Barbier08a45172018-11-30 17:20:26 +0000109
110 total *= multis;
111
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100112 return total;
113}
114
115} // namespace
116
117template <typename To, bool use_dot>
118BlockSizes NEGEMMInterleavedPrepareBWrapperKernelTemplate<To, use_dot>::block_sizes() const
119{
120 return _block_sizes;
121}
122
123template <typename To, bool use_dot>
124void NEGEMMInterleavedPrepareBWrapperKernelTemplate<To, use_dot>::configure(const ITensor *b, ITensor *transformed_b, bool transpose_b, const CPUInfo &ci, const INEGEMMWrapperKernel::Params &params)
125{
126 using strategy = typename Kernel<To, use_dot>::strategy;
127
128 const unsigned int multis = b->info()->tensor_shape().z();
129 _Nsize = b->info()->tensor_shape().x();
130 _Ksize = b->info()->tensor_shape().y();
131 _b = b;
132 _transformed_b = transformed_b;
133 _transpose_b = transpose_b;
134
135 _block_sizes = calculate_block_sizes<strategy>(ci, params.M, params.N, params.K);
136
Anthony Barbier08a45172018-11-30 17:20:26 +0000137 auto_init_if_empty(*transformed_b->info(), b->info()->clone()->set_tensor_shape(TensorShape{ get_B_pretransposed_array_size<To, use_dot>(_Nsize, _Ksize, _block_sizes, multis) }));
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100138
139 Window window;
140 window.set(Window::DimX, Window::Dimension(0, ceil_to_multiple(_Nsize, _block_sizes.x_block), _block_sizes.x_block));
141 window.set(Window::DimY, Window::Dimension(0, ceil_to_multiple(_Ksize, _block_sizes.k_block), _block_sizes.k_block));
142 window.set(Window::DimZ, Window::Dimension(0, multis));
143
144 INEKernel::configure(window);
145}
146
147template <typename To, bool use_dot>
148void NEGEMMInterleavedPrepareBWrapperKernelTemplate<To, use_dot>::transform(const PrepareBWorkload &wl, const ThreadInfo &info)
149{
150 using strategy = typename Kernel<To, use_dot>::strategy;
151
152 strategy strat(info.cpu_info);
153 strat.transforms.PrepareB(reinterpret_cast<To *>(_transformed_b->buffer() + wl._offset_transformed_b),
154 reinterpret_cast<To *>(_b->buffer() + wl._offset_b),
155 _b->info()->strides_in_bytes().y() / sizeof(To),
156 wl._x0, wl._xmax, wl._k0, wl._kmax, _transpose_b);
157}
158
159template <typename To, bool use_dot>
160void NEGEMMInterleavedPrepareBWrapperKernelTemplate<To, use_dot>::create_workloads(std::vector<PrepareBWorkload> &workloads)
161{
Anthony Barbierff0bccf2018-11-30 10:42:40 +0000162 for_each_element_in_window<To, use_dot, true>(window(), _b, _transformed_b, _Nsize, _Ksize, [&workloads](PrepareBWorkload && wl)
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100163 {
164 workloads.push_back(std::move(wl));
165 });
166}
167
168template <typename To, bool use_dot>
169void NEGEMMInterleavedPrepareBWrapperKernelTemplate<To, use_dot>::run(const Window &window, const ThreadInfo &info)
170{
171 ARM_COMPUTE_ERROR_ON_MISMATCHING_WINDOWS(window, INEKernel::window());
Anthony Barbierff0bccf2018-11-30 10:42:40 +0000172 for_each_element_in_window<To, use_dot, false>(window, _b, _transformed_b, _Nsize, _Ksize, [&](PrepareBWorkload && wl)
Anthony Barbier3d677cc2018-07-23 16:42:59 +0100173 {
174 this->transform(wl, info);
175 });
176}
177
178template class NEGEMMInterleavedPrepareBWrapperKernelTemplate<float>;
179#ifdef __aarch64__
180template class NEGEMMInterleavedPrepareBWrapperKernelTemplate<uint8_t>;
181template class NEGEMMInterleavedPrepareBWrapperKernelTemplate<int8_t>;
182template class NEGEMMInterleavedPrepareBWrapperKernelTemplate<uint8_t, true>;
183template class NEGEMMInterleavedPrepareBWrapperKernelTemplate<int8_t, true>;
184#endif /* __aarch64__ */
185
186#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
187template class NEGEMMInterleavedPrepareBWrapperKernelTemplate<float16_t>;
188#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
189} // namespace arm_compute