blob: fdb5fedc5900e41d4e8a08bdf3bacfe436020881 [file] [log] [blame]
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +01001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
Viet-Hoa Doce3c48c2023-07-03 13:44:43 +010025#ifndef CKW_PROTOTYPE_INCLUDE_CKW_KERNELWRITER_H
26#define CKW_PROTOTYPE_INCLUDE_CKW_KERNELWRITER_H
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010027
28#include "ckw/Kernel.h"
29#include "ckw/TensorInfo.h"
30#include "ckw/TensorOperand.h"
31#include "ckw/TileInfo.h"
32#include "ckw/TileOperand.h"
Nikolaj Jensen5ff48022023-06-27 14:13:24 +010033#include "ckw/types/ConvertPolicy.h"
34#include "ckw/types/Functions.h"
35#include "ckw/types/Operators.h"
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010036
37#include <memory>
38
39namespace ckw
40{
41
42namespace prototype
43{
Viet-Hoa Doe1880f02023-06-28 10:25:35 +010044struct GpuKernelWriterAttribute;
Nikolaj Jensenacea4072023-07-03 09:44:42 +010045
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010046class IGpuKernelWriter;
47} // namespace prototype
48
49/** Kernel writer. */
50class KernelWriter
51{
52public:
53 // =============================================================================================
54 // Constructors and destructor
55 // =============================================================================================
56
57 /** Initialize a new instance of kernel writer.
58 *
59 * @param[in] kernel The kernel to be written to.
60 */
61 explicit KernelWriter(Kernel &kernel);
62
63 /** Destructor */
64 ~KernelWriter();
65
66 /** No copy constructor. */
67 KernelWriter(const KernelWriter &) = delete;
68
69 /** No copy assignment. */
70 KernelWriter &operator=(const KernelWriter &) = delete;
71
72 // =============================================================================================
73 // Scope management
74 // =============================================================================================
75
76 /** Get the current ID space. */
77 int32_t id_space() const;
78
79 /** Set the current ID space. */
80 KernelWriter &id_space(int32_t id_space);
81
82 /** Switch to and return a new ID space. */
83 int32_t next_id_space();
84
85 // =============================================================================================
86 // Tensor and tile declaration
87 // =============================================================================================
88
Nikolaj Jensen5ff48022023-06-27 14:13:24 +010089 /** Declare a tensor argument.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010090 *
Viet-Hoa Doc8e16172023-06-27 14:09:46 +010091 * @param[in] name The name of the tensor.
92 * @param[in] info The tensor info.
93 * @param[in] storage_type The tensor storage type.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010094 *
95 * @return The @ref TensorOperand object.
96 */
Viet-Hoa Doc8e16172023-06-27 14:09:46 +010097 TensorOperand &declare_tensor_argument(const std::string &name, const TensorInfo &info, TensorStorageType storage_type = TensorStorageType::BufferUint8Ptr);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010098
Nikolaj Jensen5ff48022023-06-27 14:13:24 +010099 /** Declare a compile-time constant scalar argument.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100100 *
101 * @param[in] name The name of the tile.
102 * @param[in] value The value of the tile.
103 *
104 * @return The @ref TileOperand object.
105 */
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100106 TileOperand &declare_tile_argument(const std::string &name, int32_t value);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100107
108 /** Declare a new tile.
109 *
110 * The name of the tile must be unique in the current ID space.
111 *
112 * @param[in] name The name of the tile.
113 * @param[in] ... The necessary arguments to create a new @ref TileOperand.
114 *
115 * @return The @ref TileOperand object.
116 */
117 template <typename... TArgs>
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100118 TileOperand &declare_tile(const std::string &name, TArgs &&...args)
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100119 {
120 const auto var_name = generate_variable_name(name);
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100121 auto operand = std::make_unique<TileOperand>(var_name, ::std::forward<TArgs>(args)...);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100122
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100123 return declare_tile_operand(std::move(operand));
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100124 }
125
126 // =============================================================================================
127 // Load and store
128 // =============================================================================================
129
130 /** Load the data from the tensor memory to the tile using the sampling information.
131 *
Jakub Sujake1c96e72023-07-31 13:36:58 +0100132 * @param[out] tile The tile to be loaded.
133 * @param[in] tensor The tensor to be read.
134 * @param[in] sampler The tensor sampling information.
135 * @param[in] dilation_y Dilation in the Y dimension.
136 */
137 void op_load(TileOperand &tile, const TensorOperand &tensor, const TensorTileSampler &sampler, const TileOperand &dilation_y = TileOperand("dil_y", 1));
138
139 /** Load the data from the tensor memory to the tile using the indirect buffer approach and respective of the sampling information.
140 *
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100141 * @param[out] tile The tile to be loaded.
142 * @param[in] tensor The tensor to be read.
143 * @param[in] sampler The tensor sampling information.
144 */
Jakub Sujake1c96e72023-07-31 13:36:58 +0100145 void op_load_indirect(TileOperand &tile, const TensorOperand &tensor, const TensorTileSampler &sampler);
146
147 /** Construct an indirection buffer in @p tile containing the precalculated addresses of elements in the source tensor.
148 *
149 * @param[out] tile The tile to be loaded.
150 * @param[in] tensor The tensor the be read.
151 * @param[in] sampler The tensor sampling information.
152 * @param[in] x The X coordinate.
153 * @param[in] y The Y coordinate.
154 * @param[in] x_off Offset in the X dimension.
155 * @param[in] y_off Offset in the Y dimension.
156 */
157 void util_get_indirect_buffer(TileOperand &tile,
158 const TensorOperand &tensor,
159 const TensorTileSampler &sampler,
160 const TileOperand &x,
161 const TileOperand &y,
162 const TileOperand &x_off,
163 const TileOperand &y_off);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100164
165 /** Store the tile to the tensor using the specified sampling information.
166 *
167 * @param[out] dst The tensor that the tile is written to.
168 * @param[in] src The tile to be stored.
169 * @param[in] sampler The tensor sampling information.
170 */
171 void op_store(TensorOperand &tensor, const TileOperand &tile, const TensorTileSampler &sampler);
172
173 // =============================================================================================
174 // Data processing
175 // =============================================================================================
176
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100177 /** Write assignment: `<dst> = <src>;`.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100178 *
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100179 * @param[out] dst The destination tile.
180 * @param[in] src The source tile.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100181 */
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100182 void op_assign(const TileOperand &dst, const TileOperand &src);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100183
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100184 /** Write the cast: `<dst> = convert_<dst.type><_sat>(<src>);`.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100185 *
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100186 * @param[out] dst The destination tile.
187 * @param[in] src The source tile.
188 * @param[in] policy The policy governing the behavior of the cast.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100189 */
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100190 void op_cast_expression(const TileOperand &dst, const TileOperand &src, ConvertPolicy policy);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100191
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100192 /** Write the unary expression: `<dst> = <op> <src>`.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100193 *
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100194 * @param[out] dst The destination tile.
195 * @param[in] op The unary operator.
196 * @param[in] src The source tile.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100197 */
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100198 void op_unary_expression(const TileOperand &dst, UnaryOp op, const TileOperand &src);
199
200 /** Write binary expression: `<dst> = <lhs> <op> <rhs>;`.
201 *
202 * @param[out] dst The destination tile.
203 * @param[in] lhs The LHS tile.
204 * @param[in] op The binary operator.
205 * @param[in] rhs The RHS tile.
206 */
207 void op_binary_expression(const TileOperand &dst, const TileOperand &lhs, BinaryOp op, const TileOperand &rhs);
208
209 /** Write function applied to scalar value: `<dst> = <func>(<src>);`.
210 *
211 * @param[out] dst The destination tile.
212 * @param[in] func The function to be applied to the source tile.
213 * @param[in] src The source tile.
214 */
215 void op_unary_elementwise_function(const TileOperand &dst, UnaryFunction func, const TileOperand &src);
216
217 /** Write function applied to scalar value: `<dst> = <func>(<first>, <second>);`.
218 *
219 * @param[out] dst The destination tile.
220 * @param[in] func The function to be applied to the source tiles.
221 * @param[in] first The first argument tile.
222 * @param[in] second The second argument tile.
223 */
224 void op_binary_elementwise_function(const TileOperand &dst, BinaryFunction func, const TileOperand &first, const TileOperand &second);
225
226 /** Write function applied to scalar value: `<dst> = <func>(<first>, <second>, <third>);`.
227 *
228 * @param[out] dst The destination tile.
229 * @param[in] func The function to be applied to the source tiles.
230 * @param[in] first The first argument tile.
231 * @param[in] second The second argument tile.
232 * @param[in] third The third argument tile.
233 */
234 void op_ternary_elementwise_function(const TileOperand &dst, TernaryFunction func, const TileOperand &first, const TileOperand &second, const TileOperand &third);
235
236 /** Write if-statement: `if(<lhs> <op> <rhs>) { <body> }`.
237 *
238 * @param[in] lhs The LHS tile of the condition.
239 * @param[in] op The relational binary operator.
240 * @param[in] rhs The RHS tile of the condition.
241 * @param[in] body The body of the if-statement.
242 */
243 void op_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body);
244
245 /** Write else-if-statement: `else if(<lhs> <op> <rhs>) { <body> }`.
246 *
247 * @param[in] lhs The LHS tile of the condition.
248 * @param[in] op The relational binary operator.
249 * @param[in] rhs The RHS tile of the condition.
250 * @param[in] body The body of the else-if-statement.
251 */
252 void op_else_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body);
253
254 /** Write an else-statement: `else { <body> }`.
255 *
256 * @param[in] body The body of the else-statement.
257 */
258 void op_else(const std::function<void()> &body);
259
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100260 /** Write for-loops: `for(; <var> <cond_op> <cond_value>; <var> <update_op> <update_value>) { body }`.
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100261 *
262 * @param[in] var_name The name of the variable used in condition.
263 * @param[in] cond_op The relational binary operator used in condition.
264 * @param[in] cond_value_name The value which the variable is compared against.
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100265 * @param[in] update_var_name The name of the variable which is updated.
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100266 * @param[in] update_op The assignment operator used for updating the update value.
267 * @param[in, out] update_value The value which is updated at every iteration.
268 * @param[in] body The body of the for-loop.
269 */
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100270 void op_for_loop(const TileOperand &var_name, BinaryOp cond_op, const TileOperand &cond_value_name, const TileOperand &update_var_name, AssignmentOp update_op, const TileOperand &update_value_name, const std::function<void()> &body);
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100271
272 /** Write the return statement: `return;`
273 */
274 void op_return();
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100275
276 // =============================================================================================
277 // Misc
278 // =============================================================================================
279
280 /** Set `dst` the global ID of dimension `dim`.
281 *
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100282 * @param[out] dst The tile to be written to.
283 * @param[in] dim The global ID dimension.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100284 */
Gunes Bayir91cb7332023-07-25 17:00:33 +0100285 void op_get_global_id(const TileOperand &dst, int32_t dim);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100286
287 // =============================================================================================
288 // Code generation
289 // =============================================================================================
290
291 /** Generate the source code of the kernel. */
292 ::std::string generate_code();
293
294private:
295 /** Generate the full variable name based on the original name and the ID space.
296 *
297 * @param[in] name The name of the variable.
298 *
299 * @return The full variable name.
300 */
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100301 ::std::string generate_variable_name(const std::string &name) const;
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100302
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100303 /** Declare the tile operand.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100304 *
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100305 * @param[in] operand The tile operand to be declared.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100306 */
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100307 TileOperand &declare_tile_operand(std::unique_ptr<TileOperand> operand);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100308
309private:
310 Kernel *_kernel;
311 ::std::unique_ptr<prototype::GpuKernelWriterAttribute> _impl_attr;
312 ::std::unique_ptr<prototype::IGpuKernelWriter> _impl;
313
314 int32_t _id_space{ 0 };
315 int32_t _max_id_space{ 0 };
316};
317
318} // namespace ckw
319
Viet-Hoa Doce3c48c2023-07-03 13:44:43 +0100320#endif // CKW_PROTOTYPE_INCLUDE_CKW_KERNELWRITER_H