blob: 6c201cedb38fb0e305986a44670daab9c26fbe76 [file] [log] [blame]
Anthony Barbier3d677cc2018-07-23 16:42:59 +01001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.h"
26
27#include "NEGEMMInterleavedStrategies.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/ITensor.h"
30#include "arm_compute/core/Utils.h"
31#include "arm_compute/core/Validate.h"
32
33namespace arm_compute
34{
35namespace
36{
37// Call the lambda function for each workload generated by the passed window.
38template <typename To, bool use_dot, typename Lambda>
39void for_each_element_in_window(const Window &window, const ITensor *b, ITensor *transformed_b, unsigned int N, unsigned int K, Lambda &&lambda)
40{
41 using strategy = typename Kernel<To, use_dot>::strategy;
42
43 unsigned int offset_transformed_b = transformed_b->info()->offset_first_element_in_bytes();
44 execute_window_loop(window, [&](const Coordinates & coordinates)
45 {
46 const unsigned int x0 = coordinates.x();
47 const unsigned int k0 = coordinates.y();
48 const unsigned int multi = coordinates.z();
49
50 const unsigned int offset_b = b->info()->offset_element_in_bytes(Coordinates(0, 0, multi));
51 const unsigned int xmax = std::min(x0 + window.x().step(), N);
52 const unsigned int kmax = std::min(k0 + window.y().step(), K);
53
54 /* Figure out the size of each block. */
55 unsigned int x_size = (xmax - x0);
56 unsigned int k_size = (kmax - k0);
57
58 /* Round sizes up as needed. */
59 x_size = ceil_to_multiple(x_size, strategy::out_width());
60 k_size = ceil_to_multiple(k_size, strategy::k_unroll());
61
62 lambda(PrepareBWorkload(offset_b, offset_transformed_b, x0, xmax, k0, kmax));
63
64 //Each workload represents one block:
65 offset_transformed_b += (x_size * k_size * sizeof(To));
66 });
67}
68
69// Calculate the size of transformed_b:
70template <typename To, bool use_dot>
71unsigned int get_B_pretransposed_array_size(unsigned int N, unsigned int K, const BlockSizes &bs)
72{
73 using strategy = typename Kernel<To, use_dot>::strategy;
74
75 // How many full blocks do N / K contain ?
76 size_t num_full_k = K / bs.k_block;
77 size_t num_full_x = N / bs.x_block;
78
79 ARM_COMPUTE_ERROR_ON(bs.x_block % strategy::out_width() != 0);
80 ARM_COMPUTE_ERROR_ON(bs.k_block % strategy::k_unroll() != 0);
81
82 size_t normal_x_size = bs.x_block;
83 size_t normal_k_size = bs.k_block;
84
85 // Round up the leftovers to be a multiple of the strategy processing size:
86 size_t left_over_x_size = ceil_to_multiple(N % bs.x_block, strategy::out_width());
87 size_t left_over_k_size = ceil_to_multiple(K % bs.k_block, strategy::k_unroll());
88
89 // Calculate the total size of the buffer:
90 size_t total = num_full_k * normal_k_size * (num_full_x * normal_x_size + left_over_x_size);
91 total += left_over_k_size * (left_over_x_size + num_full_x * normal_x_size);
Anthony Barbier3d677cc2018-07-23 16:42:59 +010092 return total;
93}
94
95} // namespace
96
97template <typename To, bool use_dot>
98BlockSizes NEGEMMInterleavedPrepareBWrapperKernelTemplate<To, use_dot>::block_sizes() const
99{
100 return _block_sizes;
101}
102
103template <typename To, bool use_dot>
104void NEGEMMInterleavedPrepareBWrapperKernelTemplate<To, use_dot>::configure(const ITensor *b, ITensor *transformed_b, bool transpose_b, const CPUInfo &ci, const INEGEMMWrapperKernel::Params &params)
105{
106 using strategy = typename Kernel<To, use_dot>::strategy;
107
108 const unsigned int multis = b->info()->tensor_shape().z();
109 _Nsize = b->info()->tensor_shape().x();
110 _Ksize = b->info()->tensor_shape().y();
111 _b = b;
112 _transformed_b = transformed_b;
113 _transpose_b = transpose_b;
114
115 _block_sizes = calculate_block_sizes<strategy>(ci, params.M, params.N, params.K);
116
117 auto_init_if_empty(*transformed_b->info(), b->info()->clone()->set_tensor_shape(TensorShape{ get_B_pretransposed_array_size<To, use_dot>(_Nsize, _Ksize, _block_sizes) }));
118
119 Window window;
120 window.set(Window::DimX, Window::Dimension(0, ceil_to_multiple(_Nsize, _block_sizes.x_block), _block_sizes.x_block));
121 window.set(Window::DimY, Window::Dimension(0, ceil_to_multiple(_Ksize, _block_sizes.k_block), _block_sizes.k_block));
122 window.set(Window::DimZ, Window::Dimension(0, multis));
123
124 INEKernel::configure(window);
125}
126
127template <typename To, bool use_dot>
128void NEGEMMInterleavedPrepareBWrapperKernelTemplate<To, use_dot>::transform(const PrepareBWorkload &wl, const ThreadInfo &info)
129{
130 using strategy = typename Kernel<To, use_dot>::strategy;
131
132 strategy strat(info.cpu_info);
133 strat.transforms.PrepareB(reinterpret_cast<To *>(_transformed_b->buffer() + wl._offset_transformed_b),
134 reinterpret_cast<To *>(_b->buffer() + wl._offset_b),
135 _b->info()->strides_in_bytes().y() / sizeof(To),
136 wl._x0, wl._xmax, wl._k0, wl._kmax, _transpose_b);
137}
138
139template <typename To, bool use_dot>
140void NEGEMMInterleavedPrepareBWrapperKernelTemplate<To, use_dot>::create_workloads(std::vector<PrepareBWorkload> &workloads)
141{
142 for_each_element_in_window<To, use_dot>(window(), _b, _transformed_b, _Nsize, _Ksize, [&workloads](PrepareBWorkload && wl)
143 {
144 workloads.push_back(std::move(wl));
145 });
146}
147
148template <typename To, bool use_dot>
149void NEGEMMInterleavedPrepareBWrapperKernelTemplate<To, use_dot>::run(const Window &window, const ThreadInfo &info)
150{
151 ARM_COMPUTE_ERROR_ON_MISMATCHING_WINDOWS(window, INEKernel::window());
152 for_each_element_in_window<To, use_dot>(window, _b, _transformed_b, _Nsize, _Ksize, [&](PrepareBWorkload && wl)
153 {
154 this->transform(wl, info);
155 });
156}
157
158template class NEGEMMInterleavedPrepareBWrapperKernelTemplate<float>;
159#ifdef __aarch64__
160template class NEGEMMInterleavedPrepareBWrapperKernelTemplate<uint8_t>;
161template class NEGEMMInterleavedPrepareBWrapperKernelTemplate<int8_t>;
162template class NEGEMMInterleavedPrepareBWrapperKernelTemplate<uint8_t, true>;
163template class NEGEMMInterleavedPrepareBWrapperKernelTemplate<int8_t, true>;
164#endif /* __aarch64__ */
165
166#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
167template class NEGEMMInterleavedPrepareBWrapperKernelTemplate<float16_t>;
168#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
169} // namespace arm_compute