blob: 23237ace28879977e236ae83f1dcfbec2db436e5 [file] [log] [blame]
Viet-Hoa Do3389f532023-07-05 17:36:40 +01001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#ifndef CKW_INCLUDE_CKW_KERNELWRITER_H
26#define CKW_INCLUDE_CKW_KERNELWRITER_H
27
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +010028#include "ckw/TensorOperand.h"
Gunes Bayir3c776062023-07-12 14:50:56 +010029#include "ckw/TileOperand.h"
Gunes Bayir806b8e82023-08-23 23:28:31 +010030#include "ckw/types/ConstantData.h"
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +010031#include "ckw/types/ConvertPolicy.h"
32#include "ckw/types/Operators.h"
Gunes Bayirab0b7502023-07-11 14:57:36 +010033
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +010034#include <functional>
Viet-Hoa Do3389f532023-07-05 17:36:40 +010035#include <memory>
36#include <string>
37
38namespace ckw
39{
40
Gunes Bayirab0b7502023-07-11 14:57:36 +010041/** Forward Declerations */
Gunes Bayir806b8e82023-08-23 23:28:31 +010042class Kernel;
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +010043class TensorInfo;
Gunes Bayir47a396e2023-08-17 11:04:02 +010044class TensorSampler;
Gunes Bayirab0b7502023-07-11 14:57:36 +010045class TileInfo;
Gunes Bayir47a396e2023-08-17 11:04:02 +010046
Gunes Bayir806b8e82023-08-23 23:28:31 +010047enum class DataType;
Gunes Bayirab0b7502023-07-11 14:57:36 +010048enum class TargetArchitecture;
49enum class TargetLanguage;
50
Viet-Hoa Do3389f532023-07-05 17:36:40 +010051/** A kernel writer.
52 *
53 * This class is used to construct a new kernel by defining arguments, declaring variable and writing code.
54 *
55 * Use @ref KernelWriter::create_instance method to create the kernel writer for the specific target architecture and language.
56 *
57 * After having finished constructing the kernel, call @ref KernelWriter::emit_kernel to get the kernel object.
58 */
59class KernelWriter
60{
61public:
62 // =============================================================================================
63 // Construtors and destructor
64 // =============================================================================================
65
66 /** Initialize a new instance of @ref KernelWriter class for the specific architecture and language.
67 *
68 * Supported target architectures and languages:
69 *
70 * Architecture | Languages |
71 * ------------------------------|------------------------------|
72 * GpuArmMaliValhall | OpenCL |
73 *
74 * @param[in] architecture The architecture on which the kernel is executed.
75 * @param[in] language The language to write the kernel.
76 */
77 static std::unique_ptr<KernelWriter> create_instance(TargetArchitecture architecture, TargetLanguage language);
78
79 /** Destructor */
80 virtual ~KernelWriter();
81
82 // =============================================================================================
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +010083 // Data processing
84 // =============================================================================================
85
86 /** Write assignment statement: `<dst> = <src>;`.
87 *
88 * @param[in] dst The destination tile.
89 * @param[in] src The source tile.
90 */
91 virtual void op_assign(const TileOperand &dst, const TileOperand &src) = 0;
92
93 /** Write the cast statement: `<dst> = convert_<dst.type><policy>(<src>);`.
94 *
95 * @param[in] dst The destination tile.
96 * @param[in] src The source tile.
97 * @param[in] policy The policy governing the behavior of the cast.
98 */
99 virtual void op_cast(const TileOperand &dst, const TileOperand &src, ConvertPolicy policy) = 0;
100
101 /** Write the unary expression statement: `<dst> = <op> <src>;`.
102 *
103 * @param[in] dst The destination tile.
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +0100104 * @param[in] op The unary operator.
Viet-Hoa Do34b6c3a2023-08-22 11:11:23 +0100105 * @param[in] src The source tile.
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +0100106 */
Viet-Hoa Do34b6c3a2023-08-22 11:11:23 +0100107 virtual void op_unary(const TileOperand &dst, UnaryOp op, const TileOperand &src) = 0;
108
109 /** Write the binary expression statement: `<dst> = <op>(<first>, <second>);`.
110 *
111 * @param[in] dst The destination tile.
112 * @param[in] op The binary operator.
113 * @param[in] first The first source tile.
114 * @param[in] second The second source tile.
115 */
116 virtual void op_binary(const TileOperand &dst, BinaryOp op, const TileOperand &first, const TileOperand &second) = 0;
117
118 /** Write ternary expression statement: `<dst> = <op>(<first>, <second>, <third>);`.
119 *
120 * @param[in] dst The destination tile.
121 * @param[in] op The ternary operator.
122 * @param[in] first The first source tile.
123 * @param[in] second The second source tile.
124 * @param[in] third The third source tile.
125 */
126 virtual void op_ternary(const TileOperand &dst, TernaryOp op, const TileOperand &first, const TileOperand &second, const TileOperand &third) = 0;
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +0100127
128 // =============================================================================================
Viet-Hoa Do2d0c2f52023-08-24 11:48:19 +0100129 // Flow control
130 // =============================================================================================
131
132 /** Write if block: `if(<lhs> <op> <rhs>) { <body> }`.
133 *
134 * @param[in] lhs The LHS tile of the condition.
135 * @param[in] op The relational binary operator.
136 * @param[in] rhs The RHS tile of the condition.
137 * @param[in] body The function that writes the body of the if block.
138 */
139 virtual void op_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body) = 0;
140
141 /** Write else-if block: `else if(<lhs> <op> <rhs>) { <body> }`.
142 *
143 * @param[in] lhs The LHS tile of the condition.
144 * @param[in] op The relational binary operator.
145 * @param[in] rhs The RHS tile of the condition.
146 * @param[in] body The function that writes the body of the else-if block.
147 */
148 virtual void op_else_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body) = 0;
149
150 /** Write an else block: `else { <body> }`.
151 *
152 * @param[in] body The function that writes the body of the else block.
153 */
154 virtual void op_else(const std::function<void()> &body) = 0;
155
156 /** Write for-loop block: `for(; <var> <cond_op> <cond_value>; <update_var> <update_op> <update_value>) { body }`.
157 *
158 * @param[in] var The scalar tile used in loop condition.
159 * @param[in] cond_op The relational binary operator used in loop condition.
160 * @param[in] cond_value The value which the variable is compared against.
161 * @param[in] update_var The scalar tile which is updated each iteration.
162 * @param[in] update_op The assignment operator used for updating the update value.
163 * @param[in] update_value The value which is updated at every iteration.
164 * @param[in] body The function that writes the body of the for-loop block.
165 */
166 virtual void op_for_loop(
167 const TileOperand &var, BinaryOp cond_op, const TileOperand &cond_value,
168 const TileOperand &update_var, AssignmentOp update_op, const TileOperand &update_value,
169 const std::function<void()> &body) = 0;
170
171 /** Write the return statement. */
172 virtual void op_return() = 0;
173
174 // =============================================================================================
Viet-Hoa Do3389f532023-07-05 17:36:40 +0100175 // Misc
176 // =============================================================================================
177
178 /** Write the line comment in debug build.
Gunes Bayirab0b7502023-07-11 14:57:36 +0100179 *
Viet-Hoa Do3389f532023-07-05 17:36:40 +0100180 * This function does not take effect on release build.
181 *
182 * The comment must only contain one line (i.e. no newline character is allowed).
183 *
184 * @param[in] text The comment to be written.
185 */
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +0100186 virtual void op_comment(const std::string &text) = 0;
187
188 /** Write the given raw code to kernel source code
189 * It's used to address the cases where the user needs to
190 * explicitly add a code where it's not (yet) supported by
191 * the kernel writer utility calls.
192 *
193 * @param[in] raw_code raw code to write as string
194 */
195 virtual void op_write_raw_code(const std::string &raw_code) = 0;
Viet-Hoa Do3389f532023-07-05 17:36:40 +0100196
197 // =============================================================================================
198 // Code generation
199 // =============================================================================================
200
201 /** Emit the kernel object.
202 *
203 * @param[in] name The name of the kernel object to be generated.
204 */
205 virtual std::unique_ptr<Kernel> emit_kernel(const std::string &name) = 0;
Gunes Bayirab0b7502023-07-11 14:57:36 +0100206
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +0100207 // =============================================================================================
208 // Tensor and tile declaration
209 // =============================================================================================
210
211 /** Declare a tensor argument.
212 *
213 * @param[in] name The name of the tensor.
214 * @param[in] info The tensor info.
215 *
216 * @return The @ref TensorOperand object.
217 */
218 virtual TensorOperand declare_tensor_argument(const std::string &name, const TensorInfo &info) = 0;
219
Gunes Bayirab0b7502023-07-11 14:57:36 +0100220 /** Declare a tile given its name and tile info
221 *
222 * @param[in] name Name of the tile
223 * @param[in] tile_info Shape and data type of the tile
224 *
Gunes Bayir806b8e82023-08-23 23:28:31 +0100225 * @return The created tile operand
Gunes Bayirab0b7502023-07-11 14:57:36 +0100226 */
Viet-Hoa Do25d26f42023-07-20 17:31:47 +0100227 virtual TileOperand declare_tile(const std::string &name, const TileInfo &tile_info) = 0;
Gunes Bayirab0b7502023-07-11 14:57:36 +0100228
Gunes Bayir806b8e82023-08-23 23:28:31 +0100229 /** Declare a constant tile given a @ref:ConstantData object
230 *
231 * @param[in] data a @ref ckw::ConstantData object that has the values and the
232 * underlying data type of the constant tile
233 *
234 * @return The created constant tile operand
235 */
236 virtual TileOperand declare_constant_tile(const ConstantData &data) = 0;
237
Gunes Bayir47a396e2023-08-17 11:04:02 +0100238 /** Load the data from the tensor memory to the tile using the sampling information.
239 *
240 * @param[in] tile_op The tile to be loaded.
241 * @param[in] tensor_op The tensor to be read.
242 * @param[in] sampler The tensor sampling information.
243 * @param[in] x x-coordinate
244 * @param[in] y y-coordinate
245 * @param[in] z z-coordinate
246 * @param[in] batch batch offset
247 */
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +0100248 virtual void op_load(
249 const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
Gunes Bayir47a396e2023-08-17 11:04:02 +0100250 const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch) = 0;
251
252 /** Load the data from the tensor memory to the tile in a dilated way using the sampling information.
253 *
254 * Similar to @ref KernelWriter::op_load() and
255 *
256 * @param[in] dilation_x Dilation while reading in x-dimension
257 * @param[in] dilation_y Dilation while reading in y-dimension
258 */
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +0100259 virtual void op_load_dilated(
260 const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
Gunes Bayir47a396e2023-08-17 11:04:02 +0100261 const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
262 const TileOperand &dilation_x, const TileOperand &dilation_y) = 0;
263
264 /** Store the data to the tensor memory from the tile using the sampling information.
265 *
266 * Similar to @ref KernelWriter::op_load()
267 */
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +0100268 virtual void op_store(
269 const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler,
Gunes Bayir47a396e2023-08-17 11:04:02 +0100270 const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch) = 0;
271
272 /** Store the data to the tensor memory from the tile in a dilated way using the sampling information.
273 *
274 * Similar to @ref KernelWriter::op_load_dilated()
275 */
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +0100276 virtual void op_store_dilated(
277 const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler,
Gunes Bayir47a396e2023-08-17 11:04:02 +0100278 const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
279 const TileOperand &dilation_x, const TileOperand &dilation_y) = 0;
280
Gunes Bayirab0b7502023-07-11 14:57:36 +0100281protected:
Viet-Hoa Do2d0c2f52023-08-24 11:48:19 +0100282 // =============================================================================================
283 // ID space management
284 // =============================================================================================
285
286 /** Create the new unique ID space and return the value.
287 *
288 * This function changes the ID space to a new number which hasn't been used since the creation
289 * of this kernel writer object.
290 *
291 * @return The new ID space value.
292 */
293 int32_t new_id_space();
294
295 /** Get the current ID space. */
Gunes Bayirab0b7502023-07-11 14:57:36 +0100296 int32_t id_space() const;
297
Viet-Hoa Do2d0c2f52023-08-24 11:48:19 +0100298 /** Set the current ID space.
299 *
300 * @param[in] value The ID space to be used.
301 */
302 KernelWriter &id_space(int32_t value);
303
304 /** Write the body code using the specified function.
305 *
306 * This function makes sure that a new ID space is created before and then is used solely
307 * by the specified body writing function.
308 * The ID space will not be reused after that.
309 *
310 * @param[in] body The function that writes the body code.
311 */
312 void write_body(const std::function<void()> &body);
313
314protected:
Gunes Bayirab0b7502023-07-11 14:57:36 +0100315 /** Generate full variable name by prefixing it with id space */
316 std::string generate_full_name(const std::string &name) const;
317
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +0100318 /** Create a new tile operand referring to the specified tile object. */
Viet-Hoa Do25d26f42023-07-20 17:31:47 +0100319 static TileOperand create_tile_operand(ITile &tile);
320
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +0100321 /** Get the reference to tile object from the tile operand. */
Viet-Hoa Do25d26f42023-07-20 17:31:47 +0100322 static ITile &get_tile(const TileOperand &operand);
323
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +0100324 /** Create a new tensor operand from a tensor object. */
325 static TensorOperand create_tensor_operand(ITensor &tensor);
326
327 /** Get the reference to tensor object from the tensor operand. */
328 static ITensor &get_tensor(const TensorOperand &operand);
329
Gunes Bayir806b8e82023-08-23 23:28:31 +0100330 /** Get the values of a constant data object. */
331 static const std::vector<std::vector<std::string>> &get_values(const ConstantData &data);
332
333 /** Get the data type of a constant data object. */
334 static DataType get_data_type(const ConstantData &data);
335
Gunes Bayirab0b7502023-07-11 14:57:36 +0100336private:
337 int32_t _id_space{ 0 };
Viet-Hoa Do2d0c2f52023-08-24 11:48:19 +0100338 int32_t _last_created_id_space{ 0 };
Viet-Hoa Do3389f532023-07-05 17:36:40 +0100339};
340
341} // namespace ckw
342
343#endif // CKW_INCLUDE_CKW_KERNELWRITER_H