blob: f9e0066f91be19b0ce5403dda0750e1f6efe7518 [file] [log] [blame]
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +01001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
Viet-Hoa Doce3c48c2023-07-03 13:44:43 +010025#ifndef CKW_PROTOTYPE_INCLUDE_CKW_KERNELWRITER_H
26#define CKW_PROTOTYPE_INCLUDE_CKW_KERNELWRITER_H
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010027
28#include "ckw/Kernel.h"
29#include "ckw/TensorInfo.h"
30#include "ckw/TensorOperand.h"
31#include "ckw/TileInfo.h"
32#include "ckw/TileOperand.h"
Nikolaj Jensen5ff48022023-06-27 14:13:24 +010033#include "ckw/types/ConvertPolicy.h"
34#include "ckw/types/Functions.h"
35#include "ckw/types/Operators.h"
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010036
37#include <memory>
38
39namespace ckw
40{
41
42namespace prototype
43{
Viet-Hoa Doe1880f02023-06-28 10:25:35 +010044struct GpuKernelWriterAttribute;
Nikolaj Jensenacea4072023-07-03 09:44:42 +010045
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010046class IGpuKernelWriter;
47} // namespace prototype
48
49/** Kernel writer. */
50class KernelWriter
51{
52public:
53 // =============================================================================================
54 // Constructors and destructor
55 // =============================================================================================
56
57 /** Initialize a new instance of kernel writer.
58 *
59 * @param[in] kernel The kernel to be written to.
60 */
61 explicit KernelWriter(Kernel &kernel);
62
63 /** Destructor */
64 ~KernelWriter();
65
66 /** No copy constructor. */
67 KernelWriter(const KernelWriter &) = delete;
68
69 /** No copy assignment. */
70 KernelWriter &operator=(const KernelWriter &) = delete;
71
72 // =============================================================================================
73 // Scope management
74 // =============================================================================================
75
76 /** Get the current ID space. */
77 int32_t id_space() const;
78
79 /** Set the current ID space. */
80 KernelWriter &id_space(int32_t id_space);
81
82 /** Switch to and return a new ID space. */
83 int32_t next_id_space();
84
85 // =============================================================================================
86 // Tensor and tile declaration
87 // =============================================================================================
88
Nikolaj Jensen5ff48022023-06-27 14:13:24 +010089 /** Declare a tensor argument.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010090 *
Viet-Hoa Doc8e16172023-06-27 14:09:46 +010091 * @param[in] name The name of the tensor.
92 * @param[in] info The tensor info.
93 * @param[in] storage_type The tensor storage type.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010094 *
95 * @return The @ref TensorOperand object.
96 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010097 TensorOperand &declare_tensor_argument(const std::string &name,
98 const TensorInfo &info,
99 TensorStorageType storage_type = TensorStorageType::BufferUint8Ptr);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100100
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100101 /** Declare a compile-time constant scalar argument.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100102 *
103 * @param[in] name The name of the tile.
104 * @param[in] value The value of the tile.
105 *
106 * @return The @ref TileOperand object.
107 */
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100108 TileOperand &declare_tile_argument(const std::string &name, int32_t value);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100109
110 /** Declare a new tile.
111 *
112 * The name of the tile must be unique in the current ID space.
113 *
114 * @param[in] name The name of the tile.
115 * @param[in] ... The necessary arguments to create a new @ref TileOperand.
116 *
117 * @return The @ref TileOperand object.
118 */
119 template <typename... TArgs>
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100120 TileOperand &declare_tile(const std::string &name, TArgs &&...args)
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100121 {
122 const auto var_name = generate_variable_name(name);
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100123 auto operand = std::make_unique<TileOperand>(var_name, ::std::forward<TArgs>(args)...);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100124
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100125 return declare_tile_operand(std::move(operand));
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100126 }
127
128 // =============================================================================================
129 // Load and store
130 // =============================================================================================
131
132 /** Load the data from the tensor memory to the tile using the sampling information.
133 *
Jakub Sujake1c96e72023-07-31 13:36:58 +0100134 * @param[out] tile The tile to be loaded.
135 * @param[in] tensor The tensor to be read.
136 * @param[in] sampler The tensor sampling information.
137 * @param[in] dilation_y Dilation in the Y dimension.
138 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100139 void op_load(TileOperand &tile,
140 const TensorOperand &tensor,
141 const TensorTileSampler &sampler,
142 const TileOperand &dilation_y = TileOperand("dil_y", 1));
Jakub Sujake1c96e72023-07-31 13:36:58 +0100143
144 /** Load the data from the tensor memory to the tile using the indirect buffer approach and respective of the sampling information.
145 *
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100146 * @param[out] tile The tile to be loaded.
147 * @param[in] tensor The tensor to be read.
148 * @param[in] sampler The tensor sampling information.
149 */
Jakub Sujake1c96e72023-07-31 13:36:58 +0100150 void op_load_indirect(TileOperand &tile, const TensorOperand &tensor, const TensorTileSampler &sampler);
151
152 /** Construct an indirection buffer in @p tile containing the precalculated addresses of elements in the source tensor.
153 *
154 * @param[out] tile The tile to be loaded.
155 * @param[in] tensor The tensor the be read.
156 * @param[in] sampler The tensor sampling information.
157 * @param[in] x The X coordinate.
158 * @param[in] y The Y coordinate.
159 * @param[in] x_off Offset in the X dimension.
160 * @param[in] y_off Offset in the Y dimension.
161 */
162 void util_get_indirect_buffer(TileOperand &tile,
163 const TensorOperand &tensor,
164 const TensorTileSampler &sampler,
165 const TileOperand &x,
166 const TileOperand &y,
167 const TileOperand &x_off,
168 const TileOperand &y_off);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100169
170 /** Store the tile to the tensor using the specified sampling information.
171 *
172 * @param[out] dst The tensor that the tile is written to.
173 * @param[in] src The tile to be stored.
174 * @param[in] sampler The tensor sampling information.
175 */
176 void op_store(TensorOperand &tensor, const TileOperand &tile, const TensorTileSampler &sampler);
177
178 // =============================================================================================
179 // Data processing
180 // =============================================================================================
181
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100182 /** Write assignment: `<dst> = <src>;`.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100183 *
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100184 * @param[out] dst The destination tile.
185 * @param[in] src The source tile.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100186 */
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100187 void op_assign(const TileOperand &dst, const TileOperand &src);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100188
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100189 /** Write the cast: `<dst> = convert_<dst.type><_sat>(<src>);`.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100190 *
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100191 * @param[out] dst The destination tile.
192 * @param[in] src The source tile.
193 * @param[in] policy The policy governing the behavior of the cast.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100194 */
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100195 void op_cast_expression(const TileOperand &dst, const TileOperand &src, ConvertPolicy policy);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100196
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100197 /** Write the unary expression: `<dst> = <op> <src>`.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100198 *
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100199 * @param[out] dst The destination tile.
200 * @param[in] op The unary operator.
201 * @param[in] src The source tile.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100202 */
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100203 void op_unary_expression(const TileOperand &dst, UnaryOp op, const TileOperand &src);
204
205 /** Write binary expression: `<dst> = <lhs> <op> <rhs>;`.
206 *
207 * @param[out] dst The destination tile.
208 * @param[in] lhs The LHS tile.
209 * @param[in] op The binary operator.
210 * @param[in] rhs The RHS tile.
211 */
212 void op_binary_expression(const TileOperand &dst, const TileOperand &lhs, BinaryOp op, const TileOperand &rhs);
213
214 /** Write function applied to scalar value: `<dst> = <func>(<src>);`.
215 *
216 * @param[out] dst The destination tile.
217 * @param[in] func The function to be applied to the source tile.
218 * @param[in] src The source tile.
219 */
220 void op_unary_elementwise_function(const TileOperand &dst, UnaryFunction func, const TileOperand &src);
221
222 /** Write function applied to scalar value: `<dst> = <func>(<first>, <second>);`.
223 *
224 * @param[out] dst The destination tile.
225 * @param[in] func The function to be applied to the source tiles.
226 * @param[in] first The first argument tile.
227 * @param[in] second The second argument tile.
228 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100229 void op_binary_elementwise_function(const TileOperand &dst,
230 BinaryFunction func,
231 const TileOperand &first,
232 const TileOperand &second);
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100233
234 /** Write function applied to scalar value: `<dst> = <func>(<first>, <second>, <third>);`.
235 *
236 * @param[out] dst The destination tile.
237 * @param[in] func The function to be applied to the source tiles.
238 * @param[in] first The first argument tile.
239 * @param[in] second The second argument tile.
240 * @param[in] third The third argument tile.
241 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100242 void op_ternary_elementwise_function(const TileOperand &dst,
243 TernaryFunction func,
244 const TileOperand &first,
245 const TileOperand &second,
246 const TileOperand &third);
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100247
248 /** Write if-statement: `if(<lhs> <op> <rhs>) { <body> }`.
249 *
250 * @param[in] lhs The LHS tile of the condition.
251 * @param[in] op The relational binary operator.
252 * @param[in] rhs The RHS tile of the condition.
253 * @param[in] body The body of the if-statement.
254 */
255 void op_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body);
256
257 /** Write else-if-statement: `else if(<lhs> <op> <rhs>) { <body> }`.
258 *
259 * @param[in] lhs The LHS tile of the condition.
260 * @param[in] op The relational binary operator.
261 * @param[in] rhs The RHS tile of the condition.
262 * @param[in] body The body of the else-if-statement.
263 */
264 void op_else_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body);
265
266 /** Write an else-statement: `else { <body> }`.
267 *
268 * @param[in] body The body of the else-statement.
269 */
270 void op_else(const std::function<void()> &body);
271
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100272 /** Write for-loops: `for(; <var> <cond_op> <cond_value>; <var> <update_op> <update_value>) { body }`.
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100273 *
274 * @param[in] var_name The name of the variable used in condition.
275 * @param[in] cond_op The relational binary operator used in condition.
276 * @param[in] cond_value_name The value which the variable is compared against.
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100277 * @param[in] update_var_name The name of the variable which is updated.
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100278 * @param[in] update_op The assignment operator used for updating the update value.
279 * @param[in, out] update_value The value which is updated at every iteration.
280 * @param[in] body The body of the for-loop.
281 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100282 void op_for_loop(const TileOperand &var_name,
283 BinaryOp cond_op,
284 const TileOperand &cond_value_name,
285 const TileOperand &update_var_name,
286 AssignmentOp update_op,
287 const TileOperand &update_value_name,
288 const std::function<void()> &body);
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100289
290 /** Write the return statement: `return;`
291 */
292 void op_return();
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100293
294 // =============================================================================================
295 // Misc
296 // =============================================================================================
297
298 /** Set `dst` the global ID of dimension `dim`.
299 *
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100300 * @param[out] dst The tile to be written to.
301 * @param[in] dim The global ID dimension.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100302 */
Gunes Bayir91cb7332023-07-25 17:00:33 +0100303 void op_get_global_id(const TileOperand &dst, int32_t dim);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100304
305 // =============================================================================================
306 // Code generation
307 // =============================================================================================
308
309 /** Generate the source code of the kernel. */
310 ::std::string generate_code();
311
312private:
313 /** Generate the full variable name based on the original name and the ID space.
314 *
315 * @param[in] name The name of the variable.
316 *
317 * @return The full variable name.
318 */
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100319 ::std::string generate_variable_name(const std::string &name) const;
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100320
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100321 /** Declare the tile operand.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100322 *
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100323 * @param[in] operand The tile operand to be declared.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100324 */
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100325 TileOperand &declare_tile_operand(std::unique_ptr<TileOperand> operand);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100326
327private:
328 Kernel *_kernel;
329 ::std::unique_ptr<prototype::GpuKernelWriterAttribute> _impl_attr;
330 ::std::unique_ptr<prototype::IGpuKernelWriter> _impl;
331
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100332 int32_t _id_space{0};
333 int32_t _max_id_space{0};
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100334};
335
336} // namespace ckw
337
Viet-Hoa Doce3c48c2023-07-03 13:44:43 +0100338#endif // CKW_PROTOTYPE_INCLUDE_CKW_KERNELWRITER_H