blob: c116e626501a04b9fe96db754581921939a8fe28 [file] [log] [blame]
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +01001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
Viet-Hoa Doce3c48c2023-07-03 13:44:43 +010025#ifndef CKW_PROTOTYPE_INCLUDE_CKW_KERNELWRITER_H
26#define CKW_PROTOTYPE_INCLUDE_CKW_KERNELWRITER_H
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010027
28#include "ckw/Kernel.h"
29#include "ckw/TensorInfo.h"
30#include "ckw/TensorOperand.h"
31#include "ckw/TileInfo.h"
32#include "ckw/TileOperand.h"
Nikolaj Jensen5ff48022023-06-27 14:13:24 +010033#include "ckw/types/ConvertPolicy.h"
34#include "ckw/types/Functions.h"
35#include "ckw/types/Operators.h"
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010036
37#include <memory>
38
39namespace ckw
40{
41
42namespace prototype
43{
Viet-Hoa Doe1880f02023-06-28 10:25:35 +010044struct GpuKernelWriterAttribute;
Nikolaj Jensenacea4072023-07-03 09:44:42 +010045
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010046class IGpuKernelWriter;
47} // namespace prototype
48
49/** Kernel writer. */
50class KernelWriter
51{
52public:
53 // =============================================================================================
54 // Constructors and destructor
55 // =============================================================================================
56
57 /** Initialize a new instance of kernel writer.
58 *
59 * @param[in] kernel The kernel to be written to.
60 */
61 explicit KernelWriter(Kernel &kernel);
62
63 /** Destructor */
64 ~KernelWriter();
65
66 /** No copy constructor. */
67 KernelWriter(const KernelWriter &) = delete;
68
69 /** No copy assignment. */
70 KernelWriter &operator=(const KernelWriter &) = delete;
71
72 // =============================================================================================
73 // Scope management
74 // =============================================================================================
75
76 /** Get the current ID space. */
77 int32_t id_space() const;
78
79 /** Set the current ID space. */
80 KernelWriter &id_space(int32_t id_space);
81
82 /** Switch to and return a new ID space. */
83 int32_t next_id_space();
84
85 // =============================================================================================
86 // Tensor and tile declaration
87 // =============================================================================================
88
Nikolaj Jensen5ff48022023-06-27 14:13:24 +010089 /** Declare a tensor argument.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010090 *
Viet-Hoa Doc8e16172023-06-27 14:09:46 +010091 * @param[in] name The name of the tensor.
92 * @param[in] info The tensor info.
93 * @param[in] storage_type The tensor storage type.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010094 *
95 * @return The @ref TensorOperand object.
96 */
Viet-Hoa Doc8e16172023-06-27 14:09:46 +010097 TensorOperand &declare_tensor_argument(const std::string &name, const TensorInfo &info, TensorStorageType storage_type = TensorStorageType::BufferUint8Ptr);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010098
Nikolaj Jensen5ff48022023-06-27 14:13:24 +010099 /** Declare a compile-time constant scalar argument.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100100 *
101 * @param[in] name The name of the tile.
102 * @param[in] value The value of the tile.
103 *
104 * @return The @ref TileOperand object.
105 */
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100106 TileOperand &declare_tile_argument(const std::string &name, int32_t value);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100107
108 /** Declare a new tile.
109 *
110 * The name of the tile must be unique in the current ID space.
111 *
112 * @param[in] name The name of the tile.
113 * @param[in] ... The necessary arguments to create a new @ref TileOperand.
114 *
115 * @return The @ref TileOperand object.
116 */
117 template <typename... TArgs>
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100118 TileOperand &declare_tile(const std::string &name, TArgs &&...args)
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100119 {
120 const auto var_name = generate_variable_name(name);
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100121 auto operand = std::make_unique<TileOperand>(var_name, ::std::forward<TArgs>(args)...);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100122
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100123 return declare_tile_operand(std::move(operand));
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100124 }
125
126 // =============================================================================================
127 // Load and store
128 // =============================================================================================
129
130 /** Load the data from the tensor memory to the tile using the sampling information.
131 *
132 * @param[out] tile The tile to be loaded.
133 * @param[in] tensor The tensor to be read.
134 * @param[in] sampler The tensor sampling information.
135 */
136 void op_load(TileOperand &tile, TensorOperand &tensor, const TensorTileSampler &sampler);
137
138 /** Store the tile to the tensor using the specified sampling information.
139 *
140 * @param[out] dst The tensor that the tile is written to.
141 * @param[in] src The tile to be stored.
142 * @param[in] sampler The tensor sampling information.
143 */
144 void op_store(TensorOperand &tensor, const TileOperand &tile, const TensorTileSampler &sampler);
145
146 // =============================================================================================
147 // Data processing
148 // =============================================================================================
149
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100150 /** Write assignment: `<dst> = <src>;`.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100151 *
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100152 * @param[out] dst The destination tile.
153 * @param[in] src The source tile.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100154 */
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100155 void op_assign(const TileOperand &dst, const TileOperand &src);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100156
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100157 /** Write the cast: `<dst> = convert_<dst.type><_sat>(<src>);`.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100158 *
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100159 * @param[out] dst The destination tile.
160 * @param[in] src The source tile.
161 * @param[in] policy The policy governing the behavior of the cast.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100162 */
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100163 void op_cast_expression(const TileOperand &dst, const TileOperand &src, ConvertPolicy policy);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100164
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100165 /** Write the unary expression: `<dst> = <op> <src>`.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100166 *
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100167 * @param[out] dst The destination tile.
168 * @param[in] op The unary operator.
169 * @param[in] src The source tile.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100170 */
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100171 void op_unary_expression(const TileOperand &dst, UnaryOp op, const TileOperand &src);
172
173 /** Write binary expression: `<dst> = <lhs> <op> <rhs>;`.
174 *
175 * @param[out] dst The destination tile.
176 * @param[in] lhs The LHS tile.
177 * @param[in] op The binary operator.
178 * @param[in] rhs The RHS tile.
179 */
180 void op_binary_expression(const TileOperand &dst, const TileOperand &lhs, BinaryOp op, const TileOperand &rhs);
181
182 /** Write function applied to scalar value: `<dst> = <func>(<src>);`.
183 *
184 * @param[out] dst The destination tile.
185 * @param[in] func The function to be applied to the source tile.
186 * @param[in] src The source tile.
187 */
188 void op_unary_elementwise_function(const TileOperand &dst, UnaryFunction func, const TileOperand &src);
189
190 /** Write function applied to scalar value: `<dst> = <func>(<first>, <second>);`.
191 *
192 * @param[out] dst The destination tile.
193 * @param[in] func The function to be applied to the source tiles.
194 * @param[in] first The first argument tile.
195 * @param[in] second The second argument tile.
196 */
197 void op_binary_elementwise_function(const TileOperand &dst, BinaryFunction func, const TileOperand &first, const TileOperand &second);
198
199 /** Write function applied to scalar value: `<dst> = <func>(<first>, <second>, <third>);`.
200 *
201 * @param[out] dst The destination tile.
202 * @param[in] func The function to be applied to the source tiles.
203 * @param[in] first The first argument tile.
204 * @param[in] second The second argument tile.
205 * @param[in] third The third argument tile.
206 */
207 void op_ternary_elementwise_function(const TileOperand &dst, TernaryFunction func, const TileOperand &first, const TileOperand &second, const TileOperand &third);
208
209 /** Write if-statement: `if(<lhs> <op> <rhs>) { <body> }`.
210 *
211 * @param[in] lhs The LHS tile of the condition.
212 * @param[in] op The relational binary operator.
213 * @param[in] rhs The RHS tile of the condition.
214 * @param[in] body The body of the if-statement.
215 */
216 void op_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body);
217
218 /** Write else-if-statement: `else if(<lhs> <op> <rhs>) { <body> }`.
219 *
220 * @param[in] lhs The LHS tile of the condition.
221 * @param[in] op The relational binary operator.
222 * @param[in] rhs The RHS tile of the condition.
223 * @param[in] body The body of the else-if-statement.
224 */
225 void op_else_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body);
226
227 /** Write an else-statement: `else { <body> }`.
228 *
229 * @param[in] body The body of the else-statement.
230 */
231 void op_else(const std::function<void()> &body);
232
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100233 /** Write for-loops: `for(; <var> <cond_op> <cond_value>; <var> <update_op> <update_value>) { body }`.
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100234 *
235 * @param[in] var_name The name of the variable used in condition.
236 * @param[in] cond_op The relational binary operator used in condition.
237 * @param[in] cond_value_name The value which the variable is compared against.
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100238 * @param[in] update_var_name The name of the variable which is updated.
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100239 * @param[in] update_op The assignment operator used for updating the update value.
240 * @param[in, out] update_value The value which is updated at every iteration.
241 * @param[in] body The body of the for-loop.
242 */
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100243 void op_for_loop(const TileOperand &var_name, BinaryOp cond_op, const TileOperand &cond_value_name, const TileOperand &update_var_name, AssignmentOp update_op, const TileOperand &update_value_name, const std::function<void()> &body);
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100244
245 /** Write the return statement: `return;`
246 */
247 void op_return();
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100248
249 // =============================================================================================
250 // Misc
251 // =============================================================================================
252
253 /** Set `dst` the global ID of dimension `dim`.
254 *
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100255 * @param[out] dst The tile to be written to.
256 * @param[in] dim The global ID dimension.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100257 */
258 void op_get_global_id(TileOperand &dst, int32_t dim);
259
260 // =============================================================================================
261 // Code generation
262 // =============================================================================================
263
264 /** Generate the source code of the kernel. */
265 ::std::string generate_code();
266
267private:
268 /** Generate the full variable name based on the original name and the ID space.
269 *
270 * @param[in] name The name of the variable.
271 *
272 * @return The full variable name.
273 */
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100274 ::std::string generate_variable_name(const std::string &name) const;
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100275
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100276 /** Declare the tile operand.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100277 *
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100278 * @param[in] operand The tile operand to be declared.
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100279 */
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100280 TileOperand &declare_tile_operand(std::unique_ptr<TileOperand> operand);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100281
282private:
283 Kernel *_kernel;
284 ::std::unique_ptr<prototype::GpuKernelWriterAttribute> _impl_attr;
285 ::std::unique_ptr<prototype::IGpuKernelWriter> _impl;
286
287 int32_t _id_space{ 0 };
288 int32_t _max_id_space{ 0 };
289};
290
291} // namespace ckw
292
Viet-Hoa Doce3c48c2023-07-03 13:44:43 +0100293#endif // CKW_PROTOTYPE_INCLUDE_CKW_KERNELWRITER_H