blob: 0d739e859ab20b8c84bf9a05274d4dd68f171a04 [file] [log] [blame]
Viet-Hoa Do3389f532023-07-05 17:36:40 +01001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#ifndef CKW_INCLUDE_CKW_KERNELWRITER_H
26#define CKW_INCLUDE_CKW_KERNELWRITER_H
27
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +010028#include "ckw/TensorOperand.h"
Gunes Bayir3c776062023-07-12 14:50:56 +010029#include "ckw/TileOperand.h"
Gunes Bayir806b8e82023-08-23 23:28:31 +010030#include "ckw/types/ConstantData.h"
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +010031#include "ckw/types/ConvertPolicy.h"
32#include "ckw/types/Operators.h"
Gunes Bayirab0b7502023-07-11 14:57:36 +010033
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +010034#include <functional>
Viet-Hoa Do3389f532023-07-05 17:36:40 +010035#include <memory>
36#include <string>
Viet-Hoa Docd1f03e2023-09-19 16:41:34 +010037#include <tuple>
Viet-Hoa Do3389f532023-07-05 17:36:40 +010038
39namespace ckw
40{
41
Gunes Bayirab0b7502023-07-11 14:57:36 +010042/** Forward Declerations */
Gunes Bayir806b8e82023-08-23 23:28:31 +010043class Kernel;
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +010044class TensorInfo;
Gunes Bayir47a396e2023-08-17 11:04:02 +010045class TensorSampler;
Viet-Hoa Docd1f03e2023-09-19 16:41:34 +010046class TileArea;
Gunes Bayirab0b7502023-07-11 14:57:36 +010047class TileInfo;
Gunes Bayir47a396e2023-08-17 11:04:02 +010048
Gunes Bayir806b8e82023-08-23 23:28:31 +010049enum class DataType;
Gunes Bayirab0b7502023-07-11 14:57:36 +010050enum class TargetArchitecture;
51enum class TargetLanguage;
52
Viet-Hoa Do3389f532023-07-05 17:36:40 +010053/** A kernel writer.
54 *
55 * This class is used to construct a new kernel by defining arguments, declaring variable and writing code.
56 *
57 * Use @ref KernelWriter::create_instance method to create the kernel writer for the specific target architecture and language.
58 *
59 * After having finished constructing the kernel, call @ref KernelWriter::emit_kernel to get the kernel object.
60 */
61class KernelWriter
62{
63public:
64 // =============================================================================================
65 // Construtors and destructor
66 // =============================================================================================
67
68 /** Initialize a new instance of @ref KernelWriter class for the specific architecture and language.
69 *
70 * Supported target architectures and languages:
71 *
72 * Architecture | Languages |
73 * ------------------------------|------------------------------|
74 * GpuArmMaliValhall | OpenCL |
75 *
76 * @param[in] architecture The architecture on which the kernel is executed.
77 * @param[in] language The language to write the kernel.
78 */
79 static std::unique_ptr<KernelWriter> create_instance(TargetArchitecture architecture, TargetLanguage language);
80
81 /** Destructor */
82 virtual ~KernelWriter();
83
84 // =============================================================================================
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +010085 // Data processing
86 // =============================================================================================
87
88 /** Write assignment statement: `<dst> = <src>;`.
89 *
90 * @param[in] dst The destination tile.
91 * @param[in] src The source tile.
92 */
93 virtual void op_assign(const TileOperand &dst, const TileOperand &src) = 0;
94
95 /** Write the cast statement: `<dst> = convert_<dst.type><policy>(<src>);`.
96 *
97 * @param[in] dst The destination tile.
98 * @param[in] src The source tile.
99 * @param[in] policy The policy governing the behavior of the cast.
100 */
101 virtual void op_cast(const TileOperand &dst, const TileOperand &src, ConvertPolicy policy) = 0;
102
103 /** Write the unary expression statement: `<dst> = <op> <src>;`.
104 *
105 * @param[in] dst The destination tile.
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +0100106 * @param[in] op The unary operator.
Viet-Hoa Do34b6c3a2023-08-22 11:11:23 +0100107 * @param[in] src The source tile.
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +0100108 */
Viet-Hoa Do34b6c3a2023-08-22 11:11:23 +0100109 virtual void op_unary(const TileOperand &dst, UnaryOp op, const TileOperand &src) = 0;
110
111 /** Write the binary expression statement: `<dst> = <op>(<first>, <second>);`.
112 *
113 * @param[in] dst The destination tile.
114 * @param[in] op The binary operator.
115 * @param[in] first The first source tile.
116 * @param[in] second The second source tile.
117 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100118 virtual void
119 op_binary(const TileOperand &dst, BinaryOp op, const TileOperand &first, const TileOperand &second) = 0;
Viet-Hoa Do34b6c3a2023-08-22 11:11:23 +0100120
121 /** Write ternary expression statement: `<dst> = <op>(<first>, <second>, <third>);`.
122 *
123 * @param[in] dst The destination tile.
124 * @param[in] op The ternary operator.
125 * @param[in] first The first source tile.
126 * @param[in] second The second source tile.
127 * @param[in] third The third source tile.
128 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100129 virtual void op_ternary(const TileOperand &dst,
130 TernaryOp op,
131 const TileOperand &first,
132 const TileOperand &second,
133 const TileOperand &third) = 0;
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +0100134
135 // =============================================================================================
Viet-Hoa Do2d0c2f52023-08-24 11:48:19 +0100136 // Flow control
137 // =============================================================================================
138
139 /** Write if block: `if(<lhs> <op> <rhs>) { <body> }`.
140 *
141 * @param[in] lhs The LHS tile of the condition.
142 * @param[in] op The relational binary operator.
143 * @param[in] rhs The RHS tile of the condition.
144 * @param[in] body The function that writes the body of the if block.
145 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100146 virtual void
147 op_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body) = 0;
Viet-Hoa Do2d0c2f52023-08-24 11:48:19 +0100148
149 /** Write else-if block: `else if(<lhs> <op> <rhs>) { <body> }`.
150 *
151 * @param[in] lhs The LHS tile of the condition.
152 * @param[in] op The relational binary operator.
153 * @param[in] rhs The RHS tile of the condition.
154 * @param[in] body The function that writes the body of the else-if block.
155 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100156 virtual void
157 op_else_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body) = 0;
Viet-Hoa Do2d0c2f52023-08-24 11:48:19 +0100158
159 /** Write an else block: `else { <body> }`.
160 *
161 * @param[in] body The function that writes the body of the else block.
162 */
163 virtual void op_else(const std::function<void()> &body) = 0;
164
165 /** Write for-loop block: `for(; <var> <cond_op> <cond_value>; <update_var> <update_op> <update_value>) { body }`.
166 *
167 * @param[in] var The scalar tile used in loop condition.
168 * @param[in] cond_op The relational binary operator used in loop condition.
169 * @param[in] cond_value The value which the variable is compared against.
170 * @param[in] update_var The scalar tile which is updated each iteration.
171 * @param[in] update_op The assignment operator used for updating the update value.
172 * @param[in] update_value The value which is updated at every iteration.
173 * @param[in] body The function that writes the body of the for-loop block.
174 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100175 virtual void op_for_loop(const TileOperand &var,
176 BinaryOp cond_op,
177 const TileOperand &cond_value,
178 const TileOperand &update_var,
179 AssignmentOp update_op,
180 const TileOperand &update_value,
181 const std::function<void()> &body) = 0;
Viet-Hoa Do2d0c2f52023-08-24 11:48:19 +0100182
183 /** Write the return statement. */
184 virtual void op_return() = 0;
185
186 // =============================================================================================
Viet-Hoa Do3389f532023-07-05 17:36:40 +0100187 // Misc
188 // =============================================================================================
189
Viet-Hoa Dod0d8f2e2023-08-29 16:01:13 +0100190 /** Write the statement to get the global ID of the specified dimension.
191 *
192 * @param[in] dst The tile to write the global ID into.
193 * @param[in] dim The dimension.
194 */
195 virtual void op_get_global_id(const TileOperand &dst, int32_t dim) = 0;
196
Viet-Hoa Do3389f532023-07-05 17:36:40 +0100197 /** Write the line comment in debug build.
Gunes Bayirab0b7502023-07-11 14:57:36 +0100198 *
Viet-Hoa Do3389f532023-07-05 17:36:40 +0100199 * This function does not take effect on release build.
200 *
201 * The comment must only contain one line (i.e. no newline character is allowed).
202 *
203 * @param[in] text The comment to be written.
204 */
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +0100205 virtual void op_comment(const std::string &text) = 0;
206
Viet-Hoa Dod0d8f2e2023-08-29 16:01:13 +0100207 /** Write the statement to print out the value of all the specified tiles.
208 *
209 * The printing statement is constructed so that the prefix and each of the operand are printed in separate lines.
210 * The format for each operand varies depending on whether it is a 2D tile, a vector or a scalar value.
211 *
212 * Example output of the printing statement when it is executed:
213 *
214 * prefix
215 * scalar_name = scalar_value
216 * vector_name = [vector_value_0, vector_value_1, vector_value_2]
217 * tile_name = [[tile_value_00, tile_value_01], [tile_value_10, tile_value_11]]
218 *
219 * @param[in] prefix The first string to be printed out before the list of operands.
220 * @param[in] operands The list of tiles to be included in the printing statement.
221 */
222 virtual void op_print(const std::string &prefix, const std::vector<TileOperand> &operands) = 0;
223
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +0100224 /** Write the given raw code to kernel source code
225 * It's used to address the cases where the user needs to
226 * explicitly add a code where it's not (yet) supported by
227 * the kernel writer utility calls.
228 *
229 * @param[in] raw_code raw code to write as string
230 */
231 virtual void op_write_raw_code(const std::string &raw_code) = 0;
Viet-Hoa Do3389f532023-07-05 17:36:40 +0100232
233 // =============================================================================================
234 // Code generation
235 // =============================================================================================
236
237 /** Emit the kernel object.
238 *
239 * @param[in] name The name of the kernel object to be generated.
240 */
241 virtual std::unique_ptr<Kernel> emit_kernel(const std::string &name) = 0;
Gunes Bayirab0b7502023-07-11 14:57:36 +0100242
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +0100243 // =============================================================================================
244 // Tensor and tile declaration
245 // =============================================================================================
246
247 /** Declare a tensor argument.
248 *
249 * @param[in] name The name of the tensor.
250 * @param[in] info The tensor info.
251 *
252 * @return The @ref TensorOperand object.
253 */
254 virtual TensorOperand declare_tensor_argument(const std::string &name, const TensorInfo &info) = 0;
255
Gunes Bayirab0b7502023-07-11 14:57:36 +0100256 /** Declare a tile given its name and tile info
257 *
258 * @param[in] name Name of the tile
259 * @param[in] tile_info Shape and data type of the tile
260 *
Gunes Bayir806b8e82023-08-23 23:28:31 +0100261 * @return The created tile operand
Gunes Bayirab0b7502023-07-11 14:57:36 +0100262 */
Viet-Hoa Do25d26f42023-07-20 17:31:47 +0100263 virtual TileOperand declare_tile(const std::string &name, const TileInfo &tile_info) = 0;
Gunes Bayirab0b7502023-07-11 14:57:36 +0100264
Gunes Bayir806b8e82023-08-23 23:28:31 +0100265 /** Declare a constant tile given a @ref:ConstantData object
266 *
267 * @param[in] data a @ref ckw::ConstantData object that has the values and the
268 * underlying data type of the constant tile
269 *
270 * @return The created constant tile operand
271 */
272 virtual TileOperand declare_constant_tile(const ConstantData &data) = 0;
273
Gunes Bayir47a396e2023-08-17 11:04:02 +0100274 /** Load the data from the tensor memory to the tile using the sampling information.
275 *
276 * @param[in] tile_op The tile to be loaded.
277 * @param[in] tensor_op The tensor to be read.
278 * @param[in] sampler The tensor sampling information.
279 * @param[in] x x-coordinate
280 * @param[in] y y-coordinate
281 * @param[in] z z-coordinate
Gunes Bayird5f9a1c2023-08-17 11:04:02 +0100282 * @param[in] batch batch
Gunes Bayir47a396e2023-08-17 11:04:02 +0100283 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100284 virtual void op_load(const TileOperand &tile_op,
285 const TensorOperand &tensor_op,
286 TensorSampler &sampler,
287 const TileOperand &x,
288 const TileOperand &y,
289 const TileOperand &z,
290 const TileOperand &batch) = 0;
Gunes Bayir47a396e2023-08-17 11:04:02 +0100291
292 /** Load the data from the tensor memory to the tile in a dilated way using the sampling information.
293 *
294 * Similar to @ref KernelWriter::op_load() and
295 *
296 * @param[in] dilation_x Dilation while reading in x-dimension
297 * @param[in] dilation_y Dilation while reading in y-dimension
298 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100299 virtual void op_load_dilated(const TileOperand &tile_op,
300 const TensorOperand &tensor_op,
301 TensorSampler &sampler,
302 const TileOperand &x,
303 const TileOperand &y,
304 const TileOperand &z,
305 const TileOperand &batch,
306 const TileOperand &dilation_x,
307 const TileOperand &dilation_y) = 0;
Gunes Bayir47a396e2023-08-17 11:04:02 +0100308
309 /** Store the data to the tensor memory from the tile using the sampling information.
310 *
311 * Similar to @ref KernelWriter::op_load()
312 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100313 virtual void op_store(const TensorOperand &tensor_op,
314 const TileOperand &tile_op,
315 TensorSampler &sampler,
316 const TileOperand &x,
317 const TileOperand &y,
318 const TileOperand &z,
319 const TileOperand &batch) = 0;
Gunes Bayir47a396e2023-08-17 11:04:02 +0100320
321 /** Store the data to the tensor memory from the tile in a dilated way using the sampling information.
322 *
323 * Similar to @ref KernelWriter::op_load_dilated()
324 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100325 virtual void op_store_dilated(const TensorOperand &tensor_op,
326 const TileOperand &tile_op,
327 TensorSampler &sampler,
328 const TileOperand &x,
329 const TileOperand &y,
330 const TileOperand &z,
331 const TileOperand &batch,
332 const TileOperand &dilation_x,
333 const TileOperand &dilation_y) = 0;
Gunes Bayir47a396e2023-08-17 11:04:02 +0100334
Gunes Bayird5f9a1c2023-08-17 11:04:02 +0100335 /** Load the data from the tensor memory to the tile using the indirect buffer approach and respecting the sampling information.
336 *
337 * @param[in] tile_op The tile to be loaded.
338 * @param[in] tensor_op The tensor to be read.
339 * @param[in] sampler The tensor sampling information.
340 * @param[in] x x-coordinate
341 * @param[in] y y-coordinate
342 * @param[in] z z-coordinate
343 * @param[in] batch batch
344 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100345 virtual void op_load_indirect(const TileOperand &tile_op,
346 const TensorOperand &tensor_op,
347 TensorSampler &sampler,
348 const TileOperand &x,
349 const TileOperand &y,
350 const TileOperand &z,
351 const TileOperand &batch_op) = 0;
Gunes Bayird5f9a1c2023-08-17 11:04:02 +0100352
Gunes Bayirab0b7502023-07-11 14:57:36 +0100353protected:
Viet-Hoa Do2d0c2f52023-08-24 11:48:19 +0100354 // =============================================================================================
355 // ID space management
356 // =============================================================================================
357
358 /** Create the new unique ID space and return the value.
359 *
360 * This function changes the ID space to a new number which hasn't been used since the creation
361 * of this kernel writer object.
362 *
363 * @return The new ID space value.
364 */
365 int32_t new_id_space();
366
367 /** Get the current ID space. */
Gunes Bayirab0b7502023-07-11 14:57:36 +0100368 int32_t id_space() const;
369
Viet-Hoa Do2d0c2f52023-08-24 11:48:19 +0100370 /** Set the current ID space.
371 *
372 * @param[in] value The ID space to be used.
373 */
374 KernelWriter &id_space(int32_t value);
375
376 /** Write the body code using the specified function.
377 *
378 * This function makes sure that a new ID space is created before and then is used solely
379 * by the specified body writing function.
380 * The ID space will not be reused after that.
381 *
382 * @param[in] body The function that writes the body code.
383 */
384 void write_body(const std::function<void()> &body);
385
386protected:
Gunes Bayirab0b7502023-07-11 14:57:36 +0100387 /** Generate full variable name by prefixing it with id space */
388 std::string generate_full_name(const std::string &name) const;
389
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +0100390 /** Create a new tile operand referring to the specified tile object. */
Viet-Hoa Do25d26f42023-07-20 17:31:47 +0100391 static TileOperand create_tile_operand(ITile &tile);
392
Viet-Hoa Docd1f03e2023-09-19 16:41:34 +0100393 /** Get the reference to the tile object and the active area from the tile operand. */
394 static std::tuple<ITile &, TileArea> get_tile(const TileOperand &operand);
Viet-Hoa Do25d26f42023-07-20 17:31:47 +0100395
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +0100396 /** Create a new tensor operand from a tensor object. */
397 static TensorOperand create_tensor_operand(ITensor &tensor);
398
399 /** Get the reference to tensor object from the tensor operand. */
400 static ITensor &get_tensor(const TensorOperand &operand);
401
Gunes Bayir806b8e82023-08-23 23:28:31 +0100402 /** Get the values of a constant data object. */
403 static const std::vector<std::vector<std::string>> &get_values(const ConstantData &data);
404
405 /** Get the data type of a constant data object. */
406 static DataType get_data_type(const ConstantData &data);
407
Gunes Bayirab0b7502023-07-11 14:57:36 +0100408private:
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100409 int32_t _id_space{0};
410 int32_t _last_created_id_space{0};
Viet-Hoa Do3389f532023-07-05 17:36:40 +0100411};
412
413} // namespace ckw
414
415#endif // CKW_INCLUDE_CKW_KERNELWRITER_H